mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
Auth: Use SigV4 lib from grafana-aws-sdk (#30713)
* replace with lib * remove test + apply feedback
This commit is contained in:
@@ -10,6 +10,8 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana-aws-sdk/pkg/sigv4"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/metrics/metricutil"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
@@ -191,19 +193,19 @@ func (ds *DataSource) GetHttpTransport() (*dataSourceTransport, error) {
|
||||
func (ds *DataSource) sigV4Middleware(next http.RoundTripper) http.RoundTripper {
|
||||
decrypted := ds.DecryptedValues()
|
||||
|
||||
return &SigV4Middleware{
|
||||
Config: &Config{
|
||||
DatasourceType: ds.Type,
|
||||
AccessKey: decrypted["sigV4AccessKey"],
|
||||
SecretKey: decrypted["sigV4SecretKey"],
|
||||
Region: ds.JsonData.Get("sigV4Region").MustString(),
|
||||
AssumeRoleARN: ds.JsonData.Get("sigV4AssumeRoleArn").MustString(),
|
||||
AuthType: ds.JsonData.Get("sigV4AuthType").MustString(),
|
||||
ExternalID: ds.JsonData.Get("sigV4ExternalId").MustString(),
|
||||
Profile: ds.JsonData.Get("sigV4Profile").MustString(),
|
||||
return sigv4.New(
|
||||
&sigv4.Config{
|
||||
Service: awsServiceNamespace(ds.Type),
|
||||
AccessKey: decrypted["sigV4AccessKey"],
|
||||
SecretKey: decrypted["sigV4SecretKey"],
|
||||
Region: ds.JsonData.Get("sigV4Region").MustString(),
|
||||
AssumeRoleARN: ds.JsonData.Get("sigV4AssumeRoleArn").MustString(),
|
||||
AuthType: ds.JsonData.Get("sigV4AuthType").MustString(),
|
||||
ExternalID: ds.JsonData.Get("sigV4ExternalId").MustString(),
|
||||
Profile: ds.JsonData.Get("sigV4Profile").MustString(),
|
||||
},
|
||||
Next: next,
|
||||
}
|
||||
next,
|
||||
)
|
||||
}
|
||||
|
||||
func (ds *DataSource) GetTLSConfig() (*tls.Config, error) {
|
||||
@@ -319,3 +321,14 @@ func ClearDSDecryptionCache() {
|
||||
|
||||
dsDecryptionCache.cache = make(map[int64]cachedDecryptedJSON)
|
||||
}
|
||||
|
||||
func awsServiceNamespace(dsType string) string {
|
||||
switch dsType {
|
||||
case DS_ES, DS_ES_OPEN_DISTRO:
|
||||
return "es"
|
||||
case DS_PROMETHEUS:
|
||||
return "aps"
|
||||
default:
|
||||
panic(fmt.Sprintf("Unsupported datasource %q", dsType))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -214,30 +214,6 @@ func TestDataSource_GetHttpTransport(t *testing.T) {
|
||||
assert.Equal(t, "Ok", bodyStr)
|
||||
})
|
||||
|
||||
t.Run("Should use SigV4 in middleware chain if configured in JsonData", func(t *testing.T) {
|
||||
clearDSProxyCache(t)
|
||||
|
||||
origEnabled := setting.SigV4AuthEnabled
|
||||
setting.SigV4AuthEnabled = true
|
||||
t.Cleanup(func() { setting.SigV4AuthEnabled = origEnabled })
|
||||
|
||||
json, err := simplejson.NewJson([]byte(`{ "sigV4Auth": true }`))
|
||||
require.NoError(t, err)
|
||||
|
||||
ds := DataSource{
|
||||
JsonData: json,
|
||||
}
|
||||
|
||||
tr, err := ds.GetHttpTransport()
|
||||
require.NoError(t, err)
|
||||
|
||||
m1, ok := tr.next.(*SigV4Middleware)
|
||||
require.True(t, ok)
|
||||
|
||||
_, ok = m1.Next.(*http.Transport)
|
||||
require.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("Should not include SigV4 middleware if not configured in JsonData", func(t *testing.T) {
|
||||
clearDSProxyCache(t)
|
||||
|
||||
|
||||
@@ -1,166 +0,0 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
v4 "github.com/aws/aws-sdk-go/aws/signer/v4"
|
||||
"github.com/aws/aws-sdk-go/private/protocol/rest"
|
||||
)
|
||||
|
||||
type AuthType string
|
||||
|
||||
const (
|
||||
Default AuthType = "default"
|
||||
Keys AuthType = "keys"
|
||||
Credentials AuthType = "credentials"
|
||||
)
|
||||
|
||||
// Host header is likely not necessary here
|
||||
// (see https://github.com/golang/go/blob/cad6d1fef5147d31e94ee83934c8609d3ad150b7/src/net/http/request.go#L92)
|
||||
// but adding for completeness
|
||||
var permittedHeaders = map[string]struct{}{
|
||||
"Host": {},
|
||||
"Uber-Trace-Id": {},
|
||||
"User-Agent": {},
|
||||
"Accept": {},
|
||||
"Accept-Encoding": {},
|
||||
"Content-Type": {},
|
||||
"Content-Length": {},
|
||||
"securitytenant": {},
|
||||
"sgtenant": {},
|
||||
"kbn-xsrf": {},
|
||||
}
|
||||
|
||||
type SigV4Middleware struct {
|
||||
Config *Config
|
||||
Next http.RoundTripper
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
AuthType string
|
||||
|
||||
Profile string
|
||||
|
||||
DatasourceType string
|
||||
|
||||
AccessKey string
|
||||
SecretKey string
|
||||
|
||||
AssumeRoleARN string
|
||||
ExternalID string
|
||||
Region string
|
||||
}
|
||||
|
||||
func (m *SigV4Middleware) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
_, err := m.signRequest(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if m.Next == nil {
|
||||
return http.DefaultTransport.RoundTrip(req)
|
||||
}
|
||||
|
||||
return m.Next.RoundTrip(req)
|
||||
}
|
||||
|
||||
func (m *SigV4Middleware) signRequest(req *http.Request) (http.Header, error) {
|
||||
signer, err := m.signer()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body, err := replaceBody(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if strings.Contains(req.URL.RawPath, "%2C") {
|
||||
req.URL.RawPath = rest.EscapePath(req.URL.RawPath, false)
|
||||
}
|
||||
|
||||
stripHeaders(req)
|
||||
|
||||
return signer.Sign(req, bytes.NewReader(body), awsServiceNamespace(m.Config.DatasourceType), m.Config.Region, time.Now().UTC())
|
||||
}
|
||||
|
||||
func (m *SigV4Middleware) signer() (*v4.Signer, error) {
|
||||
authType := AuthType(m.Config.AuthType)
|
||||
|
||||
var c *credentials.Credentials
|
||||
switch authType {
|
||||
case Keys:
|
||||
c = credentials.NewStaticCredentials(m.Config.AccessKey, m.Config.SecretKey, "")
|
||||
case Credentials:
|
||||
c = credentials.NewSharedCredentials("", m.Config.Profile)
|
||||
case Default:
|
||||
// passing nil credentials will force AWS to allow a more complete credential chain vs the explicit default
|
||||
s, err := session.NewSession(&aws.Config{
|
||||
Region: aws.String(m.Config.Region),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if m.Config.AssumeRoleARN != "" {
|
||||
return v4.NewSigner(stscreds.NewCredentials(s, m.Config.AssumeRoleARN)), nil
|
||||
}
|
||||
|
||||
return v4.NewSigner(s.Config.Credentials), nil
|
||||
case "":
|
||||
return nil, fmt.Errorf("invalid SigV4 auth type")
|
||||
}
|
||||
|
||||
if m.Config.AssumeRoleARN != "" {
|
||||
s, err := session.NewSession(&aws.Config{
|
||||
Region: aws.String(m.Config.Region),
|
||||
Credentials: c},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return v4.NewSigner(stscreds.NewCredentials(s, m.Config.AssumeRoleARN)), nil
|
||||
}
|
||||
|
||||
return v4.NewSigner(c), nil
|
||||
}
|
||||
|
||||
func replaceBody(req *http.Request) ([]byte, error) {
|
||||
if req.Body == nil {
|
||||
return []byte{}, nil
|
||||
}
|
||||
payload, err := ioutil.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Body = ioutil.NopCloser(bytes.NewReader(payload))
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func awsServiceNamespace(dsType string) string {
|
||||
switch dsType {
|
||||
case DS_ES, DS_ES_OPEN_DISTRO:
|
||||
return "es"
|
||||
case DS_PROMETHEUS:
|
||||
return "aps"
|
||||
default:
|
||||
panic(fmt.Sprintf("Unsupported datasource %s", dsType))
|
||||
}
|
||||
}
|
||||
|
||||
func stripHeaders(req *http.Request) {
|
||||
for h := range req.Header {
|
||||
if _, exists := permittedHeaders[h]; !exists {
|
||||
req.Header.Del(h)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user