summaryrefslogtreecommitdiff
path: root/srv/pow
diff options
context:
space:
mode:
Diffstat (limited to 'srv/pow')
-rw-r--r--srv/pow/pow.go288
-rw-r--r--srv/pow/pow_test.go120
-rw-r--r--srv/pow/store.go92
-rw-r--r--srv/pow/store_test.go52
4 files changed, 552 insertions, 0 deletions
diff --git a/srv/pow/pow.go b/srv/pow/pow.go
new file mode 100644
index 0000000..3de1450
--- /dev/null
+++ b/srv/pow/pow.go
@@ -0,0 +1,288 @@
+// Package pow creates proof-of-work challenges and validates their solutions.
+package pow
+
+import (
+ "bytes"
+ "crypto/hmac"
+ "crypto/md5"
+ "crypto/rand"
+ "crypto/sha512"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "hash"
+ "time"
+
+ "github.com/tilinna/clock"
+)
+
+type challengeParams struct {
+ Target uint32
+ ExpiresAt int64
+ Random []byte
+}
+
+func (c challengeParams) MarshalBinary() ([]byte, error) {
+ buf := new(bytes.Buffer)
+
+ var err error
+ write := func(v interface{}) {
+ if err != nil {
+ return
+ }
+ err = binary.Write(buf, binary.BigEndian, v)
+ }
+
+ write(c.Target)
+ write(c.ExpiresAt)
+
+ if err != nil {
+ return nil, err
+ }
+
+ if _, err := buf.Write(c.Random); err != nil {
+ panic(err)
+ }
+
+ return buf.Bytes(), nil
+}
+
+func (c *challengeParams) UnmarshalBinary(b []byte) error {
+ buf := bytes.NewBuffer(b)
+
+ var err error
+ read := func(into interface{}) {
+ if err != nil {
+ return
+ }
+ err = binary.Read(buf, binary.BigEndian, into)
+ }
+
+ read(&c.Target)
+ read(&c.ExpiresAt)
+
+ if buf.Len() > 0 {
+ c.Random = buf.Bytes() // whatever is left
+ }
+
+ return err
+}
+
+// The seed takes the form:
+//
+// (version)+(signature of challengeParams)+(challengeParams)
+//
+// Version is currently always 0.
+func newSeed(c challengeParams, secret []byte) ([]byte, error) {
+ buf := new(bytes.Buffer)
+ buf.WriteByte(0) // version
+
+ cb, err := c.MarshalBinary()
+ if err != nil {
+ return nil, err
+ }
+
+ h := hmac.New(md5.New, secret)
+ h.Write(cb)
+ buf.Write(h.Sum(nil))
+
+ buf.Write(cb)
+
+ return buf.Bytes(), nil
+}
+
+var errMalformedSeed = errors.New("malformed seed")
+
+func challengeParamsFromSeed(seed, secret []byte) (challengeParams, error) {
+ h := hmac.New(md5.New, secret)
+ hSize := h.Size()
+
+ if len(seed) < hSize+1 || seed[0] != 0 {
+ return challengeParams{}, errMalformedSeed
+ }
+ seed = seed[1:]
+
+ sig, cb := seed[:hSize], seed[hSize:]
+
+ // check signature
+ h.Write(cb)
+ if !hmac.Equal(sig, h.Sum(nil)) {
+ return challengeParams{}, errMalformedSeed
+ }
+
+ var c challengeParams
+ if err := c.UnmarshalBinary(cb); err != nil {
+ return challengeParams{}, fmt.Errorf("unmarshaling challenge parameters: %w", err)
+ }
+
+ return c, nil
+}
+
+// Challenge is a set of fields presented to a client, with which they must
+// generate a solution.
+//
+// Generating a solution is done by:
+//
+// - Collect up to len(Seed) random bytes. These will be the potential
+// solution.
+//
+// - Calculate the sha512 of the concatenation of Seed and PotentialSolution.
+//
+// - Parse the first 4 bytes of the sha512 result as a big-endian uint32.
+//
+// - If the resulting number is _less_ than Target, the solution has been
+// found. Otherwise go back to step 1 and try again.
+//
+type Challenge struct {
+ Seed []byte
+ Target uint32
+}
+
+// Errors which may be produced by a Manager.
+var (
+ ErrInvalidSolution = errors.New("invalid solution")
+ ErrExpiredSolution = errors.New("expired solution")
+)
+
+// Manager is used to both produce proof-of-work challenges and check their
+// solutions.
+type Manager interface {
+ NewChallenge() Challenge
+
+ // Will produce ErrInvalidSolution if the solution is invalid, or
+ // ErrExpiredSolution if the solution has expired.
+ CheckSolution(seed, solution []byte) error
+}
+
+// ManagerParams are used to initialize a new Manager instance. All fields are
+// required unless otherwise noted.
+type ManagerParams struct {
+ Clock clock.Clock
+ Store Store
+
+ // Secret is used to sign each Challenge's Seed, it should _not_ be shared
+ // with clients.
+ Secret []byte
+
+ // The Target which Challenges should hit. Lower is more difficult.
+ //
+ // Defaults to 0x00FFFFFF
+ Target uint32
+
+ // ChallengeTimeout indicates how long before Challenges are considered
+ // expired and cannot be solved.
+ //
+ // Defaults to 1 minute.
+ ChallengeTimeout time.Duration
+}
+
+func (p ManagerParams) withDefaults() ManagerParams {
+ if p.Target == 0 {
+ p.Target = 0x00FFFFFF
+ }
+ if p.ChallengeTimeout == 0 {
+ p.ChallengeTimeout = 1 * time.Minute
+ }
+ return p
+}
+
+type manager struct {
+ params ManagerParams
+}
+
+// NewManager initializes and returns a Manager instance using the given
+// parameters.
+func NewManager(params ManagerParams) Manager {
+ return &manager{
+ params: params,
+ }
+}
+
+func (m *manager) NewChallenge() Challenge {
+ target := m.params.Target
+
+ c := challengeParams{
+ Target: target,
+ ExpiresAt: m.params.Clock.Now().Add(m.params.ChallengeTimeout).Unix(),
+ Random: make([]byte, 8),
+ }
+
+ if _, err := rand.Read(c.Random); err != nil {
+ panic(err)
+ }
+
+ seed, err := newSeed(c, m.params.Secret)
+ if err != nil {
+ panic(err)
+ }
+
+ return Challenge{
+ Seed: seed,
+ Target: target,
+ }
+}
+
+// SolutionChecker can be used to check possible Challenge solutions. It will
+// cache certain values internally to save on allocations when used in a loop
+// (e.g. when generating a solution).
+//
+// SolutionChecker is not thread-safe.
+type SolutionChecker struct {
+ h hash.Hash // sha512
+ sum []byte
+}
+
+// Check returns true if the given bytes are a solution to the given Challenge.
+func (s SolutionChecker) Check(challenge Challenge, solution []byte) bool {
+ if s.h == nil {
+ s.h = sha512.New()
+ }
+ s.h.Reset()
+
+ s.h.Write(challenge.Seed)
+ s.h.Write(solution)
+ s.sum = s.h.Sum(s.sum[:0])
+
+ i := binary.BigEndian.Uint32(s.sum[:4])
+ return i < challenge.Target
+}
+
+func (m *manager) CheckSolution(seed, solution []byte) error {
+ c, err := challengeParamsFromSeed(seed, m.params.Secret)
+ if err != nil {
+ return fmt.Errorf("parsing challenge parameters from seed: %w", err)
+
+ } else if c.ExpiresAt <= m.params.Clock.Now().Unix() {
+ return ErrExpiredSolution
+ }
+
+ ok := (SolutionChecker{}).Check(
+ Challenge{Seed: seed, Target: c.Target}, solution,
+ )
+
+ if !ok {
+ return ErrInvalidSolution
+ }
+
+ expiresAt := time.Unix(c.ExpiresAt, 0)
+ if err := m.params.Store.MarkSolved(seed, expiresAt.Add(1*time.Minute)); err != nil {
+ return fmt.Errorf("marking solution as solved: %w", err)
+ }
+
+ return nil
+}
+
+// Solve returns a solution for the given Challenge. This may take a while.
+func Solve(challenge Challenge) []byte {
+
+ chk := SolutionChecker{}
+ b := make([]byte, len(challenge.Seed))
+
+ for {
+ if _, err := rand.Read(b); err != nil {
+ panic(err)
+ } else if chk.Check(challenge, b) {
+ return b
+ }
+ }
+}
diff --git a/srv/pow/pow_test.go b/srv/pow/pow_test.go
new file mode 100644
index 0000000..4bc4141
--- /dev/null
+++ b/srv/pow/pow_test.go
@@ -0,0 +1,120 @@
+package pow
+
+import (
+ "encoding/hex"
+ "strconv"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/tilinna/clock"
+)
+
+func TestChallengeParams(t *testing.T) {
+ tests := []challengeParams{
+ {},
+ {
+ Target: 1,
+ ExpiresAt: 3,
+ },
+ {
+ Target: 2,
+ ExpiresAt: -10,
+ Random: []byte{0, 1, 2},
+ },
+ {
+ Random: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9},
+ },
+ }
+
+ t.Run("marshal_unmarshal", func(t *testing.T) {
+ for i, test := range tests {
+ t.Run(strconv.Itoa(i), func(t *testing.T) {
+ b, err := test.MarshalBinary()
+ assert.NoError(t, err)
+
+ var c2 challengeParams
+ assert.NoError(t, c2.UnmarshalBinary(b))
+ assert.Equal(t, test, c2)
+
+ b2, err := c2.MarshalBinary()
+ assert.NoError(t, err)
+ assert.Equal(t, b, b2)
+ })
+ }
+ })
+
+ secret := []byte("shhh")
+
+ t.Run("to_from_seed", func(t *testing.T) {
+
+ for i, test := range tests {
+ t.Run(strconv.Itoa(i), func(t *testing.T) {
+ seed, err := newSeed(test, secret)
+ assert.NoError(t, err)
+
+ // generating seed should be deterministic
+ seed2, err := newSeed(test, secret)
+ assert.NoError(t, err)
+ assert.Equal(t, seed, seed2)
+
+ c, err := challengeParamsFromSeed(seed, secret)
+ assert.NoError(t, err)
+ assert.Equal(t, test, c)
+ })
+ }
+ })
+
+ t.Run("malformed_seed", func(t *testing.T) {
+ tests := []string{
+ "",
+ "01",
+ "0000",
+ "00374a1ad84d6b7a93e68042c1f850cbb100000000000000000000000000000102030405060708A0", // changed one byte from a good seed
+ }
+
+ for i, test := range tests {
+ t.Run(strconv.Itoa(i), func(t *testing.T) {
+ seed, err := hex.DecodeString(test)
+ if err != nil {
+ panic(err)
+ }
+
+ _, err = challengeParamsFromSeed(seed, secret)
+ assert.ErrorIs(t, errMalformedSeed, err)
+ })
+ }
+ })
+}
+
+func TestManager(t *testing.T) {
+ clock := clock.NewMock(time.Now().Truncate(time.Hour))
+
+ store := NewMemoryStore(clock)
+ defer store.Close()
+
+ mgr := NewManager(ManagerParams{
+ Clock: clock,
+ Store: store,
+ Secret: []byte("shhhh"),
+ Target: 0x00FFFFFF,
+ ChallengeTimeout: 1 * time.Second,
+ })
+
+ {
+ c := mgr.NewChallenge()
+ solution := Solve(c)
+ assert.NoError(t, mgr.CheckSolution(c.Seed, solution))
+
+ // doing again should fail, the seed should already be marked as solved
+ assert.ErrorIs(t, mgr.CheckSolution(c.Seed, solution), ErrSeedSolved)
+ }
+
+ {
+ c := mgr.NewChallenge()
+ solution := Solve(c)
+ clock.Add(2 * time.Second)
+ assert.ErrorIs(t, mgr.CheckSolution(c.Seed, solution), ErrExpiredSolution)
+ }
+
+}
diff --git a/srv/pow/store.go b/srv/pow/store.go
new file mode 100644
index 0000000..0b5e7d0
--- /dev/null
+++ b/srv/pow/store.go
@@ -0,0 +1,92 @@
+package pow
+
+import (
+ "errors"
+ "sync"
+ "time"
+
+ "github.com/tilinna/clock"
+)
+
+// ErrSeedSolved is used to indicate a seed has already been solved.
+var ErrSeedSolved = errors.New("seed already solved")
+
+// Store is used to track information related to proof-of-work challenges and
+// solutions.
+type Store interface {
+
+ // MarkSolved will return ErrSeedSolved if the seed was already marked. The
+ // seed will be cleared from the Store once expiresAt is reached.
+ MarkSolved(seed []byte, expiresAt time.Time) error
+
+ Close() error
+}
+
+type inMemStore struct {
+ clock clock.Clock
+
+ m map[string]time.Time
+ l sync.Mutex
+ closeCh chan struct{}
+ spinLoopCh chan struct{} // only used by tests
+}
+
+const inMemStoreGCPeriod = 5 * time.Second
+
+// NewMemoryStore initializes and returns an in-memory Store implementation.
+func NewMemoryStore(clock clock.Clock) Store {
+ s := &inMemStore{
+ clock: clock,
+ m: map[string]time.Time{},
+ closeCh: make(chan struct{}),
+ spinLoopCh: make(chan struct{}, 1),
+ }
+ go s.spin(s.clock.NewTicker(inMemStoreGCPeriod))
+ return s
+}
+
+func (s *inMemStore) spin(ticker *clock.Ticker) {
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ticker.C:
+ now := s.clock.Now()
+
+ s.l.Lock()
+ for seed, expiresAt := range s.m {
+ if !now.Before(expiresAt) {
+ delete(s.m, seed)
+ }
+ }
+ s.l.Unlock()
+
+ case <-s.closeCh:
+ return
+ }
+
+ select {
+ case s.spinLoopCh <- struct{}{}:
+ default:
+ }
+ }
+}
+
+func (s *inMemStore) MarkSolved(seed []byte, expiresAt time.Time) error {
+ seedStr := string(seed)
+
+ s.l.Lock()
+ defer s.l.Unlock()
+
+ if _, ok := s.m[seedStr]; ok {
+ return ErrSeedSolved
+ }
+
+ s.m[seedStr] = expiresAt
+ return nil
+}
+
+func (s *inMemStore) Close() error {
+ close(s.closeCh)
+ return nil
+}
diff --git a/srv/pow/store_test.go b/srv/pow/store_test.go
new file mode 100644
index 0000000..324a40c
--- /dev/null
+++ b/srv/pow/store_test.go
@@ -0,0 +1,52 @@
+package pow
+
+import (
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/tilinna/clock"
+)
+
+func TestStore(t *testing.T) {
+ clock := clock.NewMock(time.Now().Truncate(time.Hour))
+ now := clock.Now()
+
+ s := NewMemoryStore(clock)
+ defer s.Close()
+
+ seed := []byte{0}
+
+ // mark solved should work
+ err := s.MarkSolved(seed, now.Add(time.Second))
+ assert.NoError(t, err)
+
+ // mark again, should not work
+ err = s.MarkSolved(seed, now.Add(time.Hour))
+ assert.ErrorIs(t, err, ErrSeedSolved)
+
+ // marking a different seed should still work
+ seed2 := []byte{1}
+ err = s.MarkSolved(seed2, now.Add(inMemStoreGCPeriod*2))
+ assert.NoError(t, err)
+ err = s.MarkSolved(seed2, now.Add(time.Hour))
+ assert.ErrorIs(t, err, ErrSeedSolved)
+
+ now = clock.Add(inMemStoreGCPeriod)
+ <-s.(*inMemStore).spinLoopCh
+
+ // first one should be markable again, second shouldnt
+ err = s.MarkSolved(seed, now.Add(time.Second))
+ assert.NoError(t, err)
+ err = s.MarkSolved(seed2, now.Add(time.Hour))
+ assert.ErrorIs(t, err, ErrSeedSolved)
+
+ now = clock.Add(inMemStoreGCPeriod)
+ <-s.(*inMemStore).spinLoopCh
+
+ // now both should be expired
+ err = s.MarkSolved(seed, now.Add(time.Second))
+ assert.NoError(t, err)
+ err = s.MarkSolved(seed2, now.Add(time.Second))
+ assert.NoError(t, err)
+}