summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBrian Picciano <mediocregopher@gmail.com>2022-08-18 22:25:17 -0600
committerBrian Picciano <mediocregopher@gmail.com>2022-08-18 22:25:17 -0600
commit7ac2f5ebb32a6098bc0d590130cc4b933db08771 (patch)
tree95d43d28bdd964fd461f8a2d3b72898ef5210d2b
parent76ff79f47035c11c1e0dfcd283034b738b5c3f0d (diff)
implement DraftStore
-rw-r--r--srv/src/post/draft_post.go186
-rw-r--r--srv/src/post/draft_post_test.go130
-rw-r--r--srv/src/post/post.go69
-rw-r--r--srv/src/post/sql.go43
4 files changed, 372 insertions, 56 deletions
diff --git a/srv/src/post/draft_post.go b/srv/src/post/draft_post.go
new file mode 100644
index 0000000..af52965
--- /dev/null
+++ b/srv/src/post/draft_post.go
@@ -0,0 +1,186 @@
+package post
+
+import (
+ "database/sql"
+ "encoding/json"
+ "errors"
+ "fmt"
+)
+
+type DraftStore interface {
+
+ // Set sets the draft Post's data into the storage, keyed by the draft
+ // Post's ID.
+ Set(post Post) error
+
+ // Get returns count draft Posts, sorted id descending, offset by the
+ // given page number. The returned boolean indicates if there are more pages
+ // or not.
+ Get(page, count int) ([]Post, bool, error)
+
+ // GetByID will return the draft Post with the given ID, or ErrPostNotFound.
+ GetByID(id string) (Post, error)
+
+ // Delete will delete the draft Post with the given ID.
+ Delete(id string) error
+}
+
+type draftStore struct {
+ db *SQLDB
+}
+
+// NewDraftStore initializes a new DraftStore using an existing SQLDB.
+func NewDraftStore(db *SQLDB) DraftStore {
+ return &draftStore{
+ db: db,
+ }
+}
+
+func (s *draftStore) Set(post Post) error {
+
+ if post.ID == "" {
+ return errors.New("post ID can't be empty")
+ }
+
+ tagsJSON, err := json.Marshal(post.Tags)
+ if err != nil {
+ return fmt.Errorf("json marshaling tags %#v: %w", post.Tags, err)
+ }
+
+ _, err = s.db.db.Exec(
+ `INSERT INTO post_drafts (
+ id, title, description, tags, series, body
+ )
+ VALUES
+ (?, ?, ?, ?, ?, ?)
+ ON CONFLICT (id) DO UPDATE SET
+ title=excluded.title,
+ description=excluded.description,
+ tags=excluded.tags,
+ series=excluded.series,
+ body=excluded.body`,
+ post.ID,
+ post.Title,
+ post.Description,
+ &sql.NullString{String: string(tagsJSON), Valid: len(post.Tags) > 0},
+ &sql.NullString{String: post.Series, Valid: post.Series != ""},
+ post.Body,
+ )
+
+ if err != nil {
+ return fmt.Errorf("inserting into post_drafts: %w", err)
+ }
+
+ return nil
+}
+
+func (s *draftStore) get(
+ querier interface {
+ Query(string, ...interface{}) (*sql.Rows, error)
+ },
+ limit, offset int,
+ where string, whereArgs ...interface{},
+) (
+ []Post, error,
+) {
+
+ query := `
+ SELECT
+ p.id, p.title, p.description, p.tags, p.series, p.body
+ FROM post_drafts p
+ ORDER BY p.id ASC`
+
+ if limit > 0 {
+ query += fmt.Sprintf(" LIMIT %d", limit)
+ }
+
+ if offset > 0 {
+ query += fmt.Sprintf(" OFFSET %d", offset)
+ }
+
+ rows, err := querier.Query(query, whereArgs...)
+
+ if err != nil {
+ return nil, fmt.Errorf("selecting: %w", err)
+ }
+
+ defer rows.Close()
+
+ var posts []Post
+
+ for rows.Next() {
+
+ var (
+ post Post
+ tags, series sql.NullString
+ )
+
+ err := rows.Scan(
+ &post.ID, &post.Title, &post.Description, &tags, &series,
+ &post.Body,
+ )
+
+ if err != nil {
+ return nil, fmt.Errorf("scanning row: %w", err)
+ }
+
+ post.Series = series.String
+
+ if tags.String != "" {
+
+ if err := json.Unmarshal([]byte(tags.String), &post.Tags); err != nil {
+ return nil, fmt.Errorf("json parsing %q: %w", tags.String, err)
+ }
+ }
+
+ posts = append(posts, post)
+ }
+
+ return posts, nil
+}
+
+func (s *draftStore) Get(page, count int) ([]Post, bool, error) {
+
+ posts, err := s.get(s.db.db, count+1, page*count, ``)
+
+ if err != nil {
+ return nil, false, fmt.Errorf("querying post_drafts: %w", err)
+ }
+
+ var hasMore bool
+
+ if len(posts) > count {
+ hasMore = true
+ posts = posts[:count]
+ }
+
+ return posts, hasMore, nil
+}
+
+func (s *draftStore) GetByID(id string) (Post, error) {
+
+ posts, err := s.get(s.db.db, 0, 0, `WHERE p.id=?`, id)
+
+ if err != nil {
+ return Post{}, fmt.Errorf("querying post_drafts: %w", err)
+ }
+
+ if len(posts) == 0 {
+ return Post{}, ErrPostNotFound
+ }
+
+ if len(posts) > 1 {
+ panic(fmt.Sprintf("got back multiple draft posts querying id %q: %+v", id, posts))
+ }
+
+ return posts[0], nil
+}
+
+func (s *draftStore) Delete(id string) error {
+
+ if _, err := s.db.db.Exec(`DELETE FROM post_drafts WHERE id = ?`, id); err != nil {
+ return fmt.Errorf("deleting from post_drafts: %w", err)
+ }
+
+ return nil
+}
diff --git a/srv/src/post/draft_post_test.go b/srv/src/post/draft_post_test.go
new file mode 100644
index 0000000..f404bb0
--- /dev/null
+++ b/srv/src/post/draft_post_test.go
@@ -0,0 +1,130 @@
+package post
+
+import (
+ "sort"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+type draftStoreTestHarness struct {
+ store DraftStore
+}
+
+func newDraftStoreTestHarness(t *testing.T) draftStoreTestHarness {
+
+ db := NewInMemSQLDB()
+ t.Cleanup(func() { db.Close() })
+
+ store := NewDraftStore(db)
+
+ return draftStoreTestHarness{
+ store: store,
+ }
+}
+
+func TestDraftStore(t *testing.T) {
+
+ assertPostEqual := func(t *testing.T, exp, got Post) {
+ t.Helper()
+ sort.Strings(exp.Tags)
+ sort.Strings(got.Tags)
+ assert.Equal(t, exp, got)
+ }
+
+ assertPostsEqual := func(t *testing.T, exp, got []Post) {
+ t.Helper()
+
+ if !assert.Len(t, got, len(exp), "exp:%+v\ngot: %+v", exp, got) {
+ return
+ }
+
+ for i := range exp {
+ assertPostEqual(t, exp[i], got[i])
+ }
+ }
+
+ t.Run("not_found", func(t *testing.T) {
+ h := newDraftStoreTestHarness(t)
+
+ _, err := h.store.GetByID("foo")
+ assert.ErrorIs(t, err, ErrPostNotFound)
+ })
+
+ t.Run("set_get_delete", func(t *testing.T) {
+ h := newDraftStoreTestHarness(t)
+
+ post := testPost(0)
+ post.Tags = []string{"foo", "bar"}
+
+ err := h.store.Set(post)
+ assert.NoError(t, err)
+
+ gotPost, err := h.store.GetByID(post.ID)
+ assert.NoError(t, err)
+
+ assertPostEqual(t, post, gotPost)
+
+ // we will now try updating the post, and ensure it updates properly
+
+ post.Title = "something else"
+ post.Series = "whatever"
+ post.Body = "anything"
+ post.Tags = []string{"bar", "baz"}
+
+ err = h.store.Set(post)
+ assert.NoError(t, err)
+
+ gotPost, err = h.store.GetByID(post.ID)
+ assert.NoError(t, err)
+
+ assertPostEqual(t, post, gotPost)
+
+ // delete the post, it should go away
+ assert.NoError(t, h.store.Delete(post.ID))
+
+ _, err = h.store.GetByID(post.ID)
+ assert.ErrorIs(t, err, ErrPostNotFound)
+ })
+
+ t.Run("get", func(t *testing.T) {
+ h := newDraftStoreTestHarness(t)
+
+ posts := []Post{
+ testPost(0),
+ testPost(1),
+ testPost(2),
+ testPost(3),
+ }
+
+ for _, post := range posts {
+ err := h.store.Set(post)
+ assert.NoError(t, err)
+ }
+
+ gotPosts, hasMore, err := h.store.Get(0, 2)
+ assert.NoError(t, err)
+ assert.True(t, hasMore)
+ assertPostsEqual(t, posts[:2], gotPosts)
+
+ gotPosts, hasMore, err = h.store.Get(1, 2)
+ assert.NoError(t, err)
+ assert.False(t, hasMore)
+ assertPostsEqual(t, posts[2:4], gotPosts)
+
+ posts = append(posts, testPost(4))
+ err = h.store.Set(posts[4])
+ assert.NoError(t, err)
+
+ gotPosts, hasMore, err = h.store.Get(1, 2)
+ assert.NoError(t, err)
+ assert.True(t, hasMore)
+ assertPostsEqual(t, posts[2:4], gotPosts)
+
+ gotPosts, hasMore, err = h.store.Get(2, 2)
+ assert.NoError(t, err)
+ assert.False(t, hasMore)
+ assertPostsEqual(t, posts[4:], gotPosts)
+ })
+
+}
diff --git a/srv/src/post/post.go b/srv/src/post/post.go
index a39af61..03bce6c 100644
--- a/srv/src/post/post.go
+++ b/srv/src/post/post.go
@@ -77,44 +77,16 @@ type Store interface {
}
type store struct {
- db *sql.DB
+ db *SQLDB
}
// NewStore initializes a new Store using an existing SQLDB.
func NewStore(db *SQLDB) Store {
return &store{
- db: db.db,
+ db: db,
}
}
-// if the callback returns an error then the transaction is aborted.
-func (s *store) withTx(cb func(*sql.Tx) error) error {
-
- tx, err := s.db.Begin()
-
- if err != nil {
- return fmt.Errorf("starting transaction: %w", err)
- }
-
- if err := cb(tx); err != nil {
-
- if rollbackErr := tx.Rollback(); rollbackErr != nil {
- return fmt.Errorf(
- "rolling back transaction: %w (original error: %v)",
- rollbackErr, err,
- )
- }
-
- return fmt.Errorf("performing transaction: %w (rolled back)", err)
- }
-
- if err := tx.Commit(); err != nil {
- return fmt.Errorf("committing transaction: %w", err)
- }
-
- return nil
-}
-
func (s *store) Set(post Post, now time.Time) (bool, error) {
if post.ID == "" {
@@ -123,7 +95,7 @@ func (s *store) Set(post Post, now time.Time) (bool, error) {
var first bool
- err := s.withTx(func(tx *sql.Tx) error {
+ err := s.db.withTx(func(tx *sql.Tx) error {
nowTS := now.Unix()
@@ -270,7 +242,7 @@ func (s *store) get(
func (s *store) Get(page, count int) ([]StoredPost, bool, error) {
- posts, err := s.get(s.db, count+1, page*count, ``)
+ posts, err := s.get(s.db.db, count+1, page*count, ``)
if err != nil {
return nil, false, fmt.Errorf("querying posts: %w", err)
@@ -288,7 +260,7 @@ func (s *store) Get(page, count int) ([]StoredPost, bool, error) {
func (s *store) GetByID(id string) (StoredPost, error) {
- posts, err := s.get(s.db, 0, 0, `WHERE p.id=?`, id)
+ posts, err := s.get(s.db.db, 0, 0, `WHERE p.id=?`, id)
if err != nil {
return StoredPost{}, fmt.Errorf("querying posts: %w", err)
@@ -306,14 +278,14 @@ func (s *store) GetByID(id string) (StoredPost, error) {
}
func (s *store) GetBySeries(series string) ([]StoredPost, error) {
- return s.get(s.db, 0, 0, `WHERE p.series=?`, series)
+ return s.get(s.db.db, 0, 0, `WHERE p.series=?`, series)
}
func (s *store) GetByTag(tag string) ([]StoredPost, error) {
var posts []StoredPost
- err := s.withTx(func(tx *sql.Tx) error {
+ err := s.db.withTx(func(tx *sql.Tx) error {
rows, err := tx.Query(`SELECT post_id FROM post_tags WHERE tag = ?`, tag)
@@ -357,7 +329,7 @@ func (s *store) GetByTag(tag string) ([]StoredPost, error) {
func (s *store) GetTags() ([]string, error) {
- rows, err := s.db.Query(`SELECT tag FROM post_tags GROUP BY tag`)
+ rows, err := s.db.db.Query(`SELECT tag FROM post_tags GROUP BY tag`)
if err != nil {
return nil, fmt.Errorf("querying all tags: %w", err)
}
@@ -381,23 +353,16 @@ func (s *store) GetTags() ([]string, error) {
func (s *store) Delete(id string) error {
- tx, err := s.db.Begin()
-
- if err != nil {
- return fmt.Errorf("starting transaction: %w", err)
- }
-
- if _, err := tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, id); err != nil {
- return fmt.Errorf("deleting from post_tags: %w", err)
- }
+ return s.db.withTx(func(tx *sql.Tx) error {
- if _, err := tx.Exec(`DELETE FROM posts WHERE id = ?`, id); err != nil {
- return fmt.Errorf("deleting from posts: %w", err)
- }
+ if _, err := tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, id); err != nil {
+ return fmt.Errorf("deleting from post_tags: %w", err)
+ }
- if err := tx.Commit(); err != nil {
- return fmt.Errorf("committing transaction: %w", err)
- }
+ if _, err := tx.Exec(`DELETE FROM posts WHERE id = ?`, id); err != nil {
+ return fmt.Errorf("deleting from posts: %w", err)
+ }
- return nil
+ return nil
+ })
}
diff --git a/srv/src/post/sql.go b/srv/src/post/sql.go
index 16cdc95..c768c9a 100644
--- a/srv/src/post/sql.go
+++ b/srv/src/post/sql.go
@@ -38,10 +38,18 @@ var migrations = &migrate.MemoryMigrationSource{Migrations: []*migrate.Migration
body BLOB NOT NULL
)`,
},
- Down: []string{
- "DROP TABLE assets",
- "DROP TABLE post_tags",
- "DROP TABLE posts",
+ },
+ {
+ Id: "2",
+ Up: []string{
+ `CREATE TABLE post_drafts (
+ id TEXT NOT NULL PRIMARY KEY,
+ title TEXT NOT NULL,
+ description TEXT NOT NULL,
+ tags TEXT,
+ series TEXT,
+ body TEXT NOT NULL
+ )`,
},
},
}}
@@ -89,3 +97,30 @@ func NewInMemSQLDB() *SQLDB {
func (db *SQLDB) Close() error {
return db.db.Close()
}
+
+func (db *SQLDB) withTx(cb func(*sql.Tx) error) error {
+
+ tx, err := db.db.Begin()
+
+ if err != nil {
+ return fmt.Errorf("starting transaction: %w", err)
+ }
+
+ if err := cb(tx); err != nil {
+
+ if rollbackErr := tx.Rollback(); rollbackErr != nil {
+ return fmt.Errorf(
+ "rolling back transaction: %w (original error: %v)",
+ rollbackErr, err,
+ )
+ }
+
+ return fmt.Errorf("performing transaction: %w (rolled back)", err)
+ }
+
+ if err := tx.Commit(); err != nil {
+ return fmt.Errorf("committing transaction: %w", err)
+ }
+
+ return nil
+}