提交 808c2e09 编写于 作者: M Mark Haines 提交者: GitHub

Make txn *sql.Tx arguments optional everywhere using a utility function (#191)

* Make txn *sql.Tx arguments optional everywhere using a utility function

* Clarify that if the txn is nil the stmt will run outside a transaction
上级 57b70973
......@@ -55,3 +55,14 @@ func WithTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
succeeded = true
return
}
// TxStmt wraps an SQL stmt inside an optional transaction.
// If the transaction is nil then it returns the original statement that will
// run outside of a transaction.
// Otherwise returns a copy of the statement that will run inside the transaction.
func TxStmt(transaction *sql.Tx, statement *sql.Stmt) *sql.Stmt {
if transaction != nil {
statement = transaction.Stmt(statement)
}
return statement
}
......@@ -18,6 +18,7 @@ import (
"database/sql"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/federationsender/types"
"github.com/matrix-org/gomatrixserverlib"
)
......@@ -79,18 +80,18 @@ func (s *joinedHostsStatements) prepare(db *sql.DB) (err error) {
func (s *joinedHostsStatements) insertJoinedHosts(
txn *sql.Tx, roomID, eventID string, serverName gomatrixserverlib.ServerName,
) error {
_, err := txn.Stmt(s.insertJoinedHostsStmt).Exec(roomID, eventID, serverName)
_, err := common.TxStmt(txn, s.insertJoinedHostsStmt).Exec(roomID, eventID, serverName)
return err
}
func (s *joinedHostsStatements) deleteJoinedHosts(txn *sql.Tx, eventIDs []string) error {
_, err := txn.Stmt(s.deleteJoinedHostsStmt).Exec(pq.StringArray(eventIDs))
_, err := common.TxStmt(txn, s.deleteJoinedHostsStmt).Exec(pq.StringArray(eventIDs))
return err
}
func (s *joinedHostsStatements) selectJoinedHosts(txn *sql.Tx, roomID string,
) ([]types.JoinedHost, error) {
rows, err := txn.Stmt(s.selectJoinedHostsStmt).Query(roomID)
rows, err := common.TxStmt(txn, s.selectJoinedHostsStmt).Query(roomID)
if err != nil {
return nil, err
}
......
......@@ -16,6 +16,8 @@ package storage
import (
"database/sql"
"github.com/matrix-org/dendrite/common"
)
const roomSchema = `
......@@ -65,7 +67,7 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) {
// insertRoom inserts the room if it didn't already exist.
// If the room didn't exist then last_event_id is set to the empty string.
func (s *roomStatements) insertRoom(txn *sql.Tx, roomID string) error {
_, err := txn.Stmt(s.insertRoomStmt).Exec(roomID)
_, err := common.TxStmt(txn, s.insertRoomStmt).Exec(roomID)
return err
}
......@@ -74,7 +76,7 @@ func (s *roomStatements) insertRoom(txn *sql.Tx, roomID string) error {
// exists by calling insertRoom first.
func (s *roomStatements) selectRoomForUpdate(txn *sql.Tx, roomID string) (string, error) {
var lastEventID string
err := txn.Stmt(s.selectRoomForUpdateStmt).QueryRow(roomID).Scan(&lastEventID)
err := common.TxStmt(txn, s.selectRoomForUpdateStmt).QueryRow(roomID).Scan(&lastEventID)
if err != nil {
return "", err
}
......@@ -84,6 +86,6 @@ func (s *roomStatements) selectRoomForUpdate(txn *sql.Tx, roomID string) (string
// updateRoom updates the last_event_id for the room. selectRoomForUpdate should
// have already been called earlier within the transaction.
func (s *roomStatements) updateRoom(txn *sql.Tx, roomID, lastEventID string) error {
_, err := txn.Stmt(s.updateRoomStmt).Exec(roomID, lastEventID)
_, err := common.TxStmt(txn, s.updateRoomStmt).Exec(roomID, lastEventID)
return err
}
......@@ -18,6 +18,7 @@ import (
"database/sql"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/types"
)
......@@ -92,21 +93,13 @@ func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) {
func (s *eventStateKeyStatements) insertEventStateKeyNID(txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) {
var eventStateKeyNID int64
stmt := s.insertEventStateKeyNIDStmt
if txn != nil {
stmt = txn.Stmt(stmt)
}
err := stmt.QueryRow(eventStateKey).Scan(&eventStateKeyNID)
err := common.TxStmt(txn, s.insertEventStateKeyNIDStmt).QueryRow(eventStateKey).Scan(&eventStateKeyNID)
return types.EventStateKeyNID(eventStateKeyNID), err
}
func (s *eventStateKeyStatements) selectEventStateKeyNID(txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) {
var eventStateKeyNID int64
stmt := s.selectEventStateKeyNIDStmt
if txn != nil {
stmt = txn.Stmt(stmt)
}
err := stmt.QueryRow(eventStateKey).Scan(&eventStateKeyNID)
err := common.TxStmt(txn, s.selectEventStateKeyNIDStmt).QueryRow(eventStateKey).Scan(&eventStateKeyNID)
return types.EventStateKeyNID(eventStateKeyNID), err
}
......@@ -131,11 +124,7 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(eventStateKeys []st
func (s *eventStateKeyStatements) selectEventStateKey(txn *sql.Tx, eventStateKeyNID types.EventStateKeyNID) (string, error) {
var eventStateKey string
stmt := s.selectEventStateKeyStmt
if txn != nil {
stmt = txn.Stmt(stmt)
}
err := stmt.QueryRow(eventStateKeyNID).Scan(&eventStateKey)
err := common.TxStmt(txn, s.selectEventStateKeyStmt).QueryRow(eventStateKeyNID).Scan(&eventStateKey)
return eventStateKey, err
}
......
......@@ -19,6 +19,7 @@ import (
"fmt"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
......@@ -253,22 +254,22 @@ func (s *eventStatements) updateEventState(eventNID types.EventNID, stateNID typ
}
func (s *eventStatements) selectEventSentToOutput(txn *sql.Tx, eventNID types.EventNID) (sentToOutput bool, err error) {
err = txn.Stmt(s.selectEventSentToOutputStmt).QueryRow(int64(eventNID)).Scan(&sentToOutput)
err = common.TxStmt(txn, s.selectEventSentToOutputStmt).QueryRow(int64(eventNID)).Scan(&sentToOutput)
return
}
func (s *eventStatements) updateEventSentToOutput(txn *sql.Tx, eventNID types.EventNID) error {
_, err := txn.Stmt(s.updateEventSentToOutputStmt).Exec(int64(eventNID))
_, err := common.TxStmt(txn, s.updateEventSentToOutputStmt).Exec(int64(eventNID))
return err
}
func (s *eventStatements) selectEventID(txn *sql.Tx, eventNID types.EventNID) (eventID string, err error) {
err = txn.Stmt(s.selectEventIDStmt).QueryRow(int64(eventNID)).Scan(&eventID)
err = common.TxStmt(txn, s.selectEventIDStmt).QueryRow(int64(eventNID)).Scan(&eventID)
return
}
func (s *eventStatements) bulkSelectStateAtEventAndReference(txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error) {
rows, err := txn.Stmt(s.bulkSelectStateAtEventAndReferenceStmt).Query(eventNIDsAsArray(eventNIDs))
rows, err := common.TxStmt(txn, s.bulkSelectStateAtEventAndReferenceStmt).Query(eventNIDsAsArray(eventNIDs))
if err != nil {
return nil, err
}
......
......@@ -17,6 +17,7 @@ package storage
import (
"database/sql"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/types"
)
......@@ -94,7 +95,7 @@ func (s *inviteStatements) insertInviteEvent(
targetUserNID, senderUserNID types.EventStateKeyNID,
inviteEventJSON []byte,
) (bool, error) {
result, err := txn.Stmt(s.insertInviteEventStmt).Exec(
result, err := common.TxStmt(txn, s.insertInviteEventStmt).Exec(
inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON,
)
if err != nil {
......@@ -110,7 +111,7 @@ func (s *inviteStatements) insertInviteEvent(
func (s *inviteStatements) updateInviteRetired(
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) ([]string, error) {
rows, err := txn.Stmt(s.updateInviteRetiredStmt).Query(roomNID, targetUserNID)
rows, err := common.TxStmt(txn, s.updateInviteRetiredStmt).Query(roomNID, targetUserNID)
if err != nil {
return nil, err
}
......
......@@ -17,6 +17,7 @@ package storage
import (
"database/sql"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/types"
)
......@@ -115,14 +116,14 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
func (s *membershipStatements) insertMembership(
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) error {
_, err := txn.Stmt(s.insertMembershipStmt).Exec(roomNID, targetUserNID)
_, err := common.TxStmt(txn, s.insertMembershipStmt).Exec(roomNID, targetUserNID)
return err
}
func (s *membershipStatements) selectMembershipForUpdate(
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (membership membershipState, err error) {
err = txn.Stmt(s.selectMembershipForUpdateStmt).QueryRow(
err = common.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRow(
roomNID, targetUserNID,
).Scan(&membership)
return
......@@ -179,7 +180,7 @@ func (s *membershipStatements) updateMembership(
senderUserNID types.EventStateKeyNID, membership membershipState,
eventNID types.EventNID,
) error {
_, err := txn.Stmt(s.updateMembershipStmt).Exec(
_, err := common.TxStmt(txn, s.updateMembershipStmt).Exec(
roomNID, targetUserNID, senderUserNID, membership, eventNID,
)
return err
......
......@@ -17,6 +17,7 @@ package storage
import (
"database/sql"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/types"
)
......@@ -73,7 +74,7 @@ func (s *previousEventStatements) prepare(db *sql.DB) (err error) {
}
func (s *previousEventStatements) insertPreviousEvent(txn *sql.Tx, previousEventID string, previousEventReferenceSHA256 []byte, eventNID types.EventNID) error {
_, err := txn.Stmt(s.insertPreviousEventStmt).Exec(previousEventID, previousEventReferenceSHA256, int64(eventNID))
_, err := common.TxStmt(txn, s.insertPreviousEventStmt).Exec(previousEventID, previousEventReferenceSHA256, int64(eventNID))
return err
}
......@@ -81,5 +82,5 @@ func (s *previousEventStatements) insertPreviousEvent(txn *sql.Tx, previousEvent
// Returns sql.ErrNoRows if the event reference doesn't exist.
func (s *previousEventStatements) selectPreviousEventExists(txn *sql.Tx, eventID string, eventReferenceSHA256 []byte) error {
var ok int64
return txn.Stmt(s.selectPreviousEventExistsStmt).QueryRow(eventID, eventReferenceSHA256).Scan(&ok)
return common.TxStmt(txn, s.selectPreviousEventExistsStmt).QueryRow(eventID, eventReferenceSHA256).Scan(&ok)
}
......@@ -18,6 +18,7 @@ import (
"database/sql"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/types"
)
......@@ -82,21 +83,13 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) {
func (s *roomStatements) insertRoomNID(txn *sql.Tx, roomID string) (types.RoomNID, error) {
var roomNID int64
stmt := s.insertRoomNIDStmt
if txn != nil {
stmt = txn.Stmt(stmt)
}
err := stmt.QueryRow(roomID).Scan(&roomNID)
err := common.TxStmt(txn, s.insertRoomNIDStmt).QueryRow(roomID).Scan(&roomNID)
return types.RoomNID(roomNID), err
}
func (s *roomStatements) selectRoomNID(txn *sql.Tx, roomID string) (types.RoomNID, error) {
var roomNID int64
stmt := s.selectRoomNIDStmt
if txn != nil {
stmt = txn.Stmt(stmt)
}
err := stmt.QueryRow(roomID).Scan(&roomNID)
err := common.TxStmt(txn, s.selectRoomNIDStmt).QueryRow(roomID).Scan(&roomNID)
return types.RoomNID(roomNID), err
}
......@@ -120,7 +113,7 @@ func (s *roomStatements) selectLatestEventsNIDsForUpdate(txn *sql.Tx, roomNID ty
var nids pq.Int64Array
var lastEventSentNID int64
var stateSnapshotNID int64
err := txn.Stmt(s.selectLatestEventNIDsForUpdateStmt).QueryRow(int64(roomNID)).Scan(&nids, &lastEventSentNID, &stateSnapshotNID)
err := common.TxStmt(txn, s.selectLatestEventNIDsForUpdateStmt).QueryRow(int64(roomNID)).Scan(&nids, &lastEventSentNID, &stateSnapshotNID)
if err != nil {
return nil, 0, 0, err
}
......@@ -135,7 +128,7 @@ func (s *roomStatements) updateLatestEventNIDs(
txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID,
stateSnapshotNID types.StateSnapshotNID,
) error {
_, err := txn.Stmt(s.updateLatestEventNIDsStmt).Exec(
_, err := common.TxStmt(txn, s.updateLatestEventNIDsStmt).Exec(
roomNID, eventNIDsAsArray(eventNIDs), int64(lastEventSentNID), int64(stateSnapshotNID),
)
return err
......
......@@ -18,6 +18,7 @@ import (
"database/sql"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib"
)
......@@ -136,7 +137,7 @@ func (s *currentRoomStateStatements) selectJoinedUsers() (map[string][]string, e
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
func (s *currentRoomStateStatements) selectRoomIDsWithMembership(txn *sql.Tx, userID, membership string) ([]string, error) {
rows, err := txn.Stmt(s.selectRoomIDsWithMembershipStmt).Query(userID, membership)
rows, err := common.TxStmt(txn, s.selectRoomIDsWithMembershipStmt).Query(userID, membership)
if err != nil {
return nil, err
}
......@@ -155,7 +156,7 @@ func (s *currentRoomStateStatements) selectRoomIDsWithMembership(txn *sql.Tx, us
// CurrentState returns all the current state events for the given room.
func (s *currentRoomStateStatements) selectCurrentState(txn *sql.Tx, roomID string) ([]gomatrixserverlib.Event, error) {
rows, err := txn.Stmt(s.selectCurrentStateStmt).Query(roomID)
rows, err := common.TxStmt(txn, s.selectCurrentStateStmt).Query(roomID)
if err != nil {
return nil, err
}
......@@ -165,21 +166,21 @@ func (s *currentRoomStateStatements) selectCurrentState(txn *sql.Tx, roomID stri
}
func (s *currentRoomStateStatements) deleteRoomStateByEventID(txn *sql.Tx, eventID string) error {
_, err := txn.Stmt(s.deleteRoomStateByEventIDStmt).Exec(eventID)
_, err := common.TxStmt(txn, s.deleteRoomStateByEventIDStmt).Exec(eventID)
return err
}
func (s *currentRoomStateStatements) upsertRoomState(
txn *sql.Tx, event gomatrixserverlib.Event, membership *string, addedAt int64,
) error {
_, err := txn.Stmt(s.upsertRoomStateStmt).Exec(
_, err := common.TxStmt(txn, s.upsertRoomStateStmt).Exec(
event.RoomID(), event.EventID(), event.Type(), *event.StateKey(), event.JSON(), membership, addedAt,
)
return err
}
func (s *currentRoomStateStatements) selectEventsWithEventIDs(txn *sql.Tx, eventIDs []string) ([]streamEvent, error) {
rows, err := txn.Stmt(s.selectEventsWithEventIDsStmt).Query(pq.StringArray(eventIDs))
rows, err := common.TxStmt(txn, s.selectEventsWithEventIDsStmt).Query(pq.StringArray(eventIDs))
if err != nil {
return nil, err
}
......
......@@ -19,6 +19,7 @@ import (
log "github.com/Sirupsen/logrus"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
......@@ -105,7 +106,7 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) {
func (s *outputRoomEventsStatements) selectStateInRange(
txn *sql.Tx, oldPos, newPos types.StreamPosition,
) (map[string]map[string]bool, map[string]streamEvent, error) {
rows, err := txn.Stmt(s.selectStateInRangeStmt).Query(oldPos, newPos)
rows, err := common.TxStmt(txn, s.selectStateInRangeStmt).Query(oldPos, newPos)
if err != nil {
return nil, nil, err
}
......@@ -167,12 +168,8 @@ func (s *outputRoomEventsStatements) selectStateInRange(
// then this function should only ever be used at startup, as it will race with inserting events if it is
// done afterwards. If there are no inserted events, 0 is returned.
func (s *outputRoomEventsStatements) selectMaxID(txn *sql.Tx) (id int64, err error) {
stmt := s.selectMaxIDStmt
if txn != nil {
stmt = txn.Stmt(stmt)
}
var nullableID sql.NullInt64
err = stmt.QueryRow().Scan(&nullableID)
err = common.TxStmt(txn, s.selectMaxIDStmt).QueryRow().Scan(&nullableID)
if nullableID.Valid {
id = nullableID.Int64
}
......@@ -182,7 +179,7 @@ func (s *outputRoomEventsStatements) selectMaxID(txn *sql.Tx) (id int64, err err
// InsertEvent into the output_room_events table. addState and removeState are an optional list of state event IDs. Returns the position
// of the inserted event.
func (s *outputRoomEventsStatements) insertEvent(txn *sql.Tx, event *gomatrixserverlib.Event, addState, removeState []string) (streamPos int64, err error) {
err = txn.Stmt(s.insertEventStmt).QueryRow(
err = common.TxStmt(txn, s.insertEventStmt).QueryRow(
event.RoomID(), event.EventID(), event.JSON(), pq.StringArray(addState), pq.StringArray(removeState),
).Scan(&streamPos)
return
......@@ -209,11 +206,7 @@ func (s *outputRoomEventsStatements) selectRecentEvents(
// Events returns the events for the given event IDs. Returns an error if any one of the event IDs given are missing
// from the database.
func (s *outputRoomEventsStatements) selectEvents(txn *sql.Tx, eventIDs []string) ([]streamEvent, error) {
stmt := s.selectEventsStmt
if txn != nil {
stmt = txn.Stmt(stmt)
}
rows, err := stmt.Query(pq.StringArray(eventIDs))
rows, err := common.TxStmt(txn, s.selectEventsStmt).Query(pq.StringArray(eventIDs))
if err != nil {
return nil, err
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册