diff options
author | Brian Picciano <mediocregopher@gmail.com> | 2023-04-15 21:07:16 +0200 |
---|---|---|
committer | Brian Picciano <mediocregopher@gmail.com> | 2023-04-15 21:07:16 +0200 |
commit | 7872296b838f4d1b26c6a0a01d79d27fe5ab44cc (patch) | |
tree | 9487f5abb93d88ab3b52700d2b3002b7dda373d6 | |
parent | 68f3215df6e2e4f345076dd5b20b9bf5867353cf (diff) |
Move asset store into its own package
-rw-r--r-- | src/cmd/load-test-data/main.go | 3 | ||||
-rw-r--r-- | src/cmd/mediocre-blog/main.go | 3 | ||||
-rw-r--r-- | src/gmi/gmi.go | 5 | ||||
-rw-r--r-- | src/http/assets.go | 8 | ||||
-rw-r--r-- | src/http/http.go | 3 | ||||
-rw-r--r-- | src/post/asset/asset.go (renamed from src/post/asset.go) | 38 | ||||
-rw-r--r-- | src/post/asset/asset_test.go (renamed from src/post/asset_test.go) | 29 | ||||
-rw-r--r-- | src/post/draft_post.go | 8 | ||||
-rw-r--r-- | src/post/post.go | 14 | ||||
-rw-r--r-- | src/post/sql.go | 10 |
10 files changed, 65 insertions, 56 deletions
diff --git a/src/cmd/load-test-data/main.go b/src/cmd/load-test-data/main.go index 5ebee32..850b9fd 100644 --- a/src/cmd/load-test-data/main.go +++ b/src/cmd/load-test-data/main.go @@ -9,6 +9,7 @@ import ( cfgpkg "github.com/mediocregopher/blog.mediocregopher.com/srv/cfg" "github.com/mediocregopher/blog.mediocregopher.com/srv/post" + "github.com/mediocregopher/blog.mediocregopher.com/srv/post/asset" "github.com/mediocregopher/mediocre-go-lib/v2/mctx" "github.com/mediocregopher/mediocre-go-lib/v2/mlog" "gopkg.in/yaml.v3" @@ -99,7 +100,7 @@ func main() { } { - assetStore := post.NewAssetStore(postDB) + assetStore := asset.NewStore(postDB) setAsset := func(assetID, assetPath string) error { assetFullPath := filepath.Join(testDataDir, assetPath) diff --git a/src/cmd/mediocre-blog/main.go b/src/cmd/mediocre-blog/main.go index d8ba768..aff0f8e 100644 --- a/src/cmd/mediocre-blog/main.go +++ b/src/cmd/mediocre-blog/main.go @@ -13,6 +13,7 @@ import ( "github.com/mediocregopher/blog.mediocregopher.com/srv/http" "github.com/mediocregopher/blog.mediocregopher.com/srv/mailinglist" "github.com/mediocregopher/blog.mediocregopher.com/srv/post" + "github.com/mediocregopher/blog.mediocregopher.com/srv/post/asset" "github.com/mediocregopher/blog.mediocregopher.com/srv/pow" "github.com/mediocregopher/mediocre-go-lib/v2/mctx" "github.com/mediocregopher/mediocre-go-lib/v2/mlog" @@ -100,7 +101,7 @@ func main() { defer postSQLDB.Close() postStore := post.NewStore(postSQLDB) - postAssetStore := post.NewAssetStore(postSQLDB) + postAssetStore := asset.NewStore(postSQLDB) postDraftStore := post.NewDraftStore(postSQLDB) cache := cache.New(5000) diff --git a/src/gmi/gmi.go b/src/gmi/gmi.go index 115f379..127f6b5 100644 --- a/src/gmi/gmi.go +++ b/src/gmi/gmi.go @@ -16,6 +16,7 @@ import ( "github.com/mediocregopher/blog.mediocregopher.com/srv/cache" "github.com/mediocregopher/blog.mediocregopher.com/srv/cfg" "github.com/mediocregopher/blog.mediocregopher.com/srv/post" + "github.com/mediocregopher/blog.mediocregopher.com/srv/post/asset" "github.com/mediocregopher/mediocre-go-lib/v2/mctx" "github.com/mediocregopher/mediocre-go-lib/v2/mlog" ) @@ -27,7 +28,7 @@ type Params struct { Cache cache.Cache PostStore post.Store - PostAssetStore post.AssetStore + PostAssetStore asset.Store PublicURL *url.URL ListenAddr string @@ -218,7 +219,7 @@ func (a *api) assetsMiddleware() gemini.Handler { err := a.params.PostAssetStore.Get(id, rw) - if errors.Is(err, post.ErrAssetNotFound) { + if errors.Is(err, asset.ErrNotFound) { rw.WriteHeader(gemini.StatusNotFound, "Asset not found, sorry!") return diff --git a/src/http/assets.go b/src/http/assets.go index 5b26a2e..1f5f0d6 100644 --- a/src/http/assets.go +++ b/src/http/assets.go @@ -16,7 +16,7 @@ import ( "time" "github.com/mediocregopher/blog.mediocregopher.com/srv/http/apiutil" - "github.com/mediocregopher/blog.mediocregopher.com/srv/post" + "github.com/mediocregopher/blog.mediocregopher.com/srv/post/asset" "github.com/omeid/go-tarfs" "golang.org/x/image/draw" ) @@ -170,7 +170,7 @@ func (a *api) handleGetPostAssetArchive( err := a.params.PostAssetStore.Get(info.id, buf) - if errors.Is(err, post.ErrAssetNotFound) { + if errors.Is(err, asset.ErrNotFound) { http.Error(rw, "asset not found", 404) return } else if err != nil { @@ -244,7 +244,7 @@ func (a *api) getPostAssetHandler() http.Handler { err := a.params.PostAssetStore.Get(id, buf) - if errors.Is(err, post.ErrAssetNotFound) { + if errors.Is(err, asset.ErrNotFound) { http.Error(rw, "Asset not found", 404) return } else if err != nil { @@ -297,7 +297,7 @@ func (a *api) deletePostAssetHandler() http.Handler { err := a.params.PostAssetStore.Delete(id) - if errors.Is(err, post.ErrAssetNotFound) { + if errors.Is(err, asset.ErrNotFound) { http.Error(rw, "Asset not found", 404) return } else if err != nil { diff --git a/src/http/http.go b/src/http/http.go index da404dc..dc2569a 100644 --- a/src/http/http.go +++ b/src/http/http.go @@ -20,6 +20,7 @@ import ( "github.com/mediocregopher/blog.mediocregopher.com/srv/http/apiutil" "github.com/mediocregopher/blog.mediocregopher.com/srv/mailinglist" "github.com/mediocregopher/blog.mediocregopher.com/srv/post" + "github.com/mediocregopher/blog.mediocregopher.com/srv/post/asset" "github.com/mediocregopher/blog.mediocregopher.com/srv/pow" "github.com/mediocregopher/mediocre-go-lib/v2/mctx" "github.com/mediocregopher/mediocre-go-lib/v2/mlog" @@ -36,7 +37,7 @@ type Params struct { Cache cache.Cache PostStore post.Store - PostAssetStore post.AssetStore + PostAssetStore asset.Store PostDraftStore post.DraftStore MailingList mailinglist.MailingList diff --git a/src/post/asset.go b/src/post/asset/asset.go index a7b605b..d20b347 100644 --- a/src/post/asset.go +++ b/src/post/asset/asset.go @@ -1,4 +1,4 @@ -package post +package asset import ( "bytes" @@ -6,23 +6,25 @@ import ( "errors" "fmt" "io" + + "github.com/mediocregopher/blog.mediocregopher.com/srv/post" ) var ( - // ErrAssetNotFound is used to indicate an Asset could not be found in the - // AssetStore. - ErrAssetNotFound = errors.New("asset not found") + // ErrNotFound is used to indicate an Asset could not be found in the + // Store. + ErrNotFound = errors.New("asset not found") ) -// AssetStore implements the storage and retrieval of binary assets, which are +// Store implements the storage and retrieval of binary assets, which are // intended to be used by posts (e.g. images). -type AssetStore interface { +type Store interface { // Set sets the id to the contents of the given io.Reader. Set(id string, from io.Reader) error // Get writes the id's body to the given io.Writer, or returns - // ErrAssetNotFound. + // ErrNotFound. Get(id string, into io.Writer) error // Delete's the body stored for the id, if any. @@ -32,18 +34,18 @@ type AssetStore interface { List() ([]string, error) } -type assetStore struct { - db *sql.DB +type store struct { + db *post.SQLDB } -// NewAssetStore initializes a new AssetStore using an existing SQLDB. -func NewAssetStore(db *SQLDB) AssetStore { - return &assetStore{ - db: db.db, +// NewStore initializes a new Store using an existing SQLDB. +func NewStore(db *post.SQLDB) Store { + return &store{ + db: db, } } -func (s *assetStore) Set(id string, from io.Reader) error { +func (s *store) Set(id string, from io.Reader) error { body, err := io.ReadAll(from) if err != nil { @@ -64,14 +66,14 @@ func (s *assetStore) Set(id string, from io.Reader) error { return nil } -func (s *assetStore) Get(id string, into io.Writer) error { +func (s *store) Get(id string, into io.Writer) error { var body []byte err := s.db.QueryRow(`SELECT body FROM assets WHERE id = ?`, id).Scan(&body) if errors.Is(err, sql.ErrNoRows) { - return ErrAssetNotFound + return ErrNotFound } else if err != nil { return fmt.Errorf("selecting from assets: %w", err) } @@ -83,12 +85,12 @@ func (s *assetStore) Get(id string, into io.Writer) error { return nil } -func (s *assetStore) Delete(id string) error { +func (s *store) Delete(id string) error { _, err := s.db.Exec(`DELETE FROM assets WHERE id = ?`, id) return err } -func (s *assetStore) List() ([]string, error) { +func (s *store) List() ([]string, error) { rows, err := s.db.Query(`SELECT id FROM assets ORDER BY id ASC`) diff --git a/src/post/asset_test.go b/src/post/asset/asset_test.go index 4d62d46..574cc08 100644 --- a/src/post/asset_test.go +++ b/src/post/asset/asset_test.go @@ -1,30 +1,31 @@ -package post +package asset import ( "bytes" "io" "testing" + "github.com/mediocregopher/blog.mediocregopher.com/srv/post" "github.com/stretchr/testify/assert" ) -type assetTestHarness struct { - store AssetStore +type testHarness struct { + store Store } -func newAssetTestHarness(t *testing.T) *assetTestHarness { +func newTestHarness(t *testing.T) *testHarness { - db := NewInMemSQLDB() + db := post.NewInMemSQLDB() t.Cleanup(func() { db.Close() }) - store := NewAssetStore(db) + store := NewStore(db) - return &assetTestHarness{ + return &testHarness{ store: store, } } -func (h *assetTestHarness) assertGet(t *testing.T, exp, id string) { +func (h *testHarness) assertGet(t *testing.T, exp, id string) { t.Helper() buf := new(bytes.Buffer) err := h.store.Get(id, buf) @@ -32,15 +33,15 @@ func (h *assetTestHarness) assertGet(t *testing.T, exp, id string) { assert.Equal(t, exp, buf.String()) } -func (h *assetTestHarness) assertNotFound(t *testing.T, id string) { +func (h *testHarness) assertNotFound(t *testing.T, id string) { t.Helper() err := h.store.Get(id, io.Discard) - assert.ErrorIs(t, ErrAssetNotFound, err) + assert.ErrorIs(t, ErrNotFound, err) } -func TestAssetStore(t *testing.T) { +func TestStore(t *testing.T) { - testAssetStore := func(t *testing.T, h *assetTestHarness) { + testStore := func(t *testing.T, h *testHarness) { t.Helper() h.assertNotFound(t, "foo") @@ -85,7 +86,7 @@ func TestAssetStore(t *testing.T) { } t.Run("sql", func(t *testing.T) { - h := newAssetTestHarness(t) - testAssetStore(t, h) + h := newTestHarness(t) + testStore(t, h) }) } diff --git a/src/post/draft_post.go b/src/post/draft_post.go index 1c0e075..0576258 100644 --- a/src/post/draft_post.go +++ b/src/post/draft_post.go @@ -47,7 +47,7 @@ func (s *draftStore) Set(post Post) error { return fmt.Errorf("json marshaling tags %#v: %w", post.Tags, err) } - _, err = s.db.db.Exec( + _, err = s.db.Exec( `INSERT INTO post_drafts ( id, title, description, tags, series, body, format ) @@ -145,7 +145,7 @@ func (s *draftStore) get( func (s *draftStore) Get(page, count int) ([]Post, bool, error) { - posts, err := s.get(s.db.db, count+1, page*count, ``) + posts, err := s.get(s.db, count+1, page*count, ``) if err != nil { return nil, false, fmt.Errorf("querying post_drafts: %w", err) @@ -163,7 +163,7 @@ func (s *draftStore) Get(page, count int) ([]Post, bool, error) { func (s *draftStore) GetByID(id string) (Post, error) { - posts, err := s.get(s.db.db, 0, 0, `WHERE p.id=?`, id) + posts, err := s.get(s.db, 0, 0, `WHERE p.id=?`, id) if err != nil { return Post{}, fmt.Errorf("querying post_drafts: %w", err) @@ -182,7 +182,7 @@ func (s *draftStore) GetByID(id string) (Post, error) { func (s *draftStore) Delete(id string) error { - if _, err := s.db.db.Exec(`DELETE FROM post_drafts WHERE id = ?`, id); err != nil { + if _, err := s.db.Exec(`DELETE FROM post_drafts WHERE id = ?`, id); err != nil { return fmt.Errorf("deleting from post_drafts: %w", err) } diff --git a/src/post/post.go b/src/post/post.go index 03326ad..9c8f0cf 100644 --- a/src/post/post.go +++ b/src/post/post.go @@ -96,7 +96,7 @@ func (s *store) Set(post Post, now time.Time) (bool, error) { var first bool - err := s.db.withTx(func(tx *sql.Tx) error { + err := s.db.WithTx(func(tx *sql.Tx) error { nowTS := now.Unix() @@ -244,7 +244,7 @@ func (s *store) get( func (s *store) Get(page, count int) ([]StoredPost, bool, error) { - posts, err := s.get(s.db.db, count+1, page*count, ``) + posts, err := s.get(s.db, count+1, page*count, ``) if err != nil { return nil, false, fmt.Errorf("querying posts: %w", err) @@ -262,7 +262,7 @@ func (s *store) Get(page, count int) ([]StoredPost, bool, error) { func (s *store) GetByID(id string) (StoredPost, error) { - posts, err := s.get(s.db.db, 0, 0, `WHERE p.id=?`, id) + posts, err := s.get(s.db, 0, 0, `WHERE p.id=?`, id) if err != nil { return StoredPost{}, fmt.Errorf("querying posts: %w", err) @@ -280,14 +280,14 @@ func (s *store) GetByID(id string) (StoredPost, error) { } func (s *store) GetBySeries(series string) ([]StoredPost, error) { - return s.get(s.db.db, 0, 0, `WHERE p.series=?`, series) + return s.get(s.db, 0, 0, `WHERE p.series=?`, series) } func (s *store) GetByTag(tag string) ([]StoredPost, error) { var posts []StoredPost - err := s.db.withTx(func(tx *sql.Tx) error { + err := s.db.WithTx(func(tx *sql.Tx) error { rows, err := tx.Query(`SELECT post_id FROM post_tags WHERE tag = ?`, tag) @@ -331,7 +331,7 @@ func (s *store) GetByTag(tag string) ([]StoredPost, error) { func (s *store) GetTags() ([]string, error) { - rows, err := s.db.db.Query(`SELECT tag FROM post_tags GROUP BY tag`) + rows, err := s.db.Query(`SELECT tag FROM post_tags GROUP BY tag`) if err != nil { return nil, fmt.Errorf("querying all tags: %w", err) } @@ -355,7 +355,7 @@ func (s *store) GetTags() ([]string, error) { func (s *store) Delete(id string) error { - return s.db.withTx(func(tx *sql.Tx) error { + return s.db.WithTx(func(tx *sql.Tx) error { if _, err := tx.Exec(`DELETE FROM post_tags WHERE post_id = ?`, id); err != nil { return fmt.Errorf("deleting from post_tags: %w", err) diff --git a/src/post/sql.go b/src/post/sql.go index c7b726f..46c7b9a 100644 --- a/src/post/sql.go +++ b/src/post/sql.go @@ -78,7 +78,7 @@ var migrations = &migrate.MemoryMigrationSource{Migrations: []*migrate.Migration // SQLDB is a sqlite3 database which can be used by storage interfaces within // this package. type SQLDB struct { - db *sql.DB + *sql.DB } // NewSQLDB initializes and returns a new sqlite3 database for storage @@ -116,12 +116,14 @@ func NewInMemSQLDB() *SQLDB { // Close cleans up loose resources being held by the db. func (db *SQLDB) Close() error { - return db.db.Close() + return db.DB.Close() } -func (db *SQLDB) withTx(cb func(*sql.Tx) error) error { +// WithTx initializes a transaction, runs the callback using it, and either +// commits or rolls it back depending on if the callback returns an error. +func (db *SQLDB) WithTx(cb func(*sql.Tx) error) error { - tx, err := db.db.Begin() + tx, err := db.DB.Begin() if err != nil { return fmt.Errorf("starting transaction: %w", err) |