diff options
Diffstat (limited to 'srv/chat')
-rw-r--r-- | srv/chat/chat.go | 467 | ||||
-rw-r--r-- | srv/chat/chat_test.go | 200 | ||||
-rw-r--r-- | srv/chat/user.go | 68 | ||||
-rw-r--r-- | srv/chat/user_test.go | 26 | ||||
-rw-r--r-- | srv/chat/util.go | 28 |
5 files changed, 0 insertions, 789 deletions
diff --git a/srv/chat/chat.go b/srv/chat/chat.go deleted file mode 100644 index 0a88d3b..0000000 --- a/srv/chat/chat.go +++ /dev/null @@ -1,467 +0,0 @@ -// Package chat implements a simple chatroom system. -package chat - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "strconv" - "sync" - "time" - - "github.com/mediocregopher/mediocre-go-lib/v2/mctx" - "github.com/mediocregopher/mediocre-go-lib/v2/mlog" - "github.com/mediocregopher/radix/v4" -) - -// ErrInvalidArg is returned from methods in this package when a call fails due -// to invalid input. -type ErrInvalidArg struct { - Err error -} - -func (e ErrInvalidArg) Error() string { - return fmt.Sprintf("invalid argument: %v", e.Err) -} - -var ( - errInvalidMessageID = ErrInvalidArg{Err: errors.New("invalid Message ID")} -) - -// Message describes a message which has been posted to a Room. -type Message struct { - ID string `json:"id"` - UserID UserID `json:"userID"` - Body string `json:"body"` - CreatedAt int64 `json:"createdAt,omitempty"` -} - -func msgFromStreamEntry(entry radix.StreamEntry) (Message, error) { - - // NOTE this should probably be a shortcut in radix - var bodyStr string - for _, field := range entry.Fields { - if field[0] == "json" { - bodyStr = field[1] - break - } - } - - if bodyStr == "" { - return Message{}, errors.New("no 'json' field") - } - - var msg Message - if err := json.Unmarshal([]byte(bodyStr), &msg); err != nil { - return Message{}, fmt.Errorf( - "json unmarshaling body %q: %w", bodyStr, err, - ) - } - - msg.ID = entry.ID.String() - msg.CreatedAt = int64(entry.ID.Time / 1000) - return msg, nil -} - -// MessageIterator returns a sequence of Messages which may or may not be -// unbounded. -type MessageIterator interface { - - // Next blocks until it returns the next Message in the sequence, or the - // context error if the context is cancelled, or io.EOF if the sequence has - // been exhausted. - Next(context.Context) (Message, error) - - // Close should always be called once Next has returned an error or the - // MessageIterator will no longer be used. - Close() error -} - -// HistoryOpts are passed into Room's History method in order to affect its -// result. All fields are optional. -type HistoryOpts struct { - Limit int // defaults to, and is capped at, 100. - Cursor string // If not given then the most recent Messages are returned. -} - -func (o HistoryOpts) sanitize() (HistoryOpts, error) { - if o.Limit <= 0 || o.Limit > 100 { - o.Limit = 100 - } - - if o.Cursor != "" { - id, err := parseStreamEntryID(o.Cursor) - if err != nil { - return HistoryOpts{}, fmt.Errorf("parsing Cursor: %w", err) - } - o.Cursor = id.String() - } - - return o, nil -} - -// Room implements functionality related to a single, unique chat room. -type Room interface { - - // Append accepts a new Message and stores it at the end of the room's - // history. The original Message is returned with any relevant fields (e.g. - // ID) updated. - Append(context.Context, Message) (Message, error) - - // Returns a cursor and the list of historical Messages in time descending - // order. The cursor can be passed into the next call to History to receive - // the next set of Messages. - History(context.Context, HistoryOpts) (string, []Message, error) - - // Listen returns a MessageIterator which will return all Messages appended - // to the Room since the given ID. Once all existing messages are iterated - // through then the MessageIterator will begin blocking until a new Message - // is posted. - Listen(ctx context.Context, sinceID string) (MessageIterator, error) - - // Delete deletes a Message from the Room. - Delete(ctx context.Context, id string) error - - // Close is used to clean up all resources created by the Room. - Close() error -} - -// RoomParams are used to instantiate a new Room. All fields are required unless -// otherwise noted. -type RoomParams struct { - Logger *mlog.Logger - Redis radix.Client - ID string - MaxMessages int -} - -func (p RoomParams) streamKey() string { - return fmt.Sprintf("chat:{%s}:stream", p.ID) -} - -type room struct { - params RoomParams - - closeCtx context.Context - closeCancel context.CancelFunc - wg sync.WaitGroup - - listeningL sync.Mutex - listening map[chan Message]struct{} - listeningLastID radix.StreamEntryID -} - -// NewRoom initializes and returns a new Room instance. -func NewRoom(ctx context.Context, params RoomParams) (Room, error) { - - params.Logger = params.Logger.WithNamespace("chat-room") - - r := &room{ - params: params, - listening: map[chan Message]struct{}{}, - } - - r.closeCtx, r.closeCancel = context.WithCancel(context.Background()) - - // figure out the most recent message, if any. - lastEntryID, err := r.mostRecentMsgID(ctx) - if err != nil { - return nil, fmt.Errorf("discovering most recent entry ID in stream: %w", err) - } - r.listeningLastID = lastEntryID - - r.wg.Add(1) - go func() { - defer r.wg.Done() - r.readStreamLoop(r.closeCtx) - }() - - return r, nil -} - -func (r *room) Close() error { - r.closeCancel() - r.wg.Wait() - return nil -} - -func (r *room) mostRecentMsgID(ctx context.Context) (radix.StreamEntryID, error) { - - var entries []radix.StreamEntry - err := r.params.Redis.Do(ctx, radix.Cmd( - &entries, - "XREVRANGE", r.params.streamKey(), "+", "-", "COUNT", "1", - )) - - if err != nil || len(entries) == 0 { - return radix.StreamEntryID{}, err - } - - return entries[0].ID, nil -} - -func (r *room) Append(ctx context.Context, msg Message) (Message, error) { - msg.ID = "" // just in case - - b, err := json.Marshal(msg) - if err != nil { - return Message{}, fmt.Errorf("json marshaling Message: %w", err) - } - - key := r.params.streamKey() - maxLen := strconv.Itoa(r.params.MaxMessages) - body := string(b) - - var id radix.StreamEntryID - - err = r.params.Redis.Do(ctx, radix.Cmd( - &id, "XADD", key, "MAXLEN", "=", maxLen, "*", "json", body, - )) - - if err != nil { - return Message{}, fmt.Errorf("posting message to redis: %w", err) - } - - msg.ID = id.String() - msg.CreatedAt = int64(id.Time / 1000) - return msg, nil -} - -const zeroCursor = "0-0" - -func (r *room) History(ctx context.Context, opts HistoryOpts) (string, []Message, error) { - opts, err := opts.sanitize() - if err != nil { - return "", nil, err - } - - key := r.params.streamKey() - end := opts.Cursor - if end == "" { - end = "+" - } - start := "-" - count := strconv.Itoa(opts.Limit) - - msgs := make([]Message, 0, opts.Limit) - streamEntries := make([]radix.StreamEntry, 0, opts.Limit) - - err = r.params.Redis.Do(ctx, radix.Cmd( - &streamEntries, - "XREVRANGE", key, end, start, "COUNT", count, - )) - - if err != nil { - return "", nil, fmt.Errorf("calling XREVRANGE: %w", err) - } - - var oldestEntryID radix.StreamEntryID - - for _, entry := range streamEntries { - oldestEntryID = entry.ID - - msg, err := msgFromStreamEntry(entry) - if err != nil { - return "", nil, fmt.Errorf( - "parsing stream entry %q: %w", entry.ID, err, - ) - } - msgs = append(msgs, msg) - } - - if len(msgs) < opts.Limit { - return zeroCursor, msgs, nil - } - - cursor := oldestEntryID.Prev() - return cursor.String(), msgs, nil -} - -func (r *room) readStream(ctx context.Context) error { - - r.listeningL.Lock() - lastEntryID := r.listeningLastID - r.listeningL.Unlock() - - redisAddr := r.params.Redis.Addr() - redisConn, err := radix.Dial(ctx, redisAddr.Network(), redisAddr.String()) - if err != nil { - return fmt.Errorf("creating redis connection: %w", err) - } - defer redisConn.Close() - - streamReader := (radix.StreamReaderConfig{}).New( - redisConn, - map[string]radix.StreamConfig{ - r.params.streamKey(): {After: lastEntryID}, - }, - ) - - for { - dlCtx, dlCtxCancel := context.WithTimeout(ctx, 10*time.Second) - _, streamEntry, err := streamReader.Next(dlCtx) - dlCtxCancel() - - if errors.Is(err, radix.ErrNoStreamEntries) { - continue - } else if err != nil { - return fmt.Errorf("fetching next entry from stream: %w", err) - } - - msg, err := msgFromStreamEntry(streamEntry) - if err != nil { - return fmt.Errorf("parsing stream entry %q: %w", streamEntry, err) - } - - r.listeningL.Lock() - - var dropped int - for ch := range r.listening { - select { - case ch <- msg: - default: - dropped++ - } - } - - if dropped > 0 { - ctx := mctx.Annotate(ctx, "msgID", msg.ID, "dropped", dropped) - r.params.Logger.WarnString(ctx, "some listening channels full, messages dropped") - } - - r.listeningLastID = streamEntry.ID - - r.listeningL.Unlock() - } -} - -func (r *room) readStreamLoop(ctx context.Context) { - for { - err := r.readStream(ctx) - if errors.Is(err, context.Canceled) { - return - } else if err != nil { - r.params.Logger.Error(ctx, "reading from redis stream", err) - } - } -} - -type listenMsgIterator struct { - ch <-chan Message - missedMsgs []Message - sinceEntryID radix.StreamEntryID - cleanup func() -} - -func (i *listenMsgIterator) Next(ctx context.Context) (Message, error) { - - if len(i.missedMsgs) > 0 { - msg := i.missedMsgs[0] - i.missedMsgs = i.missedMsgs[1:] - return msg, nil - } - - for { - select { - case <-ctx.Done(): - return Message{}, ctx.Err() - case msg := <-i.ch: - - entryID, err := parseStreamEntryID(msg.ID) - if err != nil { - return Message{}, fmt.Errorf("parsing Message ID %q: %w", msg.ID, err) - - } else if !i.sinceEntryID.Before(entryID) { - // this can happen if someone Appends a Message at the same time - // as another calls Listen. The Listener might have already seen - // the Message by calling History prior to the stream reader - // having processed it and updating listeningLastID. - continue - } - - return msg, nil - } - } -} - -func (i *listenMsgIterator) Close() error { - i.cleanup() - return nil -} - -func (r *room) Listen( - ctx context.Context, sinceID string, -) ( - MessageIterator, error, -) { - - var sinceEntryID radix.StreamEntryID - - if sinceID != "" { - var err error - if sinceEntryID, err = parseStreamEntryID(sinceID); err != nil { - return nil, fmt.Errorf("parsing sinceID: %w", err) - } - } - - ch := make(chan Message, 32) - - r.listeningL.Lock() - lastEntryID := r.listeningLastID - r.listening[ch] = struct{}{} - r.listeningL.Unlock() - - cleanup := func() { - r.listeningL.Lock() - defer r.listeningL.Unlock() - delete(r.listening, ch) - } - - key := r.params.streamKey() - start := sinceEntryID.Next().String() - end := "+" - if lastEntryID != (radix.StreamEntryID{}) { - end = lastEntryID.String() - } - - var streamEntries []radix.StreamEntry - - err := r.params.Redis.Do(ctx, radix.Cmd( - &streamEntries, - "XRANGE", key, start, end, - )) - - if err != nil { - cleanup() - return nil, fmt.Errorf("retrieving missed stream entries: %w", err) - } - - missedMsgs := make([]Message, len(streamEntries)) - - for i := range streamEntries { - - msg, err := msgFromStreamEntry(streamEntries[i]) - if err != nil { - cleanup() - return nil, fmt.Errorf( - "parsing stream entry %q: %w", streamEntries[i].ID, err, - ) - } - - missedMsgs[i] = msg - } - - return &listenMsgIterator{ - ch: ch, - missedMsgs: missedMsgs, - sinceEntryID: sinceEntryID, - cleanup: cleanup, - }, nil -} - -func (r *room) Delete(ctx context.Context, id string) error { - return r.params.Redis.Do(ctx, radix.Cmd( - nil, "XDEL", r.params.streamKey(), id, - )) -} diff --git a/srv/chat/chat_test.go b/srv/chat/chat_test.go deleted file mode 100644 index d37921c..0000000 --- a/srv/chat/chat_test.go +++ /dev/null @@ -1,200 +0,0 @@ -package chat - -import ( - "context" - "strconv" - "testing" - "time" - - "github.com/google/uuid" - "github.com/mediocregopher/mediocre-go-lib/v2/mlog" - "github.com/mediocregopher/radix/v4" - "github.com/stretchr/testify/assert" -) - -const roomTestHarnessMaxMsgs = 10 - -type roomTestHarness struct { - ctx context.Context - room Room - allMsgs []Message -} - -func (h *roomTestHarness) newMsg(t *testing.T) Message { - msg, err := h.room.Append(h.ctx, Message{ - UserID: UserID{ - Name: uuid.New().String(), - Hash: "0000", - }, - Body: uuid.New().String(), - }) - assert.NoError(t, err) - - t.Logf("appended message %s", msg.ID) - - h.allMsgs = append([]Message{msg}, h.allMsgs...) - - if len(h.allMsgs) > roomTestHarnessMaxMsgs { - h.allMsgs = h.allMsgs[:roomTestHarnessMaxMsgs] - } - - return msg -} - -func newRoomTestHarness(t *testing.T) *roomTestHarness { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - t.Cleanup(cancel) - - redis, err := radix.Dial(ctx, "tcp", "127.0.0.1:6379") - assert.NoError(t, err) - t.Cleanup(func() { redis.Close() }) - - roomParams := RoomParams{ - Logger: mlog.NewLogger(nil), - Redis: redis, - ID: uuid.New().String(), - MaxMessages: roomTestHarnessMaxMsgs, - } - - t.Logf("creating test Room %q", roomParams.ID) - room, err := NewRoom(ctx, roomParams) - assert.NoError(t, err) - - t.Cleanup(func() { - err := redis.Do(context.Background(), radix.Cmd( - nil, "DEL", roomParams.streamKey(), - )) - assert.NoError(t, err) - }) - - return &roomTestHarness{ctx: ctx, room: room} -} - -func TestRoom(t *testing.T) { - t.Run("history", func(t *testing.T) { - - tests := []struct { - numMsgs int - limit int - }{ - {numMsgs: 0, limit: 1}, - {numMsgs: 1, limit: 1}, - {numMsgs: 2, limit: 1}, - {numMsgs: 2, limit: 10}, - {numMsgs: 9, limit: 2}, - {numMsgs: 9, limit: 3}, - {numMsgs: 9, limit: 4}, - {numMsgs: 15, limit: 3}, - } - - for i, test := range tests { - t.Run(strconv.Itoa(i), func(t *testing.T) { - t.Logf("test: %+v", test) - - h := newRoomTestHarness(t) - - for j := 0; j < test.numMsgs; j++ { - h.newMsg(t) - } - - var gotMsgs []Message - var cursor string - - for { - - var msgs []Message - var err error - cursor, msgs, err = h.room.History(h.ctx, HistoryOpts{ - Cursor: cursor, - Limit: test.limit, - }) - - assert.NoError(t, err) - assert.NotEmpty(t, cursor) - - if len(msgs) == 0 { - break - } - - gotMsgs = append(gotMsgs, msgs...) - } - - assert.Equal(t, h.allMsgs, gotMsgs) - }) - } - }) - - assertNextMsg := func( - t *testing.T, expMsg Message, - ctx context.Context, it MessageIterator, - ) { - t.Helper() - gotMsg, err := it.Next(ctx) - assert.NoError(t, err) - assert.Equal(t, expMsg, gotMsg) - } - - t.Run("listen/already_populated", func(t *testing.T) { - h := newRoomTestHarness(t) - - msgA, msgB, msgC := h.newMsg(t), h.newMsg(t), h.newMsg(t) - _ = msgA - _ = msgB - - itFoo, err := h.room.Listen(h.ctx, msgC.ID) - assert.NoError(t, err) - defer itFoo.Close() - - itBar, err := h.room.Listen(h.ctx, msgA.ID) - assert.NoError(t, err) - defer itBar.Close() - - msgD := h.newMsg(t) - - // itBar should get msgB and msgC before anything else. - assertNextMsg(t, msgB, h.ctx, itBar) - assertNextMsg(t, msgC, h.ctx, itBar) - - // now both iterators should give msgD - assertNextMsg(t, msgD, h.ctx, itFoo) - assertNextMsg(t, msgD, h.ctx, itBar) - - // timeout should be honored - { - timeoutCtx, timeoutCancel := context.WithTimeout(h.ctx, 1*time.Second) - _, errFoo := itFoo.Next(timeoutCtx) - _, errBar := itBar.Next(timeoutCtx) - timeoutCancel() - - assert.ErrorIs(t, errFoo, context.DeadlineExceeded) - assert.ErrorIs(t, errBar, context.DeadlineExceeded) - } - - // new message should work - { - expMsg := h.newMsg(t) - - timeoutCtx, timeoutCancel := context.WithTimeout(h.ctx, 1*time.Second) - gotFooMsg, errFoo := itFoo.Next(timeoutCtx) - gotBarMsg, errBar := itBar.Next(timeoutCtx) - timeoutCancel() - - assert.Equal(t, expMsg, gotFooMsg) - assert.NoError(t, errFoo) - assert.Equal(t, expMsg, gotBarMsg) - assert.NoError(t, errBar) - } - - }) - - t.Run("listen/empty", func(t *testing.T) { - h := newRoomTestHarness(t) - - it, err := h.room.Listen(h.ctx, "") - assert.NoError(t, err) - defer it.Close() - - msg := h.newMsg(t) - assertNextMsg(t, msg, h.ctx, it) - }) -} diff --git a/srv/chat/user.go b/srv/chat/user.go deleted file mode 100644 index 3f5ab95..0000000 --- a/srv/chat/user.go +++ /dev/null @@ -1,68 +0,0 @@ -package chat - -import ( - "encoding/hex" - "fmt" - "sync" - - "golang.org/x/crypto/argon2" -) - -// UserID uniquely identifies an individual user who has posted a message in a -// Room. -type UserID struct { - - // Name will be the user's chosen display name. - Name string `json:"name"` - - // Hash will be a hex string generated from a secret only the user knows. - Hash string `json:"id"` -} - -// UserIDCalculator is used to calculate UserIDs. -type UserIDCalculator struct { - - // Secret is used when calculating UserID Hash salts. - Secret []byte - - // TimeCost, MemoryCost, and Threads are used as inputs to the Argon2id - // algorithm which is used to generate the Hash. - TimeCost, MemoryCost uint32 - Threads uint8 - - // HashLen specifies the number of bytes the Hash should be. - HashLen uint32 - - // Lock, if set, forces concurrent Calculate calls to occur sequentially. - Lock *sync.Mutex -} - -// NewUserIDCalculator returns a UserIDCalculator with sane defaults. -func NewUserIDCalculator(secret []byte) *UserIDCalculator { - return &UserIDCalculator{ - Secret: secret, - TimeCost: 15, - MemoryCost: 128 * 1024, - Threads: 2, - HashLen: 16, - Lock: new(sync.Mutex), - } -} - -// Calculate accepts a name and password and returns the calculated UserID. -func (c *UserIDCalculator) Calculate(name, password string) UserID { - - input := fmt.Sprintf("%q:%q", name, password) - - hashB := argon2.IDKey( - []byte(input), - c.Secret, // salt - c.TimeCost, c.MemoryCost, c.Threads, - c.HashLen, - ) - - return UserID{ - Name: name, - Hash: hex.EncodeToString(hashB), - } -} diff --git a/srv/chat/user_test.go b/srv/chat/user_test.go deleted file mode 100644 index 2169cde..0000000 --- a/srv/chat/user_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package chat - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestUserIDCalculator(t *testing.T) { - - const name, password = "name", "password" - - c := NewUserIDCalculator([]byte("foo")) - - // calculating with same params twice should result in same UserID - userID := c.Calculate(name, password) - assert.Equal(t, userID, c.Calculate(name, password)) - - // changing either name or password should result in a different Hash - assert.NotEqual(t, userID.Hash, c.Calculate(name+"!", password).Hash) - assert.NotEqual(t, userID.Hash, c.Calculate(name, password+"!").Hash) - - // changing the secret should change the UserID - c.Secret = []byte("bar") - assert.NotEqual(t, userID, c.Calculate(name, password)) -} diff --git a/srv/chat/util.go b/srv/chat/util.go deleted file mode 100644 index 05f4830..0000000 --- a/srv/chat/util.go +++ /dev/null @@ -1,28 +0,0 @@ -package chat - -import ( - "strconv" - "strings" - - "github.com/mediocregopher/radix/v4" -) - -func parseStreamEntryID(str string) (radix.StreamEntryID, error) { - - split := strings.SplitN(str, "-", 2) - if len(split) != 2 { - return radix.StreamEntryID{}, errInvalidMessageID - } - - time, err := strconv.ParseUint(split[0], 10, 64) - if err != nil { - return radix.StreamEntryID{}, errInvalidMessageID - } - - seq, err := strconv.ParseUint(split[1], 10, 64) - if err != nil { - return radix.StreamEntryID{}, errInvalidMessageID - } - - return radix.StreamEntryID{Time: time, Seq: seq}, nil -} |