diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/devices/devices_table.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/devices/devices_table.go index a614eb541dd1af3e5822c2a91a22708ce6e43034..903471afe48ebc4f55dff9cf286d8e77237f8bec 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/devices/devices_table.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/devices/devices_table.go @@ -40,7 +40,9 @@ CREATE TABLE IF NOT EXISTS device_devices ( -- migration to different domain names easier. localpart TEXT NOT NULL, -- When this devices was first recognised on the network, as a unix timestamp (ms resolution). - created_ts BIGINT NOT NULL + created_ts BIGINT NOT NULL, + -- The display name, human friendlier than device_id and updatable + display_name TEXT -- TODO: device keys, device display names, last used ts and IP address?, token restrictions (if 3rd-party OAuth app) ); @@ -49,16 +51,19 @@ CREATE UNIQUE INDEX IF NOT EXISTS device_localpart_id_idx ON device_devices(loca ` const insertDeviceSQL = "" + - "INSERT INTO device_devices(device_id, localpart, access_token, created_ts) VALUES ($1, $2, $3, $4)" + "INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name) VALUES ($1, $2, $3, $4, $5)" const selectDeviceByTokenSQL = "" + - "SELECT device_id, localpart FROM device_devices WHERE access_token = $1" + "SELECT device_id, localpart, display_name FROM device_devices WHERE access_token = $1" const selectDeviceByIDSQL = "" + - "SELECT created_ts FROM device_devices WHERE localpart = $1 and device_id = $2" + "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" const selectDevicesByLocalpartSQL = "" + - "SELECT device_id FROM device_devices WHERE localpart = $1" + "SELECT device_id, display_name FROM device_devices WHERE localpart = $1" + +const updateDeviceNameSQL = "" + + "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" const deleteDeviceSQL = "" + "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2" @@ -66,13 +71,12 @@ const deleteDeviceSQL = "" + const deleteDevicesByLocalpartSQL = "" + "DELETE FROM device_devices WHERE localpart = $1" -// TODO: List devices? - type devicesStatements struct { insertDeviceStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt selectDeviceByIDStmt *sql.Stmt selectDevicesByLocalpartStmt *sql.Stmt + updateDeviceNameStmt *sql.Stmt deleteDeviceStmt *sql.Stmt deleteDevicesByLocalpartStmt *sql.Stmt serverName gomatrixserverlib.ServerName @@ -95,6 +99,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil { return } + if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil { + return + } if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil { return } @@ -110,10 +117,11 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN // Returns the device on success. func (s *devicesStatements) insertDevice( ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, + displayName *string, ) (*authtypes.Device, error) { createdTimeMS := time.Now().UnixNano() / 1000000 stmt := common.TxStmt(txn, s.insertDeviceStmt) - if _, err := stmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS); err != nil { + if _, err := stmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName); err != nil { return nil, err } return &authtypes.Device{ @@ -139,6 +147,14 @@ func (s *devicesStatements) deleteDevicesByLocalpart( return err } +func (s *devicesStatements) updateDeviceName( + ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, +) error { + stmt := common.TxStmt(txn, s.updateDeviceNameStmt) + _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID) + return err +} + func (s *devicesStatements) selectDeviceByToken( ctx context.Context, accessToken string, ) (*authtypes.Device, error) { diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/devices/storage.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/devices/storage.go index dd98bb60985727c31dc19ed1f579713a635a1734..6ac475a66c2d3721ece7e245725b018e34b9f96f 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/devices/storage.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/devices/storage.go @@ -75,6 +75,7 @@ func (d *Database) GetDevicesByLocalpart( // Returns the device on success. func (d *Database) CreateDevice( ctx context.Context, localpart string, deviceID *string, accessToken string, + displayName *string, ) (dev *authtypes.Device, returnErr error) { if deviceID != nil { returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { @@ -84,7 +85,7 @@ func (d *Database) CreateDevice( return err } - dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken) + dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName) return err }) } else { @@ -99,7 +100,7 @@ func (d *Database) CreateDevice( returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { var err error - dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken) + dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName) return err }) if returnErr == nil { @@ -110,6 +111,16 @@ func (d *Database) CreateDevice( return } +// UpdateDevice updates the given device with the display name. +// Returns SQL error if there are problems and nil on success. +func (d *Database) UpdateDevice( + ctx context.Context, localpart, deviceID string, displayName *string, +) error { + return common.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName) + }) +} + // RemoveDevice revokes a device by deleting the entry in the database // matching with the given device ID and user ID localpart // If the device doesn't exist, it will not return an error diff --git a/src/github.com/matrix-org/dendrite/clientapi/routing/device.go b/src/github.com/matrix-org/dendrite/clientapi/routing/device.go index 9cb63bac95fad2994ed6b45286c41a5b12fc806d..86e393be1ef616ccae742b6b274062f5f328738c 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/routing/device.go +++ b/src/github.com/matrix-org/dendrite/clientapi/routing/device.go @@ -16,6 +16,7 @@ package routing import ( "database/sql" + "encoding/json" "net/http" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" @@ -35,6 +36,10 @@ type devicesJSON struct { Devices []deviceJSON `json:"devices"` } +type deviceUpdateJSON struct { + DisplayName *string `json:"display_name"` +} + // GetDeviceByID handles /device/{deviceID} func GetDeviceByID( req *http.Request, deviceDB *devices.Database, device *authtypes.Device, @@ -95,3 +100,56 @@ func GetDevicesByLocalpart( JSON: res, } } + +// UpdateDeviceByID handles PUT on /devices/{deviceID} +func UpdateDeviceByID( + req *http.Request, deviceDB *devices.Database, device *authtypes.Device, + deviceID string, +) util.JSONResponse { + if req.Method != "PUT" { + return util.JSONResponse{ + Code: 405, + JSON: jsonerror.NotFound("Bad Method"), + } + } + + localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + if err != nil { + return httputil.LogThenError(req, err) + } + + ctx := req.Context() + dev, err := deviceDB.GetDeviceByID(ctx, localpart, deviceID) + if err == sql.ErrNoRows { + return util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound("Unknown device"), + } + } else if err != nil { + return httputil.LogThenError(req, err) + } + + if dev.UserID != device.UserID { + return util.JSONResponse{ + Code: 403, + JSON: jsonerror.Forbidden("device not owned by current user"), + } + } + + defer req.Body.Close() // nolint: errcheck + + payload := deviceUpdateJSON{} + + if err := json.NewDecoder(req.Body).Decode(&payload); err != nil { + return httputil.LogThenError(req, err) + } + + if err := deviceDB.UpdateDevice(ctx, localpart, deviceID, payload.DisplayName); err != nil { + return httputil.LogThenError(req, err) + } + + return util.JSONResponse{ + Code: 200, + JSON: struct{}{}, + } +} diff --git a/src/github.com/matrix-org/dendrite/clientapi/routing/login.go b/src/github.com/matrix-org/dendrite/clientapi/routing/login.go index 2fa44d6d64ecbaf1f4273b504bea4f51b44c4001..56c67b77d9ce41ca31f8ace03df39174a8c632f2 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/routing/login.go +++ b/src/github.com/matrix-org/dendrite/clientapi/routing/login.go @@ -38,8 +38,9 @@ type flow struct { } type passwordRequest struct { - User string `json:"user"` - Password string `json:"password"` + User string `json:"user"` + Password string `json:"password"` + InitialDisplayName *string `json:"initial_device_display_name"` } type loginResponse struct { @@ -119,7 +120,7 @@ func Login( // TODO: Use the device ID in the request dev, err := deviceDB.CreateDevice( - req.Context(), acc.Localpart, nil, token, + req.Context(), acc.Localpart, nil, token, r.InitialDisplayName, ) if err != nil { return util.JSONResponse{ diff --git a/src/github.com/matrix-org/dendrite/clientapi/routing/register.go b/src/github.com/matrix-org/dendrite/clientapi/routing/register.go index 29e25764b0b68094c42277c25df6f3ebbdcc6fc8..875ceb04928137b0cfe72a14f76c35b6a68937e7 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/routing/register.go +++ b/src/github.com/matrix-org/dendrite/clientapi/routing/register.go @@ -60,6 +60,8 @@ type registerRequest struct { Admin bool `json:"admin"` // user-interactive auth params Auth authDict `json:"auth"` + + InitialDisplayName *string `json:"initial_device_display_name"` } type authDict struct { @@ -210,10 +212,10 @@ func Register( return util.MessageResponse(403, "HMAC incorrect") } - return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password) + return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, r.InitialDisplayName) case authtypes.LoginTypeDummy: // there is nothing to do - return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password) + return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, r.InitialDisplayName) default: return util.JSONResponse{ Code: 501, @@ -270,10 +272,10 @@ func LegacyRegister( return util.MessageResponse(403, "HMAC incorrect") } - return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password) + return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, nil) case authtypes.LoginTypeDummy: // there is nothing to do - return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password) + return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, nil) default: return util.JSONResponse{ Code: 501, @@ -287,6 +289,7 @@ func completeRegistration( accountDB *accounts.Database, deviceDB *devices.Database, username, password string, + displayName *string, ) util.JSONResponse { if username == "" { return util.JSONResponse{ @@ -318,7 +321,7 @@ func completeRegistration( } // // TODO: Use the device ID in the request. - dev, err := deviceDB.CreateDevice(ctx, username, nil, token) + dev, err := deviceDB.CreateDevice(ctx, username, nil, token, displayName) if err != nil { return util.JSONResponse{ Code: 500, diff --git a/src/github.com/matrix-org/dendrite/clientapi/routing/routing.go b/src/github.com/matrix-org/dendrite/clientapi/routing/routing.go index 87f52ad09990ef11db4c6eb701159dca4a465eff..9e9fea5a62577592807be5cdb7f2deea4036351d 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/routing/routing.go +++ b/src/github.com/matrix-org/dendrite/clientapi/routing/routing.go @@ -364,6 +364,13 @@ func Setup( }), ).Methods("GET") + r0mux.Handle("/devices/{deviceID}", + common.MakeAuthAPI("device_data", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + vars := mux.Vars(req) + return UpdateDeviceByID(req, deviceDB, device, vars["deviceID"]) + }), + ).Methods("PUT", "OPTIONS") + // Stub implementations for sytest r0mux.Handle("/events", common.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse { diff --git a/src/github.com/matrix-org/dendrite/cmd/create-account/main.go b/src/github.com/matrix-org/dendrite/cmd/create-account/main.go index 3d5c35878a0129884354ecae8c261419fbdf1a35..7914a6266abb06e1044a1926fc756d24353bdf90 100644 --- a/src/github.com/matrix-org/dendrite/cmd/create-account/main.go +++ b/src/github.com/matrix-org/dendrite/cmd/create-account/main.go @@ -87,7 +87,7 @@ func main() { } device, err := deviceDB.CreateDevice( - context.Background(), *username, nil, *accessToken, + context.Background(), *username, nil, *accessToken, nil, ) if err != nil { fmt.Println(err.Error())