mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
Auth: encrypt secrets for oauth providers in SSO settings API service (#79081)
encrypt secrets for oauth providers
This commit is contained in:
parent
c088d003f2
commit
d7641b0ecb
175
pkg/services/secrets/fakes/mock_service.go
Normal file
175
pkg/services/secrets/fakes/mock_service.go
Normal file
@ -0,0 +1,175 @@
|
||||
// Code generated by mockery v2.37.1. DO NOT EDIT.
|
||||
|
||||
package fakes
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
secrets "github.com/grafana/grafana/pkg/services/secrets"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// MockService is an autogenerated mock type for the Service type
|
||||
type MockService struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// Decrypt provides a mock function with given fields: ctx, payload
|
||||
func (_m *MockService) Decrypt(ctx context.Context, payload []byte) ([]byte, error) {
|
||||
ret := _m.Called(ctx, payload)
|
||||
|
||||
var r0 []byte
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, []byte) ([]byte, error)); ok {
|
||||
return rf(ctx, payload)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, []byte) []byte); ok {
|
||||
r0 = rf(ctx, payload)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]byte)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, []byte) error); ok {
|
||||
r1 = rf(ctx, payload)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// DecryptJsonData provides a mock function with given fields: ctx, sjd
|
||||
func (_m *MockService) DecryptJsonData(ctx context.Context, sjd map[string][]byte) (map[string]string, error) {
|
||||
ret := _m.Called(ctx, sjd)
|
||||
|
||||
var r0 map[string]string
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, map[string][]byte) (map[string]string, error)); ok {
|
||||
return rf(ctx, sjd)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, map[string][]byte) map[string]string); ok {
|
||||
r0 = rf(ctx, sjd)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(map[string]string)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, map[string][]byte) error); ok {
|
||||
r1 = rf(ctx, sjd)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Encrypt provides a mock function with given fields: ctx, payload, opt
|
||||
func (_m *MockService) Encrypt(ctx context.Context, payload []byte, opt secrets.EncryptionOptions) ([]byte, error) {
|
||||
ret := _m.Called(ctx, payload, opt)
|
||||
|
||||
var r0 []byte
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, []byte, secrets.EncryptionOptions) ([]byte, error)); ok {
|
||||
return rf(ctx, payload, opt)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, []byte, secrets.EncryptionOptions) []byte); ok {
|
||||
r0 = rf(ctx, payload, opt)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]byte)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, []byte, secrets.EncryptionOptions) error); ok {
|
||||
r1 = rf(ctx, payload, opt)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// EncryptJsonData provides a mock function with given fields: ctx, kv, opt
|
||||
func (_m *MockService) EncryptJsonData(ctx context.Context, kv map[string]string, opt secrets.EncryptionOptions) (map[string][]byte, error) {
|
||||
ret := _m.Called(ctx, kv, opt)
|
||||
|
||||
var r0 map[string][]byte
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, map[string]string, secrets.EncryptionOptions) (map[string][]byte, error)); ok {
|
||||
return rf(ctx, kv, opt)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, map[string]string, secrets.EncryptionOptions) map[string][]byte); ok {
|
||||
r0 = rf(ctx, kv, opt)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(map[string][]byte)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, map[string]string, secrets.EncryptionOptions) error); ok {
|
||||
r1 = rf(ctx, kv, opt)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// GetDecryptedValue provides a mock function with given fields: ctx, sjd, key, fallback
|
||||
func (_m *MockService) GetDecryptedValue(ctx context.Context, sjd map[string][]byte, key string, fallback string) string {
|
||||
ret := _m.Called(ctx, sjd, key, fallback)
|
||||
|
||||
var r0 string
|
||||
if rf, ok := ret.Get(0).(func(context.Context, map[string][]byte, string, string) string); ok {
|
||||
r0 = rf(ctx, sjd, key, fallback)
|
||||
} else {
|
||||
r0 = ret.Get(0).(string)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// ReEncryptDataKeys provides a mock function with given fields: ctx
|
||||
func (_m *MockService) ReEncryptDataKeys(ctx context.Context) error {
|
||||
ret := _m.Called(ctx)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context) error); ok {
|
||||
r0 = rf(ctx)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// RotateDataKeys provides a mock function with given fields: ctx
|
||||
func (_m *MockService) RotateDataKeys(ctx context.Context) error {
|
||||
ret := _m.Called(ctx)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context) error); ok {
|
||||
r0 = rf(ctx)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// NewMockService creates a new instance of MockService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewMockService(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *MockService {
|
||||
mock := &MockService{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
@ -9,6 +9,8 @@ import (
|
||||
|
||||
// Service is an envelope encryption service in charge of encrypting/decrypting secrets.
|
||||
// It is a replacement for encryption.Service
|
||||
//
|
||||
//go:generate mockery --name Service --structname MockService --outpkg fakes --filename mock_service.go --output ./fakes/
|
||||
type Service interface {
|
||||
// Encrypt MUST NOT be used within database transactions, it may cause database locks.
|
||||
// For those specific use cases where the encryption operation cannot be moved outside
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
ac "github.com/grafana/grafana/pkg/services/accesscontrol"
|
||||
"github.com/grafana/grafana/pkg/services/auth/identity"
|
||||
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
||||
"github.com/grafana/grafana/pkg/services/secrets"
|
||||
"github.com/grafana/grafana/pkg/services/ssosettings"
|
||||
"github.com/grafana/grafana/pkg/services/ssosettings/api"
|
||||
"github.com/grafana/grafana/pkg/services/ssosettings/database"
|
||||
@ -27,10 +28,12 @@ type SSOSettingsService struct {
|
||||
store ssosettings.Store
|
||||
ac ac.AccessControl
|
||||
fbStrategies []ssosettings.FallbackStrategy
|
||||
secrets secrets.Service
|
||||
}
|
||||
|
||||
func ProvideService(cfg *setting.Cfg, sqlStore db.DB, ac ac.AccessControl,
|
||||
routeRegister routing.RouteRegister, features *featuremgmt.FeatureManager) *SSOSettingsService {
|
||||
routeRegister routing.RouteRegister, features *featuremgmt.FeatureManager,
|
||||
secrets secrets.Service) *SSOSettingsService {
|
||||
strategies := []ssosettings.FallbackStrategy{
|
||||
strategies.NewOAuthStrategy(cfg),
|
||||
// register other strategies here, for example SAML
|
||||
@ -44,6 +47,7 @@ func ProvideService(cfg *setting.Cfg, sqlStore db.DB, ac ac.AccessControl,
|
||||
store: store,
|
||||
ac: ac,
|
||||
fbStrategies: strategies,
|
||||
secrets: secrets,
|
||||
}
|
||||
|
||||
if features.IsEnabledGlobally(featuremgmt.FlagSsoSettingsApi) {
|
||||
@ -114,6 +118,15 @@ func (s *SSOSettingsService) List(ctx context.Context, requester identity.Reques
|
||||
|
||||
func (s *SSOSettingsService) Upsert(ctx context.Context, settings models.SSOSettings) error {
|
||||
// TODO: validation (configurable provider? Contains the required fields? etc)
|
||||
|
||||
if isOAuthProvider(settings.Provider) {
|
||||
encryptedClientSecret, err := s.secrets.Encrypt(ctx, []byte(settings.OAuthSettings.ClientSecret), secrets.WithoutScope())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
settings.OAuthSettings.ClientSecret = string(encryptedClientSecret)
|
||||
}
|
||||
|
||||
err := s.store.Upsert(ctx, settings)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -182,3 +195,13 @@ func (s *SSOSettingsService) getFallBackstrategyFor(provider string) (ssosetting
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func isOAuthProvider(provider string) bool {
|
||||
for _, oAuthProvider := range ssosettings.AllOAuthProviders {
|
||||
if oAuthProvider == provider {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
@ -6,17 +6,20 @@ import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/services/accesscontrol"
|
||||
"github.com/grafana/grafana/pkg/services/accesscontrol/acimpl"
|
||||
"github.com/grafana/grafana/pkg/services/auth/identity"
|
||||
secretsFakes "github.com/grafana/grafana/pkg/services/secrets/fakes"
|
||||
"github.com/grafana/grafana/pkg/services/ssosettings"
|
||||
"github.com/grafana/grafana/pkg/services/ssosettings/models"
|
||||
"github.com/grafana/grafana/pkg/services/ssosettings/ssosettingstests"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSSOSettingsService_GetForProvider(t *testing.T) {
|
||||
@ -308,6 +311,66 @@ func TestSSOSettingsService_List(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOSettingsService_Upsert(t *testing.T) {
|
||||
t.Run("successfully upsert SSO settings", func(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
|
||||
settings := models.SSOSettings{
|
||||
Provider: "azuread",
|
||||
OAuthSettings: &social.OAuthInfo{
|
||||
ClientId: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
Enabled: true,
|
||||
},
|
||||
IsDeleted: false,
|
||||
}
|
||||
|
||||
env.secrets.On("Encrypt", mock.Anything, []byte(settings.OAuthSettings.ClientSecret), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once()
|
||||
|
||||
err := env.service.Upsert(context.Background(), settings)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("returns error if secrets encryption failed", func(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
|
||||
settings := models.SSOSettings{
|
||||
Provider: "azuread",
|
||||
OAuthSettings: &social.OAuthInfo{
|
||||
ClientId: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
Enabled: true,
|
||||
},
|
||||
IsDeleted: false,
|
||||
}
|
||||
|
||||
env.secrets.On("Encrypt", mock.Anything, []byte(settings.OAuthSettings.ClientSecret), mock.Anything).Return(nil, errors.New("encryption failed")).Once()
|
||||
|
||||
err := env.service.Upsert(context.Background(), settings)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("returns error if store failed to upsert settings", func(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
|
||||
settings := models.SSOSettings{
|
||||
Provider: "azuread",
|
||||
OAuthSettings: &social.OAuthInfo{
|
||||
ClientId: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
Enabled: true,
|
||||
},
|
||||
IsDeleted: false,
|
||||
}
|
||||
|
||||
env.secrets.On("Encrypt", mock.Anything, []byte(settings.OAuthSettings.ClientSecret), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once()
|
||||
env.store.ExpectedError = errors.New("upsert failed")
|
||||
|
||||
err := env.service.Upsert(context.Background(), settings)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSSOSettingsService_Delete(t *testing.T) {
|
||||
t.Run("successfully delete SSO settings", func(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
@ -345,19 +408,23 @@ func TestSSOSettingsService_Delete(t *testing.T) {
|
||||
func setupTestEnv(t *testing.T) testEnv {
|
||||
store := ssosettingstests.NewFakeStore()
|
||||
fallbackStrategy := ssosettingstests.NewFakeFallbackStrategy()
|
||||
|
||||
secrets := secretsFakes.NewMockService(t)
|
||||
accessControl := acimpl.ProvideAccessControl(setting.NewCfg())
|
||||
|
||||
svc := &SSOSettingsService{
|
||||
log: log.NewNopLogger(),
|
||||
store: store,
|
||||
ac: accessControl,
|
||||
fbStrategies: []ssosettings.FallbackStrategy{fallbackStrategy},
|
||||
secrets: secrets,
|
||||
}
|
||||
|
||||
return testEnv{
|
||||
service: svc,
|
||||
store: store,
|
||||
ac: accessControl,
|
||||
fallbackStrategy: fallbackStrategy,
|
||||
secrets: secrets,
|
||||
}
|
||||
}
|
||||
|
||||
@ -366,4 +433,5 @@ type testEnv struct {
|
||||
store *ssosettingstests.FakeStore
|
||||
ac accesscontrol.AccessControl
|
||||
fallbackStrategy *ssosettingstests.FakeFallbackStrategy
|
||||
secrets *secretsFakes.MockService
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user