diff options
Diffstat (limited to 'srv/chat/chat.go')
-rw-r--r-- | srv/chat/chat.go | 475 |
1 files changed, 475 insertions, 0 deletions
diff --git a/srv/chat/chat.go b/srv/chat/chat.go new file mode 100644 index 0000000..44449cd --- /dev/null +++ b/srv/chat/chat.go @@ -0,0 +1,475 @@ +// Package chat implements a simple chatroom system. +package chat + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strconv" + "sync" + "time" + + "github.com/mediocregopher/mediocre-go-lib/v2/mctx" + "github.com/mediocregopher/mediocre-go-lib/v2/mlog" + "github.com/mediocregopher/radix/v4" +) + +// ErrInvalidArg is returned from methods in this package when a call fails due +// to invalid input. +type ErrInvalidArg struct { + Err error +} + +func (e ErrInvalidArg) Error() string { + return fmt.Sprintf("invalid argument: %v", e.Err) +} + +var ( + errInvalidMessageID = ErrInvalidArg{Err: errors.New("invalid Message ID")} +) + +// UserID uniquely identifies an individual user who has posted a message in a +// Room. +type UserID struct { + + // Name will be the user's chosen display name. + Name string `json:"name"` + + // Hash will be a hex string generated from a secret only the user knows. + Hash string `json:"id"` +} + +// Message describes a message which has been posted to a Room. +type Message struct { + ID string `json:"id"` + UserID UserID `json:"userID"` + Body string `json:"body"` +} + +func msgFromStreamEntry(entry radix.StreamEntry) (Message, error) { + + // NOTE this should probably be a shortcut in radix + var bodyStr string + for _, field := range entry.Fields { + if field[0] == "json" { + bodyStr = field[1] + break + } + } + + if bodyStr == "" { + return Message{}, errors.New("no 'json' field") + } + + var msg Message + if err := json.Unmarshal([]byte(bodyStr), &msg); err != nil { + return Message{}, fmt.Errorf( + "json unmarshaling body %q: %w", bodyStr, err, + ) + } + + msg.ID = entry.ID.String() + return msg, nil +} + +// MessageIterator returns a sequence of Messages which may or may not be +// unbounded. +type MessageIterator interface { + + // Next blocks until it returns the next Message in the sequence, or the + // context error if the context is cancelled, or io.EOF if the sequence has + // been exhausted. + Next(context.Context) (Message, error) + + // Close should always be called once Next has returned an error or the + // MessageIterator will no longer be used. + Close() error +} + +// HistoryOpts are passed into Room's History method in order to affect its +// result. All fields are optional. +type HistoryOpts struct { + Limit int // defaults to, and is capped at, 100. + Cursor string // If not given then the most recent Messages are returned. +} + +func (o HistoryOpts) sanitize() (HistoryOpts, error) { + if o.Limit < 0 || o.Limit > 100 { + o.Limit = 100 + } + + if o.Cursor != "" { + id, err := parseStreamEntryID(o.Cursor) + if err != nil { + return HistoryOpts{}, fmt.Errorf("parsing Cursor: %w", err) + } + o.Cursor = id.String() + } + + return o, nil +} + +// Room implements functionality related to a single, unique chat room. +type Room interface { + + // Append accepts a new Message and stores it at the end of the room's + // history. The original Message is returned with any relevant fields (e.g. + // ID) updated. + Append(context.Context, Message) (Message, error) + + // Returns a cursor and the list of historical Messages in time descending + // order. The cursor can be passed into the next call to History to receive + // the next set of Messages. + History(context.Context, HistoryOpts) (string, []Message, error) + + // Listen returns a MessageIterator which will return all Messages appended + // to the Room since the given ID. Once all existing messages are iterated + // through then the MessageIterator will begin blocking until a new Message + // is posted. + Listen(ctx context.Context, sinceID string) (MessageIterator, error) + + // Delete deletes a Message from the Room. + Delete(ctx context.Context, id string) error + + // Close is used to clean up all resources created by the Room. + Close() error +} + +// RoomParams are used to instantiate a new Room. All fields are required unless +// otherwise noted. +type RoomParams struct { + Logger *mlog.Logger + Redis radix.Client + ID string + MaxMessages int +} + +func (p RoomParams) streamKey() string { + return fmt.Sprintf("chat:{%s}:stream", p.ID) +} + +type room struct { + params RoomParams + + closeCtx context.Context + closeCancel context.CancelFunc + wg sync.WaitGroup + + listeningL sync.Mutex + listening map[chan Message]struct{} + listeningLastID radix.StreamEntryID +} + +// NewRoom initializes and returns a new Room instance. +func NewRoom(ctx context.Context, params RoomParams) (Room, error) { + + params.Logger = params.Logger.WithNamespace("chat-room") + + r := &room{ + params: params, + listening: map[chan Message]struct{}{}, + } + + r.closeCtx, r.closeCancel = context.WithCancel(context.Background()) + + // figure out the most recent message, if any. + lastEntryID, err := r.mostRecentMsgID(ctx) + if err != nil { + return nil, fmt.Errorf("discovering most recent entry ID in stream: %w", err) + } + r.listeningLastID = lastEntryID + + r.wg.Add(1) + go func() { + defer r.wg.Done() + r.readStreamLoop(r.closeCtx) + }() + + return r, nil +} + +func (r *room) Close() error { + r.closeCancel() + r.wg.Wait() + return nil +} + +func (r *room) mostRecentMsgID(ctx context.Context) (radix.StreamEntryID, error) { + + var entries []radix.StreamEntry + err := r.params.Redis.Do(ctx, radix.Cmd( + &entries, + "XREVRANGE", r.params.streamKey(), "+", "-", "COUNT", "1", + )) + + if err != nil || len(entries) == 0 { + return radix.StreamEntryID{}, err + } + + return entries[0].ID, nil +} + +func (r *room) Append(ctx context.Context, msg Message) (Message, error) { + msg.ID = "" // just in case + + b, err := json.Marshal(msg) + if err != nil { + return Message{}, fmt.Errorf("json marshaling Message: %w", err) + } + + key := r.params.streamKey() + maxLen := strconv.Itoa(r.params.MaxMessages) + body := string(b) + + var id string + + err = r.params.Redis.Do(ctx, radix.Cmd( + &id, "XADD", key, "MAXLEN", "=", maxLen, "*", "json", body, + )) + + if err != nil { + return Message{}, fmt.Errorf("posting message to redis: %w", err) + } + + msg.ID = id + return msg, nil +} + +const zeroCursor = "0-0" + +func (r *room) History(ctx context.Context, opts HistoryOpts) (string, []Message, error) { + opts, err := opts.sanitize() + if err != nil { + return "", nil, err + } + + key := r.params.streamKey() + end := opts.Cursor + if end == "" { + end = "+" + } + start := "-" + count := strconv.Itoa(opts.Limit) + + msgs := make([]Message, 0, opts.Limit) + streamEntries := make([]radix.StreamEntry, 0, opts.Limit) + + err = r.params.Redis.Do(ctx, radix.Cmd( + &streamEntries, + "XREVRANGE", key, end, start, "COUNT", count, + )) + + if err != nil { + return "", nil, fmt.Errorf("calling XREVRANGE: %w", err) + } + + var oldestEntryID radix.StreamEntryID + + for _, entry := range streamEntries { + oldestEntryID = entry.ID + + msg, err := msgFromStreamEntry(entry) + if err != nil { + return "", nil, fmt.Errorf( + "parsing stream entry %q: %w", entry.ID, err, + ) + } + msgs = append(msgs, msg) + } + + if len(msgs) < opts.Limit { + return zeroCursor, msgs, nil + } + + cursor := oldestEntryID.Prev() + return cursor.String(), msgs, nil +} + +func (r *room) readStream(ctx context.Context) error { + + r.listeningL.Lock() + lastEntryID := r.listeningLastID + r.listeningL.Unlock() + + redisAddr := r.params.Redis.Addr() + redisConn, err := radix.Dial(ctx, redisAddr.Network(), redisAddr.String()) + if err != nil { + return fmt.Errorf("creating redis connection: %w", err) + } + defer redisConn.Close() + + streamReader := (radix.StreamReaderConfig{}).New( + redisConn, + map[string]radix.StreamConfig{ + r.params.streamKey(): {After: lastEntryID}, + }, + ) + + for { + dlCtx, dlCtxCancel := context.WithTimeout(ctx, 10*time.Second) + _, streamEntry, err := streamReader.Next(dlCtx) + dlCtxCancel() + + if errors.Is(err, radix.ErrNoStreamEntries) { + continue + } else if err != nil { + return fmt.Errorf("fetching next entry from stream: %w", err) + } + + msg, err := msgFromStreamEntry(streamEntry) + if err != nil { + return fmt.Errorf("parsing stream entry %q: %w", streamEntry, err) + } + + r.listeningL.Lock() + + var dropped int + for ch := range r.listening { + select { + case ch <- msg: + default: + dropped++ + } + } + + if dropped > 0 { + ctx := mctx.Annotate(ctx, "msgID", msg.ID, "dropped", dropped) + r.params.Logger.WarnString(ctx, "some listening channels full, messages dropped") + } + + r.listeningLastID = streamEntry.ID + + r.listeningL.Unlock() + } +} + +func (r *room) readStreamLoop(ctx context.Context) { + for { + err := r.readStream(ctx) + if errors.Is(err, context.Canceled) { + return + } else if err != nil { + r.params.Logger.Error(ctx, "reading from redis stream", err) + } + } +} + +type listenMsgIterator struct { + ch <-chan Message + missedMsgs []Message + sinceEntryID radix.StreamEntryID + cleanup func() +} + +func (i *listenMsgIterator) Next(ctx context.Context) (Message, error) { + + if len(i.missedMsgs) > 0 { + msg := i.missedMsgs[0] + i.missedMsgs = i.missedMsgs[1:] + return msg, nil + } + + for { + select { + case <-ctx.Done(): + return Message{}, ctx.Err() + case msg := <-i.ch: + + entryID, err := parseStreamEntryID(msg.ID) + if err != nil { + return Message{}, fmt.Errorf("parsing Message ID %q: %w", msg.ID, err) + + } else if !i.sinceEntryID.Before(entryID) { + // this can happen if someone Appends a Message at the same time + // as another calls Listen. The Listener might have already seen + // the Message by calling History prior to the stream reader + // having processed it and updating listeningLastID. + continue + } + + return msg, nil + } + } +} + +func (i *listenMsgIterator) Close() error { + i.cleanup() + return nil +} + +func (r *room) Listen( + ctx context.Context, sinceID string, +) ( + MessageIterator, error, +) { + + var sinceEntryID radix.StreamEntryID + + if sinceID != "" { + var err error + if sinceEntryID, err = parseStreamEntryID(sinceID); err != nil { + return nil, fmt.Errorf("parsing sinceID: %w", err) + } + } + + ch := make(chan Message, 32) + + r.listeningL.Lock() + lastEntryID := r.listeningLastID + r.listening[ch] = struct{}{} + r.listeningL.Unlock() + + cleanup := func() { + r.listeningL.Lock() + defer r.listeningL.Unlock() + delete(r.listening, ch) + } + + key := r.params.streamKey() + start := sinceEntryID.Next().String() + end := "+" + if lastEntryID != (radix.StreamEntryID{}) { + end = lastEntryID.String() + } + + var streamEntries []radix.StreamEntry + + err := r.params.Redis.Do(ctx, radix.Cmd( + &streamEntries, + "XRANGE", key, start, end, + )) + + if err != nil { + cleanup() + return nil, fmt.Errorf("retrieving missed stream entries: %w", err) + } + + missedMsgs := make([]Message, len(streamEntries)) + + for i := range streamEntries { + + msg, err := msgFromStreamEntry(streamEntries[i]) + if err != nil { + cleanup() + return nil, fmt.Errorf( + "parsing stream entry %q: %w", streamEntries[i].ID, err, + ) + } + + missedMsgs[i] = msg + } + + return &listenMsgIterator{ + ch: ch, + missedMsgs: missedMsgs, + sinceEntryID: sinceEntryID, + cleanup: cleanup, + }, nil +} + +func (r *room) Delete(ctx context.Context, id string) error { + return r.params.Redis.Do(ctx, radix.Cmd( + nil, "XDEL", r.params.streamKey(), id, + )) +} |