mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
Auth: encrypt secrets for oauth providers in SSO settings API service (#79081)
encrypt secrets for oauth providers
This commit is contained in:
parent
c088d003f2
commit
d7641b0ecb
pkg/services
175
pkg/services/secrets/fakes/mock_service.go
Normal file
175
pkg/services/secrets/fakes/mock_service.go
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
// Code generated by mockery v2.37.1. DO NOT EDIT.
|
||||||
|
|
||||||
|
package fakes
|
||||||
|
|
||||||
|
import (
|
||||||
|
context "context"
|
||||||
|
|
||||||
|
secrets "github.com/grafana/grafana/pkg/services/secrets"
|
||||||
|
mock "github.com/stretchr/testify/mock"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockService is an autogenerated mock type for the Service type
|
||||||
|
type MockService struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decrypt provides a mock function with given fields: ctx, payload
|
||||||
|
func (_m *MockService) Decrypt(ctx context.Context, payload []byte) ([]byte, error) {
|
||||||
|
ret := _m.Called(ctx, payload)
|
||||||
|
|
||||||
|
var r0 []byte
|
||||||
|
var r1 error
|
||||||
|
if rf, ok := ret.Get(0).(func(context.Context, []byte) ([]byte, error)); ok {
|
||||||
|
return rf(ctx, payload)
|
||||||
|
}
|
||||||
|
if rf, ok := ret.Get(0).(func(context.Context, []byte) []byte); ok {
|
||||||
|
r0 = rf(ctx, payload)
|
||||||
|
} else {
|
||||||
|
if ret.Get(0) != nil {
|
||||||
|
r0 = ret.Get(0).([]byte)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if rf, ok := ret.Get(1).(func(context.Context, []byte) error); ok {
|
||||||
|
r1 = rf(ctx, payload)
|
||||||
|
} else {
|
||||||
|
r1 = ret.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
return r0, r1
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecryptJsonData provides a mock function with given fields: ctx, sjd
|
||||||
|
func (_m *MockService) DecryptJsonData(ctx context.Context, sjd map[string][]byte) (map[string]string, error) {
|
||||||
|
ret := _m.Called(ctx, sjd)
|
||||||
|
|
||||||
|
var r0 map[string]string
|
||||||
|
var r1 error
|
||||||
|
if rf, ok := ret.Get(0).(func(context.Context, map[string][]byte) (map[string]string, error)); ok {
|
||||||
|
return rf(ctx, sjd)
|
||||||
|
}
|
||||||
|
if rf, ok := ret.Get(0).(func(context.Context, map[string][]byte) map[string]string); ok {
|
||||||
|
r0 = rf(ctx, sjd)
|
||||||
|
} else {
|
||||||
|
if ret.Get(0) != nil {
|
||||||
|
r0 = ret.Get(0).(map[string]string)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if rf, ok := ret.Get(1).(func(context.Context, map[string][]byte) error); ok {
|
||||||
|
r1 = rf(ctx, sjd)
|
||||||
|
} else {
|
||||||
|
r1 = ret.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
return r0, r1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt provides a mock function with given fields: ctx, payload, opt
|
||||||
|
func (_m *MockService) Encrypt(ctx context.Context, payload []byte, opt secrets.EncryptionOptions) ([]byte, error) {
|
||||||
|
ret := _m.Called(ctx, payload, opt)
|
||||||
|
|
||||||
|
var r0 []byte
|
||||||
|
var r1 error
|
||||||
|
if rf, ok := ret.Get(0).(func(context.Context, []byte, secrets.EncryptionOptions) ([]byte, error)); ok {
|
||||||
|
return rf(ctx, payload, opt)
|
||||||
|
}
|
||||||
|
if rf, ok := ret.Get(0).(func(context.Context, []byte, secrets.EncryptionOptions) []byte); ok {
|
||||||
|
r0 = rf(ctx, payload, opt)
|
||||||
|
} else {
|
||||||
|
if ret.Get(0) != nil {
|
||||||
|
r0 = ret.Get(0).([]byte)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if rf, ok := ret.Get(1).(func(context.Context, []byte, secrets.EncryptionOptions) error); ok {
|
||||||
|
r1 = rf(ctx, payload, opt)
|
||||||
|
} else {
|
||||||
|
r1 = ret.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
return r0, r1
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncryptJsonData provides a mock function with given fields: ctx, kv, opt
|
||||||
|
func (_m *MockService) EncryptJsonData(ctx context.Context, kv map[string]string, opt secrets.EncryptionOptions) (map[string][]byte, error) {
|
||||||
|
ret := _m.Called(ctx, kv, opt)
|
||||||
|
|
||||||
|
var r0 map[string][]byte
|
||||||
|
var r1 error
|
||||||
|
if rf, ok := ret.Get(0).(func(context.Context, map[string]string, secrets.EncryptionOptions) (map[string][]byte, error)); ok {
|
||||||
|
return rf(ctx, kv, opt)
|
||||||
|
}
|
||||||
|
if rf, ok := ret.Get(0).(func(context.Context, map[string]string, secrets.EncryptionOptions) map[string][]byte); ok {
|
||||||
|
r0 = rf(ctx, kv, opt)
|
||||||
|
} else {
|
||||||
|
if ret.Get(0) != nil {
|
||||||
|
r0 = ret.Get(0).(map[string][]byte)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if rf, ok := ret.Get(1).(func(context.Context, map[string]string, secrets.EncryptionOptions) error); ok {
|
||||||
|
r1 = rf(ctx, kv, opt)
|
||||||
|
} else {
|
||||||
|
r1 = ret.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
return r0, r1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDecryptedValue provides a mock function with given fields: ctx, sjd, key, fallback
|
||||||
|
func (_m *MockService) GetDecryptedValue(ctx context.Context, sjd map[string][]byte, key string, fallback string) string {
|
||||||
|
ret := _m.Called(ctx, sjd, key, fallback)
|
||||||
|
|
||||||
|
var r0 string
|
||||||
|
if rf, ok := ret.Get(0).(func(context.Context, map[string][]byte, string, string) string); ok {
|
||||||
|
r0 = rf(ctx, sjd, key, fallback)
|
||||||
|
} else {
|
||||||
|
r0 = ret.Get(0).(string)
|
||||||
|
}
|
||||||
|
|
||||||
|
return r0
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReEncryptDataKeys provides a mock function with given fields: ctx
|
||||||
|
func (_m *MockService) ReEncryptDataKeys(ctx context.Context) error {
|
||||||
|
ret := _m.Called(ctx)
|
||||||
|
|
||||||
|
var r0 error
|
||||||
|
if rf, ok := ret.Get(0).(func(context.Context) error); ok {
|
||||||
|
r0 = rf(ctx)
|
||||||
|
} else {
|
||||||
|
r0 = ret.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
return r0
|
||||||
|
}
|
||||||
|
|
||||||
|
// RotateDataKeys provides a mock function with given fields: ctx
|
||||||
|
func (_m *MockService) RotateDataKeys(ctx context.Context) error {
|
||||||
|
ret := _m.Called(ctx)
|
||||||
|
|
||||||
|
var r0 error
|
||||||
|
if rf, ok := ret.Get(0).(func(context.Context) error); ok {
|
||||||
|
r0 = rf(ctx)
|
||||||
|
} else {
|
||||||
|
r0 = ret.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
return r0
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockService creates a new instance of MockService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||||
|
// The first argument is typically a *testing.T value.
|
||||||
|
func NewMockService(t interface {
|
||||||
|
mock.TestingT
|
||||||
|
Cleanup(func())
|
||||||
|
}) *MockService {
|
||||||
|
mock := &MockService{}
|
||||||
|
mock.Mock.Test(t)
|
||||||
|
|
||||||
|
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||||
|
|
||||||
|
return mock
|
||||||
|
}
|
@ -9,6 +9,8 @@ import (
|
|||||||
|
|
||||||
// Service is an envelope encryption service in charge of encrypting/decrypting secrets.
|
// Service is an envelope encryption service in charge of encrypting/decrypting secrets.
|
||||||
// It is a replacement for encryption.Service
|
// It is a replacement for encryption.Service
|
||||||
|
//
|
||||||
|
//go:generate mockery --name Service --structname MockService --outpkg fakes --filename mock_service.go --output ./fakes/
|
||||||
type Service interface {
|
type Service interface {
|
||||||
// Encrypt MUST NOT be used within database transactions, it may cause database locks.
|
// Encrypt MUST NOT be used within database transactions, it may cause database locks.
|
||||||
// For those specific use cases where the encryption operation cannot be moved outside
|
// For those specific use cases where the encryption operation cannot be moved outside
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
ac "github.com/grafana/grafana/pkg/services/accesscontrol"
|
ac "github.com/grafana/grafana/pkg/services/accesscontrol"
|
||||||
"github.com/grafana/grafana/pkg/services/auth/identity"
|
"github.com/grafana/grafana/pkg/services/auth/identity"
|
||||||
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
||||||
|
"github.com/grafana/grafana/pkg/services/secrets"
|
||||||
"github.com/grafana/grafana/pkg/services/ssosettings"
|
"github.com/grafana/grafana/pkg/services/ssosettings"
|
||||||
"github.com/grafana/grafana/pkg/services/ssosettings/api"
|
"github.com/grafana/grafana/pkg/services/ssosettings/api"
|
||||||
"github.com/grafana/grafana/pkg/services/ssosettings/database"
|
"github.com/grafana/grafana/pkg/services/ssosettings/database"
|
||||||
@ -27,10 +28,12 @@ type SSOSettingsService struct {
|
|||||||
store ssosettings.Store
|
store ssosettings.Store
|
||||||
ac ac.AccessControl
|
ac ac.AccessControl
|
||||||
fbStrategies []ssosettings.FallbackStrategy
|
fbStrategies []ssosettings.FallbackStrategy
|
||||||
|
secrets secrets.Service
|
||||||
}
|
}
|
||||||
|
|
||||||
func ProvideService(cfg *setting.Cfg, sqlStore db.DB, ac ac.AccessControl,
|
func ProvideService(cfg *setting.Cfg, sqlStore db.DB, ac ac.AccessControl,
|
||||||
routeRegister routing.RouteRegister, features *featuremgmt.FeatureManager) *SSOSettingsService {
|
routeRegister routing.RouteRegister, features *featuremgmt.FeatureManager,
|
||||||
|
secrets secrets.Service) *SSOSettingsService {
|
||||||
strategies := []ssosettings.FallbackStrategy{
|
strategies := []ssosettings.FallbackStrategy{
|
||||||
strategies.NewOAuthStrategy(cfg),
|
strategies.NewOAuthStrategy(cfg),
|
||||||
// register other strategies here, for example SAML
|
// register other strategies here, for example SAML
|
||||||
@ -44,6 +47,7 @@ func ProvideService(cfg *setting.Cfg, sqlStore db.DB, ac ac.AccessControl,
|
|||||||
store: store,
|
store: store,
|
||||||
ac: ac,
|
ac: ac,
|
||||||
fbStrategies: strategies,
|
fbStrategies: strategies,
|
||||||
|
secrets: secrets,
|
||||||
}
|
}
|
||||||
|
|
||||||
if features.IsEnabledGlobally(featuremgmt.FlagSsoSettingsApi) {
|
if features.IsEnabledGlobally(featuremgmt.FlagSsoSettingsApi) {
|
||||||
@ -114,6 +118,15 @@ func (s *SSOSettingsService) List(ctx context.Context, requester identity.Reques
|
|||||||
|
|
||||||
func (s *SSOSettingsService) Upsert(ctx context.Context, settings models.SSOSettings) error {
|
func (s *SSOSettingsService) Upsert(ctx context.Context, settings models.SSOSettings) error {
|
||||||
// TODO: validation (configurable provider? Contains the required fields? etc)
|
// TODO: validation (configurable provider? Contains the required fields? etc)
|
||||||
|
|
||||||
|
if isOAuthProvider(settings.Provider) {
|
||||||
|
encryptedClientSecret, err := s.secrets.Encrypt(ctx, []byte(settings.OAuthSettings.ClientSecret), secrets.WithoutScope())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
settings.OAuthSettings.ClientSecret = string(encryptedClientSecret)
|
||||||
|
}
|
||||||
|
|
||||||
err := s.store.Upsert(ctx, settings)
|
err := s.store.Upsert(ctx, settings)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -182,3 +195,13 @@ func (s *SSOSettingsService) getFallBackstrategyFor(provider string) (ssosetting
|
|||||||
}
|
}
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isOAuthProvider(provider string) bool {
|
||||||
|
for _, oAuthProvider := range ssosettings.AllOAuthProviders {
|
||||||
|
if oAuthProvider == provider {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
@ -6,17 +6,20 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/infra/log"
|
"github.com/grafana/grafana/pkg/infra/log"
|
||||||
"github.com/grafana/grafana/pkg/login/social"
|
"github.com/grafana/grafana/pkg/login/social"
|
||||||
"github.com/grafana/grafana/pkg/services/accesscontrol"
|
"github.com/grafana/grafana/pkg/services/accesscontrol"
|
||||||
"github.com/grafana/grafana/pkg/services/accesscontrol/acimpl"
|
"github.com/grafana/grafana/pkg/services/accesscontrol/acimpl"
|
||||||
"github.com/grafana/grafana/pkg/services/auth/identity"
|
"github.com/grafana/grafana/pkg/services/auth/identity"
|
||||||
|
secretsFakes "github.com/grafana/grafana/pkg/services/secrets/fakes"
|
||||||
"github.com/grafana/grafana/pkg/services/ssosettings"
|
"github.com/grafana/grafana/pkg/services/ssosettings"
|
||||||
"github.com/grafana/grafana/pkg/services/ssosettings/models"
|
"github.com/grafana/grafana/pkg/services/ssosettings/models"
|
||||||
"github.com/grafana/grafana/pkg/services/ssosettings/ssosettingstests"
|
"github.com/grafana/grafana/pkg/services/ssosettings/ssosettingstests"
|
||||||
"github.com/grafana/grafana/pkg/services/user"
|
"github.com/grafana/grafana/pkg/services/user"
|
||||||
"github.com/grafana/grafana/pkg/setting"
|
"github.com/grafana/grafana/pkg/setting"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSSOSettingsService_GetForProvider(t *testing.T) {
|
func TestSSOSettingsService_GetForProvider(t *testing.T) {
|
||||||
@ -308,6 +311,66 @@ func TestSSOSettingsService_List(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSSOSettingsService_Upsert(t *testing.T) {
|
||||||
|
t.Run("successfully upsert SSO settings", func(t *testing.T) {
|
||||||
|
env := setupTestEnv(t)
|
||||||
|
|
||||||
|
settings := models.SSOSettings{
|
||||||
|
Provider: "azuread",
|
||||||
|
OAuthSettings: &social.OAuthInfo{
|
||||||
|
ClientId: "client-id",
|
||||||
|
ClientSecret: "client-secret",
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
IsDeleted: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
env.secrets.On("Encrypt", mock.Anything, []byte(settings.OAuthSettings.ClientSecret), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once()
|
||||||
|
|
||||||
|
err := env.service.Upsert(context.Background(), settings)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns error if secrets encryption failed", func(t *testing.T) {
|
||||||
|
env := setupTestEnv(t)
|
||||||
|
|
||||||
|
settings := models.SSOSettings{
|
||||||
|
Provider: "azuread",
|
||||||
|
OAuthSettings: &social.OAuthInfo{
|
||||||
|
ClientId: "client-id",
|
||||||
|
ClientSecret: "client-secret",
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
IsDeleted: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
env.secrets.On("Encrypt", mock.Anything, []byte(settings.OAuthSettings.ClientSecret), mock.Anything).Return(nil, errors.New("encryption failed")).Once()
|
||||||
|
|
||||||
|
err := env.service.Upsert(context.Background(), settings)
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns error if store failed to upsert settings", func(t *testing.T) {
|
||||||
|
env := setupTestEnv(t)
|
||||||
|
|
||||||
|
settings := models.SSOSettings{
|
||||||
|
Provider: "azuread",
|
||||||
|
OAuthSettings: &social.OAuthInfo{
|
||||||
|
ClientId: "client-id",
|
||||||
|
ClientSecret: "client-secret",
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
IsDeleted: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
env.secrets.On("Encrypt", mock.Anything, []byte(settings.OAuthSettings.ClientSecret), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once()
|
||||||
|
env.store.ExpectedError = errors.New("upsert failed")
|
||||||
|
|
||||||
|
err := env.service.Upsert(context.Background(), settings)
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestSSOSettingsService_Delete(t *testing.T) {
|
func TestSSOSettingsService_Delete(t *testing.T) {
|
||||||
t.Run("successfully delete SSO settings", func(t *testing.T) {
|
t.Run("successfully delete SSO settings", func(t *testing.T) {
|
||||||
env := setupTestEnv(t)
|
env := setupTestEnv(t)
|
||||||
@ -345,19 +408,23 @@ func TestSSOSettingsService_Delete(t *testing.T) {
|
|||||||
func setupTestEnv(t *testing.T) testEnv {
|
func setupTestEnv(t *testing.T) testEnv {
|
||||||
store := ssosettingstests.NewFakeStore()
|
store := ssosettingstests.NewFakeStore()
|
||||||
fallbackStrategy := ssosettingstests.NewFakeFallbackStrategy()
|
fallbackStrategy := ssosettingstests.NewFakeFallbackStrategy()
|
||||||
|
secrets := secretsFakes.NewMockService(t)
|
||||||
accessControl := acimpl.ProvideAccessControl(setting.NewCfg())
|
accessControl := acimpl.ProvideAccessControl(setting.NewCfg())
|
||||||
|
|
||||||
svc := &SSOSettingsService{
|
svc := &SSOSettingsService{
|
||||||
log: log.NewNopLogger(),
|
log: log.NewNopLogger(),
|
||||||
store: store,
|
store: store,
|
||||||
ac: accessControl,
|
ac: accessControl,
|
||||||
fbStrategies: []ssosettings.FallbackStrategy{fallbackStrategy},
|
fbStrategies: []ssosettings.FallbackStrategy{fallbackStrategy},
|
||||||
|
secrets: secrets,
|
||||||
}
|
}
|
||||||
|
|
||||||
return testEnv{
|
return testEnv{
|
||||||
service: svc,
|
service: svc,
|
||||||
store: store,
|
store: store,
|
||||||
ac: accessControl,
|
ac: accessControl,
|
||||||
fallbackStrategy: fallbackStrategy,
|
fallbackStrategy: fallbackStrategy,
|
||||||
|
secrets: secrets,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -366,4 +433,5 @@ type testEnv struct {
|
|||||||
store *ssosettingstests.FakeStore
|
store *ssosettingstests.FakeStore
|
||||||
ac accesscontrol.AccessControl
|
ac accesscontrol.AccessControl
|
||||||
fallbackStrategy *ssosettingstests.FakeFallbackStrategy
|
fallbackStrategy *ssosettingstests.FakeFallbackStrategy
|
||||||
|
secrets *secretsFakes.MockService
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user