diff --git a/server/channels/store/sqlstore/oauth_store.go b/server/channels/store/sqlstore/oauth_store.go index eac9fc4e5e..871845adc3 100644 --- a/server/channels/store/sqlstore/oauth_store.go +++ b/server/channels/store/sqlstore/oauth_store.go @@ -11,14 +11,34 @@ import ( "github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/v8/channels/store" + sq "github.com/mattermost/squirrel" ) type SqlOAuthStore struct { *SqlStore + oAuthAppsSelectQuery sq.SelectBuilder + oAuthAccessDataQuery sq.SelectBuilder + oAuthAuthDataQuery sq.SelectBuilder } func newSqlOAuthStore(sqlStore *SqlStore) store.OAuthStore { - return &SqlOAuthStore{sqlStore} + s := SqlOAuthStore{ + SqlStore: sqlStore, + } + + s.oAuthAppsSelectQuery = s.getQueryBuilder(). + Select("o.Id", "o.CreatorId", "o.CreateAt", "o.UpdateAt", "o.ClientSecret", "o.Name", "o.Description", "o.IconURL", "o.CallbackUrls", "o.Homepage", "o.IsTrusted", "o.MattermostAppID"). + From("OAuthApps o") + + s.oAuthAccessDataQuery = s.getQueryBuilder(). + Select("ClientId", "UserId", "Token", "RefreshToken", "RedirectUri", "ExpiresAt", "Scope"). + From("OAuthAccessData") + + s.oAuthAuthDataQuery = s.getQueryBuilder(). + Select("ClientId", "UserId", "Code", "ExpiresIn", "CreateAt", "RedirectUri", "State", "Scope"). + From("OAuthAuthData") + + return &s } func (as SqlOAuthStore) SaveApp(app *model.OAuthApp) (*model.OAuthApp, error) { @@ -48,8 +68,9 @@ func (as SqlOAuthStore) UpdateApp(app *model.OAuthApp) (*model.OAuthApp, error) } var oldApp model.OAuthApp - err := as.GetMaster().Get(&oldApp, `SELECT * FROM OAuthApps - WHERE id=?`, app.Id) + query := as.oAuthAppsSelectQuery.Where(sq.Eq{"o.Id": app.Id}) + + err := as.GetReplica().GetBuilder(&oldApp, query) if err != nil { return nil, errors.Wrapf(err, "failed to get OAuthApp with id=%s", app.Id) } @@ -80,7 +101,9 @@ func (as SqlOAuthStore) UpdateApp(app *model.OAuthApp) (*model.OAuthApp, error) func (as SqlOAuthStore) GetApp(id string) (*model.OAuthApp, error) { var app model.OAuthApp - if err := as.GetReplica().Get(&app, `SELECT * FROM OAuthApps WHERE Id=?`, id); err != nil { + query := as.oAuthAppsSelectQuery.Where(sq.Eq{"o.Id": id}) + + if err := as.GetReplica().GetBuilder(&app, query); err != nil { if err == sql.ErrNoRows { return nil, store.NewErrNotFound("OAuthApp", id) } @@ -95,7 +118,9 @@ func (as SqlOAuthStore) GetApp(id string) (*model.OAuthApp, error) { func (as SqlOAuthStore) GetAppByUser(userId string, offset, limit int) ([]*model.OAuthApp, error) { apps := []*model.OAuthApp{} - if err := as.GetReplica().Select(&apps, "SELECT * FROM OAuthApps WHERE CreatorId = ? LIMIT ? OFFSET ?", userId, limit, offset); err != nil { + query := as.oAuthAppsSelectQuery.Where(sq.Eq{"o.CreatorId": userId}).Limit(uint64(limit)).Offset(uint64(offset)) + + if err := as.GetReplica().SelectBuilder(&apps, query); err != nil { return nil, errors.Wrapf(err, "failed to find OAuthApps with userId=%s", userId) } @@ -105,7 +130,9 @@ func (as SqlOAuthStore) GetAppByUser(userId string, offset, limit int) ([]*model func (as SqlOAuthStore) GetApps(offset, limit int) ([]*model.OAuthApp, error) { apps := []*model.OAuthApp{} - if err := as.GetReplica().Select(&apps, "SELECT * FROM OAuthApps LIMIT ? OFFSET ?", limit, offset); err != nil { + query := as.oAuthAppsSelectQuery.Limit(uint64(limit)).Offset(uint64(offset)) + + if err := as.GetReplica().SelectBuilder(&apps, query); err != nil { return nil, errors.Wrap(err, "failed to find OAuthApps") } @@ -115,9 +142,12 @@ func (as SqlOAuthStore) GetApps(offset, limit int) ([]*model.OAuthApp, error) { func (as SqlOAuthStore) GetAuthorizedApps(userId string, offset, limit int) ([]*model.OAuthApp, error) { apps := []*model.OAuthApp{} - if err := as.GetReplica().Select(&apps, - `SELECT o.* FROM OAuthApps AS o INNER JOIN - Preferences AS p ON p.Name=o.Id AND p.UserId=? LIMIT ? OFFSET ?`, userId, limit, offset); err != nil { + query := as.oAuthAppsSelectQuery. + InnerJoin("Preferences AS p ON p.Name = o.Id AND p.UserId = ?", userId). + Limit(uint64(limit)). + Offset(uint64(offset)) + + if err := as.GetReplica().SelectBuilder(&apps, query); err != nil { return nil, errors.Wrapf(err, "failed to find OAuthApps with userId=%s", userId) } @@ -160,7 +190,9 @@ func (as SqlOAuthStore) SaveAccessData(accessData *model.AccessData) (*model.Acc func (as SqlOAuthStore) GetAccessData(token string) (*model.AccessData, error) { accessData := model.AccessData{} - if err := as.GetReplica().Get(&accessData, "SELECT * FROM OAuthAccessData WHERE Token = ?", token); err != nil { + query := as.oAuthAccessDataQuery.Where(sq.Eq{"Token": token}) + + if err := as.GetReplica().GetBuilder(&accessData, query); err != nil { return nil, errors.Wrapf(err, "failed to get OAuthAccessData with token=%s", token) } return &accessData, nil @@ -169,8 +201,9 @@ func (as SqlOAuthStore) GetAccessData(token string) (*model.AccessData, error) { func (as SqlOAuthStore) GetAccessDataByUserForApp(userID, clientID string) ([]*model.AccessData, error) { accessData := []*model.AccessData{} - if err := as.GetReplica().Select(&accessData, - "SELECT * FROM OAuthAccessData WHERE UserId = ? AND ClientId = ?", userID, clientID); err != nil { + query := as.oAuthAccessDataQuery.Where(sq.Eq{"UserId": userID, "ClientId": clientID}) + + if err := as.GetReplica().SelectBuilder(&accessData, query); err != nil { return nil, errors.Wrapf(err, "failed to delete OAuthAccessData with userId=%s and clientId=%s", userID, clientID) } return accessData, nil @@ -179,7 +212,9 @@ func (as SqlOAuthStore) GetAccessDataByUserForApp(userID, clientID string) ([]*m func (as SqlOAuthStore) GetAccessDataByRefreshToken(token string) (*model.AccessData, error) { accessData := model.AccessData{} - if err := as.GetReplica().Get(&accessData, "SELECT * FROM OAuthAccessData WHERE RefreshToken = ?", token); err != nil { + query := as.oAuthAccessDataQuery.Where(sq.Eq{"RefreshToken": token}) + + if err := as.GetReplica().GetBuilder(&accessData, query); err != nil { return nil, errors.Wrapf(err, "failed to find OAuthAccessData with refreshToken=%s", token) } return &accessData, nil @@ -188,13 +223,16 @@ func (as SqlOAuthStore) GetAccessDataByRefreshToken(token string) (*model.Access func (as SqlOAuthStore) GetPreviousAccessData(userID, clientID string) (*model.AccessData, error) { accessData := model.AccessData{} - if err := as.GetReplica().Get(&accessData, "SELECT * FROM OAuthAccessData WHERE ClientId = ? AND UserId = ?", clientID, userID); err != nil { + query := as.oAuthAccessDataQuery.Where(sq.Eq{"UserId": userID, "ClientId": clientID}). + OrderBy("ExpiresAt DESC").Limit(1) + + if err := as.GetReplica().GetBuilder(&accessData, query); err != nil { if err == sql.ErrNoRows { return nil, nil } - - return nil, errors.Wrapf(err, "failed to get AccessData with clientId=%s and userId=%s", clientID, userID) + return nil, errors.Wrapf(err, "failed to find OAuthAccessData with userId=%s and clientId=%s", userID, clientID) } + return &accessData, nil } @@ -240,16 +278,20 @@ func (as SqlOAuthStore) SaveAuthData(authData *model.AuthData) (*model.AuthData, func (as SqlOAuthStore) GetAuthData(code string) (*model.AuthData, error) { var authData model.AuthData - err := as.GetReplica().Get(&authData, `SELECT * FROM OAuthAuthData WHERE Code=?`, code) - if err != nil { + + query := as.oAuthAuthDataQuery.Where(sq.Eq{"Code": code}) + + if err := as.GetReplica().GetBuilder(&authData, query); err != nil { if err == sql.ErrNoRows { return nil, store.NewErrNotFound("AuthData", fmt.Sprintf("code=%s", code)) } return nil, errors.Wrapf(err, "failed to get AuthData with code=%s", code) } + if authData.Code == "" { return nil, store.NewErrNotFound("AuthData", fmt.Sprintf("code=%s", code)) } + return &authData, nil }