summaryrefslogtreecommitdiff
path: root/src/post
diff options
context:
space:
mode:
authorBrian Picciano <mediocregopher@gmail.com>2022-09-13 12:56:08 +0200
committerBrian Picciano <mediocregopher@gmail.com>2022-09-13 12:56:08 +0200
commit4f01edb9230f58ff84b0dd892c931ec8ac9aad55 (patch)
tree9c1598a3f98203913ac2548883c02a81deb33dc7 /src/post
parent5485984e05aebde22819adebfbd5ad51475a6c21 (diff)
move src out of srv, clean up default.nix and Makefile
Diffstat (limited to 'src/post')
-rw-r--r--src/post/asset.go114
-rw-r--r--src/post/asset_test.go91
-rw-r--r--src/post/draft_post.go187
-rw-r--r--src/post/draft_post_test.go130
-rw-r--r--src/post/post.go368
-rw-r--r--src/post/post_test.go268
-rw-r--r--src/post/sql.go126
7 files changed, 1284 insertions, 0 deletions
diff --git a/src/post/asset.go b/src/post/asset.go
new file mode 100644
index 0000000..a7b605b
--- /dev/null
+++ b/src/post/asset.go
@@ -0,0 +1,114 @@
+package post
+
+import (
+ "bytes"
+ "database/sql"
+ "errors"
+ "fmt"
+ "io"
+)
+
+var (
+ // ErrAssetNotFound is used to indicate an Asset could not be found in the
+ // AssetStore.
+ ErrAssetNotFound = errors.New("asset not found")
+)
+
+// AssetStore implements the storage and retrieval of binary assets, which are
+// intended to be used by posts (e.g. images).
+type AssetStore interface {
+
+ // Set sets the id to the contents of the given io.Reader.
+ Set(id string, from io.Reader) error
+
+ // Get writes the id's body to the given io.Writer, or returns
+ // ErrAssetNotFound.
+ Get(id string, into io.Writer) error
+
+ // Delete's the body stored for the id, if any.
+ Delete(id string) error
+
+ // List returns all ids which are currently stored.
+ List() ([]string, error)
+}
+
+type assetStore struct {
+ db *sql.DB
+}
+
+// NewAssetStore initializes a new AssetStore using an existing SQLDB.
+func NewAssetStore(db *SQLDB) AssetStore {
+ return &assetStore{
+ db: db.db,
+ }
+}
+
+func (s *assetStore) Set(id string, from io.Reader) error {
+
+ body, err := io.ReadAll(from)
+ if err != nil {
+ return fmt.Errorf("reading body fully into memory: %w", err)
+ }
+
+ _, err = s.db.Exec(
+ `INSERT INTO assets (id, body)
+ VALUES (?, ?)
+ ON CONFLICT (id) DO UPDATE SET body=excluded.body`,
+ id, body,
+ )
+
+ if err != nil {
+ return fmt.Errorf("inserting into assets: %w", err)
+ }
+
+ return nil
+}
+
+func (s *assetStore) Get(id string, into io.Writer) error {
+
+ var body []byte
+
+ err := s.db.QueryRow(`SELECT body FROM assets WHERE id = ?`, id).Scan(&body)
+
+ if errors.Is(err, sql.ErrNoRows) {
+ return ErrAssetNotFound
+ } else if err != nil {
+ return fmt.Errorf("selecting from assets: %w", err)
+ }
+
+ if _, err := io.Copy(into, bytes.NewReader(body)); err != nil {
+ return fmt.Errorf("writing body to io.Writer: %w", err)
+ }
+
+ return nil
+}
+
+func (s *assetStore) Delete(id string) error {
+ _, err := s.db.Exec(`DELETE FROM assets WHERE id = ?`, id)
+ return err
+}
+
+func (s *assetStore) List() ([]string, error) {
+
+ rows, err := s.db.Query(`SELECT id FROM assets ORDER BY id ASC`)
+
+ if err != nil {
+ return nil, fmt.Errorf("querying: %w", err)
+ }
+
+ defer rows.Close()
+
+ var ids []string
+
+ for rows.Next() {
+
+ var id string
+ if err := rows.Scan(&id); err != nil {
+ return nil, fmt.Errorf("scanning row: %w", err)
+ }
+
+ ids = append(ids, id)
+ }
+
+ return ids, nil
+}
diff --git a/src/post/asset_test.go b/src/post/asset_test.go
new file mode 100644
index 0000000..4d62d46
--- /dev/null
+++ b/src/post/asset_test.go
@@ -0,0 +1,91 @@
+package post
+
+import (
+ "bytes"
+ "io"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+type assetTestHarness struct {
+ store AssetStore
+}
+
+func newAssetTestHarness(t *testing.T) *assetTestHarness {
+
+ db := NewInMemSQLDB()
+ t.Cleanup(func() { db.Close() })
+
+ store := NewAssetStore(db)
+
+ return &assetTestHarness{
+ store: store,
+ }
+}
+
+func (h *assetTestHarness) assertGet(t *testing.T, exp, id string) {
+ t.Helper()
+ buf := new(bytes.Buffer)
+ err := h.store.Get(id, buf)
+ assert.NoError(t, err)
+ assert.Equal(t, exp, buf.String())
+}
+
+func (h *assetTestHarness) assertNotFound(t *testing.T, id string) {
+ t.Helper()
+ err := h.store.Get(id, io.Discard)
+ assert.ErrorIs(t, ErrAssetNotFound, err)
+}
+
+func TestAssetStore(t *testing.T) {
+
+ testAssetStore := func(t *testing.T, h *assetTestHarness) {
+ t.Helper()
+
+ h.assertNotFound(t, "foo")
+ h.assertNotFound(t, "bar")
+
+ err := h.store.Set("foo", bytes.NewBufferString("FOO"))
+ assert.NoError(t, err)
+
+ h.assertGet(t, "FOO", "foo")
+ h.assertNotFound(t, "bar")
+
+ err = h.store.Set("foo", bytes.NewBufferString("FOOFOO"))
+ assert.NoError(t, err)
+
+ h.assertGet(t, "FOOFOO", "foo")
+ h.assertNotFound(t, "bar")
+
+ assert.NoError(t, h.store.Delete("foo"))
+ h.assertNotFound(t, "foo")
+ h.assertNotFound(t, "bar")
+
+ assert.NoError(t, h.store.Delete("bar"))
+ h.assertNotFound(t, "foo")
+ h.assertNotFound(t, "bar")
+
+ // test list
+
+ ids, err := h.store.List()
+ assert.NoError(t, err)
+ assert.Empty(t, ids)
+
+ err = h.store.Set("foo", bytes.NewBufferString("FOOFOO"))
+ assert.NoError(t, err)
+ err = h.store.Set("foo", bytes.NewBufferString("FOOFOO"))
+ assert.NoError(t, err)
+ err = h.store.Set("bar", bytes.NewBufferString("FOOFOO"))
+ assert.NoError(t, err)
+
+ ids, err = h.store.List()
+ assert.NoError(t, err)
+ assert.Equal(t, []string{"bar", "foo"}, ids)
+ }
+
+ t.Run("sql", func(t *testing.T) {
+ h := newAssetTestHarness(t)
+ testAssetStore(t, h)
+ })
+}
diff --git a/src/post/draft_post.go b/src/post/draft_post.go
new file mode 100644
index 0000000..61283c3
--- /dev/null
+++ b/src/post/draft_post.go
@@ -0,0 +1,187 @@
+package post
+
+import (
+ "database/sql"
+ "encoding/json"
+ "errors"
+ "fmt"
+)
+
+type DraftStore interface {
+
+ // Set sets the draft Post's data into the storage, keyed by the draft
+ // Post's ID.
+ Set(post Post) error
+
+ // Get returns count draft Posts, sorted id descending, offset by the
+ // given page number. The returned boolean indicates if there are more pages
+ // or not.
+ Get(page, count int) ([]Post, bool, error)
+
+ // GetByID will return the draft Post with the given ID, or ErrPostNotFound.
+ GetByID(id string) (Post, error)
+
+ // Delete will delete the draft Post with the given ID.
+ Delete(id string) error
+}
+
+type draftStore struct {
+ db *SQLDB
+}
+
+// NewDraftStore initializes a new DraftStore using an existing SQLDB.
+func NewDraftStore(db *SQLDB) DraftStore {
+ return &draftStore{
+ db: db,
+ }
+}
+
+func (s *draftStore) Set(post Post) error {
+
+ if post.ID == "" {
+ return errors.New("post ID can't be empty")
+ }
+
+ tagsJSON, err := json.Marshal(post.Tags)
+ if err != nil {
+ return fmt.Errorf("json marshaling tags %#v: %w", post.Tags, err)
+ }
+
+ _, err = s.db.db.Exec(
+ `INSERT INTO post_drafts (
+ id, title, description, tags, series, body
+ )
+ VALUES
+ (?, ?, ?, ?, ?, ?)
+ ON CONFLICT (id) DO UPDATE SET
+ title=excluded.title,
+ description=excluded.description,
+ tags=excluded.tags,
+ series=excluded.series,
+ body=excluded.body`,
+ post.ID,
+ post.Title,
+ post.Description,
+ &sql.NullString{String: string(tagsJSON), Valid: len(post.Tags) > 0},
+ &sql.NullString{String: post.Series, Valid: post.Series != ""},
+ post.Body,
+ )
+
+ if err != nil {
+ return fmt.Errorf("inserting into post_drafts: %w", err)
+ }
+
+ return nil
+}
+
+func (s *draftStore) get(
+ querier interface {
+ Query(string, ...interface{}) (*sql.Rows, error)
+ },
+ limit, offset int,
+ where string, whereArgs ...interface{},
+) (
+ []Post, error,
+) {
+
+ query := `
+ SELECT
+ p.id, p.title, p.description, p.tags, p.series, p.body
+ FROM post_drafts p
+ ` + where + `
+ ORDER BY p.id ASC`
+
+ if limit > 0 {
+ query += fmt.Sprintf(" LIMIT %d", limit)
+ }
+
+ if offset > 0 {
+ query += fmt.Sprintf(" OFFSET %d", offset)
+ }
+
+ rows, err := querier.Query(query, whereArgs...)
+
+ if err != nil {
+ return nil, fmt.Errorf("selecting: %w", err)
+ }
+
+ defer rows.Close()
+
+ var posts []Post
+
+ for rows.Next() {
+
+ var (
+ post Post
+ tags, series sql.NullString
+ )
+
+ err := rows.Scan(
+ &post.ID, &post.Title, &post.Description, &tags, &series,
+ &post.Body,
+ )
+
+ if err != nil {
+ return nil, fmt.Errorf("scanning row: %w", err)
+ }
+
+ post.Series = series.String
+
+ if tags.String != "" {
+
+ if err := json.Unmarshal([]byte(tags.String), &post.Tags); err != nil {
+ return nil, fmt.Errorf("json parsing %q: %w", tags.String, err)
+ }
+ }
+
+ posts = append(posts, post)
+ }
+
+ return posts, nil
+}
+
+func (s *draftStore) Get(page, count int) ([]Post, bool, error) {
+
+ posts, err := s.get(s.db.db, count+1, page*count, ``)
+
+ if err != nil {
+ return nil, false, fmt.Errorf("querying post_drafts: %w", err)
+ }
+
+ var hasMore bool
+
+ if len(posts) > count {
+ hasMore = true
+ posts = posts[:count]
+ }
+
+ return posts, hasMore, nil
+}
+
+func (s *draftStore) GetByID(id string) (Post, error) {
+
+ posts, err := s.get(s.db.db, 0, 0, `WHERE p.id=?`, id)
+
+ if err != nil {
+ return Post{}, fmt.Errorf("querying post_drafts: %w", err)
+ }
+
+ if len(posts) == 0 {
+ return Post{}, ErrPostNotFound
+ }
+
+ if len(posts) > 1 {
+ panic(fmt.Sprintf("got back multiple draft posts querying id %q: %+v", id, posts))
+ }
+
+ return posts[0], nil
+}
+
+func (s *draftStore) Delete(id string) error {
+
+ if _, err := s.db.db.Exec(`DELETE FROM post_drafts WHERE id = ?`, id); err != nil {
+ return fmt.Errorf("deleting from post_drafts: %w", err)
+ }
+
+ return nil
+}
diff --git a/src/post/draft_post_test.go b/src/post/draft_post_test.go
new file mode 100644
index 0000000..f404bb0
--- /dev/null
+++ b/src/post/draft_post_test.go
@@ -0,0 +1,130 @@
+package post
+
+import (
+ "sort"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+type draftStoreTestHarness struct {
+ store DraftStore
+}
+
+func newDraftStoreTestHarness(t *testing.T) draftStoreTestHarness {
+
+ db := NewInMemSQLDB()
+ t.Cleanup(func() { db.Close() })
+
+ store := NewDraftStore(db)
+
+ return draftStoreTestHarness{
+ store: store,
+ }
+}
+
+func TestDraftStore(t *testing.T) {
+
+ assertPostEqual := func(t *testing.T, exp, got Post) {
+ t.Helper()
+ sort.Strings(exp.Tags)
+ sort.Strings(got.Tags)
+ assert.Equal(t, exp, got)
+ }
+
+ assertPostsEqual := func(t *testing.T, exp, got []Post) {
+ t.Helper()
+
+ if !assert.Len(t, got, len(exp), "exp:%+v\ngot: %+v", exp, got) {
+ return
+ }
+
+ for i := range exp {
+ assertPostEqual(t, exp[i], got[i])
+ }
+ }
+
+ t.Run("not_found", func(t *testing.T) {
+ h := newDraftStoreTestHarness(t)
+
+ _, err := h.store.GetByID("foo")
+ assert.ErrorIs(t, err, ErrPostNotFound)
+ })
+
+ t.Run("set_get_delete", func(t *testing.T) {
+ h := newDraftStoreTestHarness(t)
+
+ post := testPost(0)
+ post.Tags = []string{"foo", "bar"}
+
+ err := h.store.Set(post)
+ assert.NoError(t, err)
+
+ gotPost, err := h.store.GetByID(post.ID)
+ assert.NoError(t, err)
+
+ assertPostEqual(t, post, gotPost)
+
+ // we will now try updating the post, and ensure it updates properly
+
+ post.Title = "something else"
+ post.Series = "whatever"
+ post.Body = "anything"
+ post.Tags = []string{"bar", "baz"}
+
+ err = h.store.Set(post)
+ assert.NoError(t, err)
+
+ gotPost, err = h.store.GetByID(post.ID)
+ assert.NoError(t, err)
+
+ assertPostEqual(t, post, gotPost)
+
+ // delete the post, it should go away
+ assert.NoError(t, h.store.Delete(post.ID))
+
+ _, err = h.store.GetByID(post.ID)
+ assert.ErrorIs(t, err, ErrPostNotFound)
+ })
+
+ t.Run("get", func(t *testing.T) {
+ h := newDraftStoreTestHarness(t)
+
+ posts := []Post{
+ testPost(0),
+ testPost(1),
+ testPost(2),
+ testPost(3),
+ }
+
+ for _, post := range posts {
+ err := h.store.Set(post)
+ assert.NoError(t, err)
+ }
+
+ gotPosts, hasMore, err := h.store.Get(0, 2)
+ assert.NoError(t, err)
+ assert.True(t, hasMore)
+ assertPostsEqual(t, posts[:2], gotPosts)
+
+ gotPosts, hasMore, err = h.store.Get(1, 2)
+ assert.NoError(t, err)
+ assert.False(t, hasMore)
+ assertPostsEqual(t, posts[2:4], gotPosts)
+
+ posts = append(posts, testPost(4))
+ err = h.store.Set(posts[4])
+ assert.NoError(t, err)
+
+ gotPosts, hasMore, err = h.store.Get(1, 2)
+ assert.NoError(t, err)
+ assert.True(t, hasMore)
+ assertPostsEqual(t, posts[2:4], gotPosts)
+
+ gotPosts, hasMore, err = h.store.Get(2, 2)
+ assert.NoError(t, err)
+ assert.False(t, hasMore)
+ assertPostsEqual(t, posts[4:], gotPosts)
+ })
+
+}
diff --git a/src/post/post.go b/src/post/post.go
new file mode 100644
index 0000000..03bce6c
--- /dev/null
+++ b/src/post/post.go
@@ -0,0 +1,368 @@
+// Package post deals with the storage and rendering of blog posts.
+package post
+
+import (
+ "database/sql"
+ "errors"
+ "fmt"
+ "regexp"
+ "strings"
+ "time"
+)
+
+var (
+ // ErrPostNotFound is used to indicate a Post could not be found in the
+ // Store.
+ ErrPostNotFound = errors.New("post not found")
+)
+
+var titleCleanRegexp = regexp.MustCompile(`[^a-z ]`)
+
+// NewID generates a (hopefully) unique ID based on the given title.
+func NewID(title string) string {
+ title = strings.ToLower(title)
+ title = titleCleanRegexp.ReplaceAllString(title, "")
+ title = strings.ReplaceAll(title, " ", "-")
+ return title
+}
+
+// Post contains all information having to do with a blog post.
+type Post struct {
+ ID string
+ Title string
+ Description string
+ Tags []string // only alphanumeric supported
+ Series string
+ Body string
+}
+
+// StoredPost is a Post which has been stored in a Store, and has been given
+// some extra fields as a result.
+type StoredPost struct {
+ Post
+
+ PublishedAt time.Time
+ LastUpdatedAt time.Time
+}
+
+// Store is used for storing posts to a persistent storage.
+type Store interface {
+
+ // Set sets the Post data into the storage, keyed by the Post's ID. If there
+ // was not a previously existing Post with the same ID then Set returns
+ // true. It overwrites the previous Post with the same ID otherwise.
+ Set(post Post, now time.Time) (bool, error)
+
+ // Get returns count StoredPosts, sorted time descending, offset by the
+ // given page number. The returned boolean indicates if there are more pages
+ // or not.
+ Get(page, count int) ([]StoredPost, bool, error)
+
+ // GetByID will return the StoredPost with the given ID, or ErrPostNotFound.
+ GetByID(id string) (StoredPost, error)
+
+ // GetBySeries returns all StoredPosts with the given series, sorted time
+ // descending, or empty slice.
+ GetBySeries(series string) ([]StoredPost, error)
+
+ // GetByTag returns all StoredPosts with the given tag, sorted time
+ // descending, or empty slice.
+ GetByTag(tag string) ([]StoredPost, error)
+
+ // GetTags returns all tags which have at least one Post using them.
+ GetTags() ([]string, error)
+
+ // Delete will delete the StoredPost with the given ID.
+ Delete(id string) error
+}
+
+type store struct {
+ db *SQLDB
+}
+
+// NewStore initializes a new Store using an existing SQLDB.
+func NewStore(db *SQLDB) Store {
+ return &store{
+ db: db,
+ }
+}
+
+func (s *store) Set(post Post, now time.Time) (bool, error) {
+
+ if post.ID == "" {
+ return false, errors.New("post ID can't be empty")
+ }
+
+ var first bool
+
+ err := s.db.withTx(func(tx *sql.Tx) error {
+
+ nowTS := now.Unix()
+
+ nowSQL := sql.NullInt64{Int64: nowTS, Valid: !now.IsZero()}
+
+ _, err := tx.Exec(
+ `INSERT INTO posts (
+ id, title, description, series, published_at, body
+ )
+ VALUES
+ (?, ?, ?, ?, ?, ?)
+ ON CONFLICT (id) DO UPDATE SET
+ title=excluded.title,
+ description=excluded.description,
+ series=excluded.series,
+ last_updated_at=?,
+ body=excluded.body`,
+ post.ID,
+ post.Title,
+ post.Description,
+ &sql.NullString{String: post.Series, Valid: post.Series != ""},
+ nowSQL,
+ post.Body,
+ nowSQL,
+ )
+
+ if err != nil {
+ return fmt.Errorf("inserting into posts: %w", err)
+ }
+
+ // this is a bit of a hack, but it allows us to update the tagset without
+ // doing a diff.
+ _, err = tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, post.ID)
+
+ if err != nil {
+ return fmt.Errorf("clearning post tags: %w", err)
+ }
+
+ for _, tag := range post.Tags {
+
+ _, err = tx.Exec(
+ `INSERT INTO post_tags (post_id, tag) VALUES (?, ?)
+ ON CONFLICT DO NOTHING`,
+ post.ID,
+ tag,
+ )
+
+ if err != nil {
+ return fmt.Errorf("inserting tag %q: %w", tag, err)
+ }
+ }
+
+ err = tx.QueryRow(
+ `SELECT 1 FROM posts WHERE id=? AND last_updated_at IS NULL`,
+ post.ID,
+ ).Scan(new(int))
+
+ first = !errors.Is(err, sql.ErrNoRows)
+
+ return nil
+ })
+
+ return first, err
+}
+
+func (s *store) get(
+ querier interface {
+ Query(string, ...interface{}) (*sql.Rows, error)
+ },
+ limit, offset int,
+ where string, whereArgs ...interface{},
+) (
+ []StoredPost, error,
+) {
+
+ query := `
+ SELECT
+ p.id, p.title, p.description, p.series, GROUP_CONCAT(pt.tag),
+ p.published_at, p.last_updated_at,
+ p.body
+ FROM posts p
+ LEFT JOIN post_tags pt ON (p.id = pt.post_id)
+ ` + where + `
+ GROUP BY (p.id)
+ ORDER BY p.published_at DESC, p.title DESC`
+
+ if limit > 0 {
+ query += fmt.Sprintf(" LIMIT %d", limit)
+ }
+
+ if offset > 0 {
+ query += fmt.Sprintf(" OFFSET %d", offset)
+ }
+
+ rows, err := querier.Query(query, whereArgs...)
+
+ if err != nil {
+ return nil, fmt.Errorf("selecting: %w", err)
+ }
+
+ var posts []StoredPost
+
+ for rows.Next() {
+
+ var (
+ post StoredPost
+ series, tag sql.NullString
+ publishedAt, lastUpdatedAt sql.NullInt64
+ )
+
+ err := rows.Scan(
+ &post.ID, &post.Title, &post.Description, &series, &tag,
+ &publishedAt, &lastUpdatedAt,
+ &post.Body,
+ )
+
+ if err != nil {
+ return nil, fmt.Errorf("scanning row: %w", err)
+ }
+
+ post.Series = series.String
+
+ if tag.String != "" {
+ post.Tags = strings.Split(tag.String, ",")
+ }
+
+ if publishedAt.Valid {
+ post.PublishedAt = time.Unix(publishedAt.Int64, 0).UTC()
+ }
+
+ if lastUpdatedAt.Valid {
+ post.LastUpdatedAt = time.Unix(lastUpdatedAt.Int64, 0).UTC()
+ }
+
+ posts = append(posts, post)
+ }
+
+ if err := rows.Close(); err != nil {
+ return nil, fmt.Errorf("closing row iterator: %w", err)
+ }
+
+ return posts, nil
+}
+
+func (s *store) Get(page, count int) ([]StoredPost, bool, error) {
+
+ posts, err := s.get(s.db.db, count+1, page*count, ``)
+
+ if err != nil {
+ return nil, false, fmt.Errorf("querying posts: %w", err)
+ }
+
+ var hasMore bool
+
+ if len(posts) > count {
+ hasMore = true
+ posts = posts[:count]
+ }
+
+ return posts, hasMore, nil
+}
+
+func (s *store) GetByID(id string) (StoredPost, error) {
+
+ posts, err := s.get(s.db.db, 0, 0, `WHERE p.id=?`, id)
+
+ if err != nil {
+ return StoredPost{}, fmt.Errorf("querying posts: %w", err)
+ }
+
+ if len(posts) == 0 {
+ return StoredPost{}, ErrPostNotFound
+ }
+
+ if len(posts) > 1 {
+ panic(fmt.Sprintf("got back multiple posts querying id %q: %+v", id, posts))
+ }
+
+ return posts[0], nil
+}
+
+func (s *store) GetBySeries(series string) ([]StoredPost, error) {
+ return s.get(s.db.db, 0, 0, `WHERE p.series=?`, series)
+}
+
+func (s *store) GetByTag(tag string) ([]StoredPost, error) {
+
+ var posts []StoredPost
+
+ err := s.db.withTx(func(tx *sql.Tx) error {
+
+ rows, err := tx.Query(`SELECT post_id FROM post_tags WHERE tag = ?`, tag)
+
+ if err != nil {
+ return fmt.Errorf("querying post_tags by tag: %w", err)
+ }
+
+ var (
+ placeholders []string
+ whereArgs []interface{}
+ )
+
+ for rows.Next() {
+
+ var id string
+
+ if err := rows.Scan(&id); err != nil {
+ rows.Close()
+ return fmt.Errorf("scanning id: %w", err)
+ }
+
+ whereArgs = append(whereArgs, id)
+ placeholders = append(placeholders, "?")
+ }
+
+ if err := rows.Close(); err != nil {
+ return fmt.Errorf("closing row iterator: %w", err)
+ }
+
+ where := fmt.Sprintf("WHERE p.id IN (%s)", strings.Join(placeholders, ","))
+
+ if posts, err = s.get(tx, 0, 0, where, whereArgs...); err != nil {
+ return fmt.Errorf("querying for ids %+v: %w", whereArgs, err)
+ }
+
+ return nil
+ })
+
+ return posts, err
+}
+
+func (s *store) GetTags() ([]string, error) {
+
+ rows, err := s.db.db.Query(`SELECT tag FROM post_tags GROUP BY tag`)
+ if err != nil {
+ return nil, fmt.Errorf("querying all tags: %w", err)
+ }
+ defer rows.Close()
+
+ var tags []string
+
+ for rows.Next() {
+
+ var tag string
+
+ if err := rows.Scan(&tag); err != nil {
+ return nil, fmt.Errorf("scanning tag: %w", err)
+ }
+
+ tags = append(tags, tag)
+ }
+
+ return tags, nil
+}
+
+func (s *store) Delete(id string) error {
+
+ return s.db.withTx(func(tx *sql.Tx) error {
+
+ if _, err := tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, id); err != nil {
+ return fmt.Errorf("deleting from post_tags: %w", err)
+ }
+
+ if _, err := tx.Exec(`DELETE FROM posts WHERE id = ?`, id); err != nil {
+ return fmt.Errorf("deleting from posts: %w", err)
+ }
+
+ return nil
+ })
+}
diff --git a/src/post/post_test.go b/src/post/post_test.go
new file mode 100644
index 0000000..c7f9cdc
--- /dev/null
+++ b/src/post/post_test.go
@@ -0,0 +1,268 @@
+package post
+
+import (
+ "sort"
+ "strconv"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/tilinna/clock"
+)
+
+func TestNewID(t *testing.T) {
+
+ tests := [][2]string{
+ {
+ "Why Do We Have WiFi Passwords?",
+ "why-do-we-have-wifi-passwords",
+ },
+ {
+ "Ginger: A Small VM Update",
+ "ginger-a-small-vm-update",
+ },
+ {
+ "Something-Weird.... woah!",
+ "somethingweird-woah",
+ },
+ }
+
+ for i, test := range tests {
+ t.Run(strconv.Itoa(i), func(t *testing.T) {
+ assert.Equal(t, test[1], NewID(test[0]))
+ })
+ }
+}
+
+func testPost(i int) Post {
+ istr := strconv.Itoa(i)
+ return Post{
+ ID: istr,
+ Title: istr,
+ Description: istr,
+ Body: istr,
+ }
+}
+
+type storeTestHarness struct {
+ clock *clock.Mock
+ store Store
+}
+
+func newStoreTestHarness(t *testing.T) storeTestHarness {
+
+ clock := clock.NewMock(time.Now().UTC().Truncate(1 * time.Hour))
+
+ db := NewInMemSQLDB()
+ t.Cleanup(func() { db.Close() })
+
+ store := NewStore(db)
+
+ return storeTestHarness{
+ clock: clock,
+ store: store,
+ }
+}
+
+func (h *storeTestHarness) testStoredPost(i int) StoredPost {
+ post := testPost(i)
+ return StoredPost{
+ Post: post,
+ PublishedAt: h.clock.Now(),
+ }
+}
+
+func TestStore(t *testing.T) {
+
+ assertPostEqual := func(t *testing.T, exp, got StoredPost) {
+ t.Helper()
+ sort.Strings(exp.Tags)
+ sort.Strings(got.Tags)
+ assert.Equal(t, exp, got)
+ }
+
+ assertPostsEqual := func(t *testing.T, exp, got []StoredPost) {
+ t.Helper()
+
+ if !assert.Len(t, got, len(exp), "exp:%+v\ngot: %+v", exp, got) {
+ return
+ }
+
+ for i := range exp {
+ assertPostEqual(t, exp[i], got[i])
+ }
+ }
+
+ t.Run("not_found", func(t *testing.T) {
+ h := newStoreTestHarness(t)
+
+ _, err := h.store.GetByID("foo")
+ assert.ErrorIs(t, err, ErrPostNotFound)
+ })
+
+ t.Run("set_get_delete", func(t *testing.T) {
+ h := newStoreTestHarness(t)
+
+ now := h.clock.Now().UTC()
+
+ post := testPost(0)
+ post.Tags = []string{"foo", "bar"}
+
+ first, err := h.store.Set(post, now)
+ assert.NoError(t, err)
+ assert.True(t, first)
+
+ gotPost, err := h.store.GetByID(post.ID)
+ assert.NoError(t, err)
+
+ assertPostEqual(t, StoredPost{
+ Post: post,
+ PublishedAt: now,
+ }, gotPost)
+
+ // we will now try updating the post on a different day, and ensure it
+ // updates properly
+
+ h.clock.Add(24 * time.Hour)
+ newNow := h.clock.Now().UTC()
+
+ post.Title = "something else"
+ post.Series = "whatever"
+ post.Body = "anything"
+ post.Tags = []string{"bar", "baz"}
+
+ first, err = h.store.Set(post, newNow)
+ assert.NoError(t, err)
+ assert.False(t, first)
+
+ gotPost, err = h.store.GetByID(post.ID)
+ assert.NoError(t, err)
+
+ assertPostEqual(t, StoredPost{
+ Post: post,
+ PublishedAt: now,
+ LastUpdatedAt: newNow,
+ }, gotPost)
+
+ // delete the post, it should go away
+ assert.NoError(t, h.store.Delete(post.ID))
+
+ _, err = h.store.GetByID(post.ID)
+ assert.ErrorIs(t, err, ErrPostNotFound)
+ })
+
+ t.Run("get", func(t *testing.T) {
+ h := newStoreTestHarness(t)
+
+ now := h.clock.Now().UTC()
+
+ posts := []StoredPost{
+ h.testStoredPost(3),
+ h.testStoredPost(2),
+ h.testStoredPost(1),
+ h.testStoredPost(0),
+ }
+
+ for _, post := range posts {
+ _, err := h.store.Set(post.Post, now)
+ assert.NoError(t, err)
+ }
+
+ gotPosts, hasMore, err := h.store.Get(0, 2)
+ assert.NoError(t, err)
+ assert.True(t, hasMore)
+ assertPostsEqual(t, posts[:2], gotPosts)
+
+ gotPosts, hasMore, err = h.store.Get(1, 2)
+ assert.NoError(t, err)
+ assert.False(t, hasMore)
+ assertPostsEqual(t, posts[2:4], gotPosts)
+
+ posts = append([]StoredPost{h.testStoredPost(4)}, posts...)
+ _, err = h.store.Set(posts[0].Post, now)
+ assert.NoError(t, err)
+
+ gotPosts, hasMore, err = h.store.Get(1, 2)
+ assert.NoError(t, err)
+ assert.True(t, hasMore)
+ assertPostsEqual(t, posts[2:4], gotPosts)
+
+ gotPosts, hasMore, err = h.store.Get(2, 2)
+ assert.NoError(t, err)
+ assert.False(t, hasMore)
+ assertPostsEqual(t, posts[4:], gotPosts)
+ })
+
+ t.Run("get_by_series", func(t *testing.T) {
+ h := newStoreTestHarness(t)
+
+ now := h.clock.Now().UTC()
+
+ posts := []StoredPost{
+ h.testStoredPost(3),
+ h.testStoredPost(2),
+ h.testStoredPost(1),
+ h.testStoredPost(0),
+ }
+
+ posts[0].Series = "foo"
+ posts[1].Series = "bar"
+ posts[2].Series = "bar"
+
+ for _, post := range posts {
+ _, err := h.store.Set(post.Post, now)
+ assert.NoError(t, err)
+ }
+
+ fooPosts, err := h.store.GetBySeries("foo")
+ assert.NoError(t, err)
+ assertPostsEqual(t, posts[:1], fooPosts)
+
+ barPosts, err := h.store.GetBySeries("bar")
+ assert.NoError(t, err)
+ assertPostsEqual(t, posts[1:3], barPosts)
+
+ bazPosts, err := h.store.GetBySeries("baz")
+ assert.NoError(t, err)
+ assert.Empty(t, bazPosts)
+ })
+
+ t.Run("get_by_tag", func(t *testing.T) {
+
+ h := newStoreTestHarness(t)
+
+ now := h.clock.Now().UTC()
+
+ posts := []StoredPost{
+ h.testStoredPost(3),
+ h.testStoredPost(2),
+ h.testStoredPost(1),
+ h.testStoredPost(0),
+ }
+
+ posts[0].Tags = []string{"foo"}
+ posts[1].Tags = []string{"foo", "bar"}
+ posts[2].Tags = []string{"bar"}
+
+ for _, post := range posts {
+ _, err := h.store.Set(post.Post, now)
+ assert.NoError(t, err)
+ }
+
+ fooPosts, err := h.store.GetByTag("foo")
+ assert.NoError(t, err)
+ assertPostsEqual(t, posts[:2], fooPosts)
+
+ barPosts, err := h.store.GetByTag("bar")
+ assert.NoError(t, err)
+ assertPostsEqual(t, posts[1:3], barPosts)
+
+ bazPosts, err := h.store.GetByTag("baz")
+ assert.NoError(t, err)
+ assert.Empty(t, bazPosts)
+
+ tags, err := h.store.GetTags()
+ assert.NoError(t, err)
+ assert.ElementsMatch(t, []string{"foo", "bar"}, tags)
+ })
+}
diff --git a/src/post/sql.go b/src/post/sql.go
new file mode 100644
index 0000000..c768c9a
--- /dev/null
+++ b/src/post/sql.go
@@ -0,0 +1,126 @@
+package post
+
+import (
+ "database/sql"
+ "fmt"
+ "path"
+
+ "github.com/mediocregopher/blog.mediocregopher.com/srv/cfg"
+ migrate "github.com/rubenv/sql-migrate"
+
+ _ "github.com/mattn/go-sqlite3" // we need dis
+)
+
+var migrations = &migrate.MemoryMigrationSource{Migrations: []*migrate.Migration{
+ {
+ Id: "1",
+ Up: []string{
+ `CREATE TABLE posts (
+ id TEXT NOT NULL PRIMARY KEY,
+ title TEXT NOT NULL,
+ description TEXT NOT NULL,
+ series TEXT,
+
+ published_at INTEGER NOT NULL,
+ last_updated_at INTEGER,
+
+ body TEXT NOT NULL
+ )`,
+
+ `CREATE TABLE post_tags (
+ post_id TEXT NOT NULL,
+ tag TEXT NOT NULL,
+ UNIQUE(post_id, tag)
+ )`,
+
+ `CREATE TABLE assets (
+ id TEXT NOT NULL PRIMARY KEY,
+ body BLOB NOT NULL
+ )`,
+ },
+ },
+ {
+ Id: "2",
+ Up: []string{
+ `CREATE TABLE post_drafts (
+ id TEXT NOT NULL PRIMARY KEY,
+ title TEXT NOT NULL,
+ description TEXT NOT NULL,
+ tags TEXT,
+ series TEXT,
+ body TEXT NOT NULL
+ )`,
+ },
+ },
+}}
+
+// SQLDB is a sqlite3 database which can be used by storage interfaces within
+// this package.
+type SQLDB struct {
+ db *sql.DB
+}
+
+// NewSQLDB initializes and returns a new sqlite3 database for storage
+// intefaces. The db will be created within the given data directory.
+func NewSQLDB(dataDir cfg.DataDir) (*SQLDB, error) {
+
+ path := path.Join(dataDir.Path, "post.sqlite3")
+
+ db, err := sql.Open("sqlite3", path)
+ if err != nil {
+ return nil, fmt.Errorf("opening sqlite file at %q: %w", path, err)
+ }
+
+ if _, err := migrate.Exec(db, "sqlite3", migrations, migrate.Up); err != nil {
+ return nil, fmt.Errorf("running migrations: %w", err)
+ }
+
+ return &SQLDB{db}, nil
+}
+
+// NewSQLDB is like NewSQLDB, but the database will be initialized in memory.
+func NewInMemSQLDB() *SQLDB {
+
+ db, err := sql.Open("sqlite3", ":memory:")
+ if err != nil {
+ panic(fmt.Errorf("opening sqlite in memory: %w", err))
+ }
+
+ if _, err := migrate.Exec(db, "sqlite3", migrations, migrate.Up); err != nil {
+ panic(fmt.Errorf("running migrations: %w", err))
+ }
+
+ return &SQLDB{db}
+}
+
+// Close cleans up loose resources being held by the db.
+func (db *SQLDB) Close() error {
+ return db.db.Close()
+}
+
+func (db *SQLDB) withTx(cb func(*sql.Tx) error) error {
+
+ tx, err := db.db.Begin()
+
+ if err != nil {
+ return fmt.Errorf("starting transaction: %w", err)
+ }
+
+ if err := cb(tx); err != nil {
+
+ if rollbackErr := tx.Rollback(); rollbackErr != nil {
+ return fmt.Errorf(
+ "rolling back transaction: %w (original error: %v)",
+ rollbackErr, err,
+ )
+ }
+
+ return fmt.Errorf("performing transaction: %w (rolled back)", err)
+ }
+
+ if err := tx.Commit(); err != nil {
+ return fmt.Errorf("committing transaction: %w", err)
+ }
+
+ return nil
+}