提交 636848c3 编写于 作者: M Mark Haines 提交者: GitHub

Add invites to the sync API (#244)

* Add table for storing invites in the syncapi

* Use the invite table to list the active invites for a user

* Update the invites table from the roomserver stream

* Include the invites table when determining the maxInviteID
上级 7a30f208
...@@ -123,6 +123,8 @@ type OutputNewInviteEvent struct { ...@@ -123,6 +123,8 @@ type OutputNewInviteEvent struct {
type OutputRetireInviteEvent struct { type OutputRetireInviteEvent struct {
// The ID of the "m.room.member" invite event. // The ID of the "m.room.member" invite event.
EventID string EventID string
// The target user ID of the "m.room.member" invite event that was retired.
TargetUserID string
// Optional event ID of the event that replaced the invite. // Optional event ID of the event that replaced the invite.
// This can be empty if the invite was rejected locally and we were unable // This can be empty if the invite was rejected locally and we were unable
// to reach the server that originally sent the invite. // to reach the server that originally sent the invite.
......
...@@ -182,11 +182,10 @@ func updateToJoinMembership( ...@@ -182,11 +182,10 @@ func updateToJoinMembership(
} }
for _, eventID := range retired { for _, eventID := range retired {
orie := api.OutputRetireInviteEvent{ orie := api.OutputRetireInviteEvent{
EventID: eventID, EventID: eventID,
Membership: join, Membership: join,
} RetiredByEventID: add.EventID(),
if add != nil { TargetUserID: *add.StateKey(),
orie.RetiredByEventID = add.EventID()
} }
updates = append(updates, api.OutputEvent{ updates = append(updates, api.OutputEvent{
Type: api.OutputTypeRetireInviteEvent, Type: api.OutputTypeRetireInviteEvent,
...@@ -215,11 +214,10 @@ func updateToLeaveMembership( ...@@ -215,11 +214,10 @@ func updateToLeaveMembership(
} }
for _, eventID := range retired { for _, eventID := range retired {
orie := api.OutputRetireInviteEvent{ orie := api.OutputRetireInviteEvent{
EventID: eventID, EventID: eventID,
Membership: newMembership, Membership: newMembership,
} RetiredByEventID: add.EventID(),
if add != nil { TargetUserID: *add.StateKey(),
orie.RetiredByEventID = add.EventID()
} }
updates = append(updates, api.OutputEvent{ updates = append(updates, api.OutputEvent{
Type: api.OutputTypeRetireInviteEvent, Type: api.OutputTypeRetireInviteEvent,
......
...@@ -86,26 +86,37 @@ func (s *OutputRoomEvent) onMessage(msg *sarama.ConsumerMessage) error { ...@@ -86,26 +86,37 @@ func (s *OutputRoomEvent) onMessage(msg *sarama.ConsumerMessage) error {
return nil return nil
} }
if output.Type != api.OutputTypeNewRoomEvent { switch output.Type {
case api.OutputTypeNewRoomEvent:
return s.onNewRoomEvent(context.TODO(), *output.NewRoomEvent)
case api.OutputTypeNewInviteEvent:
return s.onNewInviteEvent(context.TODO(), *output.NewInviteEvent)
case api.OutputTypeRetireInviteEvent:
return s.onRetireInviteEvent(context.TODO(), *output.RetireInviteEvent)
default:
log.WithField("type", output.Type).Debug( log.WithField("type", output.Type).Debug(
"roomserver output log: ignoring unknown output type", "roomserver output log: ignoring unknown output type",
) )
return nil return nil
} }
}
ev := output.NewRoomEvent.Event func (s *OutputRoomEvent) onNewRoomEvent(
ctx context.Context, msg api.OutputNewRoomEvent,
) error {
ev := msg.Event
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"event_id": ev.EventID(), "event_id": ev.EventID(),
"room_id": ev.RoomID(), "room_id": ev.RoomID(),
}).Info("received event from roomserver") }).Info("received event from roomserver")
addsStateEvents, err := s.lookupStateEvents(output.NewRoomEvent.AddsStateEventIDs, ev) addsStateEvents, err := s.lookupStateEvents(msg.AddsStateEventIDs, ev)
if err != nil { if err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"event": string(ev.JSON()), "event": string(ev.JSON()),
log.ErrorKey: err, log.ErrorKey: err,
"add": output.NewRoomEvent.AddsStateEventIDs, "add": msg.AddsStateEventIDs,
"del": output.NewRoomEvent.RemovesStateEventIDs, "del": msg.RemovesStateEventIDs,
}).Panicf("roomserver output log: state event lookup failure") }).Panicf("roomserver output log: state event lookup failure")
} }
...@@ -122,20 +133,23 @@ func (s *OutputRoomEvent) onMessage(msg *sarama.ConsumerMessage) error { ...@@ -122,20 +133,23 @@ func (s *OutputRoomEvent) onMessage(msg *sarama.ConsumerMessage) error {
} }
syncStreamPos, err := s.db.WriteEvent( syncStreamPos, err := s.db.WriteEvent(
context.TODO(), ctx,
&ev, &ev,
addsStateEvents, addsStateEvents,
output.NewRoomEvent.AddsStateEventIDs, msg.AddsStateEventIDs,
output.NewRoomEvent.RemovesStateEventIDs, msg.RemovesStateEventIDs,
) )
if err != nil {
return err
}
if err != nil { if err != nil {
// panic rather than continue with an inconsistent database // panic rather than continue with an inconsistent database
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"event": string(ev.JSON()), "event": string(ev.JSON()),
log.ErrorKey: err, log.ErrorKey: err,
"add": output.NewRoomEvent.AddsStateEventIDs, "add": msg.AddsStateEventIDs,
"del": output.NewRoomEvent.RemovesStateEventIDs, "del": msg.RemovesStateEventIDs,
}).Panicf("roomserver output log: write event failure") }).Panicf("roomserver output log: write event failure")
return nil return nil
} }
...@@ -144,6 +158,39 @@ func (s *OutputRoomEvent) onMessage(msg *sarama.ConsumerMessage) error { ...@@ -144,6 +158,39 @@ func (s *OutputRoomEvent) onMessage(msg *sarama.ConsumerMessage) error {
return nil return nil
} }
func (s *OutputRoomEvent) onNewInviteEvent(
ctx context.Context, msg api.OutputNewInviteEvent,
) error {
syncStreamPos, err := s.db.AddInviteEvent(ctx, msg.Event)
if err != nil {
// panic rather than continue with an inconsistent database
log.WithFields(log.Fields{
"event": string(msg.Event.JSON()),
log.ErrorKey: err,
}).Panicf("roomserver output log: write invite failure")
return nil
}
s.notifier.OnNewEvent(&msg.Event, "", syncStreamPos)
return nil
}
func (s *OutputRoomEvent) onRetireInviteEvent(
ctx context.Context, msg api.OutputRetireInviteEvent,
) error {
err := s.db.RetireInviteEvent(ctx, msg.EventID)
if err != nil {
// panic rather than continue with an inconsistent database
log.WithFields(log.Fields{
"event_id": msg.EventID,
log.ErrorKey: err,
}).Panicf("roomserver output log: remove invite failure")
return nil
}
// TODO: Notify any active sync requests that the invite has been retired.
// s.notifier.OnNewEvent(nil, msg.TargetUserID, syncStreamPos)
return nil
}
// lookupStateEvents looks up the state events that are added by a new event. // lookupStateEvents looks up the state events that are added by a new event.
func (s *OutputRoomEvent) lookupStateEvents( func (s *OutputRoomEvent) lookupStateEvents(
addsStateEventIDs []string, event gomatrixserverlib.Event, addsStateEventIDs []string, event gomatrixserverlib.Event,
......
...@@ -140,7 +140,10 @@ func (s *currentRoomStateStatements) selectJoinedUsers( ...@@ -140,7 +140,10 @@ func (s *currentRoomStateStatements) selectJoinedUsers(
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
func (s *currentRoomStateStatements) selectRoomIDsWithMembership( func (s *currentRoomStateStatements) selectRoomIDsWithMembership(
ctx context.Context, txn *sql.Tx, userID, membership string, ctx context.Context,
txn *sql.Tx,
userID string,
membership string, // nolint: unparam
) ([]string, error) { ) ([]string, error) {
stmt := common.TxStmt(txn, s.selectRoomIDsWithMembershipStmt) stmt := common.TxStmt(txn, s.selectRoomIDsWithMembershipStmt)
rows, err := stmt.QueryContext(ctx, userID, membership) rows, err := stmt.QueryContext(ctx, userID, membership)
......
package storage
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib"
)
const inviteEventsSchema = `
CREATE TABLE IF NOT EXISTS syncapi_invite_events (
id BIGINT PRIMARY KEY DEFAULT nextval('syncapi_stream_id'),
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
target_user_id TEXT NOT NULL,
event_json TEXT NOT NULL
);
-- For looking up the invites for a given user.
CREATE INDEX IF NOT EXISTS syncapi_invites_target_user_id_idx
ON syncapi_invite_events (target_user_id, id);
-- For deleting old invites
CREATE INDEX IF NOT EXISTS syncapi_invites_event_id_idx
ON syncapi_invite_events(target_user_id, id);
`
const insertInviteEventSQL = "" +
"INSERT INTO syncapi_invite_events (" +
" room_id, event_id, target_user_id, event_json" +
") VALUES ($1, $2, $3, $4) RETURNING id"
const deleteInviteEventSQL = "" +
"DELETE FROM syncapi_invite_events WHERE event_id = $1"
const selectInviteEventsInRangeSQL = "" +
"SELECT room_id, event_json FROM syncapi_invite_events" +
" WHERE target_user_id = $1 AND id > $2 AND id <= $3" +
" ORDER BY id DESC"
const selectMaxInviteIDSQL = "" +
"SELECT MAX(id) FROM syncapi_invite_events"
type inviteEventsStatements struct {
insertInviteEventStmt *sql.Stmt
selectInviteEventsInRangeStmt *sql.Stmt
deleteInviteEventStmt *sql.Stmt
selectMaxInviteIDStmt *sql.Stmt
}
func (s *inviteEventsStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(inviteEventsSchema)
if err != nil {
return
}
if s.insertInviteEventStmt, err = db.Prepare(insertInviteEventSQL); err != nil {
return
}
if s.selectInviteEventsInRangeStmt, err = db.Prepare(selectInviteEventsInRangeSQL); err != nil {
return
}
if s.deleteInviteEventStmt, err = db.Prepare(deleteInviteEventSQL); err != nil {
return
}
if s.selectMaxInviteIDStmt, err = db.Prepare(selectMaxInviteIDSQL); err != nil {
return
}
return
}
func (s *inviteEventsStatements) insertInviteEvent(
ctx context.Context, inviteEvent gomatrixserverlib.Event,
) (streamPos int64, err error) {
err = s.insertInviteEventStmt.QueryRowContext(
ctx,
inviteEvent.RoomID(),
inviteEvent.EventID(),
*inviteEvent.StateKey(),
inviteEvent.JSON(),
).Scan(&streamPos)
return
}
func (s *inviteEventsStatements) deleteInviteEvent(
ctx context.Context, inviteEventID string,
) error {
_, err := s.deleteInviteEventStmt.ExecContext(ctx, inviteEventID)
return err
}
// selectInviteEventsInRange returns a map of room ID to invite event for the
// active invites for the target user ID in the supplied range.
func (s *inviteEventsStatements) selectInviteEventsInRange(
ctx context.Context, txn *sql.Tx, targetUserID string, startPos, endPos int64,
) (map[string]gomatrixserverlib.Event, error) {
stmt := common.TxStmt(txn, s.selectInviteEventsInRangeStmt)
rows, err := stmt.QueryContext(ctx, targetUserID, startPos, endPos)
if err != nil {
return nil, err
}
defer rows.Close() // nolint: errcheck
result := map[string]gomatrixserverlib.Event{}
for rows.Next() {
var (
roomID string
eventJSON []byte
)
if err = rows.Scan(&roomID, &eventJSON); err != nil {
return nil, err
}
event, err := gomatrixserverlib.NewEventFromTrustedJSON(eventJSON, false)
if err != nil {
return nil, err
}
result[roomID] = event
}
return result, nil
}
func (s *inviteEventsStatements) selectMaxInviteID(
ctx context.Context, txn *sql.Tx,
) (id int64, err error) {
var nullableID sql.NullInt64
stmt := common.TxStmt(txn, s.selectMaxInviteIDStmt)
err = stmt.QueryRowContext(ctx).Scan(&nullableID)
if nullableID.Valid {
id = nullableID.Int64
}
return
}
...@@ -47,32 +47,32 @@ type SyncServerDatabase struct { ...@@ -47,32 +47,32 @@ type SyncServerDatabase struct {
accountData accountDataStatements accountData accountDataStatements
events outputRoomEventsStatements events outputRoomEventsStatements
roomstate currentRoomStateStatements roomstate currentRoomStateStatements
invites inviteEventsStatements
} }
// NewSyncServerDatabase creates a new sync server database // NewSyncServerDatabase creates a new sync server database
func NewSyncServerDatabase(dataSourceName string) (*SyncServerDatabase, error) { func NewSyncServerDatabase(dataSourceName string) (*SyncServerDatabase, error) {
var db *sql.DB var d SyncServerDatabase
var err error var err error
if db, err = sql.Open("postgres", dataSourceName); err != nil { if d.db, err = sql.Open("postgres", dataSourceName); err != nil {
return nil, err return nil, err
} }
partitions := common.PartitionOffsetStatements{} if err = d.partitions.Prepare(d.db, "syncapi"); err != nil {
if err = partitions.Prepare(db, "syncapi"); err != nil {
return nil, err return nil, err
} }
accountData := accountDataStatements{} if err = d.accountData.prepare(d.db); err != nil {
if err = accountData.prepare(db); err != nil {
return nil, err return nil, err
} }
events := outputRoomEventsStatements{} if err = d.events.prepare(d.db); err != nil {
if err = events.prepare(db); err != nil {
return nil, err return nil, err
} }
state := currentRoomStateStatements{} if err := d.roomstate.prepare(d.db); err != nil {
if err := state.prepare(db); err != nil {
return nil, err return nil, err
} }
return &SyncServerDatabase{db, partitions, accountData, events, state}, nil if err := d.invites.prepare(d.db); err != nil {
return nil, err
}
return &d, nil
} }
// AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs. // AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs.
...@@ -191,6 +191,13 @@ func (d *SyncServerDatabase) syncStreamPositionTx( ...@@ -191,6 +191,13 @@ func (d *SyncServerDatabase) syncStreamPositionTx(
if maxAccountDataID > maxID { if maxAccountDataID > maxID {
maxID = maxAccountDataID maxID = maxAccountDataID
} }
maxInviteID, err := d.invites.selectMaxInviteID(ctx, txn)
if err != nil {
return 0, err
}
if maxInviteID > maxID {
maxID = maxInviteID
}
return types.StreamPosition(maxID), nil return types.StreamPosition(maxID), nil
} }
...@@ -260,7 +267,7 @@ func (d *SyncServerDatabase) IncrementalSync( ...@@ -260,7 +267,7 @@ func (d *SyncServerDatabase) IncrementalSync(
} }
// TODO: This should be done in getStateDeltas // TODO: This should be done in getStateDeltas
if err = d.addInvitesToResponse(ctx, txn, userID, res); err != nil { if err = d.addInvitesToResponse(ctx, txn, userID, fromPos, toPos, res); err != nil {
return nil, err return nil, err
} }
...@@ -322,7 +329,7 @@ func (d *SyncServerDatabase) CompleteSync( ...@@ -322,7 +329,7 @@ func (d *SyncServerDatabase) CompleteSync(
res.Rooms.Join[roomID] = *jr res.Rooms.Join[roomID] = *jr
} }
if err = d.addInvitesToResponse(ctx, txn, userID, res); err != nil { if err = d.addInvitesToResponse(ctx, txn, userID, 0, pos, res); err != nil {
return nil, err return nil, err
} }
...@@ -364,16 +371,45 @@ func (d *SyncServerDatabase) UpsertAccountData( ...@@ -364,16 +371,45 @@ func (d *SyncServerDatabase) UpsertAccountData(
return types.StreamPosition(pos), err return types.StreamPosition(pos), err
} }
// AddInviteEvent stores a new invite event for a user.
// If the invite was successfully stored this returns the stream ID it was stored at.
// Returns an error if there was a problem communicating with the database.
func (d *SyncServerDatabase) AddInviteEvent(
ctx context.Context, inviteEvent gomatrixserverlib.Event,
) (types.StreamPosition, error) {
pos, err := d.invites.insertInviteEvent(ctx, inviteEvent)
return types.StreamPosition(pos), err
}
// RetireInviteEvent removes an old invite event from the database.
// Returns an error if there was a problem communicating with the database.
func (d *SyncServerDatabase) RetireInviteEvent(
ctx context.Context, inviteEventID string,
) error {
// TODO: Record that invite has been retired in a stream so that we can
// notify the user in an incremental sync.
err := d.invites.deleteInviteEvent(ctx, inviteEventID)
return err
}
func (d *SyncServerDatabase) addInvitesToResponse( func (d *SyncServerDatabase) addInvitesToResponse(
ctx context.Context, txn *sql.Tx, userID string, res *types.Response) error { ctx context.Context, txn *sql.Tx,
// Add invites - TODO: This will break over federation as they won't be in the current state table according to Mark. userID string,
roomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, "invite") fromPos, toPos types.StreamPosition,
res *types.Response,
) error {
invites, err := d.invites.selectInviteEventsInRange(
ctx, txn, userID, int64(fromPos), int64(toPos),
)
if err != nil { if err != nil {
return err return err
} }
for _, roomID := range roomIDs { for roomID, inviteEvent := range invites {
ir := types.NewInviteResponse() ir := types.NewInviteResponse()
// TODO: invite_state. The state won't be in the current state table in cases where you get invited over federation ir.InviteState.Events = gomatrixserverlib.ToClientEvents(
[]gomatrixserverlib.Event{inviteEvent}, gomatrixserverlib.FormatSync,
)
// TODO: add the invite state from the invite event.
res.Rooms.Invite[roomID] = *ir res.Rooms.Invite[roomID] = *ir
} }
return nil return nil
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册