[MM-62152] Avoid SELECT * in oauth_store.go (#30080)

* refractored select queries

* test fail fix

* linting issues

* use builder pattern

* simplify GetAuthorizedApps

* revert to base behaviour

---------

Co-authored-by: Jesse Hallam <jesse.hallam@gmail.com>
This commit is contained in:
Arya Khochare 2025-02-08 02:08:49 +05:30 committed by GitHub
parent 0bdae41603
commit b11d536774
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -11,14 +11,34 @@ import (
"github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/public/model"
"github.com/mattermost/mattermost/server/v8/channels/store" "github.com/mattermost/mattermost/server/v8/channels/store"
sq "github.com/mattermost/squirrel"
) )
type SqlOAuthStore struct { type SqlOAuthStore struct {
*SqlStore *SqlStore
oAuthAppsSelectQuery sq.SelectBuilder
oAuthAccessDataQuery sq.SelectBuilder
oAuthAuthDataQuery sq.SelectBuilder
} }
func newSqlOAuthStore(sqlStore *SqlStore) store.OAuthStore { func newSqlOAuthStore(sqlStore *SqlStore) store.OAuthStore {
return &SqlOAuthStore{sqlStore} s := SqlOAuthStore{
SqlStore: sqlStore,
}
s.oAuthAppsSelectQuery = s.getQueryBuilder().
Select("o.Id", "o.CreatorId", "o.CreateAt", "o.UpdateAt", "o.ClientSecret", "o.Name", "o.Description", "o.IconURL", "o.CallbackUrls", "o.Homepage", "o.IsTrusted", "o.MattermostAppID").
From("OAuthApps o")
s.oAuthAccessDataQuery = s.getQueryBuilder().
Select("ClientId", "UserId", "Token", "RefreshToken", "RedirectUri", "ExpiresAt", "Scope").
From("OAuthAccessData")
s.oAuthAuthDataQuery = s.getQueryBuilder().
Select("ClientId", "UserId", "Code", "ExpiresIn", "CreateAt", "RedirectUri", "State", "Scope").
From("OAuthAuthData")
return &s
} }
func (as SqlOAuthStore) SaveApp(app *model.OAuthApp) (*model.OAuthApp, error) { func (as SqlOAuthStore) SaveApp(app *model.OAuthApp) (*model.OAuthApp, error) {
@ -48,8 +68,9 @@ func (as SqlOAuthStore) UpdateApp(app *model.OAuthApp) (*model.OAuthApp, error)
} }
var oldApp model.OAuthApp var oldApp model.OAuthApp
err := as.GetMaster().Get(&oldApp, `SELECT * FROM OAuthApps query := as.oAuthAppsSelectQuery.Where(sq.Eq{"o.Id": app.Id})
WHERE id=?`, app.Id)
err := as.GetReplica().GetBuilder(&oldApp, query)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "failed to get OAuthApp with id=%s", app.Id) return nil, errors.Wrapf(err, "failed to get OAuthApp with id=%s", app.Id)
} }
@ -80,7 +101,9 @@ func (as SqlOAuthStore) UpdateApp(app *model.OAuthApp) (*model.OAuthApp, error)
func (as SqlOAuthStore) GetApp(id string) (*model.OAuthApp, error) { func (as SqlOAuthStore) GetApp(id string) (*model.OAuthApp, error) {
var app model.OAuthApp var app model.OAuthApp
if err := as.GetReplica().Get(&app, `SELECT * FROM OAuthApps WHERE Id=?`, id); err != nil { query := as.oAuthAppsSelectQuery.Where(sq.Eq{"o.Id": id})
if err := as.GetReplica().GetBuilder(&app, query); err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, store.NewErrNotFound("OAuthApp", id) return nil, store.NewErrNotFound("OAuthApp", id)
} }
@ -95,7 +118,9 @@ func (as SqlOAuthStore) GetApp(id string) (*model.OAuthApp, error) {
func (as SqlOAuthStore) GetAppByUser(userId string, offset, limit int) ([]*model.OAuthApp, error) { func (as SqlOAuthStore) GetAppByUser(userId string, offset, limit int) ([]*model.OAuthApp, error) {
apps := []*model.OAuthApp{} apps := []*model.OAuthApp{}
if err := as.GetReplica().Select(&apps, "SELECT * FROM OAuthApps WHERE CreatorId = ? LIMIT ? OFFSET ?", userId, limit, offset); err != nil { query := as.oAuthAppsSelectQuery.Where(sq.Eq{"o.CreatorId": userId}).Limit(uint64(limit)).Offset(uint64(offset))
if err := as.GetReplica().SelectBuilder(&apps, query); err != nil {
return nil, errors.Wrapf(err, "failed to find OAuthApps with userId=%s", userId) return nil, errors.Wrapf(err, "failed to find OAuthApps with userId=%s", userId)
} }
@ -105,7 +130,9 @@ func (as SqlOAuthStore) GetAppByUser(userId string, offset, limit int) ([]*model
func (as SqlOAuthStore) GetApps(offset, limit int) ([]*model.OAuthApp, error) { func (as SqlOAuthStore) GetApps(offset, limit int) ([]*model.OAuthApp, error) {
apps := []*model.OAuthApp{} apps := []*model.OAuthApp{}
if err := as.GetReplica().Select(&apps, "SELECT * FROM OAuthApps LIMIT ? OFFSET ?", limit, offset); err != nil { query := as.oAuthAppsSelectQuery.Limit(uint64(limit)).Offset(uint64(offset))
if err := as.GetReplica().SelectBuilder(&apps, query); err != nil {
return nil, errors.Wrap(err, "failed to find OAuthApps") return nil, errors.Wrap(err, "failed to find OAuthApps")
} }
@ -115,9 +142,12 @@ func (as SqlOAuthStore) GetApps(offset, limit int) ([]*model.OAuthApp, error) {
func (as SqlOAuthStore) GetAuthorizedApps(userId string, offset, limit int) ([]*model.OAuthApp, error) { func (as SqlOAuthStore) GetAuthorizedApps(userId string, offset, limit int) ([]*model.OAuthApp, error) {
apps := []*model.OAuthApp{} apps := []*model.OAuthApp{}
if err := as.GetReplica().Select(&apps, query := as.oAuthAppsSelectQuery.
`SELECT o.* FROM OAuthApps AS o INNER JOIN InnerJoin("Preferences AS p ON p.Name = o.Id AND p.UserId = ?", userId).
Preferences AS p ON p.Name=o.Id AND p.UserId=? LIMIT ? OFFSET ?`, userId, limit, offset); err != nil { Limit(uint64(limit)).
Offset(uint64(offset))
if err := as.GetReplica().SelectBuilder(&apps, query); err != nil {
return nil, errors.Wrapf(err, "failed to find OAuthApps with userId=%s", userId) return nil, errors.Wrapf(err, "failed to find OAuthApps with userId=%s", userId)
} }
@ -160,7 +190,9 @@ func (as SqlOAuthStore) SaveAccessData(accessData *model.AccessData) (*model.Acc
func (as SqlOAuthStore) GetAccessData(token string) (*model.AccessData, error) { func (as SqlOAuthStore) GetAccessData(token string) (*model.AccessData, error) {
accessData := model.AccessData{} accessData := model.AccessData{}
if err := as.GetReplica().Get(&accessData, "SELECT * FROM OAuthAccessData WHERE Token = ?", token); err != nil { query := as.oAuthAccessDataQuery.Where(sq.Eq{"Token": token})
if err := as.GetReplica().GetBuilder(&accessData, query); err != nil {
return nil, errors.Wrapf(err, "failed to get OAuthAccessData with token=%s", token) return nil, errors.Wrapf(err, "failed to get OAuthAccessData with token=%s", token)
} }
return &accessData, nil return &accessData, nil
@ -169,8 +201,9 @@ func (as SqlOAuthStore) GetAccessData(token string) (*model.AccessData, error) {
func (as SqlOAuthStore) GetAccessDataByUserForApp(userID, clientID string) ([]*model.AccessData, error) { func (as SqlOAuthStore) GetAccessDataByUserForApp(userID, clientID string) ([]*model.AccessData, error) {
accessData := []*model.AccessData{} accessData := []*model.AccessData{}
if err := as.GetReplica().Select(&accessData, query := as.oAuthAccessDataQuery.Where(sq.Eq{"UserId": userID, "ClientId": clientID})
"SELECT * FROM OAuthAccessData WHERE UserId = ? AND ClientId = ?", userID, clientID); err != nil {
if err := as.GetReplica().SelectBuilder(&accessData, query); err != nil {
return nil, errors.Wrapf(err, "failed to delete OAuthAccessData with userId=%s and clientId=%s", userID, clientID) return nil, errors.Wrapf(err, "failed to delete OAuthAccessData with userId=%s and clientId=%s", userID, clientID)
} }
return accessData, nil return accessData, nil
@ -179,7 +212,9 @@ func (as SqlOAuthStore) GetAccessDataByUserForApp(userID, clientID string) ([]*m
func (as SqlOAuthStore) GetAccessDataByRefreshToken(token string) (*model.AccessData, error) { func (as SqlOAuthStore) GetAccessDataByRefreshToken(token string) (*model.AccessData, error) {
accessData := model.AccessData{} accessData := model.AccessData{}
if err := as.GetReplica().Get(&accessData, "SELECT * FROM OAuthAccessData WHERE RefreshToken = ?", token); err != nil { query := as.oAuthAccessDataQuery.Where(sq.Eq{"RefreshToken": token})
if err := as.GetReplica().GetBuilder(&accessData, query); err != nil {
return nil, errors.Wrapf(err, "failed to find OAuthAccessData with refreshToken=%s", token) return nil, errors.Wrapf(err, "failed to find OAuthAccessData with refreshToken=%s", token)
} }
return &accessData, nil return &accessData, nil
@ -188,13 +223,16 @@ func (as SqlOAuthStore) GetAccessDataByRefreshToken(token string) (*model.Access
func (as SqlOAuthStore) GetPreviousAccessData(userID, clientID string) (*model.AccessData, error) { func (as SqlOAuthStore) GetPreviousAccessData(userID, clientID string) (*model.AccessData, error) {
accessData := model.AccessData{} accessData := model.AccessData{}
if err := as.GetReplica().Get(&accessData, "SELECT * FROM OAuthAccessData WHERE ClientId = ? AND UserId = ?", clientID, userID); err != nil { query := as.oAuthAccessDataQuery.Where(sq.Eq{"UserId": userID, "ClientId": clientID}).
OrderBy("ExpiresAt DESC").Limit(1)
if err := as.GetReplica().GetBuilder(&accessData, query); err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil return nil, nil
} }
return nil, errors.Wrapf(err, "failed to find OAuthAccessData with userId=%s and clientId=%s", userID, clientID)
return nil, errors.Wrapf(err, "failed to get AccessData with clientId=%s and userId=%s", clientID, userID)
} }
return &accessData, nil return &accessData, nil
} }
@ -240,16 +278,20 @@ func (as SqlOAuthStore) SaveAuthData(authData *model.AuthData) (*model.AuthData,
func (as SqlOAuthStore) GetAuthData(code string) (*model.AuthData, error) { func (as SqlOAuthStore) GetAuthData(code string) (*model.AuthData, error) {
var authData model.AuthData var authData model.AuthData
err := as.GetReplica().Get(&authData, `SELECT * FROM OAuthAuthData WHERE Code=?`, code)
if err != nil { query := as.oAuthAuthDataQuery.Where(sq.Eq{"Code": code})
if err := as.GetReplica().GetBuilder(&authData, query); err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, store.NewErrNotFound("AuthData", fmt.Sprintf("code=%s", code)) return nil, store.NewErrNotFound("AuthData", fmt.Sprintf("code=%s", code))
} }
return nil, errors.Wrapf(err, "failed to get AuthData with code=%s", code) return nil, errors.Wrapf(err, "failed to get AuthData with code=%s", code)
} }
if authData.Code == "" { if authData.Code == "" {
return nil, store.NewErrNotFound("AuthData", fmt.Sprintf("code=%s", code)) return nil, store.NewErrNotFound("AuthData", fmt.Sprintf("code=%s", code))
} }
return &authData, nil return &authData, nil
} }