From 4f01edb9230f58ff84b0dd892c931ec8ac9aad55 Mon Sep 17 00:00:00 2001 From: Brian Picciano Date: Tue, 13 Sep 2022 12:56:08 +0200 Subject: move src out of srv, clean up default.nix and Makefile --- src/post/asset.go | 114 ++++++++++++++ src/post/asset_test.go | 91 +++++++++++ src/post/draft_post.go | 187 ++++++++++++++++++++++ src/post/draft_post_test.go | 130 ++++++++++++++++ src/post/post.go | 368 ++++++++++++++++++++++++++++++++++++++++++++ src/post/post_test.go | 268 ++++++++++++++++++++++++++++++++ src/post/sql.go | 126 +++++++++++++++ 7 files changed, 1284 insertions(+) create mode 100644 src/post/asset.go create mode 100644 src/post/asset_test.go create mode 100644 src/post/draft_post.go create mode 100644 src/post/draft_post_test.go create mode 100644 src/post/post.go create mode 100644 src/post/post_test.go create mode 100644 src/post/sql.go (limited to 'src/post') diff --git a/src/post/asset.go b/src/post/asset.go new file mode 100644 index 0000000..a7b605b --- /dev/null +++ b/src/post/asset.go @@ -0,0 +1,114 @@ +package post + +import ( + "bytes" + "database/sql" + "errors" + "fmt" + "io" +) + +var ( + // ErrAssetNotFound is used to indicate an Asset could not be found in the + // AssetStore. + ErrAssetNotFound = errors.New("asset not found") +) + +// AssetStore implements the storage and retrieval of binary assets, which are +// intended to be used by posts (e.g. images). +type AssetStore interface { + + // Set sets the id to the contents of the given io.Reader. + Set(id string, from io.Reader) error + + // Get writes the id's body to the given io.Writer, or returns + // ErrAssetNotFound. + Get(id string, into io.Writer) error + + // Delete's the body stored for the id, if any. + Delete(id string) error + + // List returns all ids which are currently stored. + List() ([]string, error) +} + +type assetStore struct { + db *sql.DB +} + +// NewAssetStore initializes a new AssetStore using an existing SQLDB. +func NewAssetStore(db *SQLDB) AssetStore { + return &assetStore{ + db: db.db, + } +} + +func (s *assetStore) Set(id string, from io.Reader) error { + + body, err := io.ReadAll(from) + if err != nil { + return fmt.Errorf("reading body fully into memory: %w", err) + } + + _, err = s.db.Exec( + `INSERT INTO assets (id, body) + VALUES (?, ?) + ON CONFLICT (id) DO UPDATE SET body=excluded.body`, + id, body, + ) + + if err != nil { + return fmt.Errorf("inserting into assets: %w", err) + } + + return nil +} + +func (s *assetStore) Get(id string, into io.Writer) error { + + var body []byte + + err := s.db.QueryRow(`SELECT body FROM assets WHERE id = ?`, id).Scan(&body) + + if errors.Is(err, sql.ErrNoRows) { + return ErrAssetNotFound + } else if err != nil { + return fmt.Errorf("selecting from assets: %w", err) + } + + if _, err := io.Copy(into, bytes.NewReader(body)); err != nil { + return fmt.Errorf("writing body to io.Writer: %w", err) + } + + return nil +} + +func (s *assetStore) Delete(id string) error { + _, err := s.db.Exec(`DELETE FROM assets WHERE id = ?`, id) + return err +} + +func (s *assetStore) List() ([]string, error) { + + rows, err := s.db.Query(`SELECT id FROM assets ORDER BY id ASC`) + + if err != nil { + return nil, fmt.Errorf("querying: %w", err) + } + + defer rows.Close() + + var ids []string + + for rows.Next() { + + var id string + if err := rows.Scan(&id); err != nil { + return nil, fmt.Errorf("scanning row: %w", err) + } + + ids = append(ids, id) + } + + return ids, nil +} diff --git a/src/post/asset_test.go b/src/post/asset_test.go new file mode 100644 index 0000000..4d62d46 --- /dev/null +++ b/src/post/asset_test.go @@ -0,0 +1,91 @@ +package post + +import ( + "bytes" + "io" + "testing" + + "github.com/stretchr/testify/assert" +) + +type assetTestHarness struct { + store AssetStore +} + +func newAssetTestHarness(t *testing.T) *assetTestHarness { + + db := NewInMemSQLDB() + t.Cleanup(func() { db.Close() }) + + store := NewAssetStore(db) + + return &assetTestHarness{ + store: store, + } +} + +func (h *assetTestHarness) assertGet(t *testing.T, exp, id string) { + t.Helper() + buf := new(bytes.Buffer) + err := h.store.Get(id, buf) + assert.NoError(t, err) + assert.Equal(t, exp, buf.String()) +} + +func (h *assetTestHarness) assertNotFound(t *testing.T, id string) { + t.Helper() + err := h.store.Get(id, io.Discard) + assert.ErrorIs(t, ErrAssetNotFound, err) +} + +func TestAssetStore(t *testing.T) { + + testAssetStore := func(t *testing.T, h *assetTestHarness) { + t.Helper() + + h.assertNotFound(t, "foo") + h.assertNotFound(t, "bar") + + err := h.store.Set("foo", bytes.NewBufferString("FOO")) + assert.NoError(t, err) + + h.assertGet(t, "FOO", "foo") + h.assertNotFound(t, "bar") + + err = h.store.Set("foo", bytes.NewBufferString("FOOFOO")) + assert.NoError(t, err) + + h.assertGet(t, "FOOFOO", "foo") + h.assertNotFound(t, "bar") + + assert.NoError(t, h.store.Delete("foo")) + h.assertNotFound(t, "foo") + h.assertNotFound(t, "bar") + + assert.NoError(t, h.store.Delete("bar")) + h.assertNotFound(t, "foo") + h.assertNotFound(t, "bar") + + // test list + + ids, err := h.store.List() + assert.NoError(t, err) + assert.Empty(t, ids) + + err = h.store.Set("foo", bytes.NewBufferString("FOOFOO")) + assert.NoError(t, err) + err = h.store.Set("foo", bytes.NewBufferString("FOOFOO")) + assert.NoError(t, err) + err = h.store.Set("bar", bytes.NewBufferString("FOOFOO")) + assert.NoError(t, err) + + ids, err = h.store.List() + assert.NoError(t, err) + assert.Equal(t, []string{"bar", "foo"}, ids) + } + + t.Run("sql", func(t *testing.T) { + h := newAssetTestHarness(t) + testAssetStore(t, h) + }) +} diff --git a/src/post/draft_post.go b/src/post/draft_post.go new file mode 100644 index 0000000..61283c3 --- /dev/null +++ b/src/post/draft_post.go @@ -0,0 +1,187 @@ +package post + +import ( + "database/sql" + "encoding/json" + "errors" + "fmt" +) + +type DraftStore interface { + + // Set sets the draft Post's data into the storage, keyed by the draft + // Post's ID. + Set(post Post) error + + // Get returns count draft Posts, sorted id descending, offset by the + // given page number. The returned boolean indicates if there are more pages + // or not. + Get(page, count int) ([]Post, bool, error) + + // GetByID will return the draft Post with the given ID, or ErrPostNotFound. + GetByID(id string) (Post, error) + + // Delete will delete the draft Post with the given ID. + Delete(id string) error +} + +type draftStore struct { + db *SQLDB +} + +// NewDraftStore initializes a new DraftStore using an existing SQLDB. +func NewDraftStore(db *SQLDB) DraftStore { + return &draftStore{ + db: db, + } +} + +func (s *draftStore) Set(post Post) error { + + if post.ID == "" { + return errors.New("post ID can't be empty") + } + + tagsJSON, err := json.Marshal(post.Tags) + if err != nil { + return fmt.Errorf("json marshaling tags %#v: %w", post.Tags, err) + } + + _, err = s.db.db.Exec( + `INSERT INTO post_drafts ( + id, title, description, tags, series, body + ) + VALUES + (?, ?, ?, ?, ?, ?) + ON CONFLICT (id) DO UPDATE SET + title=excluded.title, + description=excluded.description, + tags=excluded.tags, + series=excluded.series, + body=excluded.body`, + post.ID, + post.Title, + post.Description, + &sql.NullString{String: string(tagsJSON), Valid: len(post.Tags) > 0}, + &sql.NullString{String: post.Series, Valid: post.Series != ""}, + post.Body, + ) + + if err != nil { + return fmt.Errorf("inserting into post_drafts: %w", err) + } + + return nil +} + +func (s *draftStore) get( + querier interface { + Query(string, ...interface{}) (*sql.Rows, error) + }, + limit, offset int, + where string, whereArgs ...interface{}, +) ( + []Post, error, +) { + + query := ` + SELECT + p.id, p.title, p.description, p.tags, p.series, p.body + FROM post_drafts p + ` + where + ` + ORDER BY p.id ASC` + + if limit > 0 { + query += fmt.Sprintf(" LIMIT %d", limit) + } + + if offset > 0 { + query += fmt.Sprintf(" OFFSET %d", offset) + } + + rows, err := querier.Query(query, whereArgs...) + + if err != nil { + return nil, fmt.Errorf("selecting: %w", err) + } + + defer rows.Close() + + var posts []Post + + for rows.Next() { + + var ( + post Post + tags, series sql.NullString + ) + + err := rows.Scan( + &post.ID, &post.Title, &post.Description, &tags, &series, + &post.Body, + ) + + if err != nil { + return nil, fmt.Errorf("scanning row: %w", err) + } + + post.Series = series.String + + if tags.String != "" { + + if err := json.Unmarshal([]byte(tags.String), &post.Tags); err != nil { + return nil, fmt.Errorf("json parsing %q: %w", tags.String, err) + } + } + + posts = append(posts, post) + } + + return posts, nil +} + +func (s *draftStore) Get(page, count int) ([]Post, bool, error) { + + posts, err := s.get(s.db.db, count+1, page*count, ``) + + if err != nil { + return nil, false, fmt.Errorf("querying post_drafts: %w", err) + } + + var hasMore bool + + if len(posts) > count { + hasMore = true + posts = posts[:count] + } + + return posts, hasMore, nil +} + +func (s *draftStore) GetByID(id string) (Post, error) { + + posts, err := s.get(s.db.db, 0, 0, `WHERE p.id=?`, id) + + if err != nil { + return Post{}, fmt.Errorf("querying post_drafts: %w", err) + } + + if len(posts) == 0 { + return Post{}, ErrPostNotFound + } + + if len(posts) > 1 { + panic(fmt.Sprintf("got back multiple draft posts querying id %q: %+v", id, posts)) + } + + return posts[0], nil +} + +func (s *draftStore) Delete(id string) error { + + if _, err := s.db.db.Exec(`DELETE FROM post_drafts WHERE id = ?`, id); err != nil { + return fmt.Errorf("deleting from post_drafts: %w", err) + } + + return nil +} diff --git a/src/post/draft_post_test.go b/src/post/draft_post_test.go new file mode 100644 index 0000000..f404bb0 --- /dev/null +++ b/src/post/draft_post_test.go @@ -0,0 +1,130 @@ +package post + +import ( + "sort" + "testing" + + "github.com/stretchr/testify/assert" +) + +type draftStoreTestHarness struct { + store DraftStore +} + +func newDraftStoreTestHarness(t *testing.T) draftStoreTestHarness { + + db := NewInMemSQLDB() + t.Cleanup(func() { db.Close() }) + + store := NewDraftStore(db) + + return draftStoreTestHarness{ + store: store, + } +} + +func TestDraftStore(t *testing.T) { + + assertPostEqual := func(t *testing.T, exp, got Post) { + t.Helper() + sort.Strings(exp.Tags) + sort.Strings(got.Tags) + assert.Equal(t, exp, got) + } + + assertPostsEqual := func(t *testing.T, exp, got []Post) { + t.Helper() + + if !assert.Len(t, got, len(exp), "exp:%+v\ngot: %+v", exp, got) { + return + } + + for i := range exp { + assertPostEqual(t, exp[i], got[i]) + } + } + + t.Run("not_found", func(t *testing.T) { + h := newDraftStoreTestHarness(t) + + _, err := h.store.GetByID("foo") + assert.ErrorIs(t, err, ErrPostNotFound) + }) + + t.Run("set_get_delete", func(t *testing.T) { + h := newDraftStoreTestHarness(t) + + post := testPost(0) + post.Tags = []string{"foo", "bar"} + + err := h.store.Set(post) + assert.NoError(t, err) + + gotPost, err := h.store.GetByID(post.ID) + assert.NoError(t, err) + + assertPostEqual(t, post, gotPost) + + // we will now try updating the post, and ensure it updates properly + + post.Title = "something else" + post.Series = "whatever" + post.Body = "anything" + post.Tags = []string{"bar", "baz"} + + err = h.store.Set(post) + assert.NoError(t, err) + + gotPost, err = h.store.GetByID(post.ID) + assert.NoError(t, err) + + assertPostEqual(t, post, gotPost) + + // delete the post, it should go away + assert.NoError(t, h.store.Delete(post.ID)) + + _, err = h.store.GetByID(post.ID) + assert.ErrorIs(t, err, ErrPostNotFound) + }) + + t.Run("get", func(t *testing.T) { + h := newDraftStoreTestHarness(t) + + posts := []Post{ + testPost(0), + testPost(1), + testPost(2), + testPost(3), + } + + for _, post := range posts { + err := h.store.Set(post) + assert.NoError(t, err) + } + + gotPosts, hasMore, err := h.store.Get(0, 2) + assert.NoError(t, err) + assert.True(t, hasMore) + assertPostsEqual(t, posts[:2], gotPosts) + + gotPosts, hasMore, err = h.store.Get(1, 2) + assert.NoError(t, err) + assert.False(t, hasMore) + assertPostsEqual(t, posts[2:4], gotPosts) + + posts = append(posts, testPost(4)) + err = h.store.Set(posts[4]) + assert.NoError(t, err) + + gotPosts, hasMore, err = h.store.Get(1, 2) + assert.NoError(t, err) + assert.True(t, hasMore) + assertPostsEqual(t, posts[2:4], gotPosts) + + gotPosts, hasMore, err = h.store.Get(2, 2) + assert.NoError(t, err) + assert.False(t, hasMore) + assertPostsEqual(t, posts[4:], gotPosts) + }) + +} diff --git a/src/post/post.go b/src/post/post.go new file mode 100644 index 0000000..03bce6c --- /dev/null +++ b/src/post/post.go @@ -0,0 +1,368 @@ +// Package post deals with the storage and rendering of blog posts. +package post + +import ( + "database/sql" + "errors" + "fmt" + "regexp" + "strings" + "time" +) + +var ( + // ErrPostNotFound is used to indicate a Post could not be found in the + // Store. + ErrPostNotFound = errors.New("post not found") +) + +var titleCleanRegexp = regexp.MustCompile(`[^a-z ]`) + +// NewID generates a (hopefully) unique ID based on the given title. +func NewID(title string) string { + title = strings.ToLower(title) + title = titleCleanRegexp.ReplaceAllString(title, "") + title = strings.ReplaceAll(title, " ", "-") + return title +} + +// Post contains all information having to do with a blog post. +type Post struct { + ID string + Title string + Description string + Tags []string // only alphanumeric supported + Series string + Body string +} + +// StoredPost is a Post which has been stored in a Store, and has been given +// some extra fields as a result. +type StoredPost struct { + Post + + PublishedAt time.Time + LastUpdatedAt time.Time +} + +// Store is used for storing posts to a persistent storage. +type Store interface { + + // Set sets the Post data into the storage, keyed by the Post's ID. If there + // was not a previously existing Post with the same ID then Set returns + // true. It overwrites the previous Post with the same ID otherwise. + Set(post Post, now time.Time) (bool, error) + + // Get returns count StoredPosts, sorted time descending, offset by the + // given page number. The returned boolean indicates if there are more pages + // or not. + Get(page, count int) ([]StoredPost, bool, error) + + // GetByID will return the StoredPost with the given ID, or ErrPostNotFound. + GetByID(id string) (StoredPost, error) + + // GetBySeries returns all StoredPosts with the given series, sorted time + // descending, or empty slice. + GetBySeries(series string) ([]StoredPost, error) + + // GetByTag returns all StoredPosts with the given tag, sorted time + // descending, or empty slice. + GetByTag(tag string) ([]StoredPost, error) + + // GetTags returns all tags which have at least one Post using them. + GetTags() ([]string, error) + + // Delete will delete the StoredPost with the given ID. + Delete(id string) error +} + +type store struct { + db *SQLDB +} + +// NewStore initializes a new Store using an existing SQLDB. +func NewStore(db *SQLDB) Store { + return &store{ + db: db, + } +} + +func (s *store) Set(post Post, now time.Time) (bool, error) { + + if post.ID == "" { + return false, errors.New("post ID can't be empty") + } + + var first bool + + err := s.db.withTx(func(tx *sql.Tx) error { + + nowTS := now.Unix() + + nowSQL := sql.NullInt64{Int64: nowTS, Valid: !now.IsZero()} + + _, err := tx.Exec( + `INSERT INTO posts ( + id, title, description, series, published_at, body + ) + VALUES + (?, ?, ?, ?, ?, ?) + ON CONFLICT (id) DO UPDATE SET + title=excluded.title, + description=excluded.description, + series=excluded.series, + last_updated_at=?, + body=excluded.body`, + post.ID, + post.Title, + post.Description, + &sql.NullString{String: post.Series, Valid: post.Series != ""}, + nowSQL, + post.Body, + nowSQL, + ) + + if err != nil { + return fmt.Errorf("inserting into posts: %w", err) + } + + // this is a bit of a hack, but it allows us to update the tagset without + // doing a diff. + _, err = tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, post.ID) + + if err != nil { + return fmt.Errorf("clearning post tags: %w", err) + } + + for _, tag := range post.Tags { + + _, err = tx.Exec( + `INSERT INTO post_tags (post_id, tag) VALUES (?, ?) + ON CONFLICT DO NOTHING`, + post.ID, + tag, + ) + + if err != nil { + return fmt.Errorf("inserting tag %q: %w", tag, err) + } + } + + err = tx.QueryRow( + `SELECT 1 FROM posts WHERE id=? AND last_updated_at IS NULL`, + post.ID, + ).Scan(new(int)) + + first = !errors.Is(err, sql.ErrNoRows) + + return nil + }) + + return first, err +} + +func (s *store) get( + querier interface { + Query(string, ...interface{}) (*sql.Rows, error) + }, + limit, offset int, + where string, whereArgs ...interface{}, +) ( + []StoredPost, error, +) { + + query := ` + SELECT + p.id, p.title, p.description, p.series, GROUP_CONCAT(pt.tag), + p.published_at, p.last_updated_at, + p.body + FROM posts p + LEFT JOIN post_tags pt ON (p.id = pt.post_id) + ` + where + ` + GROUP BY (p.id) + ORDER BY p.published_at DESC, p.title DESC` + + if limit > 0 { + query += fmt.Sprintf(" LIMIT %d", limit) + } + + if offset > 0 { + query += fmt.Sprintf(" OFFSET %d", offset) + } + + rows, err := querier.Query(query, whereArgs...) + + if err != nil { + return nil, fmt.Errorf("selecting: %w", err) + } + + var posts []StoredPost + + for rows.Next() { + + var ( + post StoredPost + series, tag sql.NullString + publishedAt, lastUpdatedAt sql.NullInt64 + ) + + err := rows.Scan( + &post.ID, &post.Title, &post.Description, &series, &tag, + &publishedAt, &lastUpdatedAt, + &post.Body, + ) + + if err != nil { + return nil, fmt.Errorf("scanning row: %w", err) + } + + post.Series = series.String + + if tag.String != "" { + post.Tags = strings.Split(tag.String, ",") + } + + if publishedAt.Valid { + post.PublishedAt = time.Unix(publishedAt.Int64, 0).UTC() + } + + if lastUpdatedAt.Valid { + post.LastUpdatedAt = time.Unix(lastUpdatedAt.Int64, 0).UTC() + } + + posts = append(posts, post) + } + + if err := rows.Close(); err != nil { + return nil, fmt.Errorf("closing row iterator: %w", err) + } + + return posts, nil +} + +func (s *store) Get(page, count int) ([]StoredPost, bool, error) { + + posts, err := s.get(s.db.db, count+1, page*count, ``) + + if err != nil { + return nil, false, fmt.Errorf("querying posts: %w", err) + } + + var hasMore bool + + if len(posts) > count { + hasMore = true + posts = posts[:count] + } + + return posts, hasMore, nil +} + +func (s *store) GetByID(id string) (StoredPost, error) { + + posts, err := s.get(s.db.db, 0, 0, `WHERE p.id=?`, id) + + if err != nil { + return StoredPost{}, fmt.Errorf("querying posts: %w", err) + } + + if len(posts) == 0 { + return StoredPost{}, ErrPostNotFound + } + + if len(posts) > 1 { + panic(fmt.Sprintf("got back multiple posts querying id %q: %+v", id, posts)) + } + + return posts[0], nil +} + +func (s *store) GetBySeries(series string) ([]StoredPost, error) { + return s.get(s.db.db, 0, 0, `WHERE p.series=?`, series) +} + +func (s *store) GetByTag(tag string) ([]StoredPost, error) { + + var posts []StoredPost + + err := s.db.withTx(func(tx *sql.Tx) error { + + rows, err := tx.Query(`SELECT post_id FROM post_tags WHERE tag = ?`, tag) + + if err != nil { + return fmt.Errorf("querying post_tags by tag: %w", err) + } + + var ( + placeholders []string + whereArgs []interface{} + ) + + for rows.Next() { + + var id string + + if err := rows.Scan(&id); err != nil { + rows.Close() + return fmt.Errorf("scanning id: %w", err) + } + + whereArgs = append(whereArgs, id) + placeholders = append(placeholders, "?") + } + + if err := rows.Close(); err != nil { + return fmt.Errorf("closing row iterator: %w", err) + } + + where := fmt.Sprintf("WHERE p.id IN (%s)", strings.Join(placeholders, ",")) + + if posts, err = s.get(tx, 0, 0, where, whereArgs...); err != nil { + return fmt.Errorf("querying for ids %+v: %w", whereArgs, err) + } + + return nil + }) + + return posts, err +} + +func (s *store) GetTags() ([]string, error) { + + rows, err := s.db.db.Query(`SELECT tag FROM post_tags GROUP BY tag`) + if err != nil { + return nil, fmt.Errorf("querying all tags: %w", err) + } + defer rows.Close() + + var tags []string + + for rows.Next() { + + var tag string + + if err := rows.Scan(&tag); err != nil { + return nil, fmt.Errorf("scanning tag: %w", err) + } + + tags = append(tags, tag) + } + + return tags, nil +} + +func (s *store) Delete(id string) error { + + return s.db.withTx(func(tx *sql.Tx) error { + + if _, err := tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, id); err != nil { + return fmt.Errorf("deleting from post_tags: %w", err) + } + + if _, err := tx.Exec(`DELETE FROM posts WHERE id = ?`, id); err != nil { + return fmt.Errorf("deleting from posts: %w", err) + } + + return nil + }) +} diff --git a/src/post/post_test.go b/src/post/post_test.go new file mode 100644 index 0000000..c7f9cdc --- /dev/null +++ b/src/post/post_test.go @@ -0,0 +1,268 @@ +package post + +import ( + "sort" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/tilinna/clock" +) + +func TestNewID(t *testing.T) { + + tests := [][2]string{ + { + "Why Do We Have WiFi Passwords?", + "why-do-we-have-wifi-passwords", + }, + { + "Ginger: A Small VM Update", + "ginger-a-small-vm-update", + }, + { + "Something-Weird.... woah!", + "somethingweird-woah", + }, + } + + for i, test := range tests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + assert.Equal(t, test[1], NewID(test[0])) + }) + } +} + +func testPost(i int) Post { + istr := strconv.Itoa(i) + return Post{ + ID: istr, + Title: istr, + Description: istr, + Body: istr, + } +} + +type storeTestHarness struct { + clock *clock.Mock + store Store +} + +func newStoreTestHarness(t *testing.T) storeTestHarness { + + clock := clock.NewMock(time.Now().UTC().Truncate(1 * time.Hour)) + + db := NewInMemSQLDB() + t.Cleanup(func() { db.Close() }) + + store := NewStore(db) + + return storeTestHarness{ + clock: clock, + store: store, + } +} + +func (h *storeTestHarness) testStoredPost(i int) StoredPost { + post := testPost(i) + return StoredPost{ + Post: post, + PublishedAt: h.clock.Now(), + } +} + +func TestStore(t *testing.T) { + + assertPostEqual := func(t *testing.T, exp, got StoredPost) { + t.Helper() + sort.Strings(exp.Tags) + sort.Strings(got.Tags) + assert.Equal(t, exp, got) + } + + assertPostsEqual := func(t *testing.T, exp, got []StoredPost) { + t.Helper() + + if !assert.Len(t, got, len(exp), "exp:%+v\ngot: %+v", exp, got) { + return + } + + for i := range exp { + assertPostEqual(t, exp[i], got[i]) + } + } + + t.Run("not_found", func(t *testing.T) { + h := newStoreTestHarness(t) + + _, err := h.store.GetByID("foo") + assert.ErrorIs(t, err, ErrPostNotFound) + }) + + t.Run("set_get_delete", func(t *testing.T) { + h := newStoreTestHarness(t) + + now := h.clock.Now().UTC() + + post := testPost(0) + post.Tags = []string{"foo", "bar"} + + first, err := h.store.Set(post, now) + assert.NoError(t, err) + assert.True(t, first) + + gotPost, err := h.store.GetByID(post.ID) + assert.NoError(t, err) + + assertPostEqual(t, StoredPost{ + Post: post, + PublishedAt: now, + }, gotPost) + + // we will now try updating the post on a different day, and ensure it + // updates properly + + h.clock.Add(24 * time.Hour) + newNow := h.clock.Now().UTC() + + post.Title = "something else" + post.Series = "whatever" + post.Body = "anything" + post.Tags = []string{"bar", "baz"} + + first, err = h.store.Set(post, newNow) + assert.NoError(t, err) + assert.False(t, first) + + gotPost, err = h.store.GetByID(post.ID) + assert.NoError(t, err) + + assertPostEqual(t, StoredPost{ + Post: post, + PublishedAt: now, + LastUpdatedAt: newNow, + }, gotPost) + + // delete the post, it should go away + assert.NoError(t, h.store.Delete(post.ID)) + + _, err = h.store.GetByID(post.ID) + assert.ErrorIs(t, err, ErrPostNotFound) + }) + + t.Run("get", func(t *testing.T) { + h := newStoreTestHarness(t) + + now := h.clock.Now().UTC() + + posts := []StoredPost{ + h.testStoredPost(3), + h.testStoredPost(2), + h.testStoredPost(1), + h.testStoredPost(0), + } + + for _, post := range posts { + _, err := h.store.Set(post.Post, now) + assert.NoError(t, err) + } + + gotPosts, hasMore, err := h.store.Get(0, 2) + assert.NoError(t, err) + assert.True(t, hasMore) + assertPostsEqual(t, posts[:2], gotPosts) + + gotPosts, hasMore, err = h.store.Get(1, 2) + assert.NoError(t, err) + assert.False(t, hasMore) + assertPostsEqual(t, posts[2:4], gotPosts) + + posts = append([]StoredPost{h.testStoredPost(4)}, posts...) + _, err = h.store.Set(posts[0].Post, now) + assert.NoError(t, err) + + gotPosts, hasMore, err = h.store.Get(1, 2) + assert.NoError(t, err) + assert.True(t, hasMore) + assertPostsEqual(t, posts[2:4], gotPosts) + + gotPosts, hasMore, err = h.store.Get(2, 2) + assert.NoError(t, err) + assert.False(t, hasMore) + assertPostsEqual(t, posts[4:], gotPosts) + }) + + t.Run("get_by_series", func(t *testing.T) { + h := newStoreTestHarness(t) + + now := h.clock.Now().UTC() + + posts := []StoredPost{ + h.testStoredPost(3), + h.testStoredPost(2), + h.testStoredPost(1), + h.testStoredPost(0), + } + + posts[0].Series = "foo" + posts[1].Series = "bar" + posts[2].Series = "bar" + + for _, post := range posts { + _, err := h.store.Set(post.Post, now) + assert.NoError(t, err) + } + + fooPosts, err := h.store.GetBySeries("foo") + assert.NoError(t, err) + assertPostsEqual(t, posts[:1], fooPosts) + + barPosts, err := h.store.GetBySeries("bar") + assert.NoError(t, err) + assertPostsEqual(t, posts[1:3], barPosts) + + bazPosts, err := h.store.GetBySeries("baz") + assert.NoError(t, err) + assert.Empty(t, bazPosts) + }) + + t.Run("get_by_tag", func(t *testing.T) { + + h := newStoreTestHarness(t) + + now := h.clock.Now().UTC() + + posts := []StoredPost{ + h.testStoredPost(3), + h.testStoredPost(2), + h.testStoredPost(1), + h.testStoredPost(0), + } + + posts[0].Tags = []string{"foo"} + posts[1].Tags = []string{"foo", "bar"} + posts[2].Tags = []string{"bar"} + + for _, post := range posts { + _, err := h.store.Set(post.Post, now) + assert.NoError(t, err) + } + + fooPosts, err := h.store.GetByTag("foo") + assert.NoError(t, err) + assertPostsEqual(t, posts[:2], fooPosts) + + barPosts, err := h.store.GetByTag("bar") + assert.NoError(t, err) + assertPostsEqual(t, posts[1:3], barPosts) + + bazPosts, err := h.store.GetByTag("baz") + assert.NoError(t, err) + assert.Empty(t, bazPosts) + + tags, err := h.store.GetTags() + assert.NoError(t, err) + assert.ElementsMatch(t, []string{"foo", "bar"}, tags) + }) +} diff --git a/src/post/sql.go b/src/post/sql.go new file mode 100644 index 0000000..c768c9a --- /dev/null +++ b/src/post/sql.go @@ -0,0 +1,126 @@ +package post + +import ( + "database/sql" + "fmt" + "path" + + "github.com/mediocregopher/blog.mediocregopher.com/srv/cfg" + migrate "github.com/rubenv/sql-migrate" + + _ "github.com/mattn/go-sqlite3" // we need dis +) + +var migrations = &migrate.MemoryMigrationSource{Migrations: []*migrate.Migration{ + { + Id: "1", + Up: []string{ + `CREATE TABLE posts ( + id TEXT NOT NULL PRIMARY KEY, + title TEXT NOT NULL, + description TEXT NOT NULL, + series TEXT, + + published_at INTEGER NOT NULL, + last_updated_at INTEGER, + + body TEXT NOT NULL + )`, + + `CREATE TABLE post_tags ( + post_id TEXT NOT NULL, + tag TEXT NOT NULL, + UNIQUE(post_id, tag) + )`, + + `CREATE TABLE assets ( + id TEXT NOT NULL PRIMARY KEY, + body BLOB NOT NULL + )`, + }, + }, + { + Id: "2", + Up: []string{ + `CREATE TABLE post_drafts ( + id TEXT NOT NULL PRIMARY KEY, + title TEXT NOT NULL, + description TEXT NOT NULL, + tags TEXT, + series TEXT, + body TEXT NOT NULL + )`, + }, + }, +}} + +// SQLDB is a sqlite3 database which can be used by storage interfaces within +// this package. +type SQLDB struct { + db *sql.DB +} + +// NewSQLDB initializes and returns a new sqlite3 database for storage +// intefaces. The db will be created within the given data directory. +func NewSQLDB(dataDir cfg.DataDir) (*SQLDB, error) { + + path := path.Join(dataDir.Path, "post.sqlite3") + + db, err := sql.Open("sqlite3", path) + if err != nil { + return nil, fmt.Errorf("opening sqlite file at %q: %w", path, err) + } + + if _, err := migrate.Exec(db, "sqlite3", migrations, migrate.Up); err != nil { + return nil, fmt.Errorf("running migrations: %w", err) + } + + return &SQLDB{db}, nil +} + +// NewSQLDB is like NewSQLDB, but the database will be initialized in memory. +func NewInMemSQLDB() *SQLDB { + + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + panic(fmt.Errorf("opening sqlite in memory: %w", err)) + } + + if _, err := migrate.Exec(db, "sqlite3", migrations, migrate.Up); err != nil { + panic(fmt.Errorf("running migrations: %w", err)) + } + + return &SQLDB{db} +} + +// Close cleans up loose resources being held by the db. +func (db *SQLDB) Close() error { + return db.db.Close() +} + +func (db *SQLDB) withTx(cb func(*sql.Tx) error) error { + + tx, err := db.db.Begin() + + if err != nil { + return fmt.Errorf("starting transaction: %w", err) + } + + if err := cb(tx); err != nil { + + if rollbackErr := tx.Rollback(); rollbackErr != nil { + return fmt.Errorf( + "rolling back transaction: %w (original error: %v)", + rollbackErr, err, + ) + } + + return fmt.Errorf("performing transaction: %w (rolled back)", err) + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("committing transaction: %w", err) + } + + return nil +} -- cgit v1.2.3