From cfb633b3b5eed1a65b23cc641b1b250adffdbc8f Mon Sep 17 00:00:00 2001 From: Brian Picciano Date: Sat, 7 May 2022 17:58:20 -0600 Subject: Cleanup various small issues with post package --- srv/src/post/date.go | 59 -------- srv/src/post/post.go | 352 ++++++++++++++++++++++++++++++++++++++++++- srv/src/post/post_test.go | 224 ++++++++++++++++++++++++++++ srv/src/post/sql.go | 34 ++++- srv/src/post/store.go | 362 --------------------------------------------- srv/src/post/store_test.go | 245 ------------------------------ 6 files changed, 603 insertions(+), 673 deletions(-) delete mode 100644 srv/src/post/date.go delete mode 100644 srv/src/post/store.go delete mode 100644 srv/src/post/store_test.go diff --git a/srv/src/post/date.go b/srv/src/post/date.go deleted file mode 100644 index 34fe109..0000000 --- a/srv/src/post/date.go +++ /dev/null @@ -1,59 +0,0 @@ -package post - -import ( - "database/sql/driver" - "fmt" - "time" -) - -// Date represents a calendar date with no timezone information attached. -type Date struct { - Year int - Month time.Month - Day int -} - -// DateFromTime converts a Time into a Date, truncating all non-date -// information. -func DateFromTime(t time.Time) Date { - t = t.UTC() - return Date{ - Year: t.Year(), - Month: t.Month(), - Day: t.Day(), - } -} - -// ToTime converts a Date into a Time. The returned time will be UTC midnight of -// the Date. -func (d *Date) ToTime() time.Time { - return time.Date(d.Year, d.Month, d.Day, 0, 0, 0, 0, time.UTC) -} - -// Scan implements the sql.Scanner interface. -func (d *Date) Scan(src interface{}) error { - - if src == nil { - *d = Date{} - return nil - } - - ts, ok := src.(int64) - - if !ok { - return fmt.Errorf("cannot scan value %#v into Date", src) - } - - *d = DateFromTime(time.Unix(ts, 0)) - return nil -} - -// Value implements the driver.Valuer interface. -func (d Date) Value() (driver.Value, error) { - - if d == (Date{}) { - return nil, nil - } - - return d.ToTime().Unix(), nil -} diff --git a/srv/src/post/post.go b/srv/src/post/post.go index 7803c82..bdc48af 100644 --- a/srv/src/post/post.go +++ b/srv/src/post/post.go @@ -1,9 +1,20 @@ -// Package post deals with the storage and rending of blog post. +// Package post deals with the storage and rendering of blog posts. package post import ( + "database/sql" + "errors" + "fmt" + "path" "regexp" "strings" + "time" +) + +var ( + // ErrPostNotFound is used to indicate a Post could not be found in the + // Store. + ErrPostNotFound = errors.New("post not found") ) var titleCleanRegexp = regexp.MustCompile(`[^a-z ]`) @@ -25,3 +36,342 @@ type Post struct { Series string Body string } + +// StoredPost is a Post which has been stored in a Store, and has been given +// some extra fields as a result. +type StoredPost struct { + Post + + PublishedAt time.Time + LastUpdatedAt time.Time +} + +// URL returns the relative URL of the StoredPost. +func (p StoredPost) URL() string { + return path.Join( + fmt.Sprintf( + "%d/%0d/%0d", + p.PublishedAt.Year(), + p.PublishedAt.Month(), + p.PublishedAt.Day(), + ), + p.ID+".html", + ) +} + +// Store is used for storing posts to a persistent storage. +type Store interface { + + // Set sets the Post data into the storage, keyed by the Post's ID. It + // overwrites a previous Post with the same ID, if there was one. + Set(post Post, now time.Time) error + + // Get returns count StoredPosts, sorted time descending, offset by the given page + // number. The returned boolean indicates if there are more pages or not. + Get(page, count int) ([]StoredPost, bool, error) + + // GetByID will return the StoredPost with the given ID, or ErrPostNotFound. + GetByID(id string) (StoredPost, error) + + // GetBySeries returns all StoredPosts with the given series, sorted time + // ascending, or empty slice. + GetBySeries(series string) ([]StoredPost, error) + + // GetByTag returns all StoredPosts with the given tag, sorted time + // ascending, or empty slice. + GetByTag(tag string) ([]StoredPost, error) + + // Delete will delete the StoredPost with the given ID. + Delete(id string) error +} + +type store struct { + db *sql.DB +} + +// NewStore initializes a new Store using an existing SQLDB. +func NewStore(db *SQLDB) Store { + return &store{ + db: db.db, + } +} + +// if the callback returns an error then the transaction is aborted. +func (s *store) withTx(cb func(*sql.Tx) error) error { + + tx, err := s.db.Begin() + + if err != nil { + return fmt.Errorf("starting transaction: %w", err) + } + + if err := cb(tx); err != nil { + + if rollbackErr := tx.Rollback(); rollbackErr != nil { + return fmt.Errorf( + "rolling back transaction: %w (original error: %v)", + rollbackErr, err, + ) + } + + return fmt.Errorf("performing transaction: %w (rolled back)", err) + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("committing transaction: %w", err) + } + + return nil +} + +func (s *store) Set(post Post, now time.Time) error { + return s.withTx(func(tx *sql.Tx) error { + + nowTS := now.Unix() + + nowSQL := sql.NullInt64{Int64: nowTS, Valid: !now.IsZero()} + + _, err := tx.Exec( + `INSERT INTO posts ( + id, title, description, series, published_at, body + ) + VALUES + (?, ?, ?, ?, ?, ?) + ON CONFLICT (id) DO UPDATE SET + title=excluded.title, + description=excluded.description, + series=excluded.series, + last_updated_at=?, + body=excluded.body`, + post.ID, + post.Title, + post.Description, + &sql.NullString{String: post.Series, Valid: post.Series != ""}, + nowSQL, + post.Body, + nowSQL, + ) + + if err != nil { + return fmt.Errorf("inserting into posts: %w", err) + } + + // this is a bit of a hack, but it allows us to update the tagset without + // doing a diff. + _, err = tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, post.ID) + + if err != nil { + return fmt.Errorf("clearning post tags: %w", err) + } + + for _, tag := range post.Tags { + + _, err = tx.Exec( + `INSERT INTO post_tags (post_id, tag) VALUES (?, ?) + ON CONFLICT DO NOTHING`, + post.ID, + tag, + ) + + if err != nil { + return fmt.Errorf("inserting tag %q: %w", tag, err) + } + } + + return nil + }) +} + +func (s *store) get( + querier interface { + Query(string, ...interface{}) (*sql.Rows, error) + }, + limit, offset int, + where string, whereArgs ...interface{}, +) ( + []StoredPost, error, +) { + + query := ` + SELECT + p.id, p.title, p.description, p.series, pt.tag, + p.published_at, p.last_updated_at, + p.body + FROM posts p + LEFT JOIN post_tags pt ON (p.id = pt.post_id) + ` + where + ` + ORDER BY p.published_at ASC, p.title ASC` + + if limit > 0 { + query += fmt.Sprintf(" LIMIT %d", limit) + } + + if offset > 0 { + query += fmt.Sprintf(" OFFSET %d", offset) + } + + rows, err := querier.Query(query, whereArgs...) + + if err != nil { + return nil, fmt.Errorf("selecting: %w", err) + } + + var posts []StoredPost + + for rows.Next() { + + var ( + post StoredPost + series, tag sql.NullString + publishedAt, lastUpdatedAt sql.NullInt64 + ) + + err := rows.Scan( + &post.ID, &post.Title, &post.Description, &series, &tag, + &publishedAt, &lastUpdatedAt, + &post.Body, + ) + + if err != nil { + return nil, fmt.Errorf("scanning row: %w", err) + } + + if tag.Valid { + + if l := len(posts); l > 0 && posts[l-1].ID == post.ID { + posts[l-1].Tags = append(posts[l-1].Tags, tag.String) + continue + } + + post.Tags = append(post.Tags, tag.String) + } + + post.Series = series.String + + if publishedAt.Valid { + post.PublishedAt = time.Unix(publishedAt.Int64, 0).UTC() + } + + if lastUpdatedAt.Valid { + post.LastUpdatedAt = time.Unix(lastUpdatedAt.Int64, 0).UTC() + } + + posts = append(posts, post) + } + + if err := rows.Close(); err != nil { + return nil, fmt.Errorf("closing row iterator: %w", err) + } + + return posts, nil +} + +func (s *store) Get(page, count int) ([]StoredPost, bool, error) { + + posts, err := s.get(s.db, count+1, page*count, ``) + + if err != nil { + return nil, false, fmt.Errorf("querying posts: %w", err) + } + + var hasMore bool + + if len(posts) > count { + hasMore = true + posts = posts[:count] + } + + return posts, hasMore, nil +} + +func (s *store) GetByID(id string) (StoredPost, error) { + + posts, err := s.get(s.db, 0, 0, `WHERE p.id=?`, id) + + if err != nil { + return StoredPost{}, fmt.Errorf("querying posts: %w", err) + } + + if len(posts) == 0 { + return StoredPost{}, ErrPostNotFound + } + + if len(posts) > 1 { + panic(fmt.Sprintf("got back multiple posts querying id %q: %+v", id, posts)) + } + + return posts[0], nil +} + +func (s *store) GetBySeries(series string) ([]StoredPost, error) { + return s.get(s.db, 0, 0, `WHERE p.series=?`, series) +} + +func (s *store) GetByTag(tag string) ([]StoredPost, error) { + + var posts []StoredPost + + err := s.withTx(func(tx *sql.Tx) error { + + rows, err := tx.Query(`SELECT post_id FROM post_tags WHERE tag = ?`, tag) + + if err != nil { + return fmt.Errorf("querying post_tags by tag: %w", err) + } + + var ( + placeholders []string + whereArgs []interface{} + ) + + for rows.Next() { + + var id string + + if err := rows.Scan(&id); err != nil { + rows.Close() + return fmt.Errorf("scanning id: %w", err) + } + + whereArgs = append(whereArgs, id) + placeholders = append(placeholders, "?") + } + + if err := rows.Close(); err != nil { + return fmt.Errorf("closing row iterator: %w", err) + } + + where := fmt.Sprintf("WHERE p.id IN (%s)", strings.Join(placeholders, ",")) + + if posts, err = s.get(tx, 0, 0, where, whereArgs...); err != nil { + return fmt.Errorf("querying for ids %+v: %w", whereArgs, err) + } + + return nil + }) + + return posts, err +} + +func (s *store) Delete(id string) error { + + tx, err := s.db.Begin() + + if err != nil { + return fmt.Errorf("starting transaction: %w", err) + } + + if _, err := tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, id); err != nil { + return fmt.Errorf("deleting from post_tags: %w", err) + } + + if _, err := tx.Exec(`DELETE FROM posts WHERE id = ?`, id); err != nil { + return fmt.Errorf("deleting from posts: %w", err) + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("committing transaction: %w", err) + } + + return nil +} diff --git a/srv/src/post/post_test.go b/srv/src/post/post_test.go index 47c9ae8..55a29ea 100644 --- a/srv/src/post/post_test.go +++ b/srv/src/post/post_test.go @@ -1,10 +1,13 @@ package post import ( + "sort" "strconv" "testing" + "time" "github.com/stretchr/testify/assert" + "github.com/tilinna/clock" ) func TestNewID(t *testing.T) { @@ -30,3 +33,224 @@ func TestNewID(t *testing.T) { }) } } + +func testPost(i int) Post { + istr := strconv.Itoa(i) + return Post{ + ID: istr, + Title: istr, + Description: istr, + Body: istr, + } +} + +type storeTestHarness struct { + clock *clock.Mock + store Store +} + +func newStoreTestHarness(t *testing.T) storeTestHarness { + + clock := clock.NewMock(time.Now().UTC().Truncate(1 * time.Hour)) + + db := NewInMemSQLDB() + t.Cleanup(func() { db.Close() }) + + store := NewStore(db) + + return storeTestHarness{ + clock: clock, + store: store, + } +} + +func (h *storeTestHarness) testStoredPost(i int) StoredPost { + post := testPost(i) + return StoredPost{ + Post: post, + PublishedAt: h.clock.Now(), + } +} + +func TestStore(t *testing.T) { + + assertPostEqual := func(t *testing.T, exp, got StoredPost) { + t.Helper() + sort.Strings(exp.Tags) + sort.Strings(got.Tags) + assert.Equal(t, exp, got) + } + + assertPostsEqual := func(t *testing.T, exp, got []StoredPost) { + t.Helper() + + if !assert.Len(t, got, len(exp), "exp:%+v\ngot: %+v", exp, got) { + return + } + + for i := range exp { + assertPostEqual(t, exp[i], got[i]) + } + } + + t.Run("not_found", func(t *testing.T) { + h := newStoreTestHarness(t) + + _, err := h.store.GetByID("foo") + assert.ErrorIs(t, err, ErrPostNotFound) + }) + + t.Run("set_get_delete", func(t *testing.T) { + h := newStoreTestHarness(t) + + now := h.clock.Now().UTC() + + post := testPost(0) + post.Tags = []string{"foo", "bar"} + + assert.NoError(t, h.store.Set(post, now)) + + gotPost, err := h.store.GetByID(post.ID) + assert.NoError(t, err) + + assertPostEqual(t, StoredPost{ + Post: post, + PublishedAt: now, + }, gotPost) + + // we will now try updating the post on a different day, and ensure it + // updates properly + + h.clock.Add(24 * time.Hour) + newNow := h.clock.Now().UTC() + + post.Title = "something else" + post.Series = "whatever" + post.Body = "anything" + post.Tags = []string{"bar", "baz"} + + assert.NoError(t, h.store.Set(post, newNow)) + + gotPost, err = h.store.GetByID(post.ID) + assert.NoError(t, err) + + assertPostEqual(t, StoredPost{ + Post: post, + PublishedAt: now, + LastUpdatedAt: newNow, + }, gotPost) + + // delete the post, it should go away + assert.NoError(t, h.store.Delete(post.ID)) + + _, err = h.store.GetByID(post.ID) + assert.ErrorIs(t, err, ErrPostNotFound) + }) + + t.Run("get", func(t *testing.T) { + h := newStoreTestHarness(t) + + now := h.clock.Now().UTC() + + posts := []StoredPost{ + h.testStoredPost(0), + h.testStoredPost(1), + h.testStoredPost(2), + h.testStoredPost(3), + } + + for _, post := range posts { + assert.NoError(t, h.store.Set(post.Post, now)) + } + + gotPosts, hasMore, err := h.store.Get(0, 2) + assert.NoError(t, err) + assert.True(t, hasMore) + assertPostsEqual(t, posts[:2], gotPosts) + + gotPosts, hasMore, err = h.store.Get(1, 2) + assert.NoError(t, err) + assert.False(t, hasMore) + assertPostsEqual(t, posts[2:4], gotPosts) + + posts = append(posts, h.testStoredPost(4)) + assert.NoError(t, h.store.Set(posts[4].Post, now)) + + gotPosts, hasMore, err = h.store.Get(1, 2) + assert.NoError(t, err) + assert.True(t, hasMore) + assertPostsEqual(t, posts[2:4], gotPosts) + + gotPosts, hasMore, err = h.store.Get(2, 2) + assert.NoError(t, err) + assert.False(t, hasMore) + assertPostsEqual(t, posts[4:], gotPosts) + }) + + t.Run("get_by_series", func(t *testing.T) { + h := newStoreTestHarness(t) + + now := h.clock.Now().UTC() + + posts := []StoredPost{ + h.testStoredPost(0), + h.testStoredPost(1), + h.testStoredPost(2), + h.testStoredPost(3), + } + + posts[0].Series = "foo" + posts[1].Series = "bar" + posts[2].Series = "bar" + + for _, post := range posts { + assert.NoError(t, h.store.Set(post.Post, now)) + } + + fooPosts, err := h.store.GetBySeries("foo") + assert.NoError(t, err) + assertPostsEqual(t, posts[:1], fooPosts) + + barPosts, err := h.store.GetBySeries("bar") + assert.NoError(t, err) + assertPostsEqual(t, posts[1:3], barPosts) + + bazPosts, err := h.store.GetBySeries("baz") + assert.NoError(t, err) + assert.Empty(t, bazPosts) + }) + + t.Run("get_by_tag", func(t *testing.T) { + + h := newStoreTestHarness(t) + + now := h.clock.Now().UTC() + + posts := []StoredPost{ + h.testStoredPost(0), + h.testStoredPost(1), + h.testStoredPost(2), + h.testStoredPost(3), + } + + posts[0].Tags = []string{"foo"} + posts[1].Tags = []string{"foo", "bar"} + posts[2].Tags = []string{"bar"} + + for _, post := range posts { + assert.NoError(t, h.store.Set(post.Post, now)) + } + + fooPosts, err := h.store.GetByTag("foo") + assert.NoError(t, err) + assertPostsEqual(t, posts[:2], fooPosts) + + barPosts, err := h.store.GetByTag("bar") + assert.NoError(t, err) + assertPostsEqual(t, posts[1:3], barPosts) + + bazPosts, err := h.store.GetByTag("baz") + assert.NoError(t, err) + assert.Empty(t, bazPosts) + }) +} diff --git a/srv/src/post/sql.go b/srv/src/post/sql.go index fb9468f..16cdc95 100644 --- a/srv/src/post/sql.go +++ b/srv/src/post/sql.go @@ -7,10 +7,12 @@ import ( "github.com/mediocregopher/blog.mediocregopher.com/srv/cfg" migrate "github.com/rubenv/sql-migrate" + + _ "github.com/mattn/go-sqlite3" // we need dis ) -var migrations = []*migrate.Migration{ - &migrate.Migration{ +var migrations = &migrate.MemoryMigrationSource{Migrations: []*migrate.Migration{ + { Id: "1", Up: []string{ `CREATE TABLE posts ( @@ -24,18 +26,25 @@ var migrations = []*migrate.Migration{ body TEXT NOT NULL )`, + `CREATE TABLE post_tags ( post_id TEXT NOT NULL, tag TEXT NOT NULL, UNIQUE(post_id, tag) )`, + + `CREATE TABLE assets ( + id TEXT NOT NULL PRIMARY KEY, + body BLOB NOT NULL + )`, }, Down: []string{ + "DROP TABLE assets", "DROP TABLE post_tags", "DROP TABLE posts", }, }, -} +}} // SQLDB is a sqlite3 database which can be used by storage interfaces within // this package. @@ -43,7 +52,7 @@ type SQLDB struct { db *sql.DB } -// NewSQLDB initializes and returns a new sqlite3 database for post storage +// NewSQLDB initializes and returns a new sqlite3 database for storage // intefaces. The db will be created within the given data directory. func NewSQLDB(dataDir cfg.DataDir) (*SQLDB, error) { @@ -54,8 +63,6 @@ func NewSQLDB(dataDir cfg.DataDir) (*SQLDB, error) { return nil, fmt.Errorf("opening sqlite file at %q: %w", path, err) } - migrations := &migrate.MemoryMigrationSource{Migrations: migrations} - if _, err := migrate.Exec(db, "sqlite3", migrations, migrate.Up); err != nil { return nil, fmt.Errorf("running migrations: %w", err) } @@ -63,6 +70,21 @@ func NewSQLDB(dataDir cfg.DataDir) (*SQLDB, error) { return &SQLDB{db}, nil } +// NewSQLDB is like NewSQLDB, but the database will be initialized in memory. +func NewInMemSQLDB() *SQLDB { + + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + panic(fmt.Errorf("opening sqlite in memory: %w", err)) + } + + if _, err := migrate.Exec(db, "sqlite3", migrations, migrate.Up); err != nil { + panic(fmt.Errorf("running migrations: %w", err)) + } + + return &SQLDB{db} +} + // Close cleans up loose resources being held by the db. func (db *SQLDB) Close() error { return db.db.Close() diff --git a/srv/src/post/store.go b/srv/src/post/store.go deleted file mode 100644 index 5da17a9..0000000 --- a/srv/src/post/store.go +++ /dev/null @@ -1,362 +0,0 @@ -package post - -import ( - "database/sql" - "errors" - "fmt" - "path" - "strings" - "time" - - _ "github.com/mattn/go-sqlite3" // we need dis -) - -var ( - // ErrNotFound is used to indicate a Post could not be found in the - // database. - ErrNotFound = errors.New("not found") -) - -// StoredPost is a Post which has been stored in a Store, and has been given -// some extra fields as a result. -type StoredPost struct { - Post - - PublishedAt time.Time - LastUpdatedAt time.Time -} - -// URL returns the relative URL of the StoredPost. -func (p StoredPost) URL() string { - return path.Join( - fmt.Sprintf( - "%d/%0d/%0d", - p.PublishedAt.Year(), - p.PublishedAt.Month(), - p.PublishedAt.Day(), - ), - p.ID+".html", - ) -} - -// Store is used for storing posts to a persistent storage. -type Store interface { - - // Set sets the Post data into the storage, keyed by the Post's ID. It - // overwrites a previous Post with the same ID, if there was one. - Set(post Post, now time.Time) error - - // Get returns count StoredPosts, sorted time descending, offset by the given page - // number. The returned boolean indicates if there are more pages or not. - Get(page, count int) ([]StoredPost, bool, error) - - // GetByID will return the StoredPost with the given ID, or ErrNotFound. - GetByID(id string) (StoredPost, error) - - // GetBySeries returns all StoredPosts with the given series, sorted time - // ascending, or empty slice. - GetBySeries(series string) ([]StoredPost, error) - - // GetByTag returns all StoredPosts with the given tag, sorted time - // ascending, or empty slice. - GetByTag(tag string) ([]StoredPost, error) - - // Delete will delete the StoredPost with the given ID. - Delete(id string) error - - Close() error -} - -type store struct { - db *sql.DB -} - -// NewStore initializes a new Store using an existing SQLDB. -func NewStore(db *SQLDB) Store { - return &store{ - db: db.db, - } -} - -func (s *store) Close() error { - return s.db.Close() -} - -// if the callback returns an error then the transaction is aborted. -func (s *store) withTx(cb func(*sql.Tx) error) error { - - tx, err := s.db.Begin() - - if err != nil { - return fmt.Errorf("starting transaction: %w", err) - } - - if err := cb(tx); err != nil { - - if rollbackErr := tx.Rollback(); rollbackErr != nil { - return fmt.Errorf( - "rolling back transaction: %w (original error: %v)", - rollbackErr, err, - ) - } - - return fmt.Errorf("performing transaction: %w (rolled back)", err) - } - - if err := tx.Commit(); err != nil { - return fmt.Errorf("committing transaction: %w", err) - } - - return nil -} - -func (s *store) Set(post Post, now time.Time) error { - return s.withTx(func(tx *sql.Tx) error { - - nowTS := now.Unix() - - nowSQL := sql.NullInt64{Int64: nowTS, Valid: !now.IsZero()} - - _, err := tx.Exec( - `INSERT INTO posts ( - id, title, description, series, published_at, body - ) - VALUES - (?, ?, ?, ?, ?, ?) - ON CONFLICT (id) DO UPDATE SET - title=excluded.title, - description=excluded.description, - series=excluded.series, - last_updated_at=?, - body=excluded.body`, - post.ID, - post.Title, - post.Description, - &sql.NullString{String: post.Series, Valid: post.Series != ""}, - nowSQL, - post.Body, - nowSQL, - ) - - if err != nil { - return fmt.Errorf("inserting into posts: %w", err) - } - - // this is a bit of a hack, but it allows us to update the tagset without - // doing a diff. - _, err = tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, post.ID) - - if err != nil { - return fmt.Errorf("clearning post tags: %w", err) - } - - for _, tag := range post.Tags { - - _, err = tx.Exec( - `INSERT INTO post_tags (post_id, tag) VALUES (?, ?) - ON CONFLICT DO NOTHING`, - post.ID, - tag, - ) - - if err != nil { - return fmt.Errorf("inserting tag %q: %w", tag, err) - } - } - - return nil - }) -} - -func (s *store) get( - querier interface { - Query(string, ...interface{}) (*sql.Rows, error) - }, - limit, offset int, - where string, whereArgs ...interface{}, -) ( - []StoredPost, error, -) { - - query := `SELECT - p.id, p.title, p.description, p.series, pt.tag, - p.published_at, p.last_updated_at, - p.body - FROM posts p - LEFT JOIN post_tags pt ON (p.id = pt.post_id) - ` + where + ` - ORDER BY p.published_at ASC, p.title ASC` - - if limit > 0 { - query += fmt.Sprintf(" LIMIT %d", limit) - } - - if offset > 0 { - query += fmt.Sprintf(" OFFSET %d", offset) - } - - rows, err := querier.Query(query, whereArgs...) - - if err != nil { - return nil, fmt.Errorf("selecting: %w", err) - } - - var posts []StoredPost - - for rows.Next() { - - var ( - post StoredPost - series, tag sql.NullString - publishedAt, lastUpdatedAt sql.NullInt64 - ) - - err := rows.Scan( - &post.ID, &post.Title, &post.Description, &series, &tag, - &publishedAt, &lastUpdatedAt, - &post.Body, - ) - - if err != nil { - return nil, fmt.Errorf("scanning row: %w", err) - } - - if tag.Valid { - - if l := len(posts); l > 0 && posts[l-1].ID == post.ID { - posts[l-1].Tags = append(posts[l-1].Tags, tag.String) - continue - } - - post.Tags = append(post.Tags, tag.String) - } - - post.Series = series.String - - if publishedAt.Valid { - post.PublishedAt = time.Unix(publishedAt.Int64, 0).UTC() - } - - if lastUpdatedAt.Valid { - post.LastUpdatedAt = time.Unix(lastUpdatedAt.Int64, 0).UTC() - } - - posts = append(posts, post) - } - - if err := rows.Close(); err != nil { - return nil, fmt.Errorf("closing row iterator: %w", err) - } - - return posts, nil -} - -func (s *store) Get(page, count int) ([]StoredPost, bool, error) { - - posts, err := s.get(s.db, count+1, page*count, ``) - - if err != nil { - return nil, false, fmt.Errorf("querying posts: %w", err) - } - - var hasMore bool - - if len(posts) > count { - hasMore = true - posts = posts[:count] - } - - return posts, hasMore, nil -} - -func (s *store) GetByID(id string) (StoredPost, error) { - - posts, err := s.get(s.db, 0, 0, `WHERE p.id=?`, id) - - if err != nil { - return StoredPost{}, fmt.Errorf("querying posts: %w", err) - } - - if len(posts) == 0 { - return StoredPost{}, ErrNotFound - } - - if len(posts) > 1 { - panic(fmt.Sprintf("got back multiple posts querying id %q: %+v", id, posts)) - } - - return posts[0], nil -} - -func (s *store) GetBySeries(series string) ([]StoredPost, error) { - return s.get(s.db, 0, 0, `WHERE p.series=?`, series) -} - -func (s *store) GetByTag(tag string) ([]StoredPost, error) { - - var posts []StoredPost - - err := s.withTx(func(tx *sql.Tx) error { - - rows, err := tx.Query(`SELECT post_id FROM post_tags WHERE tag = ?`, tag) - - if err != nil { - return fmt.Errorf("querying post_tags by tag: %w", err) - } - - var ( - placeholders []string - whereArgs []interface{} - ) - - for rows.Next() { - - var id string - - if err := rows.Scan(&id); err != nil { - rows.Close() - return fmt.Errorf("scanning id: %w", err) - } - - whereArgs = append(whereArgs, id) - placeholders = append(placeholders, "?") - } - - if err := rows.Close(); err != nil { - return fmt.Errorf("closing row iterator: %w", err) - } - - where := fmt.Sprintf("WHERE p.id IN (%s)", strings.Join(placeholders, ",")) - - if posts, err = s.get(tx, 0, 0, where, whereArgs...); err != nil { - return fmt.Errorf("querying for ids %+v: %w", whereArgs, err) - } - - return nil - }) - - return posts, err -} - -func (s *store) Delete(id string) error { - - tx, err := s.db.Begin() - - if err != nil { - return fmt.Errorf("starting transaction: %w", err) - } - - if _, err := tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, id); err != nil { - return fmt.Errorf("deleting from post_tags: %w", err) - } - - if _, err := tx.Exec(`DELETE FROM posts WHERE id = ?`, id); err != nil { - return fmt.Errorf("deleting from posts: %w", err) - } - - if err := tx.Commit(); err != nil { - return fmt.Errorf("committing transaction: %w", err) - } - - return nil -} diff --git a/srv/src/post/store_test.go b/srv/src/post/store_test.go deleted file mode 100644 index 3ccf5d7..0000000 --- a/srv/src/post/store_test.go +++ /dev/null @@ -1,245 +0,0 @@ -package post - -import ( - "sort" - "strconv" - "testing" - "time" - - "github.com/mediocregopher/blog.mediocregopher.com/srv/cfg" - "github.com/stretchr/testify/assert" - "github.com/tilinna/clock" -) - -func testPost(i int) Post { - istr := strconv.Itoa(i) - return Post{ - ID: istr, - Title: istr, - Description: istr, - Body: istr, - } -} - -type storeTestHarness struct { - clock *clock.Mock - store Store -} - -func newStoreTestHarness(t *testing.T) storeTestHarness { - - var dataDir cfg.DataDir - - if err := dataDir.Init(); err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { dataDir.Close() }) - - clock := clock.NewMock(time.Now().UTC().Truncate(1 * time.Hour)) - - db, err := NewSQLDB(dataDir) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { db.Close() }) - - store := NewStore(db) - - return storeTestHarness{ - clock: clock, - store: store, - } -} - -func (h *storeTestHarness) testStoredPost(i int) StoredPost { - post := testPost(i) - return StoredPost{ - Post: post, - PublishedAt: h.clock.Now(), - } -} - -func TestStore(t *testing.T) { - - assertPostEqual := func(t *testing.T, exp, got StoredPost) { - t.Helper() - sort.Strings(exp.Tags) - sort.Strings(got.Tags) - assert.Equal(t, exp, got) - } - - assertPostsEqual := func(t *testing.T, exp, got []StoredPost) { - t.Helper() - - if !assert.Len(t, got, len(exp), "exp:%+v\ngot: %+v", exp, got) { - return - } - - for i := range exp { - assertPostEqual(t, exp[i], got[i]) - } - } - - t.Run("not_found", func(t *testing.T) { - h := newStoreTestHarness(t) - - _, err := h.store.GetByID("foo") - assert.ErrorIs(t, err, ErrNotFound) - }) - - t.Run("set_get_delete", func(t *testing.T) { - h := newStoreTestHarness(t) - - now := h.clock.Now().UTC() - - post := testPost(0) - post.Tags = []string{"foo", "bar"} - - assert.NoError(t, h.store.Set(post, now)) - - gotPost, err := h.store.GetByID(post.ID) - assert.NoError(t, err) - - assertPostEqual(t, StoredPost{ - Post: post, - PublishedAt: now, - }, gotPost) - - // we will now try updating the post on a different day, and ensure it - // updates properly - - h.clock.Add(24 * time.Hour) - newNow := h.clock.Now().UTC() - - post.Title = "something else" - post.Series = "whatever" - post.Body = "anything" - post.Tags = []string{"bar", "baz"} - - assert.NoError(t, h.store.Set(post, newNow)) - - gotPost, err = h.store.GetByID(post.ID) - assert.NoError(t, err) - - assertPostEqual(t, StoredPost{ - Post: post, - PublishedAt: now, - LastUpdatedAt: newNow, - }, gotPost) - - // delete the post, it should go away - assert.NoError(t, h.store.Delete(post.ID)) - - _, err = h.store.GetByID(post.ID) - assert.ErrorIs(t, err, ErrNotFound) - }) - - t.Run("get", func(t *testing.T) { - h := newStoreTestHarness(t) - - now := h.clock.Now().UTC() - - posts := []StoredPost{ - h.testStoredPost(0), - h.testStoredPost(1), - h.testStoredPost(2), - h.testStoredPost(3), - } - - for _, post := range posts { - assert.NoError(t, h.store.Set(post.Post, now)) - } - - gotPosts, hasMore, err := h.store.Get(0, 2) - assert.NoError(t, err) - assert.True(t, hasMore) - assertPostsEqual(t, posts[:2], gotPosts) - - gotPosts, hasMore, err = h.store.Get(1, 2) - assert.NoError(t, err) - assert.False(t, hasMore) - assertPostsEqual(t, posts[2:4], gotPosts) - - posts = append(posts, h.testStoredPost(4)) - assert.NoError(t, h.store.Set(posts[4].Post, now)) - - gotPosts, hasMore, err = h.store.Get(1, 2) - assert.NoError(t, err) - assert.True(t, hasMore) - assertPostsEqual(t, posts[2:4], gotPosts) - - gotPosts, hasMore, err = h.store.Get(2, 2) - assert.NoError(t, err) - assert.False(t, hasMore) - assertPostsEqual(t, posts[4:], gotPosts) - }) - - t.Run("get_by_series", func(t *testing.T) { - h := newStoreTestHarness(t) - - now := h.clock.Now().UTC() - - posts := []StoredPost{ - h.testStoredPost(0), - h.testStoredPost(1), - h.testStoredPost(2), - h.testStoredPost(3), - } - - posts[0].Series = "foo" - posts[1].Series = "bar" - posts[2].Series = "bar" - - for _, post := range posts { - assert.NoError(t, h.store.Set(post.Post, now)) - } - - fooPosts, err := h.store.GetBySeries("foo") - assert.NoError(t, err) - assertPostsEqual(t, posts[:1], fooPosts) - - barPosts, err := h.store.GetBySeries("bar") - assert.NoError(t, err) - assertPostsEqual(t, posts[1:3], barPosts) - - bazPosts, err := h.store.GetBySeries("baz") - assert.NoError(t, err) - assert.Empty(t, bazPosts) - }) - - t.Run("get_by_tag", func(t *testing.T) { - - h := newStoreTestHarness(t) - - now := h.clock.Now().UTC() - - posts := []StoredPost{ - h.testStoredPost(0), - h.testStoredPost(1), - h.testStoredPost(2), - h.testStoredPost(3), - } - - posts[0].Tags = []string{"foo"} - posts[1].Tags = []string{"foo", "bar"} - posts[2].Tags = []string{"bar"} - - for _, post := range posts { - assert.NoError(t, h.store.Set(post.Post, now)) - } - - fooPosts, err := h.store.GetByTag("foo") - assert.NoError(t, err) - assertPostsEqual(t, posts[:2], fooPosts) - - barPosts, err := h.store.GetByTag("bar") - assert.NoError(t, err) - assertPostsEqual(t, posts[1:3], barPosts) - - bazPosts, err := h.store.GetByTag("baz") - assert.NoError(t, err) - assert.Empty(t, bazPosts) - }) -} -- cgit v1.2.3