From 7ac2f5ebb32a6098bc0d590130cc4b933db08771 Mon Sep 17 00:00:00 2001 From: Brian Picciano Date: Thu, 18 Aug 2022 22:25:17 -0600 Subject: implement DraftStore --- srv/src/post/post.go | 69 +++++++++++++--------------------------------------- 1 file changed, 17 insertions(+), 52 deletions(-) (limited to 'srv/src/post/post.go') diff --git a/srv/src/post/post.go b/srv/src/post/post.go index a39af61..03bce6c 100644 --- a/srv/src/post/post.go +++ b/srv/src/post/post.go @@ -77,44 +77,16 @@ type Store interface { } type store struct { - db *sql.DB + db *SQLDB } // NewStore initializes a new Store using an existing SQLDB. func NewStore(db *SQLDB) Store { return &store{ - db: db.db, + db: db, } } -// if the callback returns an error then the transaction is aborted. -func (s *store) withTx(cb func(*sql.Tx) error) error { - - tx, err := s.db.Begin() - - if err != nil { - return fmt.Errorf("starting transaction: %w", err) - } - - if err := cb(tx); err != nil { - - if rollbackErr := tx.Rollback(); rollbackErr != nil { - return fmt.Errorf( - "rolling back transaction: %w (original error: %v)", - rollbackErr, err, - ) - } - - return fmt.Errorf("performing transaction: %w (rolled back)", err) - } - - if err := tx.Commit(); err != nil { - return fmt.Errorf("committing transaction: %w", err) - } - - return nil -} - func (s *store) Set(post Post, now time.Time) (bool, error) { if post.ID == "" { @@ -123,7 +95,7 @@ func (s *store) Set(post Post, now time.Time) (bool, error) { var first bool - err := s.withTx(func(tx *sql.Tx) error { + err := s.db.withTx(func(tx *sql.Tx) error { nowTS := now.Unix() @@ -270,7 +242,7 @@ func (s *store) get( func (s *store) Get(page, count int) ([]StoredPost, bool, error) { - posts, err := s.get(s.db, count+1, page*count, ``) + posts, err := s.get(s.db.db, count+1, page*count, ``) if err != nil { return nil, false, fmt.Errorf("querying posts: %w", err) @@ -288,7 +260,7 @@ func (s *store) Get(page, count int) ([]StoredPost, bool, error) { func (s *store) GetByID(id string) (StoredPost, error) { - posts, err := s.get(s.db, 0, 0, `WHERE p.id=?`, id) + posts, err := s.get(s.db.db, 0, 0, `WHERE p.id=?`, id) if err != nil { return StoredPost{}, fmt.Errorf("querying posts: %w", err) @@ -306,14 +278,14 @@ func (s *store) GetByID(id string) (StoredPost, error) { } func (s *store) GetBySeries(series string) ([]StoredPost, error) { - return s.get(s.db, 0, 0, `WHERE p.series=?`, series) + return s.get(s.db.db, 0, 0, `WHERE p.series=?`, series) } func (s *store) GetByTag(tag string) ([]StoredPost, error) { var posts []StoredPost - err := s.withTx(func(tx *sql.Tx) error { + err := s.db.withTx(func(tx *sql.Tx) error { rows, err := tx.Query(`SELECT post_id FROM post_tags WHERE tag = ?`, tag) @@ -357,7 +329,7 @@ func (s *store) GetByTag(tag string) ([]StoredPost, error) { func (s *store) GetTags() ([]string, error) { - rows, err := s.db.Query(`SELECT tag FROM post_tags GROUP BY tag`) + rows, err := s.db.db.Query(`SELECT tag FROM post_tags GROUP BY tag`) if err != nil { return nil, fmt.Errorf("querying all tags: %w", err) } @@ -381,23 +353,16 @@ func (s *store) GetTags() ([]string, error) { func (s *store) Delete(id string) error { - tx, err := s.db.Begin() - - if err != nil { - return fmt.Errorf("starting transaction: %w", err) - } - - if _, err := tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, id); err != nil { - return fmt.Errorf("deleting from post_tags: %w", err) - } + return s.db.withTx(func(tx *sql.Tx) error { - if _, err := tx.Exec(`DELETE FROM posts WHERE id = ?`, id); err != nil { - return fmt.Errorf("deleting from posts: %w", err) - } + if _, err := tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, id); err != nil { + return fmt.Errorf("deleting from post_tags: %w", err) + } - if err := tx.Commit(); err != nil { - return fmt.Errorf("committing transaction: %w", err) - } + if _, err := tx.Exec(`DELETE FROM posts WHERE id = ?`, id); err != nil { + return fmt.Errorf("deleting from posts: %w", err) + } - return nil + return nil + }) } -- cgit v1.2.3