Chore: Port oauth token service to identity requester (#73988)

* port oauth token service to identity requester

* fix broken test

* no need to check for render
This commit is contained in:
Jo 2023-08-29 11:55:58 +02:00 committed by GitHub
parent 7c98678188
commit fe1563882a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 61 additions and 30 deletions

View File

@ -27,6 +27,7 @@ import (
"github.com/grafana/grafana/pkg/infra/tracing"
"github.com/grafana/grafana/pkg/plugins"
acmock "github.com/grafana/grafana/pkg/services/accesscontrol/mock"
"github.com/grafana/grafana/pkg/services/auth/identity"
contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model"
"github.com/grafana/grafana/pkg/services/datasources"
datasourceservice "github.com/grafana/grafana/pkg/services/datasources/service"
@ -1118,7 +1119,7 @@ type mockOAuthTokenService struct {
oAuthEnabled bool
}
func (m *mockOAuthTokenService) GetCurrentOAuthToken(ctx context.Context, user *user.SignedInUser) *oauth2.Token {
func (m *mockOAuthTokenService) GetCurrentOAuthToken(ctx context.Context, user identity.Requester) *oauth2.Token {
return m.token
}
@ -1126,7 +1127,7 @@ func (m *mockOAuthTokenService) IsOAuthPassThruEnabled(ds *datasources.DataSourc
return m.oAuthEnabled
}
func (m *mockOAuthTokenService) HasOAuthEntry(context.Context, *user.SignedInUser) (*login.UserAuth, bool, error) {
func (m *mockOAuthTokenService) HasOAuthEntry(context.Context, identity.Requester) (*login.UserAuth, bool, error) {
return nil, false, nil
}

View File

@ -9,18 +9,19 @@ import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/infra/localcache"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/login/socialtest"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/auth/authtest"
"github.com/grafana/grafana/pkg/services/auth/identity"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest"
"github.com/grafana/grafana/pkg/services/user"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
@ -117,7 +118,7 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
)
service := &oauthtokentest.MockOauthTokenService{
HasOAuthEntryFunc: func(ctx context.Context, usr *user.SignedInUser) (*login.UserAuth, bool, error) {
HasOAuthEntryFunc: func(ctx context.Context, usr identity.Requester) (*login.UserAuth, bool, error) {
hasEntryCalled = true
return tt.expectedHasEntryToken, tt.expectedHasEntryToken != nil, nil
},

View File

@ -12,6 +12,7 @@ import (
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/services/auth/identity"
"github.com/grafana/grafana/pkg/services/datasources"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/user"
@ -35,9 +36,9 @@ type Service struct {
}
type OAuthTokenService interface {
GetCurrentOAuthToken(context.Context, *user.SignedInUser) *oauth2.Token
GetCurrentOAuthToken(context.Context, identity.Requester) *oauth2.Token
IsOAuthPassThruEnabled(*datasources.DataSource) bool
HasOAuthEntry(context.Context, *user.SignedInUser) (*login.UserAuth, bool, error)
HasOAuthEntry(context.Context, identity.Requester) (*login.UserAuth, bool, error)
TryTokenRefresh(context.Context, *login.UserAuth) error
InvalidateOAuthTokens(context.Context, *login.UserAuth) error
}
@ -52,20 +53,32 @@ func ProvideService(socialService social.Service, authInfoService login.AuthInfo
}
// GetCurrentOAuthToken returns the OAuth token, if any, for the authenticated user. Will try to refresh the token if it has expired.
func (o *Service) GetCurrentOAuthToken(ctx context.Context, usr *user.SignedInUser) *oauth2.Token {
if usr == nil {
func (o *Service) GetCurrentOAuthToken(ctx context.Context, usr identity.Requester) *oauth2.Token {
if usr == nil || usr.IsNil() {
// No user, therefore no token
return nil
}
authInfoQuery := &login.GetAuthInfoQuery{UserId: usr.UserID}
namespace, id := usr.GetNamespacedID()
if namespace != identity.NamespaceUser {
// Not a user, therefore no token.
return nil
}
userID, err := identity.IntIdentifier(namespace, id)
if err != nil {
logger.Error("failed to convert user id to int", "namespace", namespace, "userId", id, "error", err)
return nil
}
authInfoQuery := &login.GetAuthInfoQuery{UserId: userID}
authInfo, err := o.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
if err != nil {
if errors.Is(err, user.ErrUserNotFound) {
// Not necessarily an error. User may be logged in another way.
logger.Debug("no oauth token for user found", "userId", usr.UserID, "username", usr.Login)
logger.Debug("no oauth token for user found", "userId", userID, "username", usr.GetLogin())
} else {
logger.Error("failed to get oauth token for user", "userId", usr.UserID, "username", usr.Login, "error", err)
logger.Error("failed to get oauth token for user", "userId", userID, "username", usr.GetLogin(), "error", err)
}
return nil
}
@ -88,20 +101,31 @@ func (o *Service) IsOAuthPassThruEnabled(ds *datasources.DataSource) bool {
}
// HasOAuthEntry returns true and the UserAuth object when OAuth info exists for the specified User
func (o *Service) HasOAuthEntry(ctx context.Context, usr *user.SignedInUser) (*login.UserAuth, bool, error) {
if usr == nil {
func (o *Service) HasOAuthEntry(ctx context.Context, usr identity.Requester) (*login.UserAuth, bool, error) {
if usr == nil || usr.IsNil() {
// No user, therefore no token
return nil, false, nil
}
authInfoQuery := &login.GetAuthInfoQuery{UserId: usr.UserID}
namespace, id := usr.GetNamespacedID()
if namespace != identity.NamespaceUser {
// Not a user, therefore no token.
return nil, false, nil
}
userID, err := identity.IntIdentifier(namespace, id)
if err != nil {
return nil, false, err
}
authInfoQuery := &login.GetAuthInfoQuery{UserId: userID}
authInfo, err := o.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
if err != nil {
if errors.Is(err, user.ErrUserNotFound) {
// Not necessarily an error. User may be logged in another way.
return nil, false, nil
}
logger.Error("failed to fetch oauth token for user", "userId", usr.UserID, "username", usr.Login, "error", err)
logger.Error("failed to fetch oauth token for user", "userId", userID, "username", usr.GetLogin(), "error", err)
return nil, false, err
}
if !strings.Contains(authInfo.AuthModule, "oauth") {

View File

@ -40,7 +40,7 @@ func TestService_HasOAuthEntry(t *testing.T) {
},
{
name: "returns false and an error in case GetAuthInfo returns an error",
user: &user.SignedInUser{},
user: &user.SignedInUser{UserID: 1},
want: nil,
wantExist: false,
wantErr: true,
@ -48,7 +48,7 @@ func TestService_HasOAuthEntry(t *testing.T) {
},
{
name: "returns false without an error in case auth entry is not found",
user: &user.SignedInUser{},
user: &user.SignedInUser{UserID: 1},
want: nil,
wantExist: false,
wantErr: false,
@ -56,7 +56,7 @@ func TestService_HasOAuthEntry(t *testing.T) {
},
{
name: "returns false without an error in case the auth entry is not oauth",
user: &user.SignedInUser{},
user: &user.SignedInUser{UserID: 1},
want: nil,
wantExist: false,
wantErr: false,
@ -64,7 +64,7 @@ func TestService_HasOAuthEntry(t *testing.T) {
},
{
name: "returns true when the auth entry is found",
user: &user.SignedInUser{},
user: &user.SignedInUser{UserID: 1},
want: &login.UserAuth{AuthModule: "oauth_generic_oauth"},
wantExist: true,
wantErr: false,
@ -72,6 +72,7 @@ func TestService_HasOAuthEntry(t *testing.T) {
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
srv, authInfoStore, _ := setupOAuthTokenService(t)
authInfoStore.ExpectedOAuth = &tc.getAuthInfoUser

View File

@ -5,20 +5,20 @@ import (
"golang.org/x/oauth2"
"github.com/grafana/grafana/pkg/services/auth/identity"
"github.com/grafana/grafana/pkg/services/datasources"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/user"
)
type MockOauthTokenService struct {
GetCurrentOauthTokenFunc func(ctx context.Context, usr *user.SignedInUser) *oauth2.Token
GetCurrentOauthTokenFunc func(ctx context.Context, usr identity.Requester) *oauth2.Token
IsOAuthPassThruEnabledFunc func(ds *datasources.DataSource) bool
HasOAuthEntryFunc func(ctx context.Context, usr *user.SignedInUser) (*login.UserAuth, bool, error)
HasOAuthEntryFunc func(ctx context.Context, usr identity.Requester) (*login.UserAuth, bool, error)
InvalidateOAuthTokensFunc func(ctx context.Context, usr *login.UserAuth) error
TryTokenRefreshFunc func(ctx context.Context, usr *login.UserAuth) error
}
func (m *MockOauthTokenService) GetCurrentOAuthToken(ctx context.Context, usr *user.SignedInUser) *oauth2.Token {
func (m *MockOauthTokenService) GetCurrentOAuthToken(ctx context.Context, usr identity.Requester) *oauth2.Token {
if m.GetCurrentOauthTokenFunc != nil {
return m.GetCurrentOauthTokenFunc(ctx, usr)
}
@ -32,7 +32,7 @@ func (m *MockOauthTokenService) IsOAuthPassThruEnabled(ds *datasources.DataSourc
return false
}
func (m *MockOauthTokenService) HasOAuthEntry(ctx context.Context, usr *user.SignedInUser) (*login.UserAuth, bool, error) {
func (m *MockOauthTokenService) HasOAuthEntry(ctx context.Context, usr identity.Requester) (*login.UserAuth, bool, error) {
if m.HasOAuthEntryFunc != nil {
return m.HasOAuthEntryFunc(ctx, usr)
}

View File

@ -5,10 +5,10 @@ import (
"golang.org/x/oauth2"
"github.com/grafana/grafana/pkg/services/auth/identity"
"github.com/grafana/grafana/pkg/services/datasources"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/oauthtoken"
"github.com/grafana/grafana/pkg/services/user"
)
// Service an OAuth token service suitable for tests.
@ -21,7 +21,7 @@ func ProvideService() *Service {
return &Service{}
}
func (s *Service) GetCurrentOAuthToken(context.Context, *user.SignedInUser) *oauth2.Token {
func (s *Service) GetCurrentOAuthToken(context.Context, identity.Requester) *oauth2.Token {
return s.Token
}
@ -29,7 +29,7 @@ func (s *Service) IsOAuthPassThruEnabled(ds *datasources.DataSource) bool {
return oauthtoken.IsOAuthPassThruEnabled(ds)
}
func (s *Service) HasOAuthEntry(context.Context, *user.SignedInUser) (*login.UserAuth, bool, error) {
func (s *Service) HasOAuthEntry(context.Context, identity.Requester) (*login.UserAuth, bool, error) {
return nil, false, nil
}

View File

@ -156,7 +156,11 @@ func (u *SignedInUser) GetNamespacedID() (string, string) {
case u.IsAnonymous:
return identity.NamespaceAnonymous, ""
case u.AuthenticatedBy == "render": //import cycle render
return identity.NamespaceRenderService, fmt.Sprintf("%d", u.UserID)
if u.UserID == 0 {
return identity.NamespaceRenderService, fmt.Sprintf("%d", u.UserID)
} else { // this should never happen as u.UserID > 0 already catches this
return identity.NamespaceUser, fmt.Sprintf("%d", u.UserID)
}
}
// backwards compatibility