Auth: encrypt secrets for oauth providers in SSO settings API service (#79081)

encrypt secrets for oauth providers
This commit is contained in:
Mihai Doarna 2023-12-06 14:37:10 +02:00 committed by GitHub
parent c088d003f2
commit d7641b0ecb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 271 additions and 3 deletions

View File

@ -0,0 +1,175 @@
// Code generated by mockery v2.37.1. DO NOT EDIT.
package fakes
import (
context "context"
secrets "github.com/grafana/grafana/pkg/services/secrets"
mock "github.com/stretchr/testify/mock"
)
// MockService is an autogenerated mock type for the Service type
type MockService struct {
mock.Mock
}
// Decrypt provides a mock function with given fields: ctx, payload
func (_m *MockService) Decrypt(ctx context.Context, payload []byte) ([]byte, error) {
ret := _m.Called(ctx, payload)
var r0 []byte
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, []byte) ([]byte, error)); ok {
return rf(ctx, payload)
}
if rf, ok := ret.Get(0).(func(context.Context, []byte) []byte); ok {
r0 = rf(ctx, payload)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]byte)
}
}
if rf, ok := ret.Get(1).(func(context.Context, []byte) error); ok {
r1 = rf(ctx, payload)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// DecryptJsonData provides a mock function with given fields: ctx, sjd
func (_m *MockService) DecryptJsonData(ctx context.Context, sjd map[string][]byte) (map[string]string, error) {
ret := _m.Called(ctx, sjd)
var r0 map[string]string
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, map[string][]byte) (map[string]string, error)); ok {
return rf(ctx, sjd)
}
if rf, ok := ret.Get(0).(func(context.Context, map[string][]byte) map[string]string); ok {
r0 = rf(ctx, sjd)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(map[string]string)
}
}
if rf, ok := ret.Get(1).(func(context.Context, map[string][]byte) error); ok {
r1 = rf(ctx, sjd)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Encrypt provides a mock function with given fields: ctx, payload, opt
func (_m *MockService) Encrypt(ctx context.Context, payload []byte, opt secrets.EncryptionOptions) ([]byte, error) {
ret := _m.Called(ctx, payload, opt)
var r0 []byte
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, []byte, secrets.EncryptionOptions) ([]byte, error)); ok {
return rf(ctx, payload, opt)
}
if rf, ok := ret.Get(0).(func(context.Context, []byte, secrets.EncryptionOptions) []byte); ok {
r0 = rf(ctx, payload, opt)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]byte)
}
}
if rf, ok := ret.Get(1).(func(context.Context, []byte, secrets.EncryptionOptions) error); ok {
r1 = rf(ctx, payload, opt)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// EncryptJsonData provides a mock function with given fields: ctx, kv, opt
func (_m *MockService) EncryptJsonData(ctx context.Context, kv map[string]string, opt secrets.EncryptionOptions) (map[string][]byte, error) {
ret := _m.Called(ctx, kv, opt)
var r0 map[string][]byte
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, map[string]string, secrets.EncryptionOptions) (map[string][]byte, error)); ok {
return rf(ctx, kv, opt)
}
if rf, ok := ret.Get(0).(func(context.Context, map[string]string, secrets.EncryptionOptions) map[string][]byte); ok {
r0 = rf(ctx, kv, opt)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(map[string][]byte)
}
}
if rf, ok := ret.Get(1).(func(context.Context, map[string]string, secrets.EncryptionOptions) error); ok {
r1 = rf(ctx, kv, opt)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetDecryptedValue provides a mock function with given fields: ctx, sjd, key, fallback
func (_m *MockService) GetDecryptedValue(ctx context.Context, sjd map[string][]byte, key string, fallback string) string {
ret := _m.Called(ctx, sjd, key, fallback)
var r0 string
if rf, ok := ret.Get(0).(func(context.Context, map[string][]byte, string, string) string); ok {
r0 = rf(ctx, sjd, key, fallback)
} else {
r0 = ret.Get(0).(string)
}
return r0
}
// ReEncryptDataKeys provides a mock function with given fields: ctx
func (_m *MockService) ReEncryptDataKeys(ctx context.Context) error {
ret := _m.Called(ctx)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context) error); ok {
r0 = rf(ctx)
} else {
r0 = ret.Error(0)
}
return r0
}
// RotateDataKeys provides a mock function with given fields: ctx
func (_m *MockService) RotateDataKeys(ctx context.Context) error {
ret := _m.Called(ctx)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context) error); ok {
r0 = rf(ctx)
} else {
r0 = ret.Error(0)
}
return r0
}
// NewMockService creates a new instance of MockService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockService(t interface {
mock.TestingT
Cleanup(func())
}) *MockService {
mock := &MockService{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -9,6 +9,8 @@ import (
// Service is an envelope encryption service in charge of encrypting/decrypting secrets.
// It is a replacement for encryption.Service
//
//go:generate mockery --name Service --structname MockService --outpkg fakes --filename mock_service.go --output ./fakes/
type Service interface {
// Encrypt MUST NOT be used within database transactions, it may cause database locks.
// For those specific use cases where the encryption operation cannot be moved outside

View File

@ -11,6 +11,7 @@ import (
ac "github.com/grafana/grafana/pkg/services/accesscontrol"
"github.com/grafana/grafana/pkg/services/auth/identity"
"github.com/grafana/grafana/pkg/services/featuremgmt"
"github.com/grafana/grafana/pkg/services/secrets"
"github.com/grafana/grafana/pkg/services/ssosettings"
"github.com/grafana/grafana/pkg/services/ssosettings/api"
"github.com/grafana/grafana/pkg/services/ssosettings/database"
@ -27,10 +28,12 @@ type SSOSettingsService struct {
store ssosettings.Store
ac ac.AccessControl
fbStrategies []ssosettings.FallbackStrategy
secrets secrets.Service
}
func ProvideService(cfg *setting.Cfg, sqlStore db.DB, ac ac.AccessControl,
routeRegister routing.RouteRegister, features *featuremgmt.FeatureManager) *SSOSettingsService {
routeRegister routing.RouteRegister, features *featuremgmt.FeatureManager,
secrets secrets.Service) *SSOSettingsService {
strategies := []ssosettings.FallbackStrategy{
strategies.NewOAuthStrategy(cfg),
// register other strategies here, for example SAML
@ -44,6 +47,7 @@ func ProvideService(cfg *setting.Cfg, sqlStore db.DB, ac ac.AccessControl,
store: store,
ac: ac,
fbStrategies: strategies,
secrets: secrets,
}
if features.IsEnabledGlobally(featuremgmt.FlagSsoSettingsApi) {
@ -114,6 +118,15 @@ func (s *SSOSettingsService) List(ctx context.Context, requester identity.Reques
func (s *SSOSettingsService) Upsert(ctx context.Context, settings models.SSOSettings) error {
// TODO: validation (configurable provider? Contains the required fields? etc)
if isOAuthProvider(settings.Provider) {
encryptedClientSecret, err := s.secrets.Encrypt(ctx, []byte(settings.OAuthSettings.ClientSecret), secrets.WithoutScope())
if err != nil {
return err
}
settings.OAuthSettings.ClientSecret = string(encryptedClientSecret)
}
err := s.store.Upsert(ctx, settings)
if err != nil {
return err
@ -182,3 +195,13 @@ func (s *SSOSettingsService) getFallBackstrategyFor(provider string) (ssosetting
}
return nil, false
}
func isOAuthProvider(provider string) bool {
for _, oAuthProvider := range ssosettings.AllOAuthProviders {
if oAuthProvider == provider {
return true
}
}
return false
}

View File

@ -6,17 +6,20 @@ import (
"fmt"
"testing"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/services/accesscontrol"
"github.com/grafana/grafana/pkg/services/accesscontrol/acimpl"
"github.com/grafana/grafana/pkg/services/auth/identity"
secretsFakes "github.com/grafana/grafana/pkg/services/secrets/fakes"
"github.com/grafana/grafana/pkg/services/ssosettings"
"github.com/grafana/grafana/pkg/services/ssosettings/models"
"github.com/grafana/grafana/pkg/services/ssosettings/ssosettingstests"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/setting"
"github.com/stretchr/testify/require"
)
func TestSSOSettingsService_GetForProvider(t *testing.T) {
@ -308,6 +311,66 @@ func TestSSOSettingsService_List(t *testing.T) {
}
}
func TestSSOSettingsService_Upsert(t *testing.T) {
t.Run("successfully upsert SSO settings", func(t *testing.T) {
env := setupTestEnv(t)
settings := models.SSOSettings{
Provider: "azuread",
OAuthSettings: &social.OAuthInfo{
ClientId: "client-id",
ClientSecret: "client-secret",
Enabled: true,
},
IsDeleted: false,
}
env.secrets.On("Encrypt", mock.Anything, []byte(settings.OAuthSettings.ClientSecret), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once()
err := env.service.Upsert(context.Background(), settings)
require.NoError(t, err)
})
t.Run("returns error if secrets encryption failed", func(t *testing.T) {
env := setupTestEnv(t)
settings := models.SSOSettings{
Provider: "azuread",
OAuthSettings: &social.OAuthInfo{
ClientId: "client-id",
ClientSecret: "client-secret",
Enabled: true,
},
IsDeleted: false,
}
env.secrets.On("Encrypt", mock.Anything, []byte(settings.OAuthSettings.ClientSecret), mock.Anything).Return(nil, errors.New("encryption failed")).Once()
err := env.service.Upsert(context.Background(), settings)
require.Error(t, err)
})
t.Run("returns error if store failed to upsert settings", func(t *testing.T) {
env := setupTestEnv(t)
settings := models.SSOSettings{
Provider: "azuread",
OAuthSettings: &social.OAuthInfo{
ClientId: "client-id",
ClientSecret: "client-secret",
Enabled: true,
},
IsDeleted: false,
}
env.secrets.On("Encrypt", mock.Anything, []byte(settings.OAuthSettings.ClientSecret), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once()
env.store.ExpectedError = errors.New("upsert failed")
err := env.service.Upsert(context.Background(), settings)
require.Error(t, err)
})
}
func TestSSOSettingsService_Delete(t *testing.T) {
t.Run("successfully delete SSO settings", func(t *testing.T) {
env := setupTestEnv(t)
@ -345,19 +408,23 @@ func TestSSOSettingsService_Delete(t *testing.T) {
func setupTestEnv(t *testing.T) testEnv {
store := ssosettingstests.NewFakeStore()
fallbackStrategy := ssosettingstests.NewFakeFallbackStrategy()
secrets := secretsFakes.NewMockService(t)
accessControl := acimpl.ProvideAccessControl(setting.NewCfg())
svc := &SSOSettingsService{
log: log.NewNopLogger(),
store: store,
ac: accessControl,
fbStrategies: []ssosettings.FallbackStrategy{fallbackStrategy},
secrets: secrets,
}
return testEnv{
service: svc,
store: store,
ac: accessControl,
fallbackStrategy: fallbackStrategy,
secrets: secrets,
}
}
@ -366,4 +433,5 @@ type testEnv struct {
store *ssosettingstests.FakeStore
ac accesscontrol.AccessControl
fallbackStrategy *ssosettingstests.FakeFallbackStrategy
secrets *secretsFakes.MockService
}