mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
Chore: Port oauth token service to identity requester (#73988)
* port oauth token service to identity requester * fix broken test * no need to check for render
This commit is contained in:
parent
7c98678188
commit
fe1563882a
@ -27,6 +27,7 @@ import (
|
||||
"github.com/grafana/grafana/pkg/infra/tracing"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
acmock "github.com/grafana/grafana/pkg/services/accesscontrol/mock"
|
||||
"github.com/grafana/grafana/pkg/services/auth/identity"
|
||||
contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model"
|
||||
"github.com/grafana/grafana/pkg/services/datasources"
|
||||
datasourceservice "github.com/grafana/grafana/pkg/services/datasources/service"
|
||||
@ -1118,7 +1119,7 @@ type mockOAuthTokenService struct {
|
||||
oAuthEnabled bool
|
||||
}
|
||||
|
||||
func (m *mockOAuthTokenService) GetCurrentOAuthToken(ctx context.Context, user *user.SignedInUser) *oauth2.Token {
|
||||
func (m *mockOAuthTokenService) GetCurrentOAuthToken(ctx context.Context, user identity.Requester) *oauth2.Token {
|
||||
return m.token
|
||||
}
|
||||
|
||||
@ -1126,7 +1127,7 @@ func (m *mockOAuthTokenService) IsOAuthPassThruEnabled(ds *datasources.DataSourc
|
||||
return m.oAuthEnabled
|
||||
}
|
||||
|
||||
func (m *mockOAuthTokenService) HasOAuthEntry(context.Context, *user.SignedInUser) (*login.UserAuth, bool, error) {
|
||||
func (m *mockOAuthTokenService) HasOAuthEntry(context.Context, identity.Requester) (*login.UserAuth, bool, error) {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
|
@ -9,18 +9,19 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/localcache"
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/login/socialtest"
|
||||
"github.com/grafana/grafana/pkg/services/auth"
|
||||
"github.com/grafana/grafana/pkg/services/auth/authtest"
|
||||
"github.com/grafana/grafana/pkg/services/auth/identity"
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
|
||||
@ -117,7 +118,7 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
|
||||
)
|
||||
|
||||
service := &oauthtokentest.MockOauthTokenService{
|
||||
HasOAuthEntryFunc: func(ctx context.Context, usr *user.SignedInUser) (*login.UserAuth, bool, error) {
|
||||
HasOAuthEntryFunc: func(ctx context.Context, usr identity.Requester) (*login.UserAuth, bool, error) {
|
||||
hasEntryCalled = true
|
||||
return tt.expectedHasEntryToken, tt.expectedHasEntryToken != nil, nil
|
||||
},
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/services/auth/identity"
|
||||
"github.com/grafana/grafana/pkg/services/datasources"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
@ -35,9 +36,9 @@ type Service struct {
|
||||
}
|
||||
|
||||
type OAuthTokenService interface {
|
||||
GetCurrentOAuthToken(context.Context, *user.SignedInUser) *oauth2.Token
|
||||
GetCurrentOAuthToken(context.Context, identity.Requester) *oauth2.Token
|
||||
IsOAuthPassThruEnabled(*datasources.DataSource) bool
|
||||
HasOAuthEntry(context.Context, *user.SignedInUser) (*login.UserAuth, bool, error)
|
||||
HasOAuthEntry(context.Context, identity.Requester) (*login.UserAuth, bool, error)
|
||||
TryTokenRefresh(context.Context, *login.UserAuth) error
|
||||
InvalidateOAuthTokens(context.Context, *login.UserAuth) error
|
||||
}
|
||||
@ -52,20 +53,32 @@ func ProvideService(socialService social.Service, authInfoService login.AuthInfo
|
||||
}
|
||||
|
||||
// GetCurrentOAuthToken returns the OAuth token, if any, for the authenticated user. Will try to refresh the token if it has expired.
|
||||
func (o *Service) GetCurrentOAuthToken(ctx context.Context, usr *user.SignedInUser) *oauth2.Token {
|
||||
if usr == nil {
|
||||
func (o *Service) GetCurrentOAuthToken(ctx context.Context, usr identity.Requester) *oauth2.Token {
|
||||
if usr == nil || usr.IsNil() {
|
||||
// No user, therefore no token
|
||||
return nil
|
||||
}
|
||||
|
||||
authInfoQuery := &login.GetAuthInfoQuery{UserId: usr.UserID}
|
||||
namespace, id := usr.GetNamespacedID()
|
||||
if namespace != identity.NamespaceUser {
|
||||
// Not a user, therefore no token.
|
||||
return nil
|
||||
}
|
||||
|
||||
userID, err := identity.IntIdentifier(namespace, id)
|
||||
if err != nil {
|
||||
logger.Error("failed to convert user id to int", "namespace", namespace, "userId", id, "error", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
authInfoQuery := &login.GetAuthInfoQuery{UserId: userID}
|
||||
authInfo, err := o.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
|
||||
if err != nil {
|
||||
if errors.Is(err, user.ErrUserNotFound) {
|
||||
// Not necessarily an error. User may be logged in another way.
|
||||
logger.Debug("no oauth token for user found", "userId", usr.UserID, "username", usr.Login)
|
||||
logger.Debug("no oauth token for user found", "userId", userID, "username", usr.GetLogin())
|
||||
} else {
|
||||
logger.Error("failed to get oauth token for user", "userId", usr.UserID, "username", usr.Login, "error", err)
|
||||
logger.Error("failed to get oauth token for user", "userId", userID, "username", usr.GetLogin(), "error", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -88,20 +101,31 @@ func (o *Service) IsOAuthPassThruEnabled(ds *datasources.DataSource) bool {
|
||||
}
|
||||
|
||||
// HasOAuthEntry returns true and the UserAuth object when OAuth info exists for the specified User
|
||||
func (o *Service) HasOAuthEntry(ctx context.Context, usr *user.SignedInUser) (*login.UserAuth, bool, error) {
|
||||
if usr == nil {
|
||||
func (o *Service) HasOAuthEntry(ctx context.Context, usr identity.Requester) (*login.UserAuth, bool, error) {
|
||||
if usr == nil || usr.IsNil() {
|
||||
// No user, therefore no token
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
authInfoQuery := &login.GetAuthInfoQuery{UserId: usr.UserID}
|
||||
namespace, id := usr.GetNamespacedID()
|
||||
if namespace != identity.NamespaceUser {
|
||||
// Not a user, therefore no token.
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
userID, err := identity.IntIdentifier(namespace, id)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
authInfoQuery := &login.GetAuthInfoQuery{UserId: userID}
|
||||
authInfo, err := o.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
|
||||
if err != nil {
|
||||
if errors.Is(err, user.ErrUserNotFound) {
|
||||
// Not necessarily an error. User may be logged in another way.
|
||||
return nil, false, nil
|
||||
}
|
||||
logger.Error("failed to fetch oauth token for user", "userId", usr.UserID, "username", usr.Login, "error", err)
|
||||
logger.Error("failed to fetch oauth token for user", "userId", userID, "username", usr.GetLogin(), "error", err)
|
||||
return nil, false, err
|
||||
}
|
||||
if !strings.Contains(authInfo.AuthModule, "oauth") {
|
||||
|
@ -40,7 +40,7 @@ func TestService_HasOAuthEntry(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "returns false and an error in case GetAuthInfo returns an error",
|
||||
user: &user.SignedInUser{},
|
||||
user: &user.SignedInUser{UserID: 1},
|
||||
want: nil,
|
||||
wantExist: false,
|
||||
wantErr: true,
|
||||
@ -48,7 +48,7 @@ func TestService_HasOAuthEntry(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "returns false without an error in case auth entry is not found",
|
||||
user: &user.SignedInUser{},
|
||||
user: &user.SignedInUser{UserID: 1},
|
||||
want: nil,
|
||||
wantExist: false,
|
||||
wantErr: false,
|
||||
@ -56,7 +56,7 @@ func TestService_HasOAuthEntry(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "returns false without an error in case the auth entry is not oauth",
|
||||
user: &user.SignedInUser{},
|
||||
user: &user.SignedInUser{UserID: 1},
|
||||
want: nil,
|
||||
wantExist: false,
|
||||
wantErr: false,
|
||||
@ -64,7 +64,7 @@ func TestService_HasOAuthEntry(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "returns true when the auth entry is found",
|
||||
user: &user.SignedInUser{},
|
||||
user: &user.SignedInUser{UserID: 1},
|
||||
want: &login.UserAuth{AuthModule: "oauth_generic_oauth"},
|
||||
wantExist: true,
|
||||
wantErr: false,
|
||||
@ -72,6 +72,7 @@ func TestService_HasOAuthEntry(t *testing.T) {
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
srv, authInfoStore, _ := setupOAuthTokenService(t)
|
||||
authInfoStore.ExpectedOAuth = &tc.getAuthInfoUser
|
||||
|
@ -5,20 +5,20 @@ import (
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/grafana/grafana/pkg/services/auth/identity"
|
||||
"github.com/grafana/grafana/pkg/services/datasources"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
)
|
||||
|
||||
type MockOauthTokenService struct {
|
||||
GetCurrentOauthTokenFunc func(ctx context.Context, usr *user.SignedInUser) *oauth2.Token
|
||||
GetCurrentOauthTokenFunc func(ctx context.Context, usr identity.Requester) *oauth2.Token
|
||||
IsOAuthPassThruEnabledFunc func(ds *datasources.DataSource) bool
|
||||
HasOAuthEntryFunc func(ctx context.Context, usr *user.SignedInUser) (*login.UserAuth, bool, error)
|
||||
HasOAuthEntryFunc func(ctx context.Context, usr identity.Requester) (*login.UserAuth, bool, error)
|
||||
InvalidateOAuthTokensFunc func(ctx context.Context, usr *login.UserAuth) error
|
||||
TryTokenRefreshFunc func(ctx context.Context, usr *login.UserAuth) error
|
||||
}
|
||||
|
||||
func (m *MockOauthTokenService) GetCurrentOAuthToken(ctx context.Context, usr *user.SignedInUser) *oauth2.Token {
|
||||
func (m *MockOauthTokenService) GetCurrentOAuthToken(ctx context.Context, usr identity.Requester) *oauth2.Token {
|
||||
if m.GetCurrentOauthTokenFunc != nil {
|
||||
return m.GetCurrentOauthTokenFunc(ctx, usr)
|
||||
}
|
||||
@ -32,7 +32,7 @@ func (m *MockOauthTokenService) IsOAuthPassThruEnabled(ds *datasources.DataSourc
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *MockOauthTokenService) HasOAuthEntry(ctx context.Context, usr *user.SignedInUser) (*login.UserAuth, bool, error) {
|
||||
func (m *MockOauthTokenService) HasOAuthEntry(ctx context.Context, usr identity.Requester) (*login.UserAuth, bool, error) {
|
||||
if m.HasOAuthEntryFunc != nil {
|
||||
return m.HasOAuthEntryFunc(ctx, usr)
|
||||
}
|
||||
|
@ -5,10 +5,10 @@ import (
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/grafana/grafana/pkg/services/auth/identity"
|
||||
"github.com/grafana/grafana/pkg/services/datasources"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/oauthtoken"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
)
|
||||
|
||||
// Service an OAuth token service suitable for tests.
|
||||
@ -21,7 +21,7 @@ func ProvideService() *Service {
|
||||
return &Service{}
|
||||
}
|
||||
|
||||
func (s *Service) GetCurrentOAuthToken(context.Context, *user.SignedInUser) *oauth2.Token {
|
||||
func (s *Service) GetCurrentOAuthToken(context.Context, identity.Requester) *oauth2.Token {
|
||||
return s.Token
|
||||
}
|
||||
|
||||
@ -29,7 +29,7 @@ func (s *Service) IsOAuthPassThruEnabled(ds *datasources.DataSource) bool {
|
||||
return oauthtoken.IsOAuthPassThruEnabled(ds)
|
||||
}
|
||||
|
||||
func (s *Service) HasOAuthEntry(context.Context, *user.SignedInUser) (*login.UserAuth, bool, error) {
|
||||
func (s *Service) HasOAuthEntry(context.Context, identity.Requester) (*login.UserAuth, bool, error) {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
|
@ -156,7 +156,11 @@ func (u *SignedInUser) GetNamespacedID() (string, string) {
|
||||
case u.IsAnonymous:
|
||||
return identity.NamespaceAnonymous, ""
|
||||
case u.AuthenticatedBy == "render": //import cycle render
|
||||
return identity.NamespaceRenderService, fmt.Sprintf("%d", u.UserID)
|
||||
if u.UserID == 0 {
|
||||
return identity.NamespaceRenderService, fmt.Sprintf("%d", u.UserID)
|
||||
} else { // this should never happen as u.UserID > 0 already catches this
|
||||
return identity.NamespaceUser, fmt.Sprintf("%d", u.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
// backwards compatibility
|
||||
|
Loading…
Reference in New Issue
Block a user