Chore: Port oauth token service to identity requester (#73988)

* port oauth token service to identity requester

* fix broken test

* no need to check for render
This commit is contained in:
Jo 2023-08-29 11:55:58 +02:00 committed by GitHub
parent 7c98678188
commit fe1563882a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 61 additions and 30 deletions

View File

@ -27,6 +27,7 @@ import (
"github.com/grafana/grafana/pkg/infra/tracing" "github.com/grafana/grafana/pkg/infra/tracing"
"github.com/grafana/grafana/pkg/plugins" "github.com/grafana/grafana/pkg/plugins"
acmock "github.com/grafana/grafana/pkg/services/accesscontrol/mock" acmock "github.com/grafana/grafana/pkg/services/accesscontrol/mock"
"github.com/grafana/grafana/pkg/services/auth/identity"
contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model" contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model"
"github.com/grafana/grafana/pkg/services/datasources" "github.com/grafana/grafana/pkg/services/datasources"
datasourceservice "github.com/grafana/grafana/pkg/services/datasources/service" datasourceservice "github.com/grafana/grafana/pkg/services/datasources/service"
@ -1118,7 +1119,7 @@ type mockOAuthTokenService struct {
oAuthEnabled bool oAuthEnabled bool
} }
func (m *mockOAuthTokenService) GetCurrentOAuthToken(ctx context.Context, user *user.SignedInUser) *oauth2.Token { func (m *mockOAuthTokenService) GetCurrentOAuthToken(ctx context.Context, user identity.Requester) *oauth2.Token {
return m.token return m.token
} }
@ -1126,7 +1127,7 @@ func (m *mockOAuthTokenService) IsOAuthPassThruEnabled(ds *datasources.DataSourc
return m.oAuthEnabled return m.oAuthEnabled
} }
func (m *mockOAuthTokenService) HasOAuthEntry(context.Context, *user.SignedInUser) (*login.UserAuth, bool, error) { func (m *mockOAuthTokenService) HasOAuthEntry(context.Context, identity.Requester) (*login.UserAuth, bool, error) {
return nil, false, nil return nil, false, nil
} }

View File

@ -9,18 +9,19 @@ import (
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/infra/localcache" "github.com/grafana/grafana/pkg/infra/localcache"
"github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/login/socialtest" "github.com/grafana/grafana/pkg/login/socialtest"
"github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/auth/authtest" "github.com/grafana/grafana/pkg/services/auth/authtest"
"github.com/grafana/grafana/pkg/services/auth/identity"
"github.com/grafana/grafana/pkg/services/authn" "github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/login" "github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest" "github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest"
"github.com/grafana/grafana/pkg/services/user"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) { func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
@ -117,7 +118,7 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
) )
service := &oauthtokentest.MockOauthTokenService{ service := &oauthtokentest.MockOauthTokenService{
HasOAuthEntryFunc: func(ctx context.Context, usr *user.SignedInUser) (*login.UserAuth, bool, error) { HasOAuthEntryFunc: func(ctx context.Context, usr identity.Requester) (*login.UserAuth, bool, error) {
hasEntryCalled = true hasEntryCalled = true
return tt.expectedHasEntryToken, tt.expectedHasEntryToken != nil, nil return tt.expectedHasEntryToken, tt.expectedHasEntryToken != nil, nil
}, },

View File

@ -12,6 +12,7 @@ import (
"github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/services/auth/identity"
"github.com/grafana/grafana/pkg/services/datasources" "github.com/grafana/grafana/pkg/services/datasources"
"github.com/grafana/grafana/pkg/services/login" "github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/services/user"
@ -35,9 +36,9 @@ type Service struct {
} }
type OAuthTokenService interface { type OAuthTokenService interface {
GetCurrentOAuthToken(context.Context, *user.SignedInUser) *oauth2.Token GetCurrentOAuthToken(context.Context, identity.Requester) *oauth2.Token
IsOAuthPassThruEnabled(*datasources.DataSource) bool IsOAuthPassThruEnabled(*datasources.DataSource) bool
HasOAuthEntry(context.Context, *user.SignedInUser) (*login.UserAuth, bool, error) HasOAuthEntry(context.Context, identity.Requester) (*login.UserAuth, bool, error)
TryTokenRefresh(context.Context, *login.UserAuth) error TryTokenRefresh(context.Context, *login.UserAuth) error
InvalidateOAuthTokens(context.Context, *login.UserAuth) error InvalidateOAuthTokens(context.Context, *login.UserAuth) error
} }
@ -52,20 +53,32 @@ func ProvideService(socialService social.Service, authInfoService login.AuthInfo
} }
// GetCurrentOAuthToken returns the OAuth token, if any, for the authenticated user. Will try to refresh the token if it has expired. // GetCurrentOAuthToken returns the OAuth token, if any, for the authenticated user. Will try to refresh the token if it has expired.
func (o *Service) GetCurrentOAuthToken(ctx context.Context, usr *user.SignedInUser) *oauth2.Token { func (o *Service) GetCurrentOAuthToken(ctx context.Context, usr identity.Requester) *oauth2.Token {
if usr == nil { if usr == nil || usr.IsNil() {
// No user, therefore no token // No user, therefore no token
return nil return nil
} }
authInfoQuery := &login.GetAuthInfoQuery{UserId: usr.UserID} namespace, id := usr.GetNamespacedID()
if namespace != identity.NamespaceUser {
// Not a user, therefore no token.
return nil
}
userID, err := identity.IntIdentifier(namespace, id)
if err != nil {
logger.Error("failed to convert user id to int", "namespace", namespace, "userId", id, "error", err)
return nil
}
authInfoQuery := &login.GetAuthInfoQuery{UserId: userID}
authInfo, err := o.AuthInfoService.GetAuthInfo(ctx, authInfoQuery) authInfo, err := o.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
if err != nil { if err != nil {
if errors.Is(err, user.ErrUserNotFound) { if errors.Is(err, user.ErrUserNotFound) {
// Not necessarily an error. User may be logged in another way. // Not necessarily an error. User may be logged in another way.
logger.Debug("no oauth token for user found", "userId", usr.UserID, "username", usr.Login) logger.Debug("no oauth token for user found", "userId", userID, "username", usr.GetLogin())
} else { } else {
logger.Error("failed to get oauth token for user", "userId", usr.UserID, "username", usr.Login, "error", err) logger.Error("failed to get oauth token for user", "userId", userID, "username", usr.GetLogin(), "error", err)
} }
return nil return nil
} }
@ -88,20 +101,31 @@ func (o *Service) IsOAuthPassThruEnabled(ds *datasources.DataSource) bool {
} }
// HasOAuthEntry returns true and the UserAuth object when OAuth info exists for the specified User // HasOAuthEntry returns true and the UserAuth object when OAuth info exists for the specified User
func (o *Service) HasOAuthEntry(ctx context.Context, usr *user.SignedInUser) (*login.UserAuth, bool, error) { func (o *Service) HasOAuthEntry(ctx context.Context, usr identity.Requester) (*login.UserAuth, bool, error) {
if usr == nil { if usr == nil || usr.IsNil() {
// No user, therefore no token // No user, therefore no token
return nil, false, nil return nil, false, nil
} }
authInfoQuery := &login.GetAuthInfoQuery{UserId: usr.UserID} namespace, id := usr.GetNamespacedID()
if namespace != identity.NamespaceUser {
// Not a user, therefore no token.
return nil, false, nil
}
userID, err := identity.IntIdentifier(namespace, id)
if err != nil {
return nil, false, err
}
authInfoQuery := &login.GetAuthInfoQuery{UserId: userID}
authInfo, err := o.AuthInfoService.GetAuthInfo(ctx, authInfoQuery) authInfo, err := o.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
if err != nil { if err != nil {
if errors.Is(err, user.ErrUserNotFound) { if errors.Is(err, user.ErrUserNotFound) {
// Not necessarily an error. User may be logged in another way. // Not necessarily an error. User may be logged in another way.
return nil, false, nil return nil, false, nil
} }
logger.Error("failed to fetch oauth token for user", "userId", usr.UserID, "username", usr.Login, "error", err) logger.Error("failed to fetch oauth token for user", "userId", userID, "username", usr.GetLogin(), "error", err)
return nil, false, err return nil, false, err
} }
if !strings.Contains(authInfo.AuthModule, "oauth") { if !strings.Contains(authInfo.AuthModule, "oauth") {

View File

@ -40,7 +40,7 @@ func TestService_HasOAuthEntry(t *testing.T) {
}, },
{ {
name: "returns false and an error in case GetAuthInfo returns an error", name: "returns false and an error in case GetAuthInfo returns an error",
user: &user.SignedInUser{}, user: &user.SignedInUser{UserID: 1},
want: nil, want: nil,
wantExist: false, wantExist: false,
wantErr: true, wantErr: true,
@ -48,7 +48,7 @@ func TestService_HasOAuthEntry(t *testing.T) {
}, },
{ {
name: "returns false without an error in case auth entry is not found", name: "returns false without an error in case auth entry is not found",
user: &user.SignedInUser{}, user: &user.SignedInUser{UserID: 1},
want: nil, want: nil,
wantExist: false, wantExist: false,
wantErr: false, wantErr: false,
@ -56,7 +56,7 @@ func TestService_HasOAuthEntry(t *testing.T) {
}, },
{ {
name: "returns false without an error in case the auth entry is not oauth", name: "returns false without an error in case the auth entry is not oauth",
user: &user.SignedInUser{}, user: &user.SignedInUser{UserID: 1},
want: nil, want: nil,
wantExist: false, wantExist: false,
wantErr: false, wantErr: false,
@ -64,7 +64,7 @@ func TestService_HasOAuthEntry(t *testing.T) {
}, },
{ {
name: "returns true when the auth entry is found", name: "returns true when the auth entry is found",
user: &user.SignedInUser{}, user: &user.SignedInUser{UserID: 1},
want: &login.UserAuth{AuthModule: "oauth_generic_oauth"}, want: &login.UserAuth{AuthModule: "oauth_generic_oauth"},
wantExist: true, wantExist: true,
wantErr: false, wantErr: false,
@ -72,6 +72,7 @@ func TestService_HasOAuthEntry(t *testing.T) {
}, },
} }
for _, tc := range testCases { for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
srv, authInfoStore, _ := setupOAuthTokenService(t) srv, authInfoStore, _ := setupOAuthTokenService(t)
authInfoStore.ExpectedOAuth = &tc.getAuthInfoUser authInfoStore.ExpectedOAuth = &tc.getAuthInfoUser

View File

@ -5,20 +5,20 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
"github.com/grafana/grafana/pkg/services/auth/identity"
"github.com/grafana/grafana/pkg/services/datasources" "github.com/grafana/grafana/pkg/services/datasources"
"github.com/grafana/grafana/pkg/services/login" "github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/user"
) )
type MockOauthTokenService struct { type MockOauthTokenService struct {
GetCurrentOauthTokenFunc func(ctx context.Context, usr *user.SignedInUser) *oauth2.Token GetCurrentOauthTokenFunc func(ctx context.Context, usr identity.Requester) *oauth2.Token
IsOAuthPassThruEnabledFunc func(ds *datasources.DataSource) bool IsOAuthPassThruEnabledFunc func(ds *datasources.DataSource) bool
HasOAuthEntryFunc func(ctx context.Context, usr *user.SignedInUser) (*login.UserAuth, bool, error) HasOAuthEntryFunc func(ctx context.Context, usr identity.Requester) (*login.UserAuth, bool, error)
InvalidateOAuthTokensFunc func(ctx context.Context, usr *login.UserAuth) error InvalidateOAuthTokensFunc func(ctx context.Context, usr *login.UserAuth) error
TryTokenRefreshFunc func(ctx context.Context, usr *login.UserAuth) error TryTokenRefreshFunc func(ctx context.Context, usr *login.UserAuth) error
} }
func (m *MockOauthTokenService) GetCurrentOAuthToken(ctx context.Context, usr *user.SignedInUser) *oauth2.Token { func (m *MockOauthTokenService) GetCurrentOAuthToken(ctx context.Context, usr identity.Requester) *oauth2.Token {
if m.GetCurrentOauthTokenFunc != nil { if m.GetCurrentOauthTokenFunc != nil {
return m.GetCurrentOauthTokenFunc(ctx, usr) return m.GetCurrentOauthTokenFunc(ctx, usr)
} }
@ -32,7 +32,7 @@ func (m *MockOauthTokenService) IsOAuthPassThruEnabled(ds *datasources.DataSourc
return false return false
} }
func (m *MockOauthTokenService) HasOAuthEntry(ctx context.Context, usr *user.SignedInUser) (*login.UserAuth, bool, error) { func (m *MockOauthTokenService) HasOAuthEntry(ctx context.Context, usr identity.Requester) (*login.UserAuth, bool, error) {
if m.HasOAuthEntryFunc != nil { if m.HasOAuthEntryFunc != nil {
return m.HasOAuthEntryFunc(ctx, usr) return m.HasOAuthEntryFunc(ctx, usr)
} }

View File

@ -5,10 +5,10 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
"github.com/grafana/grafana/pkg/services/auth/identity"
"github.com/grafana/grafana/pkg/services/datasources" "github.com/grafana/grafana/pkg/services/datasources"
"github.com/grafana/grafana/pkg/services/login" "github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/oauthtoken" "github.com/grafana/grafana/pkg/services/oauthtoken"
"github.com/grafana/grafana/pkg/services/user"
) )
// Service an OAuth token service suitable for tests. // Service an OAuth token service suitable for tests.
@ -21,7 +21,7 @@ func ProvideService() *Service {
return &Service{} return &Service{}
} }
func (s *Service) GetCurrentOAuthToken(context.Context, *user.SignedInUser) *oauth2.Token { func (s *Service) GetCurrentOAuthToken(context.Context, identity.Requester) *oauth2.Token {
return s.Token return s.Token
} }
@ -29,7 +29,7 @@ func (s *Service) IsOAuthPassThruEnabled(ds *datasources.DataSource) bool {
return oauthtoken.IsOAuthPassThruEnabled(ds) return oauthtoken.IsOAuthPassThruEnabled(ds)
} }
func (s *Service) HasOAuthEntry(context.Context, *user.SignedInUser) (*login.UserAuth, bool, error) { func (s *Service) HasOAuthEntry(context.Context, identity.Requester) (*login.UserAuth, bool, error) {
return nil, false, nil return nil, false, nil
} }

View File

@ -156,7 +156,11 @@ func (u *SignedInUser) GetNamespacedID() (string, string) {
case u.IsAnonymous: case u.IsAnonymous:
return identity.NamespaceAnonymous, "" return identity.NamespaceAnonymous, ""
case u.AuthenticatedBy == "render": //import cycle render case u.AuthenticatedBy == "render": //import cycle render
return identity.NamespaceRenderService, fmt.Sprintf("%d", u.UserID) if u.UserID == 0 {
return identity.NamespaceRenderService, fmt.Sprintf("%d", u.UserID)
} else { // this should never happen as u.UserID > 0 already catches this
return identity.NamespaceUser, fmt.Sprintf("%d", u.UserID)
}
} }
// backwards compatibility // backwards compatibility