summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/cmd/load-test-data/main.go3
-rw-r--r--src/cmd/mediocre-blog/main.go3
-rw-r--r--src/gmi/gmi.go5
-rw-r--r--src/http/assets.go8
-rw-r--r--src/http/http.go3
-rw-r--r--src/post/asset/asset.go (renamed from src/post/asset.go)38
-rw-r--r--src/post/asset/asset_test.go (renamed from src/post/asset_test.go)29
-rw-r--r--src/post/draft_post.go8
-rw-r--r--src/post/post.go14
-rw-r--r--src/post/sql.go10
10 files changed, 65 insertions, 56 deletions
diff --git a/src/cmd/load-test-data/main.go b/src/cmd/load-test-data/main.go
index 5ebee32..850b9fd 100644
--- a/src/cmd/load-test-data/main.go
+++ b/src/cmd/load-test-data/main.go
@@ -9,6 +9,7 @@ import (
cfgpkg "github.com/mediocregopher/blog.mediocregopher.com/srv/cfg"
"github.com/mediocregopher/blog.mediocregopher.com/srv/post"
+ "github.com/mediocregopher/blog.mediocregopher.com/srv/post/asset"
"github.com/mediocregopher/mediocre-go-lib/v2/mctx"
"github.com/mediocregopher/mediocre-go-lib/v2/mlog"
"gopkg.in/yaml.v3"
@@ -99,7 +100,7 @@ func main() {
}
{
- assetStore := post.NewAssetStore(postDB)
+ assetStore := asset.NewStore(postDB)
setAsset := func(assetID, assetPath string) error {
assetFullPath := filepath.Join(testDataDir, assetPath)
diff --git a/src/cmd/mediocre-blog/main.go b/src/cmd/mediocre-blog/main.go
index d8ba768..aff0f8e 100644
--- a/src/cmd/mediocre-blog/main.go
+++ b/src/cmd/mediocre-blog/main.go
@@ -13,6 +13,7 @@ import (
"github.com/mediocregopher/blog.mediocregopher.com/srv/http"
"github.com/mediocregopher/blog.mediocregopher.com/srv/mailinglist"
"github.com/mediocregopher/blog.mediocregopher.com/srv/post"
+ "github.com/mediocregopher/blog.mediocregopher.com/srv/post/asset"
"github.com/mediocregopher/blog.mediocregopher.com/srv/pow"
"github.com/mediocregopher/mediocre-go-lib/v2/mctx"
"github.com/mediocregopher/mediocre-go-lib/v2/mlog"
@@ -100,7 +101,7 @@ func main() {
defer postSQLDB.Close()
postStore := post.NewStore(postSQLDB)
- postAssetStore := post.NewAssetStore(postSQLDB)
+ postAssetStore := asset.NewStore(postSQLDB)
postDraftStore := post.NewDraftStore(postSQLDB)
cache := cache.New(5000)
diff --git a/src/gmi/gmi.go b/src/gmi/gmi.go
index 115f379..127f6b5 100644
--- a/src/gmi/gmi.go
+++ b/src/gmi/gmi.go
@@ -16,6 +16,7 @@ import (
"github.com/mediocregopher/blog.mediocregopher.com/srv/cache"
"github.com/mediocregopher/blog.mediocregopher.com/srv/cfg"
"github.com/mediocregopher/blog.mediocregopher.com/srv/post"
+ "github.com/mediocregopher/blog.mediocregopher.com/srv/post/asset"
"github.com/mediocregopher/mediocre-go-lib/v2/mctx"
"github.com/mediocregopher/mediocre-go-lib/v2/mlog"
)
@@ -27,7 +28,7 @@ type Params struct {
Cache cache.Cache
PostStore post.Store
- PostAssetStore post.AssetStore
+ PostAssetStore asset.Store
PublicURL *url.URL
ListenAddr string
@@ -218,7 +219,7 @@ func (a *api) assetsMiddleware() gemini.Handler {
err := a.params.PostAssetStore.Get(id, rw)
- if errors.Is(err, post.ErrAssetNotFound) {
+ if errors.Is(err, asset.ErrNotFound) {
rw.WriteHeader(gemini.StatusNotFound, "Asset not found, sorry!")
return
diff --git a/src/http/assets.go b/src/http/assets.go
index 5b26a2e..1f5f0d6 100644
--- a/src/http/assets.go
+++ b/src/http/assets.go
@@ -16,7 +16,7 @@ import (
"time"
"github.com/mediocregopher/blog.mediocregopher.com/srv/http/apiutil"
- "github.com/mediocregopher/blog.mediocregopher.com/srv/post"
+ "github.com/mediocregopher/blog.mediocregopher.com/srv/post/asset"
"github.com/omeid/go-tarfs"
"golang.org/x/image/draw"
)
@@ -170,7 +170,7 @@ func (a *api) handleGetPostAssetArchive(
err := a.params.PostAssetStore.Get(info.id, buf)
- if errors.Is(err, post.ErrAssetNotFound) {
+ if errors.Is(err, asset.ErrNotFound) {
http.Error(rw, "asset not found", 404)
return
} else if err != nil {
@@ -244,7 +244,7 @@ func (a *api) getPostAssetHandler() http.Handler {
err := a.params.PostAssetStore.Get(id, buf)
- if errors.Is(err, post.ErrAssetNotFound) {
+ if errors.Is(err, asset.ErrNotFound) {
http.Error(rw, "Asset not found", 404)
return
} else if err != nil {
@@ -297,7 +297,7 @@ func (a *api) deletePostAssetHandler() http.Handler {
err := a.params.PostAssetStore.Delete(id)
- if errors.Is(err, post.ErrAssetNotFound) {
+ if errors.Is(err, asset.ErrNotFound) {
http.Error(rw, "Asset not found", 404)
return
} else if err != nil {
diff --git a/src/http/http.go b/src/http/http.go
index da404dc..dc2569a 100644
--- a/src/http/http.go
+++ b/src/http/http.go
@@ -20,6 +20,7 @@ import (
"github.com/mediocregopher/blog.mediocregopher.com/srv/http/apiutil"
"github.com/mediocregopher/blog.mediocregopher.com/srv/mailinglist"
"github.com/mediocregopher/blog.mediocregopher.com/srv/post"
+ "github.com/mediocregopher/blog.mediocregopher.com/srv/post/asset"
"github.com/mediocregopher/blog.mediocregopher.com/srv/pow"
"github.com/mediocregopher/mediocre-go-lib/v2/mctx"
"github.com/mediocregopher/mediocre-go-lib/v2/mlog"
@@ -36,7 +37,7 @@ type Params struct {
Cache cache.Cache
PostStore post.Store
- PostAssetStore post.AssetStore
+ PostAssetStore asset.Store
PostDraftStore post.DraftStore
MailingList mailinglist.MailingList
diff --git a/src/post/asset.go b/src/post/asset/asset.go
index a7b605b..d20b347 100644
--- a/src/post/asset.go
+++ b/src/post/asset/asset.go
@@ -1,4 +1,4 @@
-package post
+package asset
import (
"bytes"
@@ -6,23 +6,25 @@ import (
"errors"
"fmt"
"io"
+
+ "github.com/mediocregopher/blog.mediocregopher.com/srv/post"
)
var (
- // ErrAssetNotFound is used to indicate an Asset could not be found in the
- // AssetStore.
- ErrAssetNotFound = errors.New("asset not found")
+ // ErrNotFound is used to indicate an Asset could not be found in the
+ // Store.
+ ErrNotFound = errors.New("asset not found")
)
-// AssetStore implements the storage and retrieval of binary assets, which are
+// Store implements the storage and retrieval of binary assets, which are
// intended to be used by posts (e.g. images).
-type AssetStore interface {
+type Store interface {
// Set sets the id to the contents of the given io.Reader.
Set(id string, from io.Reader) error
// Get writes the id's body to the given io.Writer, or returns
- // ErrAssetNotFound.
+ // ErrNotFound.
Get(id string, into io.Writer) error
// Delete's the body stored for the id, if any.
@@ -32,18 +34,18 @@ type AssetStore interface {
List() ([]string, error)
}
-type assetStore struct {
- db *sql.DB
+type store struct {
+ db *post.SQLDB
}
-// NewAssetStore initializes a new AssetStore using an existing SQLDB.
-func NewAssetStore(db *SQLDB) AssetStore {
- return &assetStore{
- db: db.db,
+// NewStore initializes a new Store using an existing SQLDB.
+func NewStore(db *post.SQLDB) Store {
+ return &store{
+ db: db,
}
}
-func (s *assetStore) Set(id string, from io.Reader) error {
+func (s *store) Set(id string, from io.Reader) error {
body, err := io.ReadAll(from)
if err != nil {
@@ -64,14 +66,14 @@ func (s *assetStore) Set(id string, from io.Reader) error {
return nil
}
-func (s *assetStore) Get(id string, into io.Writer) error {
+func (s *store) Get(id string, into io.Writer) error {
var body []byte
err := s.db.QueryRow(`SELECT body FROM assets WHERE id = ?`, id).Scan(&body)
if errors.Is(err, sql.ErrNoRows) {
- return ErrAssetNotFound
+ return ErrNotFound
} else if err != nil {
return fmt.Errorf("selecting from assets: %w", err)
}
@@ -83,12 +85,12 @@ func (s *assetStore) Get(id string, into io.Writer) error {
return nil
}
-func (s *assetStore) Delete(id string) error {
+func (s *store) Delete(id string) error {
_, err := s.db.Exec(`DELETE FROM assets WHERE id = ?`, id)
return err
}
-func (s *assetStore) List() ([]string, error) {
+func (s *store) List() ([]string, error) {
rows, err := s.db.Query(`SELECT id FROM assets ORDER BY id ASC`)
diff --git a/src/post/asset_test.go b/src/post/asset/asset_test.go
index 4d62d46..574cc08 100644
--- a/src/post/asset_test.go
+++ b/src/post/asset/asset_test.go
@@ -1,30 +1,31 @@
-package post
+package asset
import (
"bytes"
"io"
"testing"
+ "github.com/mediocregopher/blog.mediocregopher.com/srv/post"
"github.com/stretchr/testify/assert"
)
-type assetTestHarness struct {
- store AssetStore
+type testHarness struct {
+ store Store
}
-func newAssetTestHarness(t *testing.T) *assetTestHarness {
+func newTestHarness(t *testing.T) *testHarness {
- db := NewInMemSQLDB()
+ db := post.NewInMemSQLDB()
t.Cleanup(func() { db.Close() })
- store := NewAssetStore(db)
+ store := NewStore(db)
- return &assetTestHarness{
+ return &testHarness{
store: store,
}
}
-func (h *assetTestHarness) assertGet(t *testing.T, exp, id string) {
+func (h *testHarness) assertGet(t *testing.T, exp, id string) {
t.Helper()
buf := new(bytes.Buffer)
err := h.store.Get(id, buf)
@@ -32,15 +33,15 @@ func (h *assetTestHarness) assertGet(t *testing.T, exp, id string) {
assert.Equal(t, exp, buf.String())
}
-func (h *assetTestHarness) assertNotFound(t *testing.T, id string) {
+func (h *testHarness) assertNotFound(t *testing.T, id string) {
t.Helper()
err := h.store.Get(id, io.Discard)
- assert.ErrorIs(t, ErrAssetNotFound, err)
+ assert.ErrorIs(t, ErrNotFound, err)
}
-func TestAssetStore(t *testing.T) {
+func TestStore(t *testing.T) {
- testAssetStore := func(t *testing.T, h *assetTestHarness) {
+ testStore := func(t *testing.T, h *testHarness) {
t.Helper()
h.assertNotFound(t, "foo")
@@ -85,7 +86,7 @@ func TestAssetStore(t *testing.T) {
}
t.Run("sql", func(t *testing.T) {
- h := newAssetTestHarness(t)
- testAssetStore(t, h)
+ h := newTestHarness(t)
+ testStore(t, h)
})
}
diff --git a/src/post/draft_post.go b/src/post/draft_post.go
index 1c0e075..0576258 100644
--- a/src/post/draft_post.go
+++ b/src/post/draft_post.go
@@ -47,7 +47,7 @@ func (s *draftStore) Set(post Post) error {
return fmt.Errorf("json marshaling tags %#v: %w", post.Tags, err)
}
- _, err = s.db.db.Exec(
+ _, err = s.db.Exec(
`INSERT INTO post_drafts (
id, title, description, tags, series, body, format
)
@@ -145,7 +145,7 @@ func (s *draftStore) get(
func (s *draftStore) Get(page, count int) ([]Post, bool, error) {
- posts, err := s.get(s.db.db, count+1, page*count, ``)
+ posts, err := s.get(s.db, count+1, page*count, ``)
if err != nil {
return nil, false, fmt.Errorf("querying post_drafts: %w", err)
@@ -163,7 +163,7 @@ func (s *draftStore) Get(page, count int) ([]Post, bool, error) {
func (s *draftStore) GetByID(id string) (Post, error) {
- posts, err := s.get(s.db.db, 0, 0, `WHERE p.id=?`, id)
+ posts, err := s.get(s.db, 0, 0, `WHERE p.id=?`, id)
if err != nil {
return Post{}, fmt.Errorf("querying post_drafts: %w", err)
@@ -182,7 +182,7 @@ func (s *draftStore) GetByID(id string) (Post, error) {
func (s *draftStore) Delete(id string) error {
- if _, err := s.db.db.Exec(`DELETE FROM post_drafts WHERE id = ?`, id); err != nil {
+ if _, err := s.db.Exec(`DELETE FROM post_drafts WHERE id = ?`, id); err != nil {
return fmt.Errorf("deleting from post_drafts: %w", err)
}
diff --git a/src/post/post.go b/src/post/post.go
index 03326ad..9c8f0cf 100644
--- a/src/post/post.go
+++ b/src/post/post.go
@@ -96,7 +96,7 @@ func (s *store) Set(post Post, now time.Time) (bool, error) {
var first bool
- err := s.db.withTx(func(tx *sql.Tx) error {
+ err := s.db.WithTx(func(tx *sql.Tx) error {
nowTS := now.Unix()
@@ -244,7 +244,7 @@ func (s *store) get(
func (s *store) Get(page, count int) ([]StoredPost, bool, error) {
- posts, err := s.get(s.db.db, count+1, page*count, ``)
+ posts, err := s.get(s.db, count+1, page*count, ``)
if err != nil {
return nil, false, fmt.Errorf("querying posts: %w", err)
@@ -262,7 +262,7 @@ func (s *store) Get(page, count int) ([]StoredPost, bool, error) {
func (s *store) GetByID(id string) (StoredPost, error) {
- posts, err := s.get(s.db.db, 0, 0, `WHERE p.id=?`, id)
+ posts, err := s.get(s.db, 0, 0, `WHERE p.id=?`, id)
if err != nil {
return StoredPost{}, fmt.Errorf("querying posts: %w", err)
@@ -280,14 +280,14 @@ func (s *store) GetByID(id string) (StoredPost, error) {
}
func (s *store) GetBySeries(series string) ([]StoredPost, error) {
- return s.get(s.db.db, 0, 0, `WHERE p.series=?`, series)
+ return s.get(s.db, 0, 0, `WHERE p.series=?`, series)
}
func (s *store) GetByTag(tag string) ([]StoredPost, error) {
var posts []StoredPost
- err := s.db.withTx(func(tx *sql.Tx) error {
+ err := s.db.WithTx(func(tx *sql.Tx) error {
rows, err := tx.Query(`SELECT post_id FROM post_tags WHERE tag = ?`, tag)
@@ -331,7 +331,7 @@ func (s *store) GetByTag(tag string) ([]StoredPost, error) {
func (s *store) GetTags() ([]string, error) {
- rows, err := s.db.db.Query(`SELECT tag FROM post_tags GROUP BY tag`)
+ rows, err := s.db.Query(`SELECT tag FROM post_tags GROUP BY tag`)
if err != nil {
return nil, fmt.Errorf("querying all tags: %w", err)
}
@@ -355,7 +355,7 @@ func (s *store) GetTags() ([]string, error) {
func (s *store) Delete(id string) error {
- return s.db.withTx(func(tx *sql.Tx) error {
+ return s.db.WithTx(func(tx *sql.Tx) error {
if _, err := tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, id); err != nil {
return fmt.Errorf("deleting from post_tags: %w", err)
diff --git a/src/post/sql.go b/src/post/sql.go
index c7b726f..46c7b9a 100644
--- a/src/post/sql.go
+++ b/src/post/sql.go
@@ -78,7 +78,7 @@ var migrations = &migrate.MemoryMigrationSource{Migrations: []*migrate.Migration
// SQLDB is a sqlite3 database which can be used by storage interfaces within
// this package.
type SQLDB struct {
- db *sql.DB
+ *sql.DB
}
// NewSQLDB initializes and returns a new sqlite3 database for storage
@@ -116,12 +116,14 @@ func NewInMemSQLDB() *SQLDB {
// Close cleans up loose resources being held by the db.
func (db *SQLDB) Close() error {
- return db.db.Close()
+ return db.DB.Close()
}
-func (db *SQLDB) withTx(cb func(*sql.Tx) error) error {
+// WithTx initializes a transaction, runs the callback using it, and either
+// commits or rolls it back depending on if the callback returns an error.
+func (db *SQLDB) WithTx(cb func(*sql.Tx) error) error {
- tx, err := db.db.Begin()
+ tx, err := db.DB.Begin()
if err != nil {
return fmt.Errorf("starting transaction: %w", err)