summaryrefslogtreecommitdiff
path: root/srv/src/post/post.go
diff options
context:
space:
mode:
Diffstat (limited to 'srv/src/post/post.go')
-rw-r--r--srv/src/post/post.go69
1 files changed, 17 insertions, 52 deletions
diff --git a/srv/src/post/post.go b/srv/src/post/post.go
index a39af61..03bce6c 100644
--- a/srv/src/post/post.go
+++ b/srv/src/post/post.go
@@ -77,44 +77,16 @@ type Store interface {
}
type store struct {
- db *sql.DB
+ db *SQLDB
}
// NewStore initializes a new Store using an existing SQLDB.
func NewStore(db *SQLDB) Store {
return &store{
- db: db.db,
+ db: db,
}
}
-// if the callback returns an error then the transaction is aborted.
-func (s *store) withTx(cb func(*sql.Tx) error) error {
-
- tx, err := s.db.Begin()
-
- if err != nil {
- return fmt.Errorf("starting transaction: %w", err)
- }
-
- if err := cb(tx); err != nil {
-
- if rollbackErr := tx.Rollback(); rollbackErr != nil {
- return fmt.Errorf(
- "rolling back transaction: %w (original error: %v)",
- rollbackErr, err,
- )
- }
-
- return fmt.Errorf("performing transaction: %w (rolled back)", err)
- }
-
- if err := tx.Commit(); err != nil {
- return fmt.Errorf("committing transaction: %w", err)
- }
-
- return nil
-}
-
func (s *store) Set(post Post, now time.Time) (bool, error) {
if post.ID == "" {
@@ -123,7 +95,7 @@ func (s *store) Set(post Post, now time.Time) (bool, error) {
var first bool
- err := s.withTx(func(tx *sql.Tx) error {
+ err := s.db.withTx(func(tx *sql.Tx) error {
nowTS := now.Unix()
@@ -270,7 +242,7 @@ func (s *store) get(
func (s *store) Get(page, count int) ([]StoredPost, bool, error) {
- posts, err := s.get(s.db, count+1, page*count, ``)
+ posts, err := s.get(s.db.db, count+1, page*count, ``)
if err != nil {
return nil, false, fmt.Errorf("querying posts: %w", err)
@@ -288,7 +260,7 @@ func (s *store) Get(page, count int) ([]StoredPost, bool, error) {
func (s *store) GetByID(id string) (StoredPost, error) {
- posts, err := s.get(s.db, 0, 0, `WHERE p.id=?`, id)
+ posts, err := s.get(s.db.db, 0, 0, `WHERE p.id=?`, id)
if err != nil {
return StoredPost{}, fmt.Errorf("querying posts: %w", err)
@@ -306,14 +278,14 @@ func (s *store) GetByID(id string) (StoredPost, error) {
}
func (s *store) GetBySeries(series string) ([]StoredPost, error) {
- return s.get(s.db, 0, 0, `WHERE p.series=?`, series)
+ return s.get(s.db.db, 0, 0, `WHERE p.series=?`, series)
}
func (s *store) GetByTag(tag string) ([]StoredPost, error) {
var posts []StoredPost
- err := s.withTx(func(tx *sql.Tx) error {
+ err := s.db.withTx(func(tx *sql.Tx) error {
rows, err := tx.Query(`SELECT post_id FROM post_tags WHERE tag = ?`, tag)
@@ -357,7 +329,7 @@ func (s *store) GetByTag(tag string) ([]StoredPost, error) {
func (s *store) GetTags() ([]string, error) {
- rows, err := s.db.Query(`SELECT tag FROM post_tags GROUP BY tag`)
+ rows, err := s.db.db.Query(`SELECT tag FROM post_tags GROUP BY tag`)
if err != nil {
return nil, fmt.Errorf("querying all tags: %w", err)
}
@@ -381,23 +353,16 @@ func (s *store) GetTags() ([]string, error) {
func (s *store) Delete(id string) error {
- tx, err := s.db.Begin()
-
- if err != nil {
- return fmt.Errorf("starting transaction: %w", err)
- }
-
- if _, err := tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, id); err != nil {
- return fmt.Errorf("deleting from post_tags: %w", err)
- }
+ return s.db.withTx(func(tx *sql.Tx) error {
- if _, err := tx.Exec(`DELETE FROM posts WHERE id = ?`, id); err != nil {
- return fmt.Errorf("deleting from posts: %w", err)
- }
+ if _, err := tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, id); err != nil {
+ return fmt.Errorf("deleting from post_tags: %w", err)
+ }
- if err := tx.Commit(); err != nil {
- return fmt.Errorf("committing transaction: %w", err)
- }
+ if _, err := tx.Exec(`DELETE FROM posts WHERE id = ?`, id); err != nil {
+ return fmt.Errorf("deleting from posts: %w", err)
+ }
- return nil
+ return nil
+ })
}