diff options
Diffstat (limited to 'srv/src/post')
-rw-r--r-- | srv/src/post/asset.go | 114 | ||||
-rw-r--r-- | srv/src/post/asset_test.go | 91 | ||||
-rw-r--r-- | srv/src/post/draft_post.go | 187 | ||||
-rw-r--r-- | srv/src/post/draft_post_test.go | 130 | ||||
-rw-r--r-- | srv/src/post/post.go | 368 | ||||
-rw-r--r-- | srv/src/post/post_test.go | 268 | ||||
-rw-r--r-- | srv/src/post/sql.go | 126 |
7 files changed, 0 insertions, 1284 deletions
diff --git a/srv/src/post/asset.go b/srv/src/post/asset.go deleted file mode 100644 index a7b605b..0000000 --- a/srv/src/post/asset.go +++ /dev/null @@ -1,114 +0,0 @@ -package post - -import ( - "bytes" - "database/sql" - "errors" - "fmt" - "io" -) - -var ( - // ErrAssetNotFound is used to indicate an Asset could not be found in the - // AssetStore. - ErrAssetNotFound = errors.New("asset not found") -) - -// AssetStore implements the storage and retrieval of binary assets, which are -// intended to be used by posts (e.g. images). -type AssetStore interface { - - // Set sets the id to the contents of the given io.Reader. - Set(id string, from io.Reader) error - - // Get writes the id's body to the given io.Writer, or returns - // ErrAssetNotFound. - Get(id string, into io.Writer) error - - // Delete's the body stored for the id, if any. - Delete(id string) error - - // List returns all ids which are currently stored. - List() ([]string, error) -} - -type assetStore struct { - db *sql.DB -} - -// NewAssetStore initializes a new AssetStore using an existing SQLDB. -func NewAssetStore(db *SQLDB) AssetStore { - return &assetStore{ - db: db.db, - } -} - -func (s *assetStore) Set(id string, from io.Reader) error { - - body, err := io.ReadAll(from) - if err != nil { - return fmt.Errorf("reading body fully into memory: %w", err) - } - - _, err = s.db.Exec( - `INSERT INTO assets (id, body) - VALUES (?, ?) - ON CONFLICT (id) DO UPDATE SET body=excluded.body`, - id, body, - ) - - if err != nil { - return fmt.Errorf("inserting into assets: %w", err) - } - - return nil -} - -func (s *assetStore) Get(id string, into io.Writer) error { - - var body []byte - - err := s.db.QueryRow(`SELECT body FROM assets WHERE id = ?`, id).Scan(&body) - - if errors.Is(err, sql.ErrNoRows) { - return ErrAssetNotFound - } else if err != nil { - return fmt.Errorf("selecting from assets: %w", err) - } - - if _, err := io.Copy(into, bytes.NewReader(body)); err != nil { - return fmt.Errorf("writing body to io.Writer: %w", err) - } - - return nil -} - -func (s *assetStore) Delete(id string) error { - _, err := s.db.Exec(`DELETE FROM assets WHERE id = ?`, id) - return err -} - -func (s *assetStore) List() ([]string, error) { - - rows, err := s.db.Query(`SELECT id FROM assets ORDER BY id ASC`) - - if err != nil { - return nil, fmt.Errorf("querying: %w", err) - } - - defer rows.Close() - - var ids []string - - for rows.Next() { - - var id string - if err := rows.Scan(&id); err != nil { - return nil, fmt.Errorf("scanning row: %w", err) - } - - ids = append(ids, id) - } - - return ids, nil -} diff --git a/srv/src/post/asset_test.go b/srv/src/post/asset_test.go deleted file mode 100644 index 4d62d46..0000000 --- a/srv/src/post/asset_test.go +++ /dev/null @@ -1,91 +0,0 @@ -package post - -import ( - "bytes" - "io" - "testing" - - "github.com/stretchr/testify/assert" -) - -type assetTestHarness struct { - store AssetStore -} - -func newAssetTestHarness(t *testing.T) *assetTestHarness { - - db := NewInMemSQLDB() - t.Cleanup(func() { db.Close() }) - - store := NewAssetStore(db) - - return &assetTestHarness{ - store: store, - } -} - -func (h *assetTestHarness) assertGet(t *testing.T, exp, id string) { - t.Helper() - buf := new(bytes.Buffer) - err := h.store.Get(id, buf) - assert.NoError(t, err) - assert.Equal(t, exp, buf.String()) -} - -func (h *assetTestHarness) assertNotFound(t *testing.T, id string) { - t.Helper() - err := h.store.Get(id, io.Discard) - assert.ErrorIs(t, ErrAssetNotFound, err) -} - -func TestAssetStore(t *testing.T) { - - testAssetStore := func(t *testing.T, h *assetTestHarness) { - t.Helper() - - h.assertNotFound(t, "foo") - h.assertNotFound(t, "bar") - - err := h.store.Set("foo", bytes.NewBufferString("FOO")) - assert.NoError(t, err) - - h.assertGet(t, "FOO", "foo") - h.assertNotFound(t, "bar") - - err = h.store.Set("foo", bytes.NewBufferString("FOOFOO")) - assert.NoError(t, err) - - h.assertGet(t, "FOOFOO", "foo") - h.assertNotFound(t, "bar") - - assert.NoError(t, h.store.Delete("foo")) - h.assertNotFound(t, "foo") - h.assertNotFound(t, "bar") - - assert.NoError(t, h.store.Delete("bar")) - h.assertNotFound(t, "foo") - h.assertNotFound(t, "bar") - - // test list - - ids, err := h.store.List() - assert.NoError(t, err) - assert.Empty(t, ids) - - err = h.store.Set("foo", bytes.NewBufferString("FOOFOO")) - assert.NoError(t, err) - err = h.store.Set("foo", bytes.NewBufferString("FOOFOO")) - assert.NoError(t, err) - err = h.store.Set("bar", bytes.NewBufferString("FOOFOO")) - assert.NoError(t, err) - - ids, err = h.store.List() - assert.NoError(t, err) - assert.Equal(t, []string{"bar", "foo"}, ids) - } - - t.Run("sql", func(t *testing.T) { - h := newAssetTestHarness(t) - testAssetStore(t, h) - }) -} diff --git a/srv/src/post/draft_post.go b/srv/src/post/draft_post.go deleted file mode 100644 index 61283c3..0000000 --- a/srv/src/post/draft_post.go +++ /dev/null @@ -1,187 +0,0 @@ -package post - -import ( - "database/sql" - "encoding/json" - "errors" - "fmt" -) - -type DraftStore interface { - - // Set sets the draft Post's data into the storage, keyed by the draft - // Post's ID. - Set(post Post) error - - // Get returns count draft Posts, sorted id descending, offset by the - // given page number. The returned boolean indicates if there are more pages - // or not. - Get(page, count int) ([]Post, bool, error) - - // GetByID will return the draft Post with the given ID, or ErrPostNotFound. - GetByID(id string) (Post, error) - - // Delete will delete the draft Post with the given ID. - Delete(id string) error -} - -type draftStore struct { - db *SQLDB -} - -// NewDraftStore initializes a new DraftStore using an existing SQLDB. -func NewDraftStore(db *SQLDB) DraftStore { - return &draftStore{ - db: db, - } -} - -func (s *draftStore) Set(post Post) error { - - if post.ID == "" { - return errors.New("post ID can't be empty") - } - - tagsJSON, err := json.Marshal(post.Tags) - if err != nil { - return fmt.Errorf("json marshaling tags %#v: %w", post.Tags, err) - } - - _, err = s.db.db.Exec( - `INSERT INTO post_drafts ( - id, title, description, tags, series, body - ) - VALUES - (?, ?, ?, ?, ?, ?) - ON CONFLICT (id) DO UPDATE SET - title=excluded.title, - description=excluded.description, - tags=excluded.tags, - series=excluded.series, - body=excluded.body`, - post.ID, - post.Title, - post.Description, - &sql.NullString{String: string(tagsJSON), Valid: len(post.Tags) > 0}, - &sql.NullString{String: post.Series, Valid: post.Series != ""}, - post.Body, - ) - - if err != nil { - return fmt.Errorf("inserting into post_drafts: %w", err) - } - - return nil -} - -func (s *draftStore) get( - querier interface { - Query(string, ...interface{}) (*sql.Rows, error) - }, - limit, offset int, - where string, whereArgs ...interface{}, -) ( - []Post, error, -) { - - query := ` - SELECT - p.id, p.title, p.description, p.tags, p.series, p.body - FROM post_drafts p - ` + where + ` - ORDER BY p.id ASC` - - if limit > 0 { - query += fmt.Sprintf(" LIMIT %d", limit) - } - - if offset > 0 { - query += fmt.Sprintf(" OFFSET %d", offset) - } - - rows, err := querier.Query(query, whereArgs...) - - if err != nil { - return nil, fmt.Errorf("selecting: %w", err) - } - - defer rows.Close() - - var posts []Post - - for rows.Next() { - - var ( - post Post - tags, series sql.NullString - ) - - err := rows.Scan( - &post.ID, &post.Title, &post.Description, &tags, &series, - &post.Body, - ) - - if err != nil { - return nil, fmt.Errorf("scanning row: %w", err) - } - - post.Series = series.String - - if tags.String != "" { - - if err := json.Unmarshal([]byte(tags.String), &post.Tags); err != nil { - return nil, fmt.Errorf("json parsing %q: %w", tags.String, err) - } - } - - posts = append(posts, post) - } - - return posts, nil -} - -func (s *draftStore) Get(page, count int) ([]Post, bool, error) { - - posts, err := s.get(s.db.db, count+1, page*count, ``) - - if err != nil { - return nil, false, fmt.Errorf("querying post_drafts: %w", err) - } - - var hasMore bool - - if len(posts) > count { - hasMore = true - posts = posts[:count] - } - - return posts, hasMore, nil -} - -func (s *draftStore) GetByID(id string) (Post, error) { - - posts, err := s.get(s.db.db, 0, 0, `WHERE p.id=?`, id) - - if err != nil { - return Post{}, fmt.Errorf("querying post_drafts: %w", err) - } - - if len(posts) == 0 { - return Post{}, ErrPostNotFound - } - - if len(posts) > 1 { - panic(fmt.Sprintf("got back multiple draft posts querying id %q: %+v", id, posts)) - } - - return posts[0], nil -} - -func (s *draftStore) Delete(id string) error { - - if _, err := s.db.db.Exec(`DELETE FROM post_drafts WHERE id = ?`, id); err != nil { - return fmt.Errorf("deleting from post_drafts: %w", err) - } - - return nil -} diff --git a/srv/src/post/draft_post_test.go b/srv/src/post/draft_post_test.go deleted file mode 100644 index f404bb0..0000000 --- a/srv/src/post/draft_post_test.go +++ /dev/null @@ -1,130 +0,0 @@ -package post - -import ( - "sort" - "testing" - - "github.com/stretchr/testify/assert" -) - -type draftStoreTestHarness struct { - store DraftStore -} - -func newDraftStoreTestHarness(t *testing.T) draftStoreTestHarness { - - db := NewInMemSQLDB() - t.Cleanup(func() { db.Close() }) - - store := NewDraftStore(db) - - return draftStoreTestHarness{ - store: store, - } -} - -func TestDraftStore(t *testing.T) { - - assertPostEqual := func(t *testing.T, exp, got Post) { - t.Helper() - sort.Strings(exp.Tags) - sort.Strings(got.Tags) - assert.Equal(t, exp, got) - } - - assertPostsEqual := func(t *testing.T, exp, got []Post) { - t.Helper() - - if !assert.Len(t, got, len(exp), "exp:%+v\ngot: %+v", exp, got) { - return - } - - for i := range exp { - assertPostEqual(t, exp[i], got[i]) - } - } - - t.Run("not_found", func(t *testing.T) { - h := newDraftStoreTestHarness(t) - - _, err := h.store.GetByID("foo") - assert.ErrorIs(t, err, ErrPostNotFound) - }) - - t.Run("set_get_delete", func(t *testing.T) { - h := newDraftStoreTestHarness(t) - - post := testPost(0) - post.Tags = []string{"foo", "bar"} - - err := h.store.Set(post) - assert.NoError(t, err) - - gotPost, err := h.store.GetByID(post.ID) - assert.NoError(t, err) - - assertPostEqual(t, post, gotPost) - - // we will now try updating the post, and ensure it updates properly - - post.Title = "something else" - post.Series = "whatever" - post.Body = "anything" - post.Tags = []string{"bar", "baz"} - - err = h.store.Set(post) - assert.NoError(t, err) - - gotPost, err = h.store.GetByID(post.ID) - assert.NoError(t, err) - - assertPostEqual(t, post, gotPost) - - // delete the post, it should go away - assert.NoError(t, h.store.Delete(post.ID)) - - _, err = h.store.GetByID(post.ID) - assert.ErrorIs(t, err, ErrPostNotFound) - }) - - t.Run("get", func(t *testing.T) { - h := newDraftStoreTestHarness(t) - - posts := []Post{ - testPost(0), - testPost(1), - testPost(2), - testPost(3), - } - - for _, post := range posts { - err := h.store.Set(post) - assert.NoError(t, err) - } - - gotPosts, hasMore, err := h.store.Get(0, 2) - assert.NoError(t, err) - assert.True(t, hasMore) - assertPostsEqual(t, posts[:2], gotPosts) - - gotPosts, hasMore, err = h.store.Get(1, 2) - assert.NoError(t, err) - assert.False(t, hasMore) - assertPostsEqual(t, posts[2:4], gotPosts) - - posts = append(posts, testPost(4)) - err = h.store.Set(posts[4]) - assert.NoError(t, err) - - gotPosts, hasMore, err = h.store.Get(1, 2) - assert.NoError(t, err) - assert.True(t, hasMore) - assertPostsEqual(t, posts[2:4], gotPosts) - - gotPosts, hasMore, err = h.store.Get(2, 2) - assert.NoError(t, err) - assert.False(t, hasMore) - assertPostsEqual(t, posts[4:], gotPosts) - }) - -} diff --git a/srv/src/post/post.go b/srv/src/post/post.go deleted file mode 100644 index 03bce6c..0000000 --- a/srv/src/post/post.go +++ /dev/null @@ -1,368 +0,0 @@ -// Package post deals with the storage and rendering of blog posts. -package post - -import ( - "database/sql" - "errors" - "fmt" - "regexp" - "strings" - "time" -) - -var ( - // ErrPostNotFound is used to indicate a Post could not be found in the - // Store. - ErrPostNotFound = errors.New("post not found") -) - -var titleCleanRegexp = regexp.MustCompile(`[^a-z ]`) - -// NewID generates a (hopefully) unique ID based on the given title. -func NewID(title string) string { - title = strings.ToLower(title) - title = titleCleanRegexp.ReplaceAllString(title, "") - title = strings.ReplaceAll(title, " ", "-") - return title -} - -// Post contains all information having to do with a blog post. -type Post struct { - ID string - Title string - Description string - Tags []string // only alphanumeric supported - Series string - Body string -} - -// StoredPost is a Post which has been stored in a Store, and has been given -// some extra fields as a result. -type StoredPost struct { - Post - - PublishedAt time.Time - LastUpdatedAt time.Time -} - -// Store is used for storing posts to a persistent storage. -type Store interface { - - // Set sets the Post data into the storage, keyed by the Post's ID. If there - // was not a previously existing Post with the same ID then Set returns - // true. It overwrites the previous Post with the same ID otherwise. - Set(post Post, now time.Time) (bool, error) - - // Get returns count StoredPosts, sorted time descending, offset by the - // given page number. The returned boolean indicates if there are more pages - // or not. - Get(page, count int) ([]StoredPost, bool, error) - - // GetByID will return the StoredPost with the given ID, or ErrPostNotFound. - GetByID(id string) (StoredPost, error) - - // GetBySeries returns all StoredPosts with the given series, sorted time - // descending, or empty slice. - GetBySeries(series string) ([]StoredPost, error) - - // GetByTag returns all StoredPosts with the given tag, sorted time - // descending, or empty slice. - GetByTag(tag string) ([]StoredPost, error) - - // GetTags returns all tags which have at least one Post using them. - GetTags() ([]string, error) - - // Delete will delete the StoredPost with the given ID. - Delete(id string) error -} - -type store struct { - db *SQLDB -} - -// NewStore initializes a new Store using an existing SQLDB. -func NewStore(db *SQLDB) Store { - return &store{ - db: db, - } -} - -func (s *store) Set(post Post, now time.Time) (bool, error) { - - if post.ID == "" { - return false, errors.New("post ID can't be empty") - } - - var first bool - - err := s.db.withTx(func(tx *sql.Tx) error { - - nowTS := now.Unix() - - nowSQL := sql.NullInt64{Int64: nowTS, Valid: !now.IsZero()} - - _, err := tx.Exec( - `INSERT INTO posts ( - id, title, description, series, published_at, body - ) - VALUES - (?, ?, ?, ?, ?, ?) - ON CONFLICT (id) DO UPDATE SET - title=excluded.title, - description=excluded.description, - series=excluded.series, - last_updated_at=?, - body=excluded.body`, - post.ID, - post.Title, - post.Description, - &sql.NullString{String: post.Series, Valid: post.Series != ""}, - nowSQL, - post.Body, - nowSQL, - ) - - if err != nil { - return fmt.Errorf("inserting into posts: %w", err) - } - - // this is a bit of a hack, but it allows us to update the tagset without - // doing a diff. - _, err = tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, post.ID) - - if err != nil { - return fmt.Errorf("clearning post tags: %w", err) - } - - for _, tag := range post.Tags { - - _, err = tx.Exec( - `INSERT INTO post_tags (post_id, tag) VALUES (?, ?) - ON CONFLICT DO NOTHING`, - post.ID, - tag, - ) - - if err != nil { - return fmt.Errorf("inserting tag %q: %w", tag, err) - } - } - - err = tx.QueryRow( - `SELECT 1 FROM posts WHERE id=? AND last_updated_at IS NULL`, - post.ID, - ).Scan(new(int)) - - first = !errors.Is(err, sql.ErrNoRows) - - return nil - }) - - return first, err -} - -func (s *store) get( - querier interface { - Query(string, ...interface{}) (*sql.Rows, error) - }, - limit, offset int, - where string, whereArgs ...interface{}, -) ( - []StoredPost, error, -) { - - query := ` - SELECT - p.id, p.title, p.description, p.series, GROUP_CONCAT(pt.tag), - p.published_at, p.last_updated_at, - p.body - FROM posts p - LEFT JOIN post_tags pt ON (p.id = pt.post_id) - ` + where + ` - GROUP BY (p.id) - ORDER BY p.published_at DESC, p.title DESC` - - if limit > 0 { - query += fmt.Sprintf(" LIMIT %d", limit) - } - - if offset > 0 { - query += fmt.Sprintf(" OFFSET %d", offset) - } - - rows, err := querier.Query(query, whereArgs...) - - if err != nil { - return nil, fmt.Errorf("selecting: %w", err) - } - - var posts []StoredPost - - for rows.Next() { - - var ( - post StoredPost - series, tag sql.NullString - publishedAt, lastUpdatedAt sql.NullInt64 - ) - - err := rows.Scan( - &post.ID, &post.Title, &post.Description, &series, &tag, - &publishedAt, &lastUpdatedAt, - &post.Body, - ) - - if err != nil { - return nil, fmt.Errorf("scanning row: %w", err) - } - - post.Series = series.String - - if tag.String != "" { - post.Tags = strings.Split(tag.String, ",") - } - - if publishedAt.Valid { - post.PublishedAt = time.Unix(publishedAt.Int64, 0).UTC() - } - - if lastUpdatedAt.Valid { - post.LastUpdatedAt = time.Unix(lastUpdatedAt.Int64, 0).UTC() - } - - posts = append(posts, post) - } - - if err := rows.Close(); err != nil { - return nil, fmt.Errorf("closing row iterator: %w", err) - } - - return posts, nil -} - -func (s *store) Get(page, count int) ([]StoredPost, bool, error) { - - posts, err := s.get(s.db.db, count+1, page*count, ``) - - if err != nil { - return nil, false, fmt.Errorf("querying posts: %w", err) - } - - var hasMore bool - - if len(posts) > count { - hasMore = true - posts = posts[:count] - } - - return posts, hasMore, nil -} - -func (s *store) GetByID(id string) (StoredPost, error) { - - posts, err := s.get(s.db.db, 0, 0, `WHERE p.id=?`, id) - - if err != nil { - return StoredPost{}, fmt.Errorf("querying posts: %w", err) - } - - if len(posts) == 0 { - return StoredPost{}, ErrPostNotFound - } - - if len(posts) > 1 { - panic(fmt.Sprintf("got back multiple posts querying id %q: %+v", id, posts)) - } - - return posts[0], nil -} - -func (s *store) GetBySeries(series string) ([]StoredPost, error) { - return s.get(s.db.db, 0, 0, `WHERE p.series=?`, series) -} - -func (s *store) GetByTag(tag string) ([]StoredPost, error) { - - var posts []StoredPost - - err := s.db.withTx(func(tx *sql.Tx) error { - - rows, err := tx.Query(`SELECT post_id FROM post_tags WHERE tag = ?`, tag) - - if err != nil { - return fmt.Errorf("querying post_tags by tag: %w", err) - } - - var ( - placeholders []string - whereArgs []interface{} - ) - - for rows.Next() { - - var id string - - if err := rows.Scan(&id); err != nil { - rows.Close() - return fmt.Errorf("scanning id: %w", err) - } - - whereArgs = append(whereArgs, id) - placeholders = append(placeholders, "?") - } - - if err := rows.Close(); err != nil { - return fmt.Errorf("closing row iterator: %w", err) - } - - where := fmt.Sprintf("WHERE p.id IN (%s)", strings.Join(placeholders, ",")) - - if posts, err = s.get(tx, 0, 0, where, whereArgs...); err != nil { - return fmt.Errorf("querying for ids %+v: %w", whereArgs, err) - } - - return nil - }) - - return posts, err -} - -func (s *store) GetTags() ([]string, error) { - - rows, err := s.db.db.Query(`SELECT tag FROM post_tags GROUP BY tag`) - if err != nil { - return nil, fmt.Errorf("querying all tags: %w", err) - } - defer rows.Close() - - var tags []string - - for rows.Next() { - - var tag string - - if err := rows.Scan(&tag); err != nil { - return nil, fmt.Errorf("scanning tag: %w", err) - } - - tags = append(tags, tag) - } - - return tags, nil -} - -func (s *store) Delete(id string) error { - - return s.db.withTx(func(tx *sql.Tx) error { - - if _, err := tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, id); err != nil { - return fmt.Errorf("deleting from post_tags: %w", err) - } - - if _, err := tx.Exec(`DELETE FROM posts WHERE id = ?`, id); err != nil { - return fmt.Errorf("deleting from posts: %w", err) - } - - return nil - }) -} diff --git a/srv/src/post/post_test.go b/srv/src/post/post_test.go deleted file mode 100644 index c7f9cdc..0000000 --- a/srv/src/post/post_test.go +++ /dev/null @@ -1,268 +0,0 @@ -package post - -import ( - "sort" - "strconv" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/tilinna/clock" -) - -func TestNewID(t *testing.T) { - - tests := [][2]string{ - { - "Why Do We Have WiFi Passwords?", - "why-do-we-have-wifi-passwords", - }, - { - "Ginger: A Small VM Update", - "ginger-a-small-vm-update", - }, - { - "Something-Weird.... woah!", - "somethingweird-woah", - }, - } - - for i, test := range tests { - t.Run(strconv.Itoa(i), func(t *testing.T) { - assert.Equal(t, test[1], NewID(test[0])) - }) - } -} - -func testPost(i int) Post { - istr := strconv.Itoa(i) - return Post{ - ID: istr, - Title: istr, - Description: istr, - Body: istr, - } -} - -type storeTestHarness struct { - clock *clock.Mock - store Store -} - -func newStoreTestHarness(t *testing.T) storeTestHarness { - - clock := clock.NewMock(time.Now().UTC().Truncate(1 * time.Hour)) - - db := NewInMemSQLDB() - t.Cleanup(func() { db.Close() }) - - store := NewStore(db) - - return storeTestHarness{ - clock: clock, - store: store, - } -} - -func (h *storeTestHarness) testStoredPost(i int) StoredPost { - post := testPost(i) - return StoredPost{ - Post: post, - PublishedAt: h.clock.Now(), - } -} - -func TestStore(t *testing.T) { - - assertPostEqual := func(t *testing.T, exp, got StoredPost) { - t.Helper() - sort.Strings(exp.Tags) - sort.Strings(got.Tags) - assert.Equal(t, exp, got) - } - - assertPostsEqual := func(t *testing.T, exp, got []StoredPost) { - t.Helper() - - if !assert.Len(t, got, len(exp), "exp:%+v\ngot: %+v", exp, got) { - return - } - - for i := range exp { - assertPostEqual(t, exp[i], got[i]) - } - } - - t.Run("not_found", func(t *testing.T) { - h := newStoreTestHarness(t) - - _, err := h.store.GetByID("foo") - assert.ErrorIs(t, err, ErrPostNotFound) - }) - - t.Run("set_get_delete", func(t *testing.T) { - h := newStoreTestHarness(t) - - now := h.clock.Now().UTC() - - post := testPost(0) - post.Tags = []string{"foo", "bar"} - - first, err := h.store.Set(post, now) - assert.NoError(t, err) - assert.True(t, first) - - gotPost, err := h.store.GetByID(post.ID) - assert.NoError(t, err) - - assertPostEqual(t, StoredPost{ - Post: post, - PublishedAt: now, - }, gotPost) - - // we will now try updating the post on a different day, and ensure it - // updates properly - - h.clock.Add(24 * time.Hour) - newNow := h.clock.Now().UTC() - - post.Title = "something else" - post.Series = "whatever" - post.Body = "anything" - post.Tags = []string{"bar", "baz"} - - first, err = h.store.Set(post, newNow) - assert.NoError(t, err) - assert.False(t, first) - - gotPost, err = h.store.GetByID(post.ID) - assert.NoError(t, err) - - assertPostEqual(t, StoredPost{ - Post: post, - PublishedAt: now, - LastUpdatedAt: newNow, - }, gotPost) - - // delete the post, it should go away - assert.NoError(t, h.store.Delete(post.ID)) - - _, err = h.store.GetByID(post.ID) - assert.ErrorIs(t, err, ErrPostNotFound) - }) - - t.Run("get", func(t *testing.T) { - h := newStoreTestHarness(t) - - now := h.clock.Now().UTC() - - posts := []StoredPost{ - h.testStoredPost(3), - h.testStoredPost(2), - h.testStoredPost(1), - h.testStoredPost(0), - } - - for _, post := range posts { - _, err := h.store.Set(post.Post, now) - assert.NoError(t, err) - } - - gotPosts, hasMore, err := h.store.Get(0, 2) - assert.NoError(t, err) - assert.True(t, hasMore) - assertPostsEqual(t, posts[:2], gotPosts) - - gotPosts, hasMore, err = h.store.Get(1, 2) - assert.NoError(t, err) - assert.False(t, hasMore) - assertPostsEqual(t, posts[2:4], gotPosts) - - posts = append([]StoredPost{h.testStoredPost(4)}, posts...) - _, err = h.store.Set(posts[0].Post, now) - assert.NoError(t, err) - - gotPosts, hasMore, err = h.store.Get(1, 2) - assert.NoError(t, err) - assert.True(t, hasMore) - assertPostsEqual(t, posts[2:4], gotPosts) - - gotPosts, hasMore, err = h.store.Get(2, 2) - assert.NoError(t, err) - assert.False(t, hasMore) - assertPostsEqual(t, posts[4:], gotPosts) - }) - - t.Run("get_by_series", func(t *testing.T) { - h := newStoreTestHarness(t) - - now := h.clock.Now().UTC() - - posts := []StoredPost{ - h.testStoredPost(3), - h.testStoredPost(2), - h.testStoredPost(1), - h.testStoredPost(0), - } - - posts[0].Series = "foo" - posts[1].Series = "bar" - posts[2].Series = "bar" - - for _, post := range posts { - _, err := h.store.Set(post.Post, now) - assert.NoError(t, err) - } - - fooPosts, err := h.store.GetBySeries("foo") - assert.NoError(t, err) - assertPostsEqual(t, posts[:1], fooPosts) - - barPosts, err := h.store.GetBySeries("bar") - assert.NoError(t, err) - assertPostsEqual(t, posts[1:3], barPosts) - - bazPosts, err := h.store.GetBySeries("baz") - assert.NoError(t, err) - assert.Empty(t, bazPosts) - }) - - t.Run("get_by_tag", func(t *testing.T) { - - h := newStoreTestHarness(t) - - now := h.clock.Now().UTC() - - posts := []StoredPost{ - h.testStoredPost(3), - h.testStoredPost(2), - h.testStoredPost(1), - h.testStoredPost(0), - } - - posts[0].Tags = []string{"foo"} - posts[1].Tags = []string{"foo", "bar"} - posts[2].Tags = []string{"bar"} - - for _, post := range posts { - _, err := h.store.Set(post.Post, now) - assert.NoError(t, err) - } - - fooPosts, err := h.store.GetByTag("foo") - assert.NoError(t, err) - assertPostsEqual(t, posts[:2], fooPosts) - - barPosts, err := h.store.GetByTag("bar") - assert.NoError(t, err) - assertPostsEqual(t, posts[1:3], barPosts) - - bazPosts, err := h.store.GetByTag("baz") - assert.NoError(t, err) - assert.Empty(t, bazPosts) - - tags, err := h.store.GetTags() - assert.NoError(t, err) - assert.ElementsMatch(t, []string{"foo", "bar"}, tags) - }) -} diff --git a/srv/src/post/sql.go b/srv/src/post/sql.go deleted file mode 100644 index c768c9a..0000000 --- a/srv/src/post/sql.go +++ /dev/null @@ -1,126 +0,0 @@ -package post - -import ( - "database/sql" - "fmt" - "path" - - "github.com/mediocregopher/blog.mediocregopher.com/srv/cfg" - migrate "github.com/rubenv/sql-migrate" - - _ "github.com/mattn/go-sqlite3" // we need dis -) - -var migrations = &migrate.MemoryMigrationSource{Migrations: []*migrate.Migration{ - { - Id: "1", - Up: []string{ - `CREATE TABLE posts ( - id TEXT NOT NULL PRIMARY KEY, - title TEXT NOT NULL, - description TEXT NOT NULL, - series TEXT, - - published_at INTEGER NOT NULL, - last_updated_at INTEGER, - - body TEXT NOT NULL - )`, - - `CREATE TABLE post_tags ( - post_id TEXT NOT NULL, - tag TEXT NOT NULL, - UNIQUE(post_id, tag) - )`, - - `CREATE TABLE assets ( - id TEXT NOT NULL PRIMARY KEY, - body BLOB NOT NULL - )`, - }, - }, - { - Id: "2", - Up: []string{ - `CREATE TABLE post_drafts ( - id TEXT NOT NULL PRIMARY KEY, - title TEXT NOT NULL, - description TEXT NOT NULL, - tags TEXT, - series TEXT, - body TEXT NOT NULL - )`, - }, - }, -}} - -// SQLDB is a sqlite3 database which can be used by storage interfaces within -// this package. -type SQLDB struct { - db *sql.DB -} - -// NewSQLDB initializes and returns a new sqlite3 database for storage -// intefaces. The db will be created within the given data directory. -func NewSQLDB(dataDir cfg.DataDir) (*SQLDB, error) { - - path := path.Join(dataDir.Path, "post.sqlite3") - - db, err := sql.Open("sqlite3", path) - if err != nil { - return nil, fmt.Errorf("opening sqlite file at %q: %w", path, err) - } - - if _, err := migrate.Exec(db, "sqlite3", migrations, migrate.Up); err != nil { - return nil, fmt.Errorf("running migrations: %w", err) - } - - return &SQLDB{db}, nil -} - -// NewSQLDB is like NewSQLDB, but the database will be initialized in memory. -func NewInMemSQLDB() *SQLDB { - - db, err := sql.Open("sqlite3", ":memory:") - if err != nil { - panic(fmt.Errorf("opening sqlite in memory: %w", err)) - } - - if _, err := migrate.Exec(db, "sqlite3", migrations, migrate.Up); err != nil { - panic(fmt.Errorf("running migrations: %w", err)) - } - - return &SQLDB{db} -} - -// Close cleans up loose resources being held by the db. -func (db *SQLDB) Close() error { - return db.db.Close() -} - -func (db *SQLDB) withTx(cb func(*sql.Tx) error) error { - - tx, err := db.db.Begin() - - if err != nil { - return fmt.Errorf("starting transaction: %w", err) - } - - if err := cb(tx); err != nil { - - if rollbackErr := tx.Rollback(); rollbackErr != nil { - return fmt.Errorf( - "rolling back transaction: %w (original error: %v)", - rollbackErr, err, - ) - } - - return fmt.Errorf("performing transaction: %w (rolled back)", err) - } - - if err := tx.Commit(); err != nil { - return fmt.Errorf("committing transaction: %w", err) - } - - return nil -} |