diff options
author | Brian Picciano <mediocregopher@gmail.com> | 2022-08-18 22:25:17 -0600 |
---|---|---|
committer | Brian Picciano <mediocregopher@gmail.com> | 2022-08-18 22:25:17 -0600 |
commit | 7ac2f5ebb32a6098bc0d590130cc4b933db08771 (patch) | |
tree | 95d43d28bdd964fd461f8a2d3b72898ef5210d2b /srv/src/post/post.go | |
parent | 76ff79f47035c11c1e0dfcd283034b738b5c3f0d (diff) |
implement DraftStore
Diffstat (limited to 'srv/src/post/post.go')
-rw-r--r-- | srv/src/post/post.go | 69 |
1 files changed, 17 insertions, 52 deletions
diff --git a/srv/src/post/post.go b/srv/src/post/post.go index a39af61..03bce6c 100644 --- a/srv/src/post/post.go +++ b/srv/src/post/post.go @@ -77,44 +77,16 @@ type Store interface { } type store struct { - db *sql.DB + db *SQLDB } // NewStore initializes a new Store using an existing SQLDB. func NewStore(db *SQLDB) Store { return &store{ - db: db.db, + db: db, } } -// if the callback returns an error then the transaction is aborted. -func (s *store) withTx(cb func(*sql.Tx) error) error { - - tx, err := s.db.Begin() - - if err != nil { - return fmt.Errorf("starting transaction: %w", err) - } - - if err := cb(tx); err != nil { - - if rollbackErr := tx.Rollback(); rollbackErr != nil { - return fmt.Errorf( - "rolling back transaction: %w (original error: %v)", - rollbackErr, err, - ) - } - - return fmt.Errorf("performing transaction: %w (rolled back)", err) - } - - if err := tx.Commit(); err != nil { - return fmt.Errorf("committing transaction: %w", err) - } - - return nil -} - func (s *store) Set(post Post, now time.Time) (bool, error) { if post.ID == "" { @@ -123,7 +95,7 @@ func (s *store) Set(post Post, now time.Time) (bool, error) { var first bool - err := s.withTx(func(tx *sql.Tx) error { + err := s.db.withTx(func(tx *sql.Tx) error { nowTS := now.Unix() @@ -270,7 +242,7 @@ func (s *store) get( func (s *store) Get(page, count int) ([]StoredPost, bool, error) { - posts, err := s.get(s.db, count+1, page*count, ``) + posts, err := s.get(s.db.db, count+1, page*count, ``) if err != nil { return nil, false, fmt.Errorf("querying posts: %w", err) @@ -288,7 +260,7 @@ func (s *store) Get(page, count int) ([]StoredPost, bool, error) { func (s *store) GetByID(id string) (StoredPost, error) { - posts, err := s.get(s.db, 0, 0, `WHERE p.id=?`, id) + posts, err := s.get(s.db.db, 0, 0, `WHERE p.id=?`, id) if err != nil { return StoredPost{}, fmt.Errorf("querying posts: %w", err) @@ -306,14 +278,14 @@ func (s *store) GetByID(id string) (StoredPost, error) { } func (s *store) GetBySeries(series string) ([]StoredPost, error) { - return s.get(s.db, 0, 0, `WHERE p.series=?`, series) + return s.get(s.db.db, 0, 0, `WHERE p.series=?`, series) } func (s *store) GetByTag(tag string) ([]StoredPost, error) { var posts []StoredPost - err := s.withTx(func(tx *sql.Tx) error { + err := s.db.withTx(func(tx *sql.Tx) error { rows, err := tx.Query(`SELECT post_id FROM post_tags WHERE tag = ?`, tag) @@ -357,7 +329,7 @@ func (s *store) GetByTag(tag string) ([]StoredPost, error) { func (s *store) GetTags() ([]string, error) { - rows, err := s.db.Query(`SELECT tag FROM post_tags GROUP BY tag`) + rows, err := s.db.db.Query(`SELECT tag FROM post_tags GROUP BY tag`) if err != nil { return nil, fmt.Errorf("querying all tags: %w", err) } @@ -381,23 +353,16 @@ func (s *store) GetTags() ([]string, error) { func (s *store) Delete(id string) error { - tx, err := s.db.Begin() - - if err != nil { - return fmt.Errorf("starting transaction: %w", err) - } - - if _, err := tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, id); err != nil { - return fmt.Errorf("deleting from post_tags: %w", err) - } + return s.db.withTx(func(tx *sql.Tx) error { - if _, err := tx.Exec(`DELETE FROM posts WHERE id = ?`, id); err != nil { - return fmt.Errorf("deleting from posts: %w", err) - } + if _, err := tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, id); err != nil { + return fmt.Errorf("deleting from post_tags: %w", err) + } - if err := tx.Commit(); err != nil { - return fmt.Errorf("committing transaction: %w", err) - } + if _, err := tx.Exec(`DELETE FROM posts WHERE id = ?`, id); err != nil { + return fmt.Errorf("deleting from posts: %w", err) + } - return nil + return nil + }) } |