summaryrefslogtreecommitdiff
path: root/srv
diff options
context:
space:
mode:
Diffstat (limited to 'srv')
-rw-r--r--srv/src/post/asset.go67
-rw-r--r--srv/src/post/asset_test.go51
2 files changed, 99 insertions, 19 deletions
diff --git a/srv/src/post/asset.go b/srv/src/post/asset.go
index 3e6ae28..18af8f6 100644
--- a/srv/src/post/asset.go
+++ b/srv/src/post/asset.go
@@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"io"
+ "sync"
)
var (
@@ -84,3 +85,69 @@ func (s *assetStore) Delete(id string) error {
_, err := s.db.Exec(`DELETE FROM assets WHERE id = ?`, id)
return err
}
+
+////////////////////////////////////////////////////////////////////////////////
+
+type cachedAssetStore struct {
+ inner AssetStore
+ m sync.Map
+}
+
+// NewCachedAssetStore wraps an AssetStore in an in-memory cache.
+func NewCachedAssetStore(assetStore AssetStore) AssetStore {
+ return &cachedAssetStore{
+ inner: assetStore,
+ }
+}
+
+func (s *cachedAssetStore) Set(id string, from io.Reader) error {
+
+ buf := new(bytes.Buffer)
+ from = io.TeeReader(from, buf)
+
+ if err := s.inner.Set(id, from); err != nil {
+ return err
+ }
+
+ s.m.Store(id, buf.Bytes())
+ return nil
+}
+
+func (s *cachedAssetStore) Get(id string, into io.Writer) error {
+
+ if bodyI, ok := s.m.Load(id); ok {
+
+ if err, ok := bodyI.(error); ok {
+ return err
+ }
+
+ if _, err := io.Copy(into, bytes.NewReader(bodyI.([]byte))); err != nil {
+ return fmt.Errorf("writing body to io.Writer: %w", err)
+ }
+
+ return nil
+ }
+
+ buf := new(bytes.Buffer)
+ into = io.MultiWriter(into, buf)
+
+ if err := s.inner.Get(id, into); errors.Is(err, ErrAssetNotFound) {
+ s.m.Store(id, err)
+ return err
+ } else if err != nil {
+ return err
+ }
+
+ s.m.Store(id, buf.Bytes())
+ return nil
+}
+
+func (s *cachedAssetStore) Delete(id string) error {
+
+ if err := s.inner.Delete(id); err != nil {
+ return err
+ }
+
+ s.m.Delete(id)
+ return nil
+}
diff --git a/srv/src/post/asset_test.go b/srv/src/post/asset_test.go
index 4b88000..d0cff48 100644
--- a/srv/src/post/asset_test.go
+++ b/srv/src/post/asset_test.go
@@ -12,14 +12,14 @@ type assetTestHarness struct {
store AssetStore
}
-func newAssetTestHarness(t *testing.T) assetTestHarness {
+func newAssetTestHarness(t *testing.T) *assetTestHarness {
db := NewInMemSQLDB()
t.Cleanup(func() { db.Close() })
store := NewAssetStore(db)
- return assetTestHarness{
+ return &assetTestHarness{
store: store,
}
}
@@ -40,28 +40,41 @@ func (h *assetTestHarness) assertNotFound(t *testing.T, id string) {
func TestAssetStore(t *testing.T) {
- h := newAssetTestHarness(t)
+ testAssetStore := func(t *testing.T, h *assetTestHarness) {
+ t.Helper()
- h.assertNotFound(t, "foo")
- h.assertNotFound(t, "bar")
+ h.assertNotFound(t, "foo")
+ h.assertNotFound(t, "bar")
- err := h.store.Set("foo", bytes.NewBufferString("FOO"))
- assert.NoError(t, err)
+ err := h.store.Set("foo", bytes.NewBufferString("FOO"))
+ assert.NoError(t, err)
- h.assertGet(t, "FOO", "foo")
- h.assertNotFound(t, "bar")
+ h.assertGet(t, "FOO", "foo")
+ h.assertNotFound(t, "bar")
- err = h.store.Set("foo", bytes.NewBufferString("FOOFOO"))
- assert.NoError(t, err)
+ err = h.store.Set("foo", bytes.NewBufferString("FOOFOO"))
+ assert.NoError(t, err)
+
+ h.assertGet(t, "FOOFOO", "foo")
+ h.assertNotFound(t, "bar")
- h.assertGet(t, "FOOFOO", "foo")
- h.assertNotFound(t, "bar")
+ assert.NoError(t, h.store.Delete("foo"))
+ h.assertNotFound(t, "foo")
+ h.assertNotFound(t, "bar")
+
+ assert.NoError(t, h.store.Delete("bar"))
+ h.assertNotFound(t, "foo")
+ h.assertNotFound(t, "bar")
+ }
- assert.NoError(t, h.store.Delete("foo"))
- h.assertNotFound(t, "foo")
- h.assertNotFound(t, "bar")
+ t.Run("sql", func(t *testing.T) {
+ h := newAssetTestHarness(t)
+ testAssetStore(t, h)
+ })
- assert.NoError(t, h.store.Delete("bar"))
- h.assertNotFound(t, "foo")
- h.assertNotFound(t, "bar")
+ t.Run("mem", func(t *testing.T) {
+ h := newAssetTestHarness(t)
+ h.store = NewCachedAssetStore(h.store)
+ testAssetStore(t, h)
+ })
}