summaryrefslogtreecommitdiff
path: root/srv
diff options
context:
space:
mode:
authorBrian Picciano <mediocregopher@gmail.com>2022-05-07 17:58:20 -0600
committerBrian Picciano <mediocregopher@gmail.com>2022-05-07 18:07:13 -0600
commitcfb633b3b5eed1a65b23cc641b1b250adffdbc8f (patch)
tree0f10762026e5fca27f3cbf3612cb3b8104e5d31c /srv
parent07806c694269f6226a0f42c9f2cfb8c7655afad9 (diff)
Cleanup various small issues with post package
Diffstat (limited to 'srv')
-rw-r--r--srv/src/post/date.go59
-rw-r--r--srv/src/post/post.go352
-rw-r--r--srv/src/post/post_test.go224
-rw-r--r--srv/src/post/sql.go34
-rw-r--r--srv/src/post/store.go362
-rw-r--r--srv/src/post/store_test.go245
6 files changed, 603 insertions, 673 deletions
diff --git a/srv/src/post/date.go b/srv/src/post/date.go
deleted file mode 100644
index 34fe109..0000000
--- a/srv/src/post/date.go
+++ /dev/null
@@ -1,59 +0,0 @@
-package post
-
-import (
- "database/sql/driver"
- "fmt"
- "time"
-)
-
-// Date represents a calendar date with no timezone information attached.
-type Date struct {
- Year int
- Month time.Month
- Day int
-}
-
-// DateFromTime converts a Time into a Date, truncating all non-date
-// information.
-func DateFromTime(t time.Time) Date {
- t = t.UTC()
- return Date{
- Year: t.Year(),
- Month: t.Month(),
- Day: t.Day(),
- }
-}
-
-// ToTime converts a Date into a Time. The returned time will be UTC midnight of
-// the Date.
-func (d *Date) ToTime() time.Time {
- return time.Date(d.Year, d.Month, d.Day, 0, 0, 0, 0, time.UTC)
-}
-
-// Scan implements the sql.Scanner interface.
-func (d *Date) Scan(src interface{}) error {
-
- if src == nil {
- *d = Date{}
- return nil
- }
-
- ts, ok := src.(int64)
-
- if !ok {
- return fmt.Errorf("cannot scan value %#v into Date", src)
- }
-
- *d = DateFromTime(time.Unix(ts, 0))
- return nil
-}
-
-// Value implements the driver.Valuer interface.
-func (d Date) Value() (driver.Value, error) {
-
- if d == (Date{}) {
- return nil, nil
- }
-
- return d.ToTime().Unix(), nil
-}
diff --git a/srv/src/post/post.go b/srv/src/post/post.go
index 7803c82..bdc48af 100644
--- a/srv/src/post/post.go
+++ b/srv/src/post/post.go
@@ -1,9 +1,20 @@
-// Package post deals with the storage and rending of blog post.
+// Package post deals with the storage and rendering of blog posts.
package post
import (
+ "database/sql"
+ "errors"
+ "fmt"
+ "path"
"regexp"
"strings"
+ "time"
+)
+
+var (
+ // ErrPostNotFound is used to indicate a Post could not be found in the
+ // Store.
+ ErrPostNotFound = errors.New("post not found")
)
var titleCleanRegexp = regexp.MustCompile(`[^a-z ]`)
@@ -25,3 +36,342 @@ type Post struct {
Series string
Body string
}
+
+// StoredPost is a Post which has been stored in a Store, and has been given
+// some extra fields as a result.
+type StoredPost struct {
+ Post
+
+ PublishedAt time.Time
+ LastUpdatedAt time.Time
+}
+
+// URL returns the relative URL of the StoredPost.
+func (p StoredPost) URL() string {
+ return path.Join(
+ fmt.Sprintf(
+ "%d/%0d/%0d",
+ p.PublishedAt.Year(),
+ p.PublishedAt.Month(),
+ p.PublishedAt.Day(),
+ ),
+ p.ID+".html",
+ )
+}
+
+// Store is used for storing posts to a persistent storage.
+type Store interface {
+
+ // Set sets the Post data into the storage, keyed by the Post's ID. It
+ // overwrites a previous Post with the same ID, if there was one.
+ Set(post Post, now time.Time) error
+
+ // Get returns count StoredPosts, sorted time descending, offset by the given page
+ // number. The returned boolean indicates if there are more pages or not.
+ Get(page, count int) ([]StoredPost, bool, error)
+
+ // GetByID will return the StoredPost with the given ID, or ErrPostNotFound.
+ GetByID(id string) (StoredPost, error)
+
+ // GetBySeries returns all StoredPosts with the given series, sorted time
+ // ascending, or empty slice.
+ GetBySeries(series string) ([]StoredPost, error)
+
+ // GetByTag returns all StoredPosts with the given tag, sorted time
+ // ascending, or empty slice.
+ GetByTag(tag string) ([]StoredPost, error)
+
+ // Delete will delete the StoredPost with the given ID.
+ Delete(id string) error
+}
+
+type store struct {
+ db *sql.DB
+}
+
+// NewStore initializes a new Store using an existing SQLDB.
+func NewStore(db *SQLDB) Store {
+ return &store{
+ db: db.db,
+ }
+}
+
+// if the callback returns an error then the transaction is aborted.
+func (s *store) withTx(cb func(*sql.Tx) error) error {
+
+ tx, err := s.db.Begin()
+
+ if err != nil {
+ return fmt.Errorf("starting transaction: %w", err)
+ }
+
+ if err := cb(tx); err != nil {
+
+ if rollbackErr := tx.Rollback(); rollbackErr != nil {
+ return fmt.Errorf(
+ "rolling back transaction: %w (original error: %v)",
+ rollbackErr, err,
+ )
+ }
+
+ return fmt.Errorf("performing transaction: %w (rolled back)", err)
+ }
+
+ if err := tx.Commit(); err != nil {
+ return fmt.Errorf("committing transaction: %w", err)
+ }
+
+ return nil
+}
+
+func (s *store) Set(post Post, now time.Time) error {
+ return s.withTx(func(tx *sql.Tx) error {
+
+ nowTS := now.Unix()
+
+ nowSQL := sql.NullInt64{Int64: nowTS, Valid: !now.IsZero()}
+
+ _, err := tx.Exec(
+ `INSERT INTO posts (
+ id, title, description, series, published_at, body
+ )
+ VALUES
+ (?, ?, ?, ?, ?, ?)
+ ON CONFLICT (id) DO UPDATE SET
+ title=excluded.title,
+ description=excluded.description,
+ series=excluded.series,
+ last_updated_at=?,
+ body=excluded.body`,
+ post.ID,
+ post.Title,
+ post.Description,
+ &sql.NullString{String: post.Series, Valid: post.Series != ""},
+ nowSQL,
+ post.Body,
+ nowSQL,
+ )
+
+ if err != nil {
+ return fmt.Errorf("inserting into posts: %w", err)
+ }
+
+ // this is a bit of a hack, but it allows us to update the tagset without
+ // doing a diff.
+ _, err = tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, post.ID)
+
+ if err != nil {
+ return fmt.Errorf("clearning post tags: %w", err)
+ }
+
+ for _, tag := range post.Tags {
+
+ _, err = tx.Exec(
+ `INSERT INTO post_tags (post_id, tag) VALUES (?, ?)
+ ON CONFLICT DO NOTHING`,
+ post.ID,
+ tag,
+ )
+
+ if err != nil {
+ return fmt.Errorf("inserting tag %q: %w", tag, err)
+ }
+ }
+
+ return nil
+ })
+}
+
+func (s *store) get(
+ querier interface {
+ Query(string, ...interface{}) (*sql.Rows, error)
+ },
+ limit, offset int,
+ where string, whereArgs ...interface{},
+) (
+ []StoredPost, error,
+) {
+
+ query := `
+ SELECT
+ p.id, p.title, p.description, p.series, pt.tag,
+ p.published_at, p.last_updated_at,
+ p.body
+ FROM posts p
+ LEFT JOIN post_tags pt ON (p.id = pt.post_id)
+ ` + where + `
+ ORDER BY p.published_at ASC, p.title ASC`
+
+ if limit > 0 {
+ query += fmt.Sprintf(" LIMIT %d", limit)
+ }
+
+ if offset > 0 {
+ query += fmt.Sprintf(" OFFSET %d", offset)
+ }
+
+ rows, err := querier.Query(query, whereArgs...)
+
+ if err != nil {
+ return nil, fmt.Errorf("selecting: %w", err)
+ }
+
+ var posts []StoredPost
+
+ for rows.Next() {
+
+ var (
+ post StoredPost
+ series, tag sql.NullString
+ publishedAt, lastUpdatedAt sql.NullInt64
+ )
+
+ err := rows.Scan(
+ &post.ID, &post.Title, &post.Description, &series, &tag,
+ &publishedAt, &lastUpdatedAt,
+ &post.Body,
+ )
+
+ if err != nil {
+ return nil, fmt.Errorf("scanning row: %w", err)
+ }
+
+ if tag.Valid {
+
+ if l := len(posts); l > 0 && posts[l-1].ID == post.ID {
+ posts[l-1].Tags = append(posts[l-1].Tags, tag.String)
+ continue
+ }
+
+ post.Tags = append(post.Tags, tag.String)
+ }
+
+ post.Series = series.String
+
+ if publishedAt.Valid {
+ post.PublishedAt = time.Unix(publishedAt.Int64, 0).UTC()
+ }
+
+ if lastUpdatedAt.Valid {
+ post.LastUpdatedAt = time.Unix(lastUpdatedAt.Int64, 0).UTC()
+ }
+
+ posts = append(posts, post)
+ }
+
+ if err := rows.Close(); err != nil {
+ return nil, fmt.Errorf("closing row iterator: %w", err)
+ }
+
+ return posts, nil
+}
+
+func (s *store) Get(page, count int) ([]StoredPost, bool, error) {
+
+ posts, err := s.get(s.db, count+1, page*count, ``)
+
+ if err != nil {
+ return nil, false, fmt.Errorf("querying posts: %w", err)
+ }
+
+ var hasMore bool
+
+ if len(posts) > count {
+ hasMore = true
+ posts = posts[:count]
+ }
+
+ return posts, hasMore, nil
+}
+
+func (s *store) GetByID(id string) (StoredPost, error) {
+
+ posts, err := s.get(s.db, 0, 0, `WHERE p.id=?`, id)
+
+ if err != nil {
+ return StoredPost{}, fmt.Errorf("querying posts: %w", err)
+ }
+
+ if len(posts) == 0 {
+ return StoredPost{}, ErrPostNotFound
+ }
+
+ if len(posts) > 1 {
+ panic(fmt.Sprintf("got back multiple posts querying id %q: %+v", id, posts))
+ }
+
+ return posts[0], nil
+}
+
+func (s *store) GetBySeries(series string) ([]StoredPost, error) {
+ return s.get(s.db, 0, 0, `WHERE p.series=?`, series)
+}
+
+func (s *store) GetByTag(tag string) ([]StoredPost, error) {
+
+ var posts []StoredPost
+
+ err := s.withTx(func(tx *sql.Tx) error {
+
+ rows, err := tx.Query(`SELECT post_id FROM post_tags WHERE tag = ?`, tag)
+
+ if err != nil {
+ return fmt.Errorf("querying post_tags by tag: %w", err)
+ }
+
+ var (
+ placeholders []string
+ whereArgs []interface{}
+ )
+
+ for rows.Next() {
+
+ var id string
+
+ if err := rows.Scan(&id); err != nil {
+ rows.Close()
+ return fmt.Errorf("scanning id: %w", err)
+ }
+
+ whereArgs = append(whereArgs, id)
+ placeholders = append(placeholders, "?")
+ }
+
+ if err := rows.Close(); err != nil {
+ return fmt.Errorf("closing row iterator: %w", err)
+ }
+
+ where := fmt.Sprintf("WHERE p.id IN (%s)", strings.Join(placeholders, ","))
+
+ if posts, err = s.get(tx, 0, 0, where, whereArgs...); err != nil {
+ return fmt.Errorf("querying for ids %+v: %w", whereArgs, err)
+ }
+
+ return nil
+ })
+
+ return posts, err
+}
+
+func (s *store) Delete(id string) error {
+
+ tx, err := s.db.Begin()
+
+ if err != nil {
+ return fmt.Errorf("starting transaction: %w", err)
+ }
+
+ if _, err := tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, id); err != nil {
+ return fmt.Errorf("deleting from post_tags: %w", err)
+ }
+
+ if _, err := tx.Exec(`DELETE FROM posts WHERE id = ?`, id); err != nil {
+ return fmt.Errorf("deleting from posts: %w", err)
+ }
+
+ if err := tx.Commit(); err != nil {
+ return fmt.Errorf("committing transaction: %w", err)
+ }
+
+ return nil
+}
diff --git a/srv/src/post/post_test.go b/srv/src/post/post_test.go
index 47c9ae8..55a29ea 100644
--- a/srv/src/post/post_test.go
+++ b/srv/src/post/post_test.go
@@ -1,10 +1,13 @@
package post
import (
+ "sort"
"strconv"
"testing"
+ "time"
"github.com/stretchr/testify/assert"
+ "github.com/tilinna/clock"
)
func TestNewID(t *testing.T) {
@@ -30,3 +33,224 @@ func TestNewID(t *testing.T) {
})
}
}
+
+func testPost(i int) Post {
+ istr := strconv.Itoa(i)
+ return Post{
+ ID: istr,
+ Title: istr,
+ Description: istr,
+ Body: istr,
+ }
+}
+
+type storeTestHarness struct {
+ clock *clock.Mock
+ store Store
+}
+
+func newStoreTestHarness(t *testing.T) storeTestHarness {
+
+ clock := clock.NewMock(time.Now().UTC().Truncate(1 * time.Hour))
+
+ db := NewInMemSQLDB()
+ t.Cleanup(func() { db.Close() })
+
+ store := NewStore(db)
+
+ return storeTestHarness{
+ clock: clock,
+ store: store,
+ }
+}
+
+func (h *storeTestHarness) testStoredPost(i int) StoredPost {
+ post := testPost(i)
+ return StoredPost{
+ Post: post,
+ PublishedAt: h.clock.Now(),
+ }
+}
+
+func TestStore(t *testing.T) {
+
+ assertPostEqual := func(t *testing.T, exp, got StoredPost) {
+ t.Helper()
+ sort.Strings(exp.Tags)
+ sort.Strings(got.Tags)
+ assert.Equal(t, exp, got)
+ }
+
+ assertPostsEqual := func(t *testing.T, exp, got []StoredPost) {
+ t.Helper()
+
+ if !assert.Len(t, got, len(exp), "exp:%+v\ngot: %+v", exp, got) {
+ return
+ }
+
+ for i := range exp {
+ assertPostEqual(t, exp[i], got[i])
+ }
+ }
+
+ t.Run("not_found", func(t *testing.T) {
+ h := newStoreTestHarness(t)
+
+ _, err := h.store.GetByID("foo")
+ assert.ErrorIs(t, err, ErrPostNotFound)
+ })
+
+ t.Run("set_get_delete", func(t *testing.T) {
+ h := newStoreTestHarness(t)
+
+ now := h.clock.Now().UTC()
+
+ post := testPost(0)
+ post.Tags = []string{"foo", "bar"}
+
+ assert.NoError(t, h.store.Set(post, now))
+
+ gotPost, err := h.store.GetByID(post.ID)
+ assert.NoError(t, err)
+
+ assertPostEqual(t, StoredPost{
+ Post: post,
+ PublishedAt: now,
+ }, gotPost)
+
+ // we will now try updating the post on a different day, and ensure it
+ // updates properly
+
+ h.clock.Add(24 * time.Hour)
+ newNow := h.clock.Now().UTC()
+
+ post.Title = "something else"
+ post.Series = "whatever"
+ post.Body = "anything"
+ post.Tags = []string{"bar", "baz"}
+
+ assert.NoError(t, h.store.Set(post, newNow))
+
+ gotPost, err = h.store.GetByID(post.ID)
+ assert.NoError(t, err)
+
+ assertPostEqual(t, StoredPost{
+ Post: post,
+ PublishedAt: now,
+ LastUpdatedAt: newNow,
+ }, gotPost)
+
+ // delete the post, it should go away
+ assert.NoError(t, h.store.Delete(post.ID))
+
+ _, err = h.store.GetByID(post.ID)
+ assert.ErrorIs(t, err, ErrPostNotFound)
+ })
+
+ t.Run("get", func(t *testing.T) {
+ h := newStoreTestHarness(t)
+
+ now := h.clock.Now().UTC()
+
+ posts := []StoredPost{
+ h.testStoredPost(0),
+ h.testStoredPost(1),
+ h.testStoredPost(2),
+ h.testStoredPost(3),
+ }
+
+ for _, post := range posts {
+ assert.NoError(t, h.store.Set(post.Post, now))
+ }
+
+ gotPosts, hasMore, err := h.store.Get(0, 2)
+ assert.NoError(t, err)
+ assert.True(t, hasMore)
+ assertPostsEqual(t, posts[:2], gotPosts)
+
+ gotPosts, hasMore, err = h.store.Get(1, 2)
+ assert.NoError(t, err)
+ assert.False(t, hasMore)
+ assertPostsEqual(t, posts[2:4], gotPosts)
+
+ posts = append(posts, h.testStoredPost(4))
+ assert.NoError(t, h.store.Set(posts[4].Post, now))
+
+ gotPosts, hasMore, err = h.store.Get(1, 2)
+ assert.NoError(t, err)
+ assert.True(t, hasMore)
+ assertPostsEqual(t, posts[2:4], gotPosts)
+
+ gotPosts, hasMore, err = h.store.Get(2, 2)
+ assert.NoError(t, err)
+ assert.False(t, hasMore)
+ assertPostsEqual(t, posts[4:], gotPosts)
+ })
+
+ t.Run("get_by_series", func(t *testing.T) {
+ h := newStoreTestHarness(t)
+
+ now := h.clock.Now().UTC()
+
+ posts := []StoredPost{
+ h.testStoredPost(0),
+ h.testStoredPost(1),
+ h.testStoredPost(2),
+ h.testStoredPost(3),
+ }
+
+ posts[0].Series = "foo"
+ posts[1].Series = "bar"
+ posts[2].Series = "bar"
+
+ for _, post := range posts {
+ assert.NoError(t, h.store.Set(post.Post, now))
+ }
+
+ fooPosts, err := h.store.GetBySeries("foo")
+ assert.NoError(t, err)
+ assertPostsEqual(t, posts[:1], fooPosts)
+
+ barPosts, err := h.store.GetBySeries("bar")
+ assert.NoError(t, err)
+ assertPostsEqual(t, posts[1:3], barPosts)
+
+ bazPosts, err := h.store.GetBySeries("baz")
+ assert.NoError(t, err)
+ assert.Empty(t, bazPosts)
+ })
+
+ t.Run("get_by_tag", func(t *testing.T) {
+
+ h := newStoreTestHarness(t)
+
+ now := h.clock.Now().UTC()
+
+ posts := []StoredPost{
+ h.testStoredPost(0),
+ h.testStoredPost(1),
+ h.testStoredPost(2),
+ h.testStoredPost(3),
+ }
+
+ posts[0].Tags = []string{"foo"}
+ posts[1].Tags = []string{"foo", "bar"}
+ posts[2].Tags = []string{"bar"}
+
+ for _, post := range posts {
+ assert.NoError(t, h.store.Set(post.Post, now))
+ }
+
+ fooPosts, err := h.store.GetByTag("foo")
+ assert.NoError(t, err)
+ assertPostsEqual(t, posts[:2], fooPosts)
+
+ barPosts, err := h.store.GetByTag("bar")
+ assert.NoError(t, err)
+ assertPostsEqual(t, posts[1:3], barPosts)
+
+ bazPosts, err := h.store.GetByTag("baz")
+ assert.NoError(t, err)
+ assert.Empty(t, bazPosts)
+ })
+}
diff --git a/srv/src/post/sql.go b/srv/src/post/sql.go
index fb9468f..16cdc95 100644
--- a/srv/src/post/sql.go
+++ b/srv/src/post/sql.go
@@ -7,10 +7,12 @@ import (
"github.com/mediocregopher/blog.mediocregopher.com/srv/cfg"
migrate "github.com/rubenv/sql-migrate"
+
+ _ "github.com/mattn/go-sqlite3" // we need dis
)
-var migrations = []*migrate.Migration{
- &migrate.Migration{
+var migrations = &migrate.MemoryMigrationSource{Migrations: []*migrate.Migration{
+ {
Id: "1",
Up: []string{
`CREATE TABLE posts (
@@ -24,18 +26,25 @@ var migrations = []*migrate.Migration{
body TEXT NOT NULL
)`,
+
`CREATE TABLE post_tags (
post_id TEXT NOT NULL,
tag TEXT NOT NULL,
UNIQUE(post_id, tag)
)`,
+
+ `CREATE TABLE assets (
+ id TEXT NOT NULL PRIMARY KEY,
+ body BLOB NOT NULL
+ )`,
},
Down: []string{
+ "DROP TABLE assets",
"DROP TABLE post_tags",
"DROP TABLE posts",
},
},
-}
+}}
// SQLDB is a sqlite3 database which can be used by storage interfaces within
// this package.
@@ -43,7 +52,7 @@ type SQLDB struct {
db *sql.DB
}
-// NewSQLDB initializes and returns a new sqlite3 database for post storage
+// NewSQLDB initializes and returns a new sqlite3 database for storage
// intefaces. The db will be created within the given data directory.
func NewSQLDB(dataDir cfg.DataDir) (*SQLDB, error) {
@@ -54,8 +63,6 @@ func NewSQLDB(dataDir cfg.DataDir) (*SQLDB, error) {
return nil, fmt.Errorf("opening sqlite file at %q: %w", path, err)
}
- migrations := &migrate.MemoryMigrationSource{Migrations: migrations}
-
if _, err := migrate.Exec(db, "sqlite3", migrations, migrate.Up); err != nil {
return nil, fmt.Errorf("running migrations: %w", err)
}
@@ -63,6 +70,21 @@ func NewSQLDB(dataDir cfg.DataDir) (*SQLDB, error) {
return &SQLDB{db}, nil
}
+// NewSQLDB is like NewSQLDB, but the database will be initialized in memory.
+func NewInMemSQLDB() *SQLDB {
+
+ db, err := sql.Open("sqlite3", ":memory:")
+ if err != nil {
+ panic(fmt.Errorf("opening sqlite in memory: %w", err))
+ }
+
+ if _, err := migrate.Exec(db, "sqlite3", migrations, migrate.Up); err != nil {
+ panic(fmt.Errorf("running migrations: %w", err))
+ }
+
+ return &SQLDB{db}
+}
+
// Close cleans up loose resources being held by the db.
func (db *SQLDB) Close() error {
return db.db.Close()
diff --git a/srv/src/post/store.go b/srv/src/post/store.go
deleted file mode 100644
index 5da17a9..0000000
--- a/srv/src/post/store.go
+++ /dev/null
@@ -1,362 +0,0 @@
-package post
-
-import (
- "database/sql"
- "errors"
- "fmt"
- "path"
- "strings"
- "time"
-
- _ "github.com/mattn/go-sqlite3" // we need dis
-)
-
-var (
- // ErrNotFound is used to indicate a Post could not be found in the
- // database.
- ErrNotFound = errors.New("not found")
-)
-
-// StoredPost is a Post which has been stored in a Store, and has been given
-// some extra fields as a result.
-type StoredPost struct {
- Post
-
- PublishedAt time.Time
- LastUpdatedAt time.Time
-}
-
-// URL returns the relative URL of the StoredPost.
-func (p StoredPost) URL() string {
- return path.Join(
- fmt.Sprintf(
- "%d/%0d/%0d",
- p.PublishedAt.Year(),
- p.PublishedAt.Month(),
- p.PublishedAt.Day(),
- ),
- p.ID+".html",
- )
-}
-
-// Store is used for storing posts to a persistent storage.
-type Store interface {
-
- // Set sets the Post data into the storage, keyed by the Post's ID. It
- // overwrites a previous Post with the same ID, if there was one.
- Set(post Post, now time.Time) error
-
- // Get returns count StoredPosts, sorted time descending, offset by the given page
- // number. The returned boolean indicates if there are more pages or not.
- Get(page, count int) ([]StoredPost, bool, error)
-
- // GetByID will return the StoredPost with the given ID, or ErrNotFound.
- GetByID(id string) (StoredPost, error)
-
- // GetBySeries returns all StoredPosts with the given series, sorted time
- // ascending, or empty slice.
- GetBySeries(series string) ([]StoredPost, error)
-
- // GetByTag returns all StoredPosts with the given tag, sorted time
- // ascending, or empty slice.
- GetByTag(tag string) ([]StoredPost, error)
-
- // Delete will delete the StoredPost with the given ID.
- Delete(id string) error
-
- Close() error
-}
-
-type store struct {
- db *sql.DB
-}
-
-// NewStore initializes a new Store using an existing SQLDB.
-func NewStore(db *SQLDB) Store {
- return &store{
- db: db.db,
- }
-}
-
-func (s *store) Close() error {
- return s.db.Close()
-}
-
-// if the callback returns an error then the transaction is aborted.
-func (s *store) withTx(cb func(*sql.Tx) error) error {
-
- tx, err := s.db.Begin()
-
- if err != nil {
- return fmt.Errorf("starting transaction: %w", err)
- }
-
- if err := cb(tx); err != nil {
-
- if rollbackErr := tx.Rollback(); rollbackErr != nil {
- return fmt.Errorf(
- "rolling back transaction: %w (original error: %v)",
- rollbackErr, err,
- )
- }
-
- return fmt.Errorf("performing transaction: %w (rolled back)", err)
- }
-
- if err := tx.Commit(); err != nil {
- return fmt.Errorf("committing transaction: %w", err)
- }
-
- return nil
-}
-
-func (s *store) Set(post Post, now time.Time) error {
- return s.withTx(func(tx *sql.Tx) error {
-
- nowTS := now.Unix()
-
- nowSQL := sql.NullInt64{Int64: nowTS, Valid: !now.IsZero()}
-
- _, err := tx.Exec(
- `INSERT INTO posts (
- id, title, description, series, published_at, body
- )
- VALUES
- (?, ?, ?, ?, ?, ?)
- ON CONFLICT (id) DO UPDATE SET
- title=excluded.title,
- description=excluded.description,
- series=excluded.series,
- last_updated_at=?,
- body=excluded.body`,
- post.ID,
- post.Title,
- post.Description,
- &sql.NullString{String: post.Series, Valid: post.Series != ""},
- nowSQL,
- post.Body,
- nowSQL,
- )
-
- if err != nil {
- return fmt.Errorf("inserting into posts: %w", err)
- }
-
- // this is a bit of a hack, but it allows us to update the tagset without
- // doing a diff.
- _, err = tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, post.ID)
-
- if err != nil {
- return fmt.Errorf("clearning post tags: %w", err)
- }
-
- for _, tag := range post.Tags {
-
- _, err = tx.Exec(
- `INSERT INTO post_tags (post_id, tag) VALUES (?, ?)
- ON CONFLICT DO NOTHING`,
- post.ID,
- tag,
- )
-
- if err != nil {
- return fmt.Errorf("inserting tag %q: %w", tag, err)
- }
- }
-
- return nil
- })
-}
-
-func (s *store) get(
- querier interface {
- Query(string, ...interface{}) (*sql.Rows, error)
- },
- limit, offset int,
- where string, whereArgs ...interface{},
-) (
- []StoredPost, error,
-) {
-
- query := `SELECT
- p.id, p.title, p.description, p.series, pt.tag,
- p.published_at, p.last_updated_at,
- p.body
- FROM posts p
- LEFT JOIN post_tags pt ON (p.id = pt.post_id)
- ` + where + `
- ORDER BY p.published_at ASC, p.title ASC`
-
- if limit > 0 {
- query += fmt.Sprintf(" LIMIT %d", limit)
- }
-
- if offset > 0 {
- query += fmt.Sprintf(" OFFSET %d", offset)
- }
-
- rows, err := querier.Query(query, whereArgs...)
-
- if err != nil {
- return nil, fmt.Errorf("selecting: %w", err)
- }
-
- var posts []StoredPost
-
- for rows.Next() {
-
- var (
- post StoredPost
- series, tag sql.NullString
- publishedAt, lastUpdatedAt sql.NullInt64
- )
-
- err := rows.Scan(
- &post.ID, &post.Title, &post.Description, &series, &tag,
- &publishedAt, &lastUpdatedAt,
- &post.Body,
- )
-
- if err != nil {
- return nil, fmt.Errorf("scanning row: %w", err)
- }
-
- if tag.Valid {
-
- if l := len(posts); l > 0 && posts[l-1].ID == post.ID {
- posts[l-1].Tags = append(posts[l-1].Tags, tag.String)
- continue
- }
-
- post.Tags = append(post.Tags, tag.String)
- }
-
- post.Series = series.String
-
- if publishedAt.Valid {
- post.PublishedAt = time.Unix(publishedAt.Int64, 0).UTC()
- }
-
- if lastUpdatedAt.Valid {
- post.LastUpdatedAt = time.Unix(lastUpdatedAt.Int64, 0).UTC()
- }
-
- posts = append(posts, post)
- }
-
- if err := rows.Close(); err != nil {
- return nil, fmt.Errorf("closing row iterator: %w", err)
- }
-
- return posts, nil
-}
-
-func (s *store) Get(page, count int) ([]StoredPost, bool, error) {
-
- posts, err := s.get(s.db, count+1, page*count, ``)
-
- if err != nil {
- return nil, false, fmt.Errorf("querying posts: %w", err)
- }
-
- var hasMore bool
-
- if len(posts) > count {
- hasMore = true
- posts = posts[:count]
- }
-
- return posts, hasMore, nil
-}
-
-func (s *store) GetByID(id string) (StoredPost, error) {
-
- posts, err := s.get(s.db, 0, 0, `WHERE p.id=?`, id)
-
- if err != nil {
- return StoredPost{}, fmt.Errorf("querying posts: %w", err)
- }
-
- if len(posts) == 0 {
- return StoredPost{}, ErrNotFound
- }
-
- if len(posts) > 1 {
- panic(fmt.Sprintf("got back multiple posts querying id %q: %+v", id, posts))
- }
-
- return posts[0], nil
-}
-
-func (s *store) GetBySeries(series string) ([]StoredPost, error) {
- return s.get(s.db, 0, 0, `WHERE p.series=?`, series)
-}
-
-func (s *store) GetByTag(tag string) ([]StoredPost, error) {
-
- var posts []StoredPost
-
- err := s.withTx(func(tx *sql.Tx) error {
-
- rows, err := tx.Query(`SELECT post_id FROM post_tags WHERE tag = ?`, tag)
-
- if err != nil {
- return fmt.Errorf("querying post_tags by tag: %w", err)
- }
-
- var (
- placeholders []string
- whereArgs []interface{}
- )
-
- for rows.Next() {
-
- var id string
-
- if err := rows.Scan(&id); err != nil {
- rows.Close()
- return fmt.Errorf("scanning id: %w", err)
- }
-
- whereArgs = append(whereArgs, id)
- placeholders = append(placeholders, "?")
- }
-
- if err := rows.Close(); err != nil {
- return fmt.Errorf("closing row iterator: %w", err)
- }
-
- where := fmt.Sprintf("WHERE p.id IN (%s)", strings.Join(placeholders, ","))
-
- if posts, err = s.get(tx, 0, 0, where, whereArgs...); err != nil {
- return fmt.Errorf("querying for ids %+v: %w", whereArgs, err)
- }
-
- return nil
- })
-
- return posts, err
-}
-
-func (s *store) Delete(id string) error {
-
- tx, err := s.db.Begin()
-
- if err != nil {
- return fmt.Errorf("starting transaction: %w", err)
- }
-
- if _, err := tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, id); err != nil {
- return fmt.Errorf("deleting from post_tags: %w", err)
- }
-
- if _, err := tx.Exec(`DELETE FROM posts WHERE id = ?`, id); err != nil {
- return fmt.Errorf("deleting from posts: %w", err)
- }
-
- if err := tx.Commit(); err != nil {
- return fmt.Errorf("committing transaction: %w", err)
- }
-
- return nil
-}
diff --git a/srv/src/post/store_test.go b/srv/src/post/store_test.go
deleted file mode 100644
index 3ccf5d7..0000000
--- a/srv/src/post/store_test.go
+++ /dev/null
@@ -1,245 +0,0 @@
-package post
-
-import (
- "sort"
- "strconv"
- "testing"
- "time"
-
- "github.com/mediocregopher/blog.mediocregopher.com/srv/cfg"
- "github.com/stretchr/testify/assert"
- "github.com/tilinna/clock"
-)
-
-func testPost(i int) Post {
- istr := strconv.Itoa(i)
- return Post{
- ID: istr,
- Title: istr,
- Description: istr,
- Body: istr,
- }
-}
-
-type storeTestHarness struct {
- clock *clock.Mock
- store Store
-}
-
-func newStoreTestHarness(t *testing.T) storeTestHarness {
-
- var dataDir cfg.DataDir
-
- if err := dataDir.Init(); err != nil {
- t.Fatal(err)
- }
-
- t.Cleanup(func() { dataDir.Close() })
-
- clock := clock.NewMock(time.Now().UTC().Truncate(1 * time.Hour))
-
- db, err := NewSQLDB(dataDir)
- if err != nil {
- t.Fatal(err)
- }
-
- t.Cleanup(func() { db.Close() })
-
- store := NewStore(db)
-
- return storeTestHarness{
- clock: clock,
- store: store,
- }
-}
-
-func (h *storeTestHarness) testStoredPost(i int) StoredPost {
- post := testPost(i)
- return StoredPost{
- Post: post,
- PublishedAt: h.clock.Now(),
- }
-}
-
-func TestStore(t *testing.T) {
-
- assertPostEqual := func(t *testing.T, exp, got StoredPost) {
- t.Helper()
- sort.Strings(exp.Tags)
- sort.Strings(got.Tags)
- assert.Equal(t, exp, got)
- }
-
- assertPostsEqual := func(t *testing.T, exp, got []StoredPost) {
- t.Helper()
-
- if !assert.Len(t, got, len(exp), "exp:%+v\ngot: %+v", exp, got) {
- return
- }
-
- for i := range exp {
- assertPostEqual(t, exp[i], got[i])
- }
- }
-
- t.Run("not_found", func(t *testing.T) {
- h := newStoreTestHarness(t)
-
- _, err := h.store.GetByID("foo")
- assert.ErrorIs(t, err, ErrNotFound)
- })
-
- t.Run("set_get_delete", func(t *testing.T) {
- h := newStoreTestHarness(t)
-
- now := h.clock.Now().UTC()
-
- post := testPost(0)
- post.Tags = []string{"foo", "bar"}
-
- assert.NoError(t, h.store.Set(post, now))
-
- gotPost, err := h.store.GetByID(post.ID)
- assert.NoError(t, err)
-
- assertPostEqual(t, StoredPost{
- Post: post,
- PublishedAt: now,
- }, gotPost)
-
- // we will now try updating the post on a different day, and ensure it
- // updates properly
-
- h.clock.Add(24 * time.Hour)
- newNow := h.clock.Now().UTC()
-
- post.Title = "something else"
- post.Series = "whatever"
- post.Body = "anything"
- post.Tags = []string{"bar", "baz"}
-
- assert.NoError(t, h.store.Set(post, newNow))
-
- gotPost, err = h.store.GetByID(post.ID)
- assert.NoError(t, err)
-
- assertPostEqual(t, StoredPost{
- Post: post,
- PublishedAt: now,
- LastUpdatedAt: newNow,
- }, gotPost)
-
- // delete the post, it should go away
- assert.NoError(t, h.store.Delete(post.ID))
-
- _, err = h.store.GetByID(post.ID)
- assert.ErrorIs(t, err, ErrNotFound)
- })
-
- t.Run("get", func(t *testing.T) {
- h := newStoreTestHarness(t)
-
- now := h.clock.Now().UTC()
-
- posts := []StoredPost{
- h.testStoredPost(0),
- h.testStoredPost(1),
- h.testStoredPost(2),
- h.testStoredPost(3),
- }
-
- for _, post := range posts {
- assert.NoError(t, h.store.Set(post.Post, now))
- }
-
- gotPosts, hasMore, err := h.store.Get(0, 2)
- assert.NoError(t, err)
- assert.True(t, hasMore)
- assertPostsEqual(t, posts[:2], gotPosts)
-
- gotPosts, hasMore, err = h.store.Get(1, 2)
- assert.NoError(t, err)
- assert.False(t, hasMore)
- assertPostsEqual(t, posts[2:4], gotPosts)
-
- posts = append(posts, h.testStoredPost(4))
- assert.NoError(t, h.store.Set(posts[4].Post, now))
-
- gotPosts, hasMore, err = h.store.Get(1, 2)
- assert.NoError(t, err)
- assert.True(t, hasMore)
- assertPostsEqual(t, posts[2:4], gotPosts)
-
- gotPosts, hasMore, err = h.store.Get(2, 2)
- assert.NoError(t, err)
- assert.False(t, hasMore)
- assertPostsEqual(t, posts[4:], gotPosts)
- })
-
- t.Run("get_by_series", func(t *testing.T) {
- h := newStoreTestHarness(t)
-
- now := h.clock.Now().UTC()
-
- posts := []StoredPost{
- h.testStoredPost(0),
- h.testStoredPost(1),
- h.testStoredPost(2),
- h.testStoredPost(3),
- }
-
- posts[0].Series = "foo"
- posts[1].Series = "bar"
- posts[2].Series = "bar"
-
- for _, post := range posts {
- assert.NoError(t, h.store.Set(post.Post, now))
- }
-
- fooPosts, err := h.store.GetBySeries("foo")
- assert.NoError(t, err)
- assertPostsEqual(t, posts[:1], fooPosts)
-
- barPosts, err := h.store.GetBySeries("bar")
- assert.NoError(t, err)
- assertPostsEqual(t, posts[1:3], barPosts)
-
- bazPosts, err := h.store.GetBySeries("baz")
- assert.NoError(t, err)
- assert.Empty(t, bazPosts)
- })
-
- t.Run("get_by_tag", func(t *testing.T) {
-
- h := newStoreTestHarness(t)
-
- now := h.clock.Now().UTC()
-
- posts := []StoredPost{
- h.testStoredPost(0),
- h.testStoredPost(1),
- h.testStoredPost(2),
- h.testStoredPost(3),
- }
-
- posts[0].Tags = []string{"foo"}
- posts[1].Tags = []string{"foo", "bar"}
- posts[2].Tags = []string{"bar"}
-
- for _, post := range posts {
- assert.NoError(t, h.store.Set(post.Post, now))
- }
-
- fooPosts, err := h.store.GetByTag("foo")
- assert.NoError(t, err)
- assertPostsEqual(t, posts[:2], fooPosts)
-
- barPosts, err := h.store.GetByTag("bar")
- assert.NoError(t, err)
- assertPostsEqual(t, posts[1:3], barPosts)
-
- bazPosts, err := h.store.GetByTag("baz")
- assert.NoError(t, err)
- assert.Empty(t, bazPosts)
- })
-}