From f7d72adfb594980b0442c0fa7b5586e6248f96ac Mon Sep 17 00:00:00 2001 From: Brian Picciano Date: Fri, 6 May 2022 17:22:17 -0600 Subject: Implement post.Store --- srv/src/post/store.go | 403 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 403 insertions(+) (limited to 'srv/src/post/store.go') diff --git a/srv/src/post/store.go b/srv/src/post/store.go index 235520f..2898315 100644 --- a/srv/src/post/store.go +++ b/srv/src/post/store.go @@ -1 +1,404 @@ package post + +import ( + "database/sql" + "errors" + "fmt" + "path" + "strings" + + _ "github.com/mattn/go-sqlite3" + migrate "github.com/rubenv/sql-migrate" + "github.com/tilinna/clock" +) + +var ( + // ErrNotFound is used to indicate a Post could not be found in the + // database. + ErrNotFound = errors.New("not found") +) + +// StoredPost is a Post which has been stored in a Store, and has been given +// some extra fields as a result. +type StoredPost struct { + Post + + PublishedAt Date + LastUpdatedAt Date +} + +// URL returns the relative URL of the StoredPost. +func (p StoredPost) URL() string { + return path.Join( + fmt.Sprintf( + "%d/%0d/%0d", + p.PublishedAt.Year, + p.PublishedAt.Month, + p.PublishedAt.Day, + ), + p.ID+".html", + ) +} + +// Store is used for storing posts to a persistent storage. +type Store interface { + + // Set sets the Post data into the storage, keyed by the Post's ID. It + // overwrites a previous Post with the same ID, if there was one. + Set(Post) error + + // Get returns count StoredPosts, sorted time descending, offset by the given page + // number. The returned boolean indicates if there are more pages or not. + Get(page, count int) ([]StoredPost, bool, error) + + // GetByID will return the StoredPost with the given ID, or ErrNotFound. + GetByID(id string) (StoredPost, error) + + // GetBySeries returns all StoredPosts with the given series, sorted time + // ascending, or empty slice. + GetBySeries(series string) ([]StoredPost, error) + + // GetByTag returns all StoredPosts with the given tag, sorted time + // ascending, or empty slice. + GetByTag(tag string) ([]StoredPost, error) + + // Delete will delete the StoredPost with the given ID. + Delete(id string) error + + Close() error +} + +var migrations = []*migrate.Migration{ + &migrate.Migration{ + Id: "1", + Up: []string{ + `CREATE TABLE posts ( + id TEXT NOT NULL PRIMARY KEY, + title TEXT NOT NULL, + description TEXT NOT NULL, + series TEXT, + + published_at INTEGER NOT NULL, + last_updated_at INTEGER, + + body TEXT NOT NULL + )`, + `CREATE TABLE post_tags ( + post_id TEXT NOT NULL, + tag TEXT NOT NULL, + UNIQUE(post_id, tag) + )`, + }, + Down: []string{ + "DROP TABLE post_tags", + "DROP TABLE posts", + }, + }, +} + +// Params are parameters used to initialize a new Store. All fields are required +// unless otherwise noted. +type StoreParams struct { + + // Path to the file the database will be stored at. + DBFilePath string + + Clock clock.Clock +} + +type store struct { + params StoreParams + db *sql.DB +} + +// NewStore initializes a new Store using a sqlite3 database at the given file +// path. +func NewStore(params StoreParams) (Store, error) { + + db, err := sql.Open("sqlite3", params.DBFilePath) + if err != nil { + return nil, fmt.Errorf("opening sqlite file: %w", err) + } + + migrations := &migrate.MemoryMigrationSource{Migrations: migrations} + + if _, err := migrate.Exec(db, "sqlite3", migrations, migrate.Up); err != nil { + return nil, fmt.Errorf("running migrations: %w", err) + } + + return &store{ + params: params, + db: db, + }, nil +} + +func (s *store) Close() error { + return s.db.Close() +} + +// if the callback returns an error then the transaction is aborted +func (s *store) withTx(cb func(*sql.Tx) error) error { + + tx, err := s.db.Begin() + + if err != nil { + return fmt.Errorf("starting transaction: %w", err) + } + + if err := cb(tx); err != nil { + + if rollbackErr := tx.Rollback(); rollbackErr != nil { + return fmt.Errorf( + "rolling back transaction: %w (original error: %v)", + rollbackErr, err, + ) + } + + return fmt.Errorf("performing transaction: %w (rolled back)", err) + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("committing transaction: %w", err) + } + + return nil +} + +func (s *store) Set(post Post) error { + return s.withTx(func(tx *sql.Tx) error { + currentDate := DateFromTime(s.params.Clock.Now()) + + _, err := tx.Exec( + `INSERT INTO posts ( + id, title, description, series, published_at, body + ) + VALUES + (?, ?, ?, ?, ?, ?) + ON CONFLICT (id) DO UPDATE SET + title=excluded.title, + description=excluded.description, + series=excluded.series, + last_updated_at=?, + body=excluded.body`, + post.ID, + post.Title, + post.Description, + &sql.NullString{String: post.Series, Valid: post.Series != ""}, + currentDate, + post.Body, + currentDate, + ) + + if err != nil { + return fmt.Errorf("inserting into posts: %w", err) + } + + // this is a bit of a hack, but it allows us to update the tagset without + // doing a diff. + _, err = tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, post.ID) + + if err != nil { + return fmt.Errorf("clearning post tags: %w", err) + } + + for _, tag := range post.Tags { + + _, err = tx.Exec( + `INSERT INTO post_tags (post_id, tag) VALUES (?, ?) + ON CONFLICT DO NOTHING`, + post.ID, + tag, + ) + + if err != nil { + return fmt.Errorf("inserting tag %q: %w", tag, err) + } + } + + return nil + }) +} + +func (s *store) get( + querier interface { + Query(string, ...interface{}) (*sql.Rows, error) + }, + limit, offset int, + where string, whereArgs ...interface{}, +) ( + []StoredPost, error, +) { + + query := `SELECT + p.id, p.title, p.description, p.series, pt.tag, + p.published_at, p.last_updated_at, + p.body + FROM posts p + LEFT JOIN post_tags pt ON (p.id = pt.post_id) + ` + where + ` + ORDER BY p.published_at ASC, p.title ASC` + + if limit > 0 { + query += fmt.Sprintf(" LIMIT %d", limit) + } + + if offset > 0 { + query += fmt.Sprintf(" OFFSET %d", offset) + } + + rows, err := querier.Query(query, whereArgs...) + + if err != nil { + return nil, fmt.Errorf("selecting: %w", err) + } + + var posts []StoredPost + + for rows.Next() { + + var ( + post StoredPost + series, tag sql.NullString + ) + + err := rows.Scan( + &post.ID, &post.Title, &post.Description, &series, &tag, + &post.PublishedAt, &post.LastUpdatedAt, + &post.Body, + ) + + if err != nil { + return nil, fmt.Errorf("scanning row: %w", err) + } + + if tag.Valid { + + if l := len(posts); l > 0 && posts[l-1].ID == post.ID { + posts[l-1].Tags = append(posts[l-1].Tags, tag.String) + continue + } + + post.Tags = append(post.Tags, tag.String) + } + + post.Series = series.String + + posts = append(posts, post) + } + + if err := rows.Close(); err != nil { + return nil, fmt.Errorf("closing row iterator: %w", err) + } + + return posts, nil +} + +func (s *store) Get(page, count int) ([]StoredPost, bool, error) { + + posts, err := s.get(s.db, count+1, page*count, ``) + + if err != nil { + return nil, false, fmt.Errorf("querying posts: %w", err) + } + + var hasMore bool + + if len(posts) > count { + hasMore = true + posts = posts[:count] + } + + return posts, hasMore, nil +} + +func (s *store) GetByID(id string) (StoredPost, error) { + + posts, err := s.get(s.db, 0, 0, `WHERE p.id=?`, id) + + if err != nil { + return StoredPost{}, fmt.Errorf("querying posts: %w", err) + } + + if len(posts) == 0 { + return StoredPost{}, ErrNotFound + } + + if len(posts) > 1 { + panic(fmt.Sprintf("got back multiple posts querying id %q: %+v", id, posts)) + } + + return posts[0], nil +} + +func (s *store) GetBySeries(series string) ([]StoredPost, error) { + return s.get(s.db, 0, 0, `WHERE p.series=?`, series) +} + +func (s *store) GetByTag(tag string) ([]StoredPost, error) { + + var posts []StoredPost + + err := s.withTx(func(tx *sql.Tx) error { + + rows, err := tx.Query(`SELECT post_id FROM post_tags WHERE tag = ?`, tag) + + if err != nil { + return fmt.Errorf("querying post_tags by tag: %w", err) + } + + var ( + placeholders []string + whereArgs []interface{} + ) + + for rows.Next() { + + var id string + + if err := rows.Scan(&id); err != nil { + rows.Close() + return fmt.Errorf("scanning id: %w", err) + } + + whereArgs = append(whereArgs, id) + placeholders = append(placeholders, "?") + } + + if err := rows.Close(); err != nil { + return fmt.Errorf("closing row iterator: %w", err) + } + + where := fmt.Sprintf("WHERE p.id IN (%s)", strings.Join(placeholders, ",")) + + if posts, err = s.get(tx, 0, 0, where, whereArgs...); err != nil { + return fmt.Errorf("querying for ids %+v: %w", whereArgs, err) + } + + return nil + }) + + return posts, err +} + +func (s *store) Delete(id string) error { + + tx, err := s.db.Begin() + + if err != nil { + return fmt.Errorf("starting transaction: %w", err) + } + + if _, err := tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, id); err != nil { + return fmt.Errorf("deleting from post_tags: %w", err) + } + + if _, err := tx.Exec(`DELETE FROM posts WHERE id = ?`, id); err != nil { + return fmt.Errorf("deleting from posts: %w", err) + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("committing transaction: %w", err) + } + + return nil +} -- cgit v1.2.3