diff options
author | Brian Picciano <mediocregopher@gmail.com> | 2021-08-16 21:53:02 -0600 |
---|---|---|
committer | Brian Picciano <mediocregopher@gmail.com> | 2021-08-17 14:35:07 -0600 |
commit | eaccf41563a5696996c0c75ceff1f270e88fc207 (patch) | |
tree | 4ac478cb68ac205298793bd1afc98a759193d4f6 /srv/chat | |
parent | ac5275353c0d0f33ebe47c3e177d0b35b7bd6581 (diff) |
MVP of chat package
Diffstat (limited to 'srv/chat')
-rw-r--r-- | srv/chat/chat.go | 475 | ||||
-rw-r--r-- | srv/chat/chat_test.go | 200 | ||||
-rw-r--r-- | srv/chat/util.go | 28 |
3 files changed, 703 insertions, 0 deletions
diff --git a/srv/chat/chat.go b/srv/chat/chat.go new file mode 100644 index 0000000..44449cd --- /dev/null +++ b/srv/chat/chat.go @@ -0,0 +1,475 @@ +// Package chat implements a simple chatroom system. +package chat + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strconv" + "sync" + "time" + + "github.com/mediocregopher/mediocre-go-lib/v2/mctx" + "github.com/mediocregopher/mediocre-go-lib/v2/mlog" + "github.com/mediocregopher/radix/v4" +) + +// ErrInvalidArg is returned from methods in this package when a call fails due +// to invalid input. +type ErrInvalidArg struct { + Err error +} + +func (e ErrInvalidArg) Error() string { + return fmt.Sprintf("invalid argument: %v", e.Err) +} + +var ( + errInvalidMessageID = ErrInvalidArg{Err: errors.New("invalid Message ID")} +) + +// UserID uniquely identifies an individual user who has posted a message in a +// Room. +type UserID struct { + + // Name will be the user's chosen display name. + Name string `json:"name"` + + // Hash will be a hex string generated from a secret only the user knows. + Hash string `json:"id"` +} + +// Message describes a message which has been posted to a Room. +type Message struct { + ID string `json:"id"` + UserID UserID `json:"userID"` + Body string `json:"body"` +} + +func msgFromStreamEntry(entry radix.StreamEntry) (Message, error) { + + // NOTE this should probably be a shortcut in radix + var bodyStr string + for _, field := range entry.Fields { + if field[0] == "json" { + bodyStr = field[1] + break + } + } + + if bodyStr == "" { + return Message{}, errors.New("no 'json' field") + } + + var msg Message + if err := json.Unmarshal([]byte(bodyStr), &msg); err != nil { + return Message{}, fmt.Errorf( + "json unmarshaling body %q: %w", bodyStr, err, + ) + } + + msg.ID = entry.ID.String() + return msg, nil +} + +// MessageIterator returns a sequence of Messages which may or may not be +// unbounded. +type MessageIterator interface { + + // Next blocks until it returns the next Message in the sequence, or the + // context error if the context is cancelled, or io.EOF if the sequence has + // been exhausted. + Next(context.Context) (Message, error) + + // Close should always be called once Next has returned an error or the + // MessageIterator will no longer be used. + Close() error +} + +// HistoryOpts are passed into Room's History method in order to affect its +// result. All fields are optional. +type HistoryOpts struct { + Limit int // defaults to, and is capped at, 100. + Cursor string // If not given then the most recent Messages are returned. +} + +func (o HistoryOpts) sanitize() (HistoryOpts, error) { + if o.Limit < 0 || o.Limit > 100 { + o.Limit = 100 + } + + if o.Cursor != "" { + id, err := parseStreamEntryID(o.Cursor) + if err != nil { + return HistoryOpts{}, fmt.Errorf("parsing Cursor: %w", err) + } + o.Cursor = id.String() + } + + return o, nil +} + +// Room implements functionality related to a single, unique chat room. +type Room interface { + + // Append accepts a new Message and stores it at the end of the room's + // history. The original Message is returned with any relevant fields (e.g. + // ID) updated. + Append(context.Context, Message) (Message, error) + + // Returns a cursor and the list of historical Messages in time descending + // order. The cursor can be passed into the next call to History to receive + // the next set of Messages. + History(context.Context, HistoryOpts) (string, []Message, error) + + // Listen returns a MessageIterator which will return all Messages appended + // to the Room since the given ID. Once all existing messages are iterated + // through then the MessageIterator will begin blocking until a new Message + // is posted. + Listen(ctx context.Context, sinceID string) (MessageIterator, error) + + // Delete deletes a Message from the Room. + Delete(ctx context.Context, id string) error + + // Close is used to clean up all resources created by the Room. + Close() error +} + +// RoomParams are used to instantiate a new Room. All fields are required unless +// otherwise noted. +type RoomParams struct { + Logger *mlog.Logger + Redis radix.Client + ID string + MaxMessages int +} + +func (p RoomParams) streamKey() string { + return fmt.Sprintf("chat:{%s}:stream", p.ID) +} + +type room struct { + params RoomParams + + closeCtx context.Context + closeCancel context.CancelFunc + wg sync.WaitGroup + + listeningL sync.Mutex + listening map[chan Message]struct{} + listeningLastID radix.StreamEntryID +} + +// NewRoom initializes and returns a new Room instance. +func NewRoom(ctx context.Context, params RoomParams) (Room, error) { + + params.Logger = params.Logger.WithNamespace("chat-room") + + r := &room{ + params: params, + listening: map[chan Message]struct{}{}, + } + + r.closeCtx, r.closeCancel = context.WithCancel(context.Background()) + + // figure out the most recent message, if any. + lastEntryID, err := r.mostRecentMsgID(ctx) + if err != nil { + return nil, fmt.Errorf("discovering most recent entry ID in stream: %w", err) + } + r.listeningLastID = lastEntryID + + r.wg.Add(1) + go func() { + defer r.wg.Done() + r.readStreamLoop(r.closeCtx) + }() + + return r, nil +} + +func (r *room) Close() error { + r.closeCancel() + r.wg.Wait() + return nil +} + +func (r *room) mostRecentMsgID(ctx context.Context) (radix.StreamEntryID, error) { + + var entries []radix.StreamEntry + err := r.params.Redis.Do(ctx, radix.Cmd( + &entries, + "XREVRANGE", r.params.streamKey(), "+", "-", "COUNT", "1", + )) + + if err != nil || len(entries) == 0 { + return radix.StreamEntryID{}, err + } + + return entries[0].ID, nil +} + +func (r *room) Append(ctx context.Context, msg Message) (Message, error) { + msg.ID = "" // just in case + + b, err := json.Marshal(msg) + if err != nil { + return Message{}, fmt.Errorf("json marshaling Message: %w", err) + } + + key := r.params.streamKey() + maxLen := strconv.Itoa(r.params.MaxMessages) + body := string(b) + + var id string + + err = r.params.Redis.Do(ctx, radix.Cmd( + &id, "XADD", key, "MAXLEN", "=", maxLen, "*", "json", body, + )) + + if err != nil { + return Message{}, fmt.Errorf("posting message to redis: %w", err) + } + + msg.ID = id + return msg, nil +} + +const zeroCursor = "0-0" + +func (r *room) History(ctx context.Context, opts HistoryOpts) (string, []Message, error) { + opts, err := opts.sanitize() + if err != nil { + return "", nil, err + } + + key := r.params.streamKey() + end := opts.Cursor + if end == "" { + end = "+" + } + start := "-" + count := strconv.Itoa(opts.Limit) + + msgs := make([]Message, 0, opts.Limit) + streamEntries := make([]radix.StreamEntry, 0, opts.Limit) + + err = r.params.Redis.Do(ctx, radix.Cmd( + &streamEntries, + "XREVRANGE", key, end, start, "COUNT", count, + )) + + if err != nil { + return "", nil, fmt.Errorf("calling XREVRANGE: %w", err) + } + + var oldestEntryID radix.StreamEntryID + + for _, entry := range streamEntries { + oldestEntryID = entry.ID + + msg, err := msgFromStreamEntry(entry) + if err != nil { + return "", nil, fmt.Errorf( + "parsing stream entry %q: %w", entry.ID, err, + ) + } + msgs = append(msgs, msg) + } + + if len(msgs) < opts.Limit { + return zeroCursor, msgs, nil + } + + cursor := oldestEntryID.Prev() + return cursor.String(), msgs, nil +} + +func (r *room) readStream(ctx context.Context) error { + + r.listeningL.Lock() + lastEntryID := r.listeningLastID + r.listeningL.Unlock() + + redisAddr := r.params.Redis.Addr() + redisConn, err := radix.Dial(ctx, redisAddr.Network(), redisAddr.String()) + if err != nil { + return fmt.Errorf("creating redis connection: %w", err) + } + defer redisConn.Close() + + streamReader := (radix.StreamReaderConfig{}).New( + redisConn, + map[string]radix.StreamConfig{ + r.params.streamKey(): {After: lastEntryID}, + }, + ) + + for { + dlCtx, dlCtxCancel := context.WithTimeout(ctx, 10*time.Second) + _, streamEntry, err := streamReader.Next(dlCtx) + dlCtxCancel() + + if errors.Is(err, radix.ErrNoStreamEntries) { + continue + } else if err != nil { + return fmt.Errorf("fetching next entry from stream: %w", err) + } + + msg, err := msgFromStreamEntry(streamEntry) + if err != nil { + return fmt.Errorf("parsing stream entry %q: %w", streamEntry, err) + } + + r.listeningL.Lock() + + var dropped int + for ch := range r.listening { + select { + case ch <- msg: + default: + dropped++ + } + } + + if dropped > 0 { + ctx := mctx.Annotate(ctx, "msgID", msg.ID, "dropped", dropped) + r.params.Logger.WarnString(ctx, "some listening channels full, messages dropped") + } + + r.listeningLastID = streamEntry.ID + + r.listeningL.Unlock() + } +} + +func (r *room) readStreamLoop(ctx context.Context) { + for { + err := r.readStream(ctx) + if errors.Is(err, context.Canceled) { + return + } else if err != nil { + r.params.Logger.Error(ctx, "reading from redis stream", err) + } + } +} + +type listenMsgIterator struct { + ch <-chan Message + missedMsgs []Message + sinceEntryID radix.StreamEntryID + cleanup func() +} + +func (i *listenMsgIterator) Next(ctx context.Context) (Message, error) { + + if len(i.missedMsgs) > 0 { + msg := i.missedMsgs[0] + i.missedMsgs = i.missedMsgs[1:] + return msg, nil + } + + for { + select { + case <-ctx.Done(): + return Message{}, ctx.Err() + case msg := <-i.ch: + + entryID, err := parseStreamEntryID(msg.ID) + if err != nil { + return Message{}, fmt.Errorf("parsing Message ID %q: %w", msg.ID, err) + + } else if !i.sinceEntryID.Before(entryID) { + // this can happen if someone Appends a Message at the same time + // as another calls Listen. The Listener might have already seen + // the Message by calling History prior to the stream reader + // having processed it and updating listeningLastID. + continue + } + + return msg, nil + } + } +} + +func (i *listenMsgIterator) Close() error { + i.cleanup() + return nil +} + +func (r *room) Listen( + ctx context.Context, sinceID string, +) ( + MessageIterator, error, +) { + + var sinceEntryID radix.StreamEntryID + + if sinceID != "" { + var err error + if sinceEntryID, err = parseStreamEntryID(sinceID); err != nil { + return nil, fmt.Errorf("parsing sinceID: %w", err) + } + } + + ch := make(chan Message, 32) + + r.listeningL.Lock() + lastEntryID := r.listeningLastID + r.listening[ch] = struct{}{} + r.listeningL.Unlock() + + cleanup := func() { + r.listeningL.Lock() + defer r.listeningL.Unlock() + delete(r.listening, ch) + } + + key := r.params.streamKey() + start := sinceEntryID.Next().String() + end := "+" + if lastEntryID != (radix.StreamEntryID{}) { + end = lastEntryID.String() + } + + var streamEntries []radix.StreamEntry + + err := r.params.Redis.Do(ctx, radix.Cmd( + &streamEntries, + "XRANGE", key, start, end, + )) + + if err != nil { + cleanup() + return nil, fmt.Errorf("retrieving missed stream entries: %w", err) + } + + missedMsgs := make([]Message, len(streamEntries)) + + for i := range streamEntries { + + msg, err := msgFromStreamEntry(streamEntries[i]) + if err != nil { + cleanup() + return nil, fmt.Errorf( + "parsing stream entry %q: %w", streamEntries[i].ID, err, + ) + } + + missedMsgs[i] = msg + } + + return &listenMsgIterator{ + ch: ch, + missedMsgs: missedMsgs, + sinceEntryID: sinceEntryID, + cleanup: cleanup, + }, nil +} + +func (r *room) Delete(ctx context.Context, id string) error { + return r.params.Redis.Do(ctx, radix.Cmd( + nil, "XDEL", r.params.streamKey(), id, + )) +} diff --git a/srv/chat/chat_test.go b/srv/chat/chat_test.go new file mode 100644 index 0000000..d37921c --- /dev/null +++ b/srv/chat/chat_test.go @@ -0,0 +1,200 @@ +package chat + +import ( + "context" + "strconv" + "testing" + "time" + + "github.com/google/uuid" + "github.com/mediocregopher/mediocre-go-lib/v2/mlog" + "github.com/mediocregopher/radix/v4" + "github.com/stretchr/testify/assert" +) + +const roomTestHarnessMaxMsgs = 10 + +type roomTestHarness struct { + ctx context.Context + room Room + allMsgs []Message +} + +func (h *roomTestHarness) newMsg(t *testing.T) Message { + msg, err := h.room.Append(h.ctx, Message{ + UserID: UserID{ + Name: uuid.New().String(), + Hash: "0000", + }, + Body: uuid.New().String(), + }) + assert.NoError(t, err) + + t.Logf("appended message %s", msg.ID) + + h.allMsgs = append([]Message{msg}, h.allMsgs...) + + if len(h.allMsgs) > roomTestHarnessMaxMsgs { + h.allMsgs = h.allMsgs[:roomTestHarnessMaxMsgs] + } + + return msg +} + +func newRoomTestHarness(t *testing.T) *roomTestHarness { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(cancel) + + redis, err := radix.Dial(ctx, "tcp", "127.0.0.1:6379") + assert.NoError(t, err) + t.Cleanup(func() { redis.Close() }) + + roomParams := RoomParams{ + Logger: mlog.NewLogger(nil), + Redis: redis, + ID: uuid.New().String(), + MaxMessages: roomTestHarnessMaxMsgs, + } + + t.Logf("creating test Room %q", roomParams.ID) + room, err := NewRoom(ctx, roomParams) + assert.NoError(t, err) + + t.Cleanup(func() { + err := redis.Do(context.Background(), radix.Cmd( + nil, "DEL", roomParams.streamKey(), + )) + assert.NoError(t, err) + }) + + return &roomTestHarness{ctx: ctx, room: room} +} + +func TestRoom(t *testing.T) { + t.Run("history", func(t *testing.T) { + + tests := []struct { + numMsgs int + limit int + }{ + {numMsgs: 0, limit: 1}, + {numMsgs: 1, limit: 1}, + {numMsgs: 2, limit: 1}, + {numMsgs: 2, limit: 10}, + {numMsgs: 9, limit: 2}, + {numMsgs: 9, limit: 3}, + {numMsgs: 9, limit: 4}, + {numMsgs: 15, limit: 3}, + } + + for i, test := range tests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + t.Logf("test: %+v", test) + + h := newRoomTestHarness(t) + + for j := 0; j < test.numMsgs; j++ { + h.newMsg(t) + } + + var gotMsgs []Message + var cursor string + + for { + + var msgs []Message + var err error + cursor, msgs, err = h.room.History(h.ctx, HistoryOpts{ + Cursor: cursor, + Limit: test.limit, + }) + + assert.NoError(t, err) + assert.NotEmpty(t, cursor) + + if len(msgs) == 0 { + break + } + + gotMsgs = append(gotMsgs, msgs...) + } + + assert.Equal(t, h.allMsgs, gotMsgs) + }) + } + }) + + assertNextMsg := func( + t *testing.T, expMsg Message, + ctx context.Context, it MessageIterator, + ) { + t.Helper() + gotMsg, err := it.Next(ctx) + assert.NoError(t, err) + assert.Equal(t, expMsg, gotMsg) + } + + t.Run("listen/already_populated", func(t *testing.T) { + h := newRoomTestHarness(t) + + msgA, msgB, msgC := h.newMsg(t), h.newMsg(t), h.newMsg(t) + _ = msgA + _ = msgB + + itFoo, err := h.room.Listen(h.ctx, msgC.ID) + assert.NoError(t, err) + defer itFoo.Close() + + itBar, err := h.room.Listen(h.ctx, msgA.ID) + assert.NoError(t, err) + defer itBar.Close() + + msgD := h.newMsg(t) + + // itBar should get msgB and msgC before anything else. + assertNextMsg(t, msgB, h.ctx, itBar) + assertNextMsg(t, msgC, h.ctx, itBar) + + // now both iterators should give msgD + assertNextMsg(t, msgD, h.ctx, itFoo) + assertNextMsg(t, msgD, h.ctx, itBar) + + // timeout should be honored + { + timeoutCtx, timeoutCancel := context.WithTimeout(h.ctx, 1*time.Second) + _, errFoo := itFoo.Next(timeoutCtx) + _, errBar := itBar.Next(timeoutCtx) + timeoutCancel() + + assert.ErrorIs(t, errFoo, context.DeadlineExceeded) + assert.ErrorIs(t, errBar, context.DeadlineExceeded) + } + + // new message should work + { + expMsg := h.newMsg(t) + + timeoutCtx, timeoutCancel := context.WithTimeout(h.ctx, 1*time.Second) + gotFooMsg, errFoo := itFoo.Next(timeoutCtx) + gotBarMsg, errBar := itBar.Next(timeoutCtx) + timeoutCancel() + + assert.Equal(t, expMsg, gotFooMsg) + assert.NoError(t, errFoo) + assert.Equal(t, expMsg, gotBarMsg) + assert.NoError(t, errBar) + } + + }) + + t.Run("listen/empty", func(t *testing.T) { + h := newRoomTestHarness(t) + + it, err := h.room.Listen(h.ctx, "") + assert.NoError(t, err) + defer it.Close() + + msg := h.newMsg(t) + assertNextMsg(t, msg, h.ctx, it) + }) +} diff --git a/srv/chat/util.go b/srv/chat/util.go new file mode 100644 index 0000000..05f4830 --- /dev/null +++ b/srv/chat/util.go @@ -0,0 +1,28 @@ +package chat + +import ( + "strconv" + "strings" + + "github.com/mediocregopher/radix/v4" +) + +func parseStreamEntryID(str string) (radix.StreamEntryID, error) { + + split := strings.SplitN(str, "-", 2) + if len(split) != 2 { + return radix.StreamEntryID{}, errInvalidMessageID + } + + time, err := strconv.ParseUint(split[0], 10, 64) + if err != nil { + return radix.StreamEntryID{}, errInvalidMessageID + } + + seq, err := strconv.ParseUint(split[1], 10, 64) + if err != nil { + return radix.StreamEntryID{}, errInvalidMessageID + } + + return radix.StreamEntryID{Time: time, Seq: seq}, nil +} |