summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBrian Picciano <mediocregopher@gmail.com>2022-05-07 18:07:27 -0600
committerBrian Picciano <mediocregopher@gmail.com>2022-05-07 18:07:27 -0600
commit3088c9d76ca2c88996bb794b17e87cd1d5e76064 (patch)
tree9b09a41288ae65bfd20d1e408c7fdebd3358d19c
parentcfb633b3b5eed1a65b23cc641b1b250adffdbc8f (diff)
Implement AssetStore interface
-rw-r--r--srv/src/post/asset.go86
-rw-r--r--srv/src/post/asset_test.go67
2 files changed, 153 insertions, 0 deletions
diff --git a/srv/src/post/asset.go b/srv/src/post/asset.go
new file mode 100644
index 0000000..3e6ae28
--- /dev/null
+++ b/srv/src/post/asset.go
@@ -0,0 +1,86 @@
+package post
+
+import (
+ "bytes"
+ "database/sql"
+ "errors"
+ "fmt"
+ "io"
+)
+
+var (
+ // ErrAssetNotFound is used to indicate an Asset could not be found in the
+ // AssetStore.
+ ErrAssetNotFound = errors.New("asset not found")
+)
+
+// AssetStore implements the storage and retrieval of binary assets, which are
+// intended to be used by posts (e.g. images).
+type AssetStore interface {
+
+ // Set sets the id to the contents of the given io.Reader.
+ Set(id string, from io.Reader) error
+
+ // Get writes the id's body to the given io.Writer, or returns
+ // ErrAssetNotFound.
+ Get(id string, into io.Writer) error
+
+ // Delete's the body stored for the id, if any.
+ Delete(id string) error
+}
+
+type assetStore struct {
+ db *sql.DB
+}
+
+// NewAssetStore initializes a new AssetStore using an existing SQLDB.
+func NewAssetStore(db *SQLDB) AssetStore {
+ return &assetStore{
+ db: db.db,
+ }
+}
+
+func (s *assetStore) Set(id string, from io.Reader) error {
+
+ body, err := io.ReadAll(from)
+ if err != nil {
+ return fmt.Errorf("reading body fully into memory: %w", err)
+ }
+
+ _, err = s.db.Exec(
+ `INSERT INTO assets (id, body)
+ VALUES (?, ?)
+ ON CONFLICT (id) DO UPDATE SET body=excluded.body`,
+ id, body,
+ )
+
+ if err != nil {
+ return fmt.Errorf("inserting into assets: %w", err)
+ }
+
+ return nil
+}
+
+func (s *assetStore) Get(id string, into io.Writer) error {
+
+ var body []byte
+
+ err := s.db.QueryRow(`SELECT body FROM assets WHERE id = ?`, id).Scan(&body)
+
+ if errors.Is(err, sql.ErrNoRows) {
+ return ErrAssetNotFound
+ } else if err != nil {
+ return fmt.Errorf("selecting from assets: %w", err)
+ }
+
+ if _, err := io.Copy(into, bytes.NewReader(body)); err != nil {
+ return fmt.Errorf("writing body to io.Writer: %w", err)
+ }
+
+ return nil
+}
+
+func (s *assetStore) Delete(id string) error {
+ _, err := s.db.Exec(`DELETE FROM assets WHERE id = ?`, id)
+ return err
+}
diff --git a/srv/src/post/asset_test.go b/srv/src/post/asset_test.go
new file mode 100644
index 0000000..4b88000
--- /dev/null
+++ b/srv/src/post/asset_test.go
@@ -0,0 +1,67 @@
+package post
+
+import (
+ "bytes"
+ "io"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+type assetTestHarness struct {
+ store AssetStore
+}
+
+func newAssetTestHarness(t *testing.T) assetTestHarness {
+
+ db := NewInMemSQLDB()
+ t.Cleanup(func() { db.Close() })
+
+ store := NewAssetStore(db)
+
+ return assetTestHarness{
+ store: store,
+ }
+}
+
+func (h *assetTestHarness) assertGet(t *testing.T, exp, id string) {
+ t.Helper()
+ buf := new(bytes.Buffer)
+ err := h.store.Get(id, buf)
+ assert.NoError(t, err)
+ assert.Equal(t, exp, buf.String())
+}
+
+func (h *assetTestHarness) assertNotFound(t *testing.T, id string) {
+ t.Helper()
+ err := h.store.Get(id, io.Discard)
+ assert.ErrorIs(t, ErrAssetNotFound, err)
+}
+
+func TestAssetStore(t *testing.T) {
+
+ h := newAssetTestHarness(t)
+
+ h.assertNotFound(t, "foo")
+ h.assertNotFound(t, "bar")
+
+ err := h.store.Set("foo", bytes.NewBufferString("FOO"))
+ assert.NoError(t, err)
+
+ h.assertGet(t, "FOO", "foo")
+ h.assertNotFound(t, "bar")
+
+ err = h.store.Set("foo", bytes.NewBufferString("FOOFOO"))
+ assert.NoError(t, err)
+
+ h.assertGet(t, "FOOFOO", "foo")
+ h.assertNotFound(t, "bar")
+
+ assert.NoError(t, h.store.Delete("foo"))
+ h.assertNotFound(t, "foo")
+ h.assertNotFound(t, "bar")
+
+ assert.NoError(t, h.store.Delete("bar"))
+ h.assertNotFound(t, "foo")
+ h.assertNotFound(t, "bar")
+}