grafana/pkg/models/sigv4.go
Will Browne eba046d3cb
Auth: Enable more complete credential chain for SigV4 default SDK auth option (#29065)
* Force more complete credential chain for default auth option

* simplify

* allow assume role for default
2020-11-13 10:27:35 +01:00

154 lines
3.4 KiB
Go

package models
import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/session"
v4 "github.com/aws/aws-sdk-go/aws/signer/v4"
"github.com/aws/aws-sdk-go/private/protocol/rest"
)
type AuthType string
const (
Default AuthType = "default"
Keys AuthType = "keys"
Credentials AuthType = "credentials"
)
type SigV4Middleware struct {
Config *Config
Next http.RoundTripper
}
type Config struct {
AuthType string
Profile string
DatasourceType string
AccessKey string
SecretKey string
AssumeRoleARN string
ExternalID string
Region string
}
func (m *SigV4Middleware) RoundTrip(req *http.Request) (*http.Response, error) {
_, err := m.signRequest(req)
if err != nil {
return nil, err
}
if m.Next == nil {
return http.DefaultTransport.RoundTrip(req)
}
return m.Next.RoundTrip(req)
}
func (m *SigV4Middleware) signRequest(req *http.Request) (http.Header, error) {
signer, err := m.signer()
if err != nil {
return nil, err
}
body, err := replaceBody(req)
if err != nil {
return nil, err
}
if strings.Contains(req.URL.RawPath, "%2C") {
req.URL.RawPath = rest.EscapePath(req.URL.RawPath, false)
}
// if X-Forwarded-For header is present, omit during signing step as it breaks AWS request verification
forwardHeader := req.Header.Get("X-Forwarded-For")
if forwardHeader != "" {
req.Header.Del("X-Forwarded-For")
header, err := signer.Sign(req, bytes.NewReader(body), awsServiceNamespace(m.Config.DatasourceType), m.Config.Region, time.Now().UTC())
// reset pre-existing X-Forwarded-For header value
req.Header.Set("X-Forwarded-For", forwardHeader)
return header, err
}
return signer.Sign(req, bytes.NewReader(body), awsServiceNamespace(m.Config.DatasourceType), m.Config.Region, time.Now().UTC())
}
func (m *SigV4Middleware) signer() (*v4.Signer, error) {
authType := AuthType(m.Config.AuthType)
var c *credentials.Credentials
switch authType {
case Keys:
c = credentials.NewStaticCredentials(m.Config.AccessKey, m.Config.SecretKey, "")
case Credentials:
c = credentials.NewSharedCredentials("", m.Config.Profile)
}
// passing nil credentials will force AWS to allow a more complete credential chain vs the explicit default
if c == nil {
s, err := session.NewSession(&aws.Config{
Region: aws.String(m.Config.Region),
})
if err != nil {
return nil, err
}
if m.Config.AssumeRoleARN != "" {
return v4.NewSigner(stscreds.NewCredentials(s, m.Config.AssumeRoleARN)), nil
}
return v4.NewSigner(s.Config.Credentials), nil
}
if m.Config.AssumeRoleARN != "" {
s, err := session.NewSession(&aws.Config{
Region: aws.String(m.Config.Region),
Credentials: c},
)
if err != nil {
return nil, err
}
return v4.NewSigner(stscreds.NewCredentials(s, m.Config.AssumeRoleARN)), nil
}
return v4.NewSigner(c), nil
}
func replaceBody(req *http.Request) ([]byte, error) {
if req.Body == nil {
return []byte{}, nil
}
payload, err := ioutil.ReadAll(req.Body)
if err != nil {
return nil, err
}
req.Body = ioutil.NopCloser(bytes.NewReader(payload))
return payload, nil
}
func awsServiceNamespace(dsType string) string {
switch dsType {
case DS_ES:
return "es"
case DS_PROMETHEUS:
return "aps"
default:
panic(fmt.Sprintf("Unsupported datasource %s", dsType))
}
}