diff --git a/builtin/providers/aws/auth_helpers.go b/builtin/providers/aws/auth_helpers.go index 914c7e9717..552a4234f1 100644 --- a/builtin/providers/aws/auth_helpers.go +++ b/builtin/providers/aws/auth_helpers.go @@ -14,10 +14,11 @@ import ( "github.com/aws/aws-sdk-go/aws/ec2metadata" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/iam" + "github.com/aws/aws-sdk-go/service/sts" "github.com/hashicorp/go-cleanhttp" ) -func GetAccountId(iamconn *iam.IAM, authProviderName string) (string, error) { +func GetAccountId(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) (string, error) { // If we have creds from instance profile, we can use metadata API if authProviderName == ec2rolecreds.ProviderName { log.Println("[DEBUG] Trying to get account ID via AWS Metadata API") @@ -42,16 +43,24 @@ func GetAccountId(iamconn *iam.IAM, authProviderName string) (string, error) { return parseAccountIdFromArn(*outUser.User.Arn) } - // Then try IAM ListRoles awsErr, ok := err.(awserr.Error) // AccessDenied and ValidationError can be raised // if credentials belong to federated profile, so we ignore these if !ok || (awsErr.Code() != "AccessDenied" && awsErr.Code() != "ValidationError") { return "", fmt.Errorf("Failed getting account ID via 'iam:GetUser': %s", err) } - log.Printf("[DEBUG] Getting account ID via iam:GetUser failed: %s", err) - log.Println("[DEBUG] Trying to get account ID via iam:ListRoles instead") + + // Then try STS GetCallerIdentity + log.Println("[DEBUG] Trying to get account ID via sts:GetCallerIdentity") + outCallerIdentity, err := stsconn.GetCallerIdentity(&sts.GetCallerIdentityInput{}) + if err == nil { + return *outCallerIdentity.Account, nil + } + log.Printf("[DEBUG] Getting account ID via sts:GetCallerIdentity failed: %s", err) + + // Then try IAM ListRoles + log.Println("[DEBUG] Trying to get account ID via iam:ListRoles") outRoles, err := iamconn.ListRoles(&iam.ListRolesInput{ MaxItems: aws.Int64(int64(1)), }) diff --git a/builtin/providers/aws/auth_helpers_test.go b/builtin/providers/aws/auth_helpers_test.go index a5fcf8f163..a9de0fcc6c 100644 --- a/builtin/providers/aws/auth_helpers_test.go +++ b/builtin/providers/aws/auth_helpers_test.go @@ -18,6 +18,7 @@ import ( "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/iam" + "github.com/aws/aws-sdk-go/service/sts" ) func TestAWSGetAccountId_shouldBeValid_fromEC2Role(t *testing.T) { @@ -28,10 +29,10 @@ func TestAWSGetAccountId_shouldBeValid_fromEC2Role(t *testing.T) { defer awsTs() iamEndpoints := []*iamEndpoint{} - ts, iamConn := getMockedAwsIamApi(iamEndpoints) + ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints) defer ts() - id, err := GetAccountId(iamConn, ec2rolecreds.ProviderName) + id, err := GetAccountId(iamConn, stsConn, ec2rolecreds.ProviderName) if err != nil { t.Fatalf("Getting account ID from EC2 metadata API failed: %s", err) } @@ -55,10 +56,10 @@ func TestAWSGetAccountId_shouldBeValid_EC2RoleHasPriority(t *testing.T) { Response: &iamResponse{200, iamResponse_GetUser_valid, "text/xml"}, }, } - ts, iamConn := getMockedAwsIamApi(iamEndpoints) + ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints) defer ts() - id, err := GetAccountId(iamConn, ec2rolecreds.ProviderName) + id, err := GetAccountId(iamConn, stsConn, ec2rolecreds.ProviderName) if err != nil { t.Fatalf("Getting account ID from EC2 metadata API failed: %s", err) } @@ -76,10 +77,36 @@ func TestAWSGetAccountId_shouldBeValid_fromIamUser(t *testing.T) { Response: &iamResponse{200, iamResponse_GetUser_valid, "text/xml"}, }, } - ts, iamConn := getMockedAwsIamApi(iamEndpoints) + + ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints) defer ts() - id, err := GetAccountId(iamConn, "") + id, err := GetAccountId(iamConn, stsConn, "") + if err != nil { + t.Fatalf("Getting account ID via GetUser failed: %s", err) + } + + expectedAccountId := "123456789012" + if id != expectedAccountId { + t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id) + } +} + +func TestAWSGetAccountId_shouldBeValid_fromGetCallerIdentity(t *testing.T) { + iamEndpoints := []*iamEndpoint{ + &iamEndpoint{ + Request: &iamRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"}, + Response: &iamResponse{403, iamResponse_GetUser_unauthorized, "text/xml"}, + }, + &iamEndpoint{ + Request: &iamRequest{"POST", "/", "Action=GetCallerIdentity&Version=2011-06-15"}, + Response: &iamResponse{200, stsResponse_GetCallerIdentity_valid, "text/xml"}, + }, + } + ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints) + defer ts() + + id, err := GetAccountId(iamConn, stsConn, "") if err != nil { t.Fatalf("Getting account ID via GetUser failed: %s", err) } @@ -96,15 +123,19 @@ func TestAWSGetAccountId_shouldBeValid_fromIamListRoles(t *testing.T) { Request: &iamRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"}, Response: &iamResponse{403, iamResponse_GetUser_unauthorized, "text/xml"}, }, + &iamEndpoint{ + Request: &iamRequest{"POST", "/", "Action=GetCallerIdentity&Version=2011-06-15"}, + Response: &iamResponse{403, stsResponse_GetCallerIdentity_unauthorized, "text/xml"}, + }, &iamEndpoint{ Request: &iamRequest{"POST", "/", "Action=ListRoles&MaxItems=1&Version=2010-05-08"}, Response: &iamResponse{200, iamResponse_ListRoles_valid, "text/xml"}, }, } - ts, iamConn := getMockedAwsIamApi(iamEndpoints) + ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints) defer ts() - id, err := GetAccountId(iamConn, "") + id, err := GetAccountId(iamConn, stsConn, "") if err != nil { t.Fatalf("Getting account ID via ListRoles failed: %s", err) } @@ -126,10 +157,10 @@ func TestAWSGetAccountId_shouldBeValid_federatedRole(t *testing.T) { Response: &iamResponse{200, iamResponse_ListRoles_valid, "text/xml"}, }, } - ts, iamConn := getMockedAwsIamApi(iamEndpoints) + ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints) defer ts() - id, err := GetAccountId(iamConn, "") + id, err := GetAccountId(iamConn, stsConn, "") if err != nil { t.Fatalf("Getting account ID via ListRoles failed: %s", err) } @@ -151,10 +182,10 @@ func TestAWSGetAccountId_shouldError_unauthorizedFromIam(t *testing.T) { Response: &iamResponse{403, iamResponse_ListRoles_unauthorized, "text/xml"}, }, } - ts, iamConn := getMockedAwsIamApi(iamEndpoints) + ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints) defer ts() - id, err := GetAccountId(iamConn, "") + id, err := GetAccountId(iamConn, stsConn, "") if err == nil { t.Fatal("Expected error when getting account ID") } @@ -586,15 +617,15 @@ func invalidAwsEnv(t *testing.T) func() { return ts.Close } -// getMockedAwsIamApi establishes a httptest server to simulate behaviour -// of a real AWS' IAM server -func getMockedAwsIamApi(endpoints []*iamEndpoint) (func(), *iam.IAM) { +// getMockedAwsIamStsApi establishes a httptest server to simulate behaviour +// of a real AWS' IAM & STS server +func getMockedAwsIamStsApi(endpoints []*iamEndpoint) (func(), *iam.IAM, *sts.STS) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { buf := new(bytes.Buffer) buf.ReadFrom(r.Body) requestBody := buf.String() - log.Printf("[DEBUG] Received IAM API %q request to %q: %s", + log.Printf("[DEBUG] Received API %q request to %q: %s", r.Method, r.RequestURI, requestBody) for _, e := range endpoints { @@ -624,8 +655,8 @@ func getMockedAwsIamApi(endpoints []*iamEndpoint) (func(), *iam.IAM) { CredentialsChainVerboseErrors: aws.Bool(true), }) iamConn := iam.New(sess) - - return ts.Close, iamConn + stsConn := sts.New(sess) + return ts.Close, iamConn, stsConn } func getEnv() *currentEnv { @@ -718,6 +749,26 @@ const iamResponse_GetUser_unauthorized = ` + + arn:aws:iam::123456789012:user/Alice + AKIAI44QH8DHBEXAMPLE + 123456789012 + + + 01234567-89ab-cdef-0123-456789abcdef + +` + +const stsResponse_GetCallerIdentity_unauthorized = ` + + Sender + AccessDenied + User: arn:aws:iam::123456789012:user/Bob is not authorized to perform: sts:GetCallerIdentity + + 01234567-89ab-cdef-0123-456789abcdef +` + const iamResponse_GetUser_federatedFailure = ` Sender diff --git a/builtin/providers/aws/config.go b/builtin/providers/aws/config.go index 82a82e016f..0db0ec0f9a 100644 --- a/builtin/providers/aws/config.go +++ b/builtin/providers/aws/config.go @@ -50,6 +50,7 @@ import ( "github.com/aws/aws-sdk-go/service/s3" "github.com/aws/aws-sdk-go/service/sns" "github.com/aws/aws-sdk-go/service/sqs" + "github.com/aws/aws-sdk-go/service/sts" ) type Config struct { @@ -92,8 +93,10 @@ type AWSClient struct { s3conn *s3.S3 sqsconn *sqs.SQS snsconn *sns.SNS + stsconn *sts.STS redshiftconn *redshift.Redshift r53conn *route53.Route53 + accountid string region string rdsconn *rds.RDS iamconn *iam.IAM @@ -172,6 +175,9 @@ func (c *Config) Client() (interface{}, error) { awsIamSess := sess.Copy(&aws.Config{Endpoint: aws.String(c.IamEndpoint)}) client.iamconn = iam.New(awsIamSess) + log.Println("[INFO] Initializing STS connection") + client.stsconn = sts.New(sess) + err = c.ValidateCredentials(client.iamconn) if err != nil { errs = append(errs, err) @@ -185,6 +191,11 @@ func (c *Config) Client() (interface{}, error) { // http://docs.aws.amazon.com/general/latest/gr/sigv4_changes.html usEast1Sess := sess.Copy(&aws.Config{Region: aws.String("us-east-1")}) + accountId, err := GetAccountId(client.iamconn, client.stsconn, cp.ProviderName) + if err == nil { + client.accountid = accountId + } + log.Println("[INFO] Initializing DynamoDB connection") dynamoSess := sess.Copy(&aws.Config{Endpoint: aws.String(c.DynamoDBEndpoint)}) client.dynamodbconn = dynamodb.New(dynamoSess) @@ -215,7 +226,7 @@ func (c *Config) Client() (interface{}, error) { log.Println("[INFO] Initializing Elastic Beanstalk Connection") client.elasticbeanstalkconn = elasticbeanstalk.New(sess) - authErr := c.ValidateAccountId(client.iamconn, cp.ProviderName) + authErr := c.ValidateAccountId(client.accountid) if authErr != nil { errs = append(errs, authErr) } @@ -338,20 +349,16 @@ func (c *Config) ValidateCredentials(iamconn *iam.IAM) error { // ValidateAccountId returns a context-specific error if the configured account // id is explicitly forbidden or not authorised; and nil if it is authorised. -func (c *Config) ValidateAccountId(iamconn *iam.IAM, authProviderName string) error { +func (c *Config) ValidateAccountId(accountId string) error { if c.AllowedAccountIds == nil && c.ForbiddenAccountIds == nil { return nil } log.Printf("[INFO] Validating account ID") - account_id, err := GetAccountId(iamconn, authProviderName) - if err != nil { - return err - } if c.ForbiddenAccountIds != nil { for _, id := range c.ForbiddenAccountIds { - if id == account_id { + if id == accountId { return fmt.Errorf("Forbidden account ID (%s)", id) } } @@ -359,11 +366,11 @@ func (c *Config) ValidateAccountId(iamconn *iam.IAM, authProviderName string) er if c.AllowedAccountIds != nil { for _, id := range c.AllowedAccountIds { - if id == account_id { + if id == accountId { return nil } } - return fmt.Errorf("Account ID not allowed (%s)", account_id) + return fmt.Errorf("Account ID not allowed (%s)", accountId) } return nil