fix(build/aws): updated aws dependency and fixed minor build issue, fixes #3026

This commit is contained in:
Torkel Ödegaard 2015-10-26 12:48:30 +01:00
parent 052c9aca15
commit da7ae2b0ab
77 changed files with 15465 additions and 8840 deletions

50
Godeps/Godeps.json generated
View File

@ -1,6 +1,6 @@
{
"ImportPath": "github.com/grafana/grafana",
"GoVersion": "go1.4.2",
"GoVersion": "go1.5",
"Packages": [
"./pkg/..."
],
@ -20,48 +20,48 @@
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws",
"Comment": "v0.7.3",
"Rev": "bed164a424e75154a40550c04c313ef51a7bb275"
"Comment": "v0.9.16-3-g4944a94",
"Rev": "4944a94a3b092d1dad8a964415700a70f55e580a"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/internal/endpoints",
"Comment": "v0.7.3",
"Rev": "bed164a424e75154a40550c04c313ef51a7bb275"
"ImportPath": "github.com/aws/aws-sdk-go/private/endpoints",
"Comment": "v0.9.16-3-g4944a94",
"Rev": "4944a94a3b092d1dad8a964415700a70f55e580a"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/internal/protocol/ec2query",
"Comment": "v0.7.3",
"Rev": "bed164a424e75154a40550c04c313ef51a7bb275"
"ImportPath": "github.com/aws/aws-sdk-go/private/protocol/ec2query",
"Comment": "v0.9.16-3-g4944a94",
"Rev": "4944a94a3b092d1dad8a964415700a70f55e580a"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/internal/protocol/query",
"Comment": "v0.7.3",
"Rev": "bed164a424e75154a40550c04c313ef51a7bb275"
"ImportPath": "github.com/aws/aws-sdk-go/private/protocol/query",
"Comment": "v0.9.16-3-g4944a94",
"Rev": "4944a94a3b092d1dad8a964415700a70f55e580a"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/internal/protocol/rest",
"Comment": "v0.7.3",
"Rev": "bed164a424e75154a40550c04c313ef51a7bb275"
"ImportPath": "github.com/aws/aws-sdk-go/private/protocol/rest",
"Comment": "v0.9.16-3-g4944a94",
"Rev": "4944a94a3b092d1dad8a964415700a70f55e580a"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/internal/protocol/xml/xmlutil",
"Comment": "v0.7.3",
"Rev": "bed164a424e75154a40550c04c313ef51a7bb275"
"ImportPath": "github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil",
"Comment": "v0.9.16-3-g4944a94",
"Rev": "4944a94a3b092d1dad8a964415700a70f55e580a"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/internal/signer/v4",
"Comment": "v0.7.3",
"Rev": "bed164a424e75154a40550c04c313ef51a7bb275"
"ImportPath": "github.com/aws/aws-sdk-go/private/signer/v4",
"Comment": "v0.9.16-3-g4944a94",
"Rev": "4944a94a3b092d1dad8a964415700a70f55e580a"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/service/cloudwatch",
"Comment": "v0.7.3",
"Rev": "bed164a424e75154a40550c04c313ef51a7bb275"
"Comment": "v0.9.16-3-g4944a94",
"Rev": "4944a94a3b092d1dad8a964415700a70f55e580a"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/service/ec2",
"Comment": "v0.7.3",
"Rev": "bed164a424e75154a40550c04c313ef51a7bb275"
"Comment": "v0.9.16-3-g4944a94",
"Rev": "4944a94a3b092d1dad8a964415700a70f55e580a"
},
{
"ImportPath": "github.com/davecgh/go-spew/spew",

View File

@ -113,7 +113,7 @@ func newRequestError(err Error, statusCode int, requestID string) *requestError
// Error returns the string representation of the error.
// Satisfies the error interface.
func (r requestError) Error() string {
extra := fmt.Sprintf("status code: %d, request id: [%s]",
extra := fmt.Sprintf("status code: %d, request id: %s",
r.statusCode, r.requestID)
return SprintError(r.Code(), r.Message(), extra, r.OrigErr())
}

View File

@ -2,48 +2,20 @@ package aws
import (
"net/http"
"os"
"time"
"github.com/aws/aws-sdk-go/aws/credentials"
)
// DefaultChainCredentials is a Credentials which will find the first available
// credentials Value from the list of Providers.
//
// This should be used in the default case. Once the type of credentials are
// known switching to the specific Credentials will be more efficient.
var DefaultChainCredentials = credentials.NewChainCredentials(
[]credentials.Provider{
&credentials.EnvProvider{},
&credentials.SharedCredentialsProvider{Filename: "", Profile: ""},
&credentials.EC2RoleProvider{ExpiryWindow: 5 * time.Minute},
})
// The default number of retries for a service. The value of -1 indicates that
// the service specific retry default will be used.
// DefaultRetries is the default number of retries for a service. The value of
// -1 indicates that the service specific retry default will be used.
const DefaultRetries = -1
// DefaultConfig is the default all service configuration will be based off of.
// By default, all clients use this structure for initialization options unless
// a custom configuration object is passed in.
//
// You may modify this global structure to change all default configuration
// in the SDK. Note that configuration options are copied by value, so any
// modifications must happen before constructing a client.
var DefaultConfig = NewConfig().
WithCredentials(DefaultChainCredentials).
WithRegion(os.Getenv("AWS_REGION")).
WithHTTPClient(http.DefaultClient).
WithMaxRetries(DefaultRetries).
WithLogger(NewDefaultLogger()).
WithLogLevel(LogOff)
// A Config provides service configuration for service clients. By default,
// all clients will use the {DefaultConfig} structure.
// all clients will use the {defaults.DefaultConfig} structure.
type Config struct {
// The credentials object to use when signing requests. Defaults to
// {DefaultChainCredentials}.
// {defaults.DefaultChainCredentials}.
Credentials *credentials.Credentials
// An optional endpoint URL (hostname only or fully qualified URI)
@ -102,6 +74,8 @@ type Config struct {
// @see http://docs.aws.amazon.com/AmazonS3/latest/dev/VirtualHosting.html
// Amazon S3: Virtual Hosting of Buckets
S3ForcePathStyle *bool
SleepDelay func(time.Duration)
}
// NewConfig returns a new Config pointer that can be chained with builder methods to
@ -190,6 +164,13 @@ func (c *Config) WithS3ForcePathStyle(force bool) *Config {
return c
}
// WithSleepDelay overrides the function used to sleep while waiting for the
// next retry. Defaults to time.Sleep.
func (c *Config) WithSleepDelay(fn func(time.Duration)) *Config {
c.SleepDelay = fn
return c
}
// Merge returns a new Config with the other Config's attribute values merged into
// this Config. If the other Config's attribute is nil it will not be merged into
// the new Config to be returned.
@ -244,6 +225,10 @@ func (c Config) Merge(other *Config) *Config {
dst.S3ForcePathStyle = other.S3ForcePathStyle
}
if other.SleepDelay != nil {
dst.SleepDelay = other.SleepDelay
}
return &dst
}

View File

@ -4,18 +4,11 @@ import (
"net/http"
"reflect"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws/credentials"
)
var testCredentials = credentials.NewChainCredentials([]credentials.Provider{
&credentials.EnvProvider{},
&credentials.SharedCredentialsProvider{
Filename: "TestFilename",
Profile: "TestProfile"},
&credentials.EC2RoleProvider{ExpiryWindow: 5 * time.Minute},
})
var testCredentials = credentials.NewStaticCredentials("AKID", "SECRET", "SESSION")
var copyTestConfig = Config{
Credentials: testCredentials,

View File

@ -1,10 +1,9 @@
package aws_test
package aws
import (
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/stretchr/testify/assert"
)
@ -18,20 +17,20 @@ func TestStringSlice(t *testing.T) {
if in == nil {
continue
}
out := aws.StringSlice(in)
out := StringSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.StringValueSlice(out)
out2 := StringValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}
var testCasesStringValueSlice = [][]*string{
{aws.String("a"), aws.String("b"), nil, aws.String("c")},
{String("a"), String("b"), nil, String("c")},
}
func TestStringValueSlice(t *testing.T) {
@ -39,7 +38,7 @@ func TestStringValueSlice(t *testing.T) {
if in == nil {
continue
}
out := aws.StringValueSlice(in)
out := StringValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
if in[i] == nil {
@ -49,7 +48,7 @@ func TestStringValueSlice(t *testing.T) {
}
}
out2 := aws.StringSlice(out)
out2 := StringSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
for i := range out2 {
if in[i] == nil {
@ -70,13 +69,13 @@ func TestStringMap(t *testing.T) {
if in == nil {
continue
}
out := aws.StringMap(in)
out := StringMap(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.StringValueMap(out)
out2 := StringValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
@ -91,13 +90,13 @@ func TestBoolSlice(t *testing.T) {
if in == nil {
continue
}
out := aws.BoolSlice(in)
out := BoolSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.BoolValueSlice(out)
out2 := BoolValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
@ -110,7 +109,7 @@ func TestBoolValueSlice(t *testing.T) {
if in == nil {
continue
}
out := aws.BoolValueSlice(in)
out := BoolValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
if in[i] == nil {
@ -120,7 +119,7 @@ func TestBoolValueSlice(t *testing.T) {
}
}
out2 := aws.BoolSlice(out)
out2 := BoolSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
for i := range out2 {
if in[i] == nil {
@ -141,13 +140,13 @@ func TestBoolMap(t *testing.T) {
if in == nil {
continue
}
out := aws.BoolMap(in)
out := BoolMap(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.BoolValueMap(out)
out2 := BoolValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
@ -162,13 +161,13 @@ func TestIntSlice(t *testing.T) {
if in == nil {
continue
}
out := aws.IntSlice(in)
out := IntSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.IntValueSlice(out)
out2 := IntValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
@ -181,7 +180,7 @@ func TestIntValueSlice(t *testing.T) {
if in == nil {
continue
}
out := aws.IntValueSlice(in)
out := IntValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
if in[i] == nil {
@ -191,7 +190,7 @@ func TestIntValueSlice(t *testing.T) {
}
}
out2 := aws.IntSlice(out)
out2 := IntSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
for i := range out2 {
if in[i] == nil {
@ -212,13 +211,13 @@ func TestIntMap(t *testing.T) {
if in == nil {
continue
}
out := aws.IntMap(in)
out := IntMap(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.IntValueMap(out)
out2 := IntValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
@ -233,13 +232,13 @@ func TestInt64Slice(t *testing.T) {
if in == nil {
continue
}
out := aws.Int64Slice(in)
out := Int64Slice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.Int64ValueSlice(out)
out2 := Int64ValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
@ -252,7 +251,7 @@ func TestInt64ValueSlice(t *testing.T) {
if in == nil {
continue
}
out := aws.Int64ValueSlice(in)
out := Int64ValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
if in[i] == nil {
@ -262,7 +261,7 @@ func TestInt64ValueSlice(t *testing.T) {
}
}
out2 := aws.Int64Slice(out)
out2 := Int64Slice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
for i := range out2 {
if in[i] == nil {
@ -283,13 +282,13 @@ func TestInt64Map(t *testing.T) {
if in == nil {
continue
}
out := aws.Int64Map(in)
out := Int64Map(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.Int64ValueMap(out)
out2 := Int64ValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
@ -304,13 +303,13 @@ func TestFloat64Slice(t *testing.T) {
if in == nil {
continue
}
out := aws.Float64Slice(in)
out := Float64Slice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.Float64ValueSlice(out)
out2 := Float64ValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
@ -323,7 +322,7 @@ func TestFloat64ValueSlice(t *testing.T) {
if in == nil {
continue
}
out := aws.Float64ValueSlice(in)
out := Float64ValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
if in[i] == nil {
@ -333,7 +332,7 @@ func TestFloat64ValueSlice(t *testing.T) {
}
}
out2 := aws.Float64Slice(out)
out2 := Float64Slice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
for i := range out2 {
if in[i] == nil {
@ -354,13 +353,13 @@ func TestFloat64Map(t *testing.T) {
if in == nil {
continue
}
out := aws.Float64Map(in)
out := Float64Map(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.Float64ValueMap(out)
out2 := Float64ValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
@ -375,13 +374,13 @@ func TestTimeSlice(t *testing.T) {
if in == nil {
continue
}
out := aws.TimeSlice(in)
out := TimeSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.TimeValueSlice(out)
out2 := TimeValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
@ -394,7 +393,7 @@ func TestTimeValueSlice(t *testing.T) {
if in == nil {
continue
}
out := aws.TimeValueSlice(in)
out := TimeValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
if in[i] == nil {
@ -404,7 +403,7 @@ func TestTimeValueSlice(t *testing.T) {
}
}
out2 := aws.TimeSlice(out)
out2 := TimeSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
for i := range out2 {
if in[i] == nil {
@ -425,13 +424,13 @@ func TestTimeMap(t *testing.T) {
if in == nil {
continue
}
out := aws.TimeMap(in)
out := TimeMap(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.TimeValueMap(out)
out2 := TimeValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}

View File

@ -1,4 +1,4 @@
package aws
package corehandlers
import (
"bytes"
@ -9,24 +9,21 @@ import (
"net/url"
"regexp"
"strconv"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
var sleepDelay = func(delay time.Duration) {
time.Sleep(delay)
}
// Interface for matching types which also have a Len method.
type lener interface {
Len() int
}
// BuildContentLength builds the content length of a request based on the body,
// BuildContentLengthHandler builds the content length of a request based on the body,
// or will use the HTTPRequest.Header's "Content-Length" if defined. If unable
// to determine request body length and no "Content-Length" was specified it will panic.
func BuildContentLength(r *Request) {
var BuildContentLengthHandler = request.NamedHandler{"core.BuildContentLengthHandler", func(r *request.Request) {
if slength := r.HTTPRequest.Header.Get("Content-Length"); slength != "" {
length, _ := strconv.ParseInt(slength, 10, 64)
r.HTTPRequest.ContentLength = length
@ -40,27 +37,27 @@ func BuildContentLength(r *Request) {
case lener:
length = int64(body.Len())
case io.Seeker:
r.bodyStart, _ = body.Seek(0, 1)
r.BodyStart, _ = body.Seek(0, 1)
end, _ := body.Seek(0, 2)
body.Seek(r.bodyStart, 0) // make sure to seek back to original location
length = end - r.bodyStart
body.Seek(r.BodyStart, 0) // make sure to seek back to original location
length = end - r.BodyStart
default:
panic("Cannot get length of body, must provide `ContentLength`")
}
r.HTTPRequest.ContentLength = length
r.HTTPRequest.Header.Set("Content-Length", fmt.Sprintf("%d", length))
}
}}
// UserAgentHandler is a request handler for injecting User agent into requests.
func UserAgentHandler(r *Request) {
r.HTTPRequest.Header.Set("User-Agent", SDKName+"/"+SDKVersion)
}
var UserAgentHandler = request.NamedHandler{"core.UserAgentHandler", func(r *request.Request) {
r.HTTPRequest.Header.Set("User-Agent", aws.SDKName+"/"+aws.SDKVersion)
}}
var reStatusCode = regexp.MustCompile(`^(\d+)`)
var reStatusCode = regexp.MustCompile(`^(\d{3})`)
// SendHandler is a request handler to send service request using HTTP client.
func SendHandler(r *Request) {
var SendHandler = request.NamedHandler{"core.SendHandler", func(r *request.Request) {
var err error
r.HTTPResponse, err = r.Service.Config.HTTPClient.Do(r.HTTPRequest)
if err != nil {
@ -68,8 +65,8 @@ func SendHandler(r *Request) {
// response. e.g. 301 without location header comes back as string
// error and r.HTTPResponse is nil. Other url redirect errors will
// comeback in a similar method.
if e, ok := err.(*url.Error); ok {
if s := reStatusCode.FindStringSubmatch(e.Error()); s != nil {
if e, ok := err.(*url.Error); ok && e.Err != nil {
if s := reStatusCode.FindStringSubmatch(e.Err.Error()); s != nil {
code, _ := strconv.ParseInt(s[1], 10, 64)
r.HTTPResponse = &http.Response{
StatusCode: int(code),
@ -79,7 +76,7 @@ func SendHandler(r *Request) {
return
}
}
if r.HTTPRequest == nil {
if r.HTTPResponse == nil {
// Add a dummy request response object to ensure the HTTPResponse
// value is consistent.
r.HTTPResponse = &http.Response{
@ -90,68 +87,50 @@ func SendHandler(r *Request) {
}
// Catch all other request errors.
r.Error = awserr.New("RequestError", "send request failed", err)
r.Retryable = Bool(true) // network errors are retryable
r.Retryable = aws.Bool(true) // network errors are retryable
}
}
}}
// ValidateResponseHandler is a request handler to validate service response.
func ValidateResponseHandler(r *Request) {
var ValidateResponseHandler = request.NamedHandler{"core.ValidateResponseHandler", func(r *request.Request) {
if r.HTTPResponse.StatusCode == 0 || r.HTTPResponse.StatusCode >= 300 {
// this may be replaced by an UnmarshalError handler
r.Error = awserr.New("UnknownError", "unknown error", nil)
}
}
}}
// AfterRetryHandler performs final checks to determine if the request should
// be retried and how long to delay.
func AfterRetryHandler(r *Request) {
var AfterRetryHandler = request.NamedHandler{"core.AfterRetryHandler", func(r *request.Request) {
// If one of the other handlers already set the retry state
// we don't want to override it based on the service's state
if r.Retryable == nil {
r.Retryable = Bool(r.Service.ShouldRetry(r))
r.Retryable = aws.Bool(r.ShouldRetry(r))
}
if r.WillRetry() {
r.RetryDelay = r.Service.RetryRules(r)
sleepDelay(r.RetryDelay)
r.RetryDelay = r.RetryRules(r)
r.Service.Config.SleepDelay(r.RetryDelay)
// when the expired token exception occurs the credentials
// need to be expired locally so that the next request to
// get credentials will trigger a credentials refresh.
if r.Error != nil {
if err, ok := r.Error.(awserr.Error); ok {
if isCodeExpiredCreds(err.Code()) {
r.Config.Credentials.Expire()
}
}
if r.IsErrorExpired() {
r.Service.Config.Credentials.Expire()
}
r.RetryCount++
r.Error = nil
}
}
var (
// ErrMissingRegion is an error that is returned if region configuration is
// not found.
//
// @readonly
ErrMissingRegion error = awserr.New("MissingRegion", "could not find region configuration", nil)
// ErrMissingEndpoint is an error that is returned if an endpoint cannot be
// resolved for a service.
//
// @readonly
ErrMissingEndpoint error = awserr.New("MissingEndpoint", "'Endpoint' configuration is required for this service", nil)
)
}}
// ValidateEndpointHandler is a request handler to validate a request had the
// appropriate Region and Endpoint set. Will set r.Error if the endpoint or
// region is not valid.
func ValidateEndpointHandler(r *Request) {
if r.Service.SigningRegion == "" && StringValue(r.Service.Config.Region) == "" {
r.Error = ErrMissingRegion
var ValidateEndpointHandler = request.NamedHandler{"core.ValidateEndpointHandler", func(r *request.Request) {
if r.Service.SigningRegion == "" && aws.StringValue(r.Service.Config.Region) == "" {
r.Error = aws.ErrMissingRegion
} else if r.Service.Endpoint == "" {
r.Error = ErrMissingEndpoint
r.Error = aws.ErrMissingEndpoint
}
}
}}

View File

@ -0,0 +1,107 @@
package corehandlers_test
import (
"fmt"
"net/http"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/service"
)
func TestValidateEndpointHandler(t *testing.T) {
os.Clearenv()
svc := service.New(aws.NewConfig().WithRegion("us-west-2"))
svc.Handlers.Clear()
svc.Handlers.Validate.PushBackNamed(corehandlers.ValidateEndpointHandler)
req := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
err := req.Build()
assert.NoError(t, err)
}
func TestValidateEndpointHandlerErrorRegion(t *testing.T) {
os.Clearenv()
svc := service.New(nil)
svc.Handlers.Clear()
svc.Handlers.Validate.PushBackNamed(corehandlers.ValidateEndpointHandler)
req := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
err := req.Build()
assert.Error(t, err)
assert.Equal(t, aws.ErrMissingRegion, err)
}
type mockCredsProvider struct {
expired bool
retrieveCalled bool
}
func (m *mockCredsProvider) Retrieve() (credentials.Value, error) {
m.retrieveCalled = true
return credentials.Value{}, nil
}
func (m *mockCredsProvider) IsExpired() bool {
return m.expired
}
func TestAfterRetryRefreshCreds(t *testing.T) {
os.Clearenv()
credProvider := &mockCredsProvider{}
svc := service.New(&aws.Config{Credentials: credentials.NewCredentials(credProvider), MaxRetries: aws.Int(1)})
svc.Handlers.Clear()
svc.Handlers.ValidateResponse.PushBack(func(r *request.Request) {
r.Error = awserr.New("UnknownError", "", nil)
r.HTTPResponse = &http.Response{StatusCode: 400}
})
svc.Handlers.UnmarshalError.PushBack(func(r *request.Request) {
r.Error = awserr.New("ExpiredTokenException", "", nil)
})
svc.Handlers.AfterRetry.PushBackNamed(corehandlers.AfterRetryHandler)
assert.True(t, svc.Config.Credentials.IsExpired(), "Expect to start out expired")
assert.False(t, credProvider.retrieveCalled)
req := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
req.Send()
assert.True(t, svc.Config.Credentials.IsExpired())
assert.False(t, credProvider.retrieveCalled)
_, err := svc.Config.Credentials.Get()
assert.NoError(t, err)
assert.True(t, credProvider.retrieveCalled)
}
type testSendHandlerTransport struct{}
func (t *testSendHandlerTransport) RoundTrip(r *http.Request) (*http.Response, error) {
return nil, fmt.Errorf("mock error")
}
func TestSendHandlerError(t *testing.T) {
svc := service.New(&aws.Config{
HTTPClient: &http.Client{
Transport: &testSendHandlerTransport{},
},
})
svc.Handlers.Clear()
svc.Handlers.Send.PushBackNamed(corehandlers.SendHandler)
r := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
r.Send()
assert.Error(t, r.Error)
assert.NotNil(t, r.HTTPResponse)
}

View File

@ -0,0 +1,144 @@
package corehandlers
import (
"fmt"
"reflect"
"strconv"
"strings"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
// ValidateParametersHandler is a request handler to validate the input parameters.
// Validating parameters only has meaning if done prior to the request being sent.
var ValidateParametersHandler = request.NamedHandler{"core.ValidateParametersHandler", func(r *request.Request) {
if r.ParamsFilled() {
v := validator{errors: []string{}}
v.validateAny(reflect.ValueOf(r.Params), "")
if count := len(v.errors); count > 0 {
format := "%d validation errors:\n- %s"
msg := fmt.Sprintf(format, count, strings.Join(v.errors, "\n- "))
r.Error = awserr.New("InvalidParameter", msg, nil)
}
}
}}
// A validator validates values. Collects validations errors which occurs.
type validator struct {
errors []string
}
// validateAny will validate any struct, slice or map type. All validations
// are also performed recursively for nested types.
func (v *validator) validateAny(value reflect.Value, path string) {
value = reflect.Indirect(value)
if !value.IsValid() {
return
}
switch value.Kind() {
case reflect.Struct:
v.validateStruct(value, path)
case reflect.Slice:
for i := 0; i < value.Len(); i++ {
v.validateAny(value.Index(i), path+fmt.Sprintf("[%d]", i))
}
case reflect.Map:
for _, n := range value.MapKeys() {
v.validateAny(value.MapIndex(n), path+fmt.Sprintf("[%q]", n.String()))
}
}
}
// validateStruct will validate the struct value's fields. If the structure has
// nested types those types will be validated also.
func (v *validator) validateStruct(value reflect.Value, path string) {
prefix := "."
if path == "" {
prefix = ""
}
for i := 0; i < value.Type().NumField(); i++ {
f := value.Type().Field(i)
if strings.ToLower(f.Name[0:1]) == f.Name[0:1] {
continue
}
fvalue := value.FieldByName(f.Name)
err := validateField(f, fvalue, validateFieldRequired, validateFieldMin)
if err != nil {
v.errors = append(v.errors, fmt.Sprintf("%s: %s", err.Error(), path+prefix+f.Name))
continue
}
v.validateAny(fvalue, path+prefix+f.Name)
}
}
type validatorFunc func(f reflect.StructField, fvalue reflect.Value) error
func validateField(f reflect.StructField, fvalue reflect.Value, funcs ...validatorFunc) error {
for _, fn := range funcs {
if err := fn(f, fvalue); err != nil {
return err
}
}
return nil
}
// Validates that a field has a valid value provided for required fields.
func validateFieldRequired(f reflect.StructField, fvalue reflect.Value) error {
if f.Tag.Get("required") == "" {
return nil
}
switch fvalue.Kind() {
case reflect.Ptr, reflect.Slice, reflect.Map:
if fvalue.IsNil() {
return fmt.Errorf("missing required parameter")
}
default:
if !fvalue.IsValid() {
return fmt.Errorf("missing required parameter")
}
}
return nil
}
// Validates that if a value is provided for a field, that value must be at
// least a minimum length.
func validateFieldMin(f reflect.StructField, fvalue reflect.Value) error {
minStr := f.Tag.Get("min")
if minStr == "" {
return nil
}
min, _ := strconv.ParseInt(minStr, 10, 64)
kind := fvalue.Kind()
if kind == reflect.Ptr {
if fvalue.IsNil() {
return nil
}
fvalue = fvalue.Elem()
}
switch fvalue.Kind() {
case reflect.String:
if int64(fvalue.Len()) < min {
return fmt.Errorf("field too short, minimum length %d", min)
}
case reflect.Slice, reflect.Map:
if fvalue.IsNil() {
return nil
}
if int64(fvalue.Len()) < min {
return fmt.Errorf("field too short, minimum length %d", min)
}
// TODO min can also apply to number minimum value.
}
return nil
}

View File

@ -0,0 +1,134 @@
package corehandlers_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/service"
"github.com/aws/aws-sdk-go/aws/service/serviceinfo"
"github.com/stretchr/testify/require"
)
var testSvc = func() *service.Service {
s := &service.Service{
ServiceInfo: serviceinfo.ServiceInfo{
Config: &aws.Config{},
ServiceName: "mock-service",
APIVersion: "2015-01-01",
},
}
return s
}()
type StructShape struct {
RequiredList []*ConditionalStructShape `required:"true"`
RequiredMap map[string]*ConditionalStructShape `required:"true"`
RequiredBool *bool `required:"true"`
OptionalStruct *ConditionalStructShape
hiddenParameter *string
metadataStructureShape
}
type metadataStructureShape struct {
SDKShapeTraits bool
}
type ConditionalStructShape struct {
Name *string `required:"true"`
SDKShapeTraits bool
}
func TestNoErrors(t *testing.T) {
input := &StructShape{
RequiredList: []*ConditionalStructShape{},
RequiredMap: map[string]*ConditionalStructShape{
"key1": {Name: aws.String("Name")},
"key2": {Name: aws.String("Name")},
},
RequiredBool: aws.Bool(true),
OptionalStruct: &ConditionalStructShape{Name: aws.String("Name")},
}
req := testSvc.NewRequest(&request.Operation{}, input, nil)
corehandlers.ValidateParametersHandler.Fn(req)
require.NoError(t, req.Error)
}
func TestMissingRequiredParameters(t *testing.T) {
input := &StructShape{}
req := testSvc.NewRequest(&request.Operation{}, input, nil)
corehandlers.ValidateParametersHandler.Fn(req)
require.Error(t, req.Error)
assert.Equal(t, "InvalidParameter", req.Error.(awserr.Error).Code())
assert.Equal(t, "3 validation errors:\n- missing required parameter: RequiredList\n- missing required parameter: RequiredMap\n- missing required parameter: RequiredBool", req.Error.(awserr.Error).Message())
}
func TestNestedMissingRequiredParameters(t *testing.T) {
input := &StructShape{
RequiredList: []*ConditionalStructShape{{}},
RequiredMap: map[string]*ConditionalStructShape{
"key1": {Name: aws.String("Name")},
"key2": {},
},
RequiredBool: aws.Bool(true),
OptionalStruct: &ConditionalStructShape{},
}
req := testSvc.NewRequest(&request.Operation{}, input, nil)
corehandlers.ValidateParametersHandler.Fn(req)
require.Error(t, req.Error)
assert.Equal(t, "InvalidParameter", req.Error.(awserr.Error).Code())
assert.Equal(t, "3 validation errors:\n- missing required parameter: RequiredList[0].Name\n- missing required parameter: RequiredMap[\"key2\"].Name\n- missing required parameter: OptionalStruct.Name", req.Error.(awserr.Error).Message())
}
type testInput struct {
StringField string `min:"5"`
PtrStrField *string `min:"2"`
ListField []string `min:"3"`
MapField map[string]string `min:"4"`
}
var testsFieldMin = []struct {
err awserr.Error
in testInput
}{
{
err: awserr.New("InvalidParameter", "1 validation errors:\n- field too short, minimum length 5: StringField", nil),
in: testInput{StringField: "abcd"},
},
{
err: awserr.New("InvalidParameter", "2 validation errors:\n- field too short, minimum length 5: StringField\n- field too short, minimum length 3: ListField", nil),
in: testInput{StringField: "abcd", ListField: []string{"a", "b"}},
},
{
err: awserr.New("InvalidParameter", "3 validation errors:\n- field too short, minimum length 5: StringField\n- field too short, minimum length 3: ListField\n- field too short, minimum length 4: MapField", nil),
in: testInput{StringField: "abcd", ListField: []string{"a", "b"}, MapField: map[string]string{"a": "a", "b": "b"}},
},
{
err: awserr.New("InvalidParameter", "1 validation errors:\n- field too short, minimum length 2: PtrStrField", nil),
in: testInput{StringField: "abcde", PtrStrField: aws.String("v")},
},
{
err: nil,
in: testInput{StringField: "abcde", PtrStrField: aws.String("value"),
ListField: []string{"a", "b", "c"}, MapField: map[string]string{"a": "a", "b": "b", "c": "c", "d": "d"}},
},
}
func TestValidateFieldMinParameter(t *testing.T) {
for i, c := range testsFieldMin {
req := testSvc.NewRequest(&request.Operation{}, &c.in, nil)
corehandlers.ValidateParametersHandler.Fn(req)
require.Equal(t, c.err, req.Error, "%d case failed", i)
}
}

View File

@ -53,8 +53,8 @@ import (
"time"
)
// Create an empty Credential object that can be used as dummy placeholder
// credentials for requests that do not need signed.
// AnonymousCredentials is an empty Credential object that can be used as
// dummy placeholder credentials for requests that do not need signed.
//
// This Credentials can be used to configure a service to not sign requests
// when making service API calls. For example, when accessing public

View File

@ -1,108 +0,0 @@
package credentials
import (
"fmt"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func initTestServer(expireOn string) *httptest.Server {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.RequestURI == "/" {
fmt.Fprintln(w, "/creds")
} else {
fmt.Fprintf(w, `{
"AccessKeyId" : "accessKey",
"SecretAccessKey" : "secret",
"Token" : "token",
"Expiration" : "%s"
}`, expireOn)
}
}))
return server
}
func TestEC2RoleProvider(t *testing.T) {
server := initTestServer("2014-12-16T01:51:37Z")
defer server.Close()
p := &EC2RoleProvider{Client: http.DefaultClient, Endpoint: server.URL}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
}
func TestEC2RoleProviderIsExpired(t *testing.T) {
server := initTestServer("2014-12-16T01:51:37Z")
defer server.Close()
p := &EC2RoleProvider{Client: http.DefaultClient, Endpoint: server.URL}
p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 15, 21, 26, 0, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired before retrieve.")
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve.")
p.CurrentTime = func() time.Time {
return time.Date(3014, 12, 15, 21, 26, 0, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired.")
}
func TestEC2RoleProviderExpiryWindowIsExpired(t *testing.T) {
server := initTestServer("2014-12-16T01:51:37Z")
defer server.Close()
p := &EC2RoleProvider{Client: http.DefaultClient, Endpoint: server.URL, ExpiryWindow: time.Hour * 1}
p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 15, 0, 51, 37, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired before retrieve.")
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve.")
p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 16, 0, 55, 37, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired.")
}
func BenchmarkEC2RoleProvider(b *testing.B) {
server := initTestServer("2014-12-16T01:51:37Z")
defer server.Close()
p := &EC2RoleProvider{Client: http.DefaultClient, Endpoint: server.URL}
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
}
})
}

View File

@ -1,24 +1,25 @@
package credentials
package ec2rolecreds
import (
"bufio"
"encoding/json"
"fmt"
"net/http"
"path"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
)
const metadataCredentialsEndpoint = "http://169.254.169.254/latest/meta-data/iam/security-credentials/"
// A EC2RoleProvider retrieves credentials from the EC2 service, and keeps track if
// those credentials are expired.
//
// Example how to configure the EC2RoleProvider with custom http Client, Endpoint
// or ExpiryWindow
//
// p := &credentials.EC2RoleProvider{
// p := &ec2rolecreds.EC2RoleProvider{
// // Pass in a custom timeout to be used when requesting
// // IAM EC2 Role credentials.
// Client: &http.Client{
@ -32,13 +33,10 @@ const metadataCredentialsEndpoint = "http://169.254.169.254/latest/meta-data/iam
// ExpiryWindow: 0,
// }
type EC2RoleProvider struct {
Expiry
credentials.Expiry
// Endpoint must be fully quantified URL
Endpoint string
// HTTP client to use when connecting to EC2 service
Client *http.Client
// EC2Metadata client to use when connecting to EC2 metadata service
Client *ec2metadata.Client
// ExpiryWindow will allow the credentials to trigger refreshing prior to
// the credentials actually expiring. This is beneficial so race conditions
@ -52,7 +50,7 @@ type EC2RoleProvider struct {
ExpiryWindow time.Duration
}
// NewEC2RoleCredentials returns a pointer to a new Credentials object
// NewCredentials returns a pointer to a new Credentials object
// wrapping the EC2RoleProvider.
//
// Takes a custom http.Client which can be configured for custom handling of
@ -64,9 +62,8 @@ type EC2RoleProvider struct {
// Window is the expiry window that will be subtracted from the expiry returned
// by the role credential request. This is done so that the credentials will
// expire sooner than their actual lifespan.
func NewEC2RoleCredentials(client *http.Client, endpoint string, window time.Duration) *Credentials {
return NewCredentials(&EC2RoleProvider{
Endpoint: endpoint,
func NewCredentials(client *ec2metadata.Client, window time.Duration) *credentials.Credentials {
return credentials.NewCredentials(&EC2RoleProvider{
Client: client,
ExpiryWindow: window,
})
@ -75,32 +72,29 @@ func NewEC2RoleCredentials(client *http.Client, endpoint string, window time.Dur
// Retrieve retrieves credentials from the EC2 service.
// Error will be returned if the request fails, or unable to extract
// the desired credentials.
func (m *EC2RoleProvider) Retrieve() (Value, error) {
func (m *EC2RoleProvider) Retrieve() (credentials.Value, error) {
if m.Client == nil {
m.Client = http.DefaultClient
}
if m.Endpoint == "" {
m.Endpoint = metadataCredentialsEndpoint
m.Client = ec2metadata.New(nil)
}
credsList, err := requestCredList(m.Client, m.Endpoint)
credsList, err := requestCredList(m.Client)
if err != nil {
return Value{}, err
return credentials.Value{}, err
}
if len(credsList) == 0 {
return Value{}, awserr.New("EmptyEC2RoleList", "empty EC2 Role list", nil)
return credentials.Value{}, awserr.New("EmptyEC2RoleList", "empty EC2 Role list", nil)
}
credsName := credsList[0]
roleCreds, err := requestCred(m.Client, m.Endpoint, credsName)
roleCreds, err := requestCred(m.Client, credsName)
if err != nil {
return Value{}, err
return credentials.Value{}, err
}
m.SetExpiration(roleCreds.Expiration, m.ExpiryWindow)
return Value{
return credentials.Value{
AccessKeyID: roleCreds.AccessKeyID,
SecretAccessKey: roleCreds.SecretAccessKey,
SessionToken: roleCreds.Token,
@ -110,29 +104,35 @@ func (m *EC2RoleProvider) Retrieve() (Value, error) {
// A ec2RoleCredRespBody provides the shape for deserializing credential
// request responses.
type ec2RoleCredRespBody struct {
// Success State
Expiration time.Time
AccessKeyID string
SecretAccessKey string
Token string
// Error state
Code string
Message string
}
const iamSecurityCredsPath = "/iam/security-credentials"
// requestCredList requests a list of credentials from the EC2 service.
// If there are no credentials, or there is an error making or receiving the request
func requestCredList(client *http.Client, endpoint string) ([]string, error) {
resp, err := client.Get(endpoint)
func requestCredList(client *ec2metadata.Client) ([]string, error) {
resp, err := client.GetMetadata(iamSecurityCredsPath)
if err != nil {
return nil, awserr.New("ListEC2Role", "failed to list EC2 Roles", err)
return nil, awserr.New("EC2RoleRequestError", "failed to list EC2 Roles", err)
}
defer resp.Body.Close()
credsList := []string{}
s := bufio.NewScanner(resp.Body)
s := bufio.NewScanner(strings.NewReader(resp))
for s.Scan() {
credsList = append(credsList, s.Text())
}
if err := s.Err(); err != nil {
return nil, awserr.New("ReadEC2Role", "failed to read list of EC2 Roles", err)
return nil, awserr.New("SerializationError", "failed to read list of EC2 Roles", err)
}
return credsList, nil
@ -142,20 +142,26 @@ func requestCredList(client *http.Client, endpoint string) ([]string, error) {
//
// If the credentials cannot be found, or there is an error reading the response
// and error will be returned.
func requestCred(client *http.Client, endpoint, credsName string) (*ec2RoleCredRespBody, error) {
resp, err := client.Get(endpoint + credsName)
func requestCred(client *ec2metadata.Client, credsName string) (ec2RoleCredRespBody, error) {
resp, err := client.GetMetadata(path.Join(iamSecurityCredsPath, credsName))
if err != nil {
return nil, awserr.New("GetEC2RoleCredentials",
fmt.Sprintf("failed to get %s EC2 Role credentials", credsName),
err)
return ec2RoleCredRespBody{},
awserr.New("EC2RoleRequestError",
fmt.Sprintf("failed to get %s EC2 Role credentials", credsName),
err)
}
defer resp.Body.Close()
respCreds := &ec2RoleCredRespBody{}
if err := json.NewDecoder(resp.Body).Decode(respCreds); err != nil {
return nil, awserr.New("DecodeEC2RoleCredentials",
fmt.Sprintf("failed to decode %s EC2 Role credentials", credsName),
err)
respCreds := ec2RoleCredRespBody{}
if err := json.NewDecoder(strings.NewReader(resp)).Decode(&respCreds); err != nil {
return ec2RoleCredRespBody{},
awserr.New("SerializationError",
fmt.Sprintf("failed to decode %s EC2 Role credentials", credsName),
err)
}
if respCreds.Code != "Success" {
// If an error code was returned something failed requesting the role.
return ec2RoleCredRespBody{}, awserr.New(respCreds.Code, respCreds.Message, nil)
}
return respCreds, nil

View File

@ -0,0 +1,161 @@
package ec2rolecreds_test
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
)
const credsRespTmpl = `{
"Code": "Success",
"Type": "AWS-HMAC",
"AccessKeyId" : "accessKey",
"SecretAccessKey" : "secret",
"Token" : "token",
"Expiration" : "%s",
"LastUpdated" : "2009-11-23T0:00:00Z"
}`
const credsFailRespTmpl = `{
"Code": "ErrorCode",
"Message": "ErrorMsg",
"LastUpdated": "2009-11-23T0:00:00Z"
}`
func initTestServer(expireOn string, failAssume bool) *httptest.Server {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/latest/meta-data/iam/security-credentials" {
fmt.Fprintln(w, "RoleName")
} else if r.URL.Path == "/latest/meta-data/iam/security-credentials/RoleName" {
if failAssume {
fmt.Fprintf(w, credsFailRespTmpl)
} else {
fmt.Fprintf(w, credsRespTmpl, expireOn)
}
} else {
http.Error(w, "bad request", http.StatusBadRequest)
}
}))
return server
}
func TestEC2RoleProvider(t *testing.T) {
server := initTestServer("2014-12-16T01:51:37Z", false)
defer server.Close()
p := &ec2rolecreds.EC2RoleProvider{
Client: ec2metadata.New(&ec2metadata.Config{Endpoint: aws.String(server.URL + "/latest")}),
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
}
func TestEC2RoleProviderFailAssume(t *testing.T) {
server := initTestServer("2014-12-16T01:51:37Z", true)
defer server.Close()
p := &ec2rolecreds.EC2RoleProvider{
Client: ec2metadata.New(&ec2metadata.Config{Endpoint: aws.String(server.URL + "/latest")}),
}
creds, err := p.Retrieve()
assert.Error(t, err, "Expect error")
e := err.(awserr.Error)
assert.Equal(t, "ErrorCode", e.Code())
assert.Equal(t, "ErrorMsg", e.Message())
assert.Nil(t, e.OrigErr())
assert.Equal(t, "", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "", creds.SessionToken, "Expect session token to match")
}
func TestEC2RoleProviderIsExpired(t *testing.T) {
server := initTestServer("2014-12-16T01:51:37Z", false)
defer server.Close()
p := &ec2rolecreds.EC2RoleProvider{
Client: ec2metadata.New(&ec2metadata.Config{Endpoint: aws.String(server.URL + "/latest")}),
}
p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 15, 21, 26, 0, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired before retrieve.")
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve.")
p.CurrentTime = func() time.Time {
return time.Date(3014, 12, 15, 21, 26, 0, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired.")
}
func TestEC2RoleProviderExpiryWindowIsExpired(t *testing.T) {
server := initTestServer("2014-12-16T01:51:37Z", false)
defer server.Close()
p := &ec2rolecreds.EC2RoleProvider{
Client: ec2metadata.New(&ec2metadata.Config{Endpoint: aws.String(server.URL + "/latest")}),
ExpiryWindow: time.Hour * 1,
}
p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 15, 0, 51, 37, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired before retrieve.")
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve.")
p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 16, 0, 55, 37, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired.")
}
func BenchmarkEC2RoleProvider(b *testing.B) {
server := initTestServer("2014-12-16T01:51:37Z", false)
defer server.Close()
p := &ec2rolecreds.EC2RoleProvider{
Client: ec2metadata.New(&ec2metadata.Config{Endpoint: aws.String(server.URL + "/latest")}),
}
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
}
})
}

View File

@ -22,8 +22,12 @@ var (
//
// Profile ini file example: $HOME/.aws/credentials
type SharedCredentialsProvider struct {
// Path to the shared credentials file. If empty will default to current user's
// home directory.
// Path to the shared credentials file.
//
// If empty will look for "AWS_SHARED_CREDENTIALS_FILE" env variable. If the
// env value is empty will default to current user's home directory.
// Linux/OSX: "$HOME/.aws/credentials"
// Windows: "%USERPROFILE%\.aws\credentials"
Filename string
// AWS Profile to extract credentials from the shared credentials file. If empty
@ -106,6 +110,10 @@ func loadProfile(filename, profile string) (Value, error) {
// Will return an error if the user's home directory path cannot be found.
func (p *SharedCredentialsProvider) filename() (string, error) {
if p.Filename == "" {
if p.Filename = os.Getenv("AWS_SHARED_CREDENTIALS_FILE"); p.Filename != "" {
return p.Filename, nil
}
homeDir := os.Getenv("HOME") // *nix
if homeDir == "" { // Windows
homeDir = os.Getenv("USERPROFILE")

View File

@ -31,6 +31,19 @@ func TestSharedCredentialsProviderIsExpired(t *testing.T) {
assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve")
}
func TestSharedCredentialsProviderWithAWS_SHARED_CREDENTIALS_FILE(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "example.ini")
p := SharedCredentialsProvider{}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
}
func TestSharedCredentialsProviderWithAWS_PROFILE(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_PROFILE", "no_token")
@ -66,12 +79,10 @@ func BenchmarkSharedCredentialsProvider(b *testing.B) {
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
for i := 0; i < b.N; i++ {
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
})
}
}

View File

@ -6,10 +6,11 @@ package stscreds
import (
"fmt"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/service/sts"
"time"
)
// AssumeRoler represents the minimal subset of the STS client API used by this provider.
@ -53,6 +54,9 @@ type AssumeRoleProvider struct {
// Expiry duration of the STS credentials. Defaults to 15 minutes if not set.
Duration time.Duration
// Optional ExternalID to pass along, defaults to nil if not set.
ExternalID *string
// ExpiryWindow will allow the credentials to trigger refreshing prior to
// the credentials actually expiring. This is beneficial so race conditions
// with expiring credentials do not cause request to fail unexpectedly
@ -101,8 +105,9 @@ func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
roleOutput, err := p.Client.AssumeRole(&sts.AssumeRoleInput{
DurationSeconds: aws.Int64(int64(p.Duration / time.Second)),
RoleARN: aws.String(p.RoleARN),
RoleArn: aws.String(p.RoleARN),
RoleSessionName: aws.String(p.RoleSessionName),
ExternalId: p.ExternalID,
})
if err != nil {
@ -113,7 +118,7 @@ func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
p.SetExpiration(*roleOutput.Credentials.Expiration, p.ExpiryWindow)
return credentials.Value{
AccessKeyID: *roleOutput.Credentials.AccessKeyID,
AccessKeyID: *roleOutput.Credentials.AccessKeyId,
SecretAccessKey: *roleOutput.Credentials.SecretAccessKey,
SessionToken: *roleOutput.Credentials.SessionToken,
}, nil

View File

@ -1,11 +1,12 @@
package stscreds
import (
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/stretchr/testify/assert"
"testing"
"time"
)
type stubSTS struct {
@ -16,7 +17,7 @@ func (s *stubSTS) AssumeRole(input *sts.AssumeRoleInput) (*sts.AssumeRoleOutput,
return &sts.AssumeRoleOutput{
Credentials: &sts.Credentials{
// Just reflect the role arn to the provider.
AccessKeyID: input.RoleARN,
AccessKeyId: input.RoleArn,
SecretAccessKey: aws.String("assumedSecretAccessKey"),
SessionToken: aws.String("assumedSessionToken"),
Expiration: &expiry,

View File

@ -0,0 +1,39 @@
package defaults
import (
"net/http"
"os"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
)
// DefaultChainCredentials is a Credentials which will find the first available
// credentials Value from the list of Providers.
//
// This should be used in the default case. Once the type of credentials are
// known switching to the specific Credentials will be more efficient.
var DefaultChainCredentials = credentials.NewChainCredentials(
[]credentials.Provider{
&credentials.EnvProvider{},
&credentials.SharedCredentialsProvider{Filename: "", Profile: ""},
&ec2rolecreds.EC2RoleProvider{ExpiryWindow: 5 * time.Minute},
})
// DefaultConfig is the default all service configuration will be based off of.
// By default, all clients use this structure for initialization options unless
// a custom configuration object is passed in.
//
// You may modify this global structure to change all default configuration
// in the SDK. Note that configuration options are copied by value, so any
// modifications must happen before constructing a client.
var DefaultConfig = aws.NewConfig().
WithCredentials(DefaultChainCredentials).
WithRegion(os.Getenv("AWS_REGION")).
WithHTTPClient(http.DefaultClient).
WithMaxRetries(aws.DefaultRetries).
WithLogger(aws.NewDefaultLogger()).
WithLogLevel(aws.LogOff).
WithSleepDelay(time.Sleep)

View File

@ -0,0 +1,43 @@
package ec2metadata
import (
"path"
"github.com/aws/aws-sdk-go/aws/request"
)
// GetMetadata uses the path provided to request
func (c *Client) GetMetadata(p string) (string, error) {
op := &request.Operation{
Name: "GetMetadata",
HTTPMethod: "GET",
HTTPPath: path.Join("/", "meta-data", p),
}
output := &metadataOutput{}
req := request.New(c.Service.ServiceInfo, c.Service.Handlers, c.Service.Retryer, op, nil, output)
return output.Content, req.Send()
}
// Region returns the region the instance is running in.
func (c *Client) Region() (string, error) {
resp, err := c.GetMetadata("placement/availability-zone")
if err != nil {
return "", err
}
// returns region without the suffix. Eg: us-west-2a becomes us-west-2
return resp[:len(resp)-1], nil
}
// Available returns if the application has access to the EC2 Metadata service.
// Can be used to determine if application is running within an EC2 Instance and
// the metadata service is available.
func (c *Client) Available() bool {
if _, err := c.GetMetadata("instance-id"); err != nil {
return false
}
return true
}

View File

@ -0,0 +1,100 @@
package ec2metadata_test
import (
"bytes"
"io/ioutil"
"net/http"
"net/http/httptest"
"path"
"testing"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/request"
)
func initTestServer(path string, resp string) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.RequestURI != path {
http.Error(w, "not found", http.StatusNotFound)
return
}
w.Write([]byte(resp))
}))
}
func TestEndpoint(t *testing.T) {
c := ec2metadata.New(&ec2metadata.Config{})
op := &request.Operation{
Name: "GetMetadata",
HTTPMethod: "GET",
HTTPPath: path.Join("/", "meta-data", "testpath"),
}
req := c.Service.NewRequest(op, nil, nil)
assert.Equal(t, "http://169.254.169.254/latest", req.Service.Endpoint)
assert.Equal(t, "http://169.254.169.254/latest/meta-data/testpath", req.HTTPRequest.URL.String())
}
func TestGetMetadata(t *testing.T) {
server := initTestServer(
"/latest/meta-data/some/path",
"success", // real response includes suffix
)
defer server.Close()
c := ec2metadata.New(&ec2metadata.Config{Endpoint: aws.String(server.URL + "/latest")})
resp, err := c.GetMetadata("some/path")
assert.NoError(t, err)
assert.Equal(t, "success", resp)
}
func TestGetRegion(t *testing.T) {
server := initTestServer(
"/latest/meta-data/placement/availability-zone",
"us-west-2a", // real response includes suffix
)
defer server.Close()
c := ec2metadata.New(&ec2metadata.Config{Endpoint: aws.String(server.URL + "/latest")})
region, err := c.Region()
assert.NoError(t, err)
assert.Equal(t, "us-west-2", region)
}
func TestMetadataAvailable(t *testing.T) {
server := initTestServer(
"/latest/meta-data/instance-id",
"instance-id",
)
defer server.Close()
c := ec2metadata.New(&ec2metadata.Config{Endpoint: aws.String(server.URL + "/latest")})
available := c.Available()
assert.True(t, available)
}
func TestMetadataNotAvailable(t *testing.T) {
c := ec2metadata.New(nil)
c.Handlers.Send.Clear()
c.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &http.Response{
StatusCode: int(0),
Status: http.StatusText(int(0)),
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
}
r.Error = awserr.New("RequestError", "send request failed", nil)
r.Retryable = aws.Bool(true) // network errors are retryable
})
available := c.Available()
assert.False(t, available)
}

View File

@ -0,0 +1,149 @@
package ec2metadata
import (
"io/ioutil"
"net"
"net/http"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/service"
"github.com/aws/aws-sdk-go/aws/service/serviceinfo"
)
// DefaultRetries states the default number of times the service client will
// attempt to retry a failed request before failing.
const DefaultRetries = 3
// A Config provides the configuration for the EC2 Metadata service.
type Config struct {
// An optional endpoint URL (hostname only or fully qualified URI)
// that overrides the default service endpoint for a client. Set this
// to nil, or `""` to use the default service endpoint.
Endpoint *string
// The HTTP client to use when sending requests. Defaults to
// `http.DefaultClient`.
HTTPClient *http.Client
// An integer value representing the logging level. The default log level
// is zero (LogOff), which represents no logging. To enable logging set
// to a LogLevel Value.
Logger aws.Logger
// The logger writer interface to write logging messages to. Defaults to
// standard out.
LogLevel *aws.LogLevelType
// The maximum number of times that a request will be retried for failures.
// Defaults to DefaultRetries for the number of retries to be performed
// per request.
MaxRetries *int
}
// A Client is an EC2 Metadata service Client.
type Client struct {
*service.Service
}
// New creates a new instance of the EC2 Metadata service client.
//
// In the general use case the configuration for this service client should not
// be needed and `nil` can be provided. Configuration is only needed if the
// `ec2metadata.Config` defaults need to be overridden. Eg. Setting LogLevel.
//
// @note This configuration will NOT be merged with the default AWS service
// client configuration `defaults.DefaultConfig`. Due to circular dependencies
// with the defaults package and credentials EC2 Role Provider.
func New(config *Config) *Client {
service := &service.Service{
ServiceInfo: serviceinfo.ServiceInfo{
Config: copyConfig(config),
ServiceName: "Client",
Endpoint: "http://169.254.169.254/latest",
APIVersion: "latest",
},
}
service.Initialize()
service.Handlers.Unmarshal.PushBack(unmarshalHandler)
service.Handlers.UnmarshalError.PushBack(unmarshalError)
service.Handlers.Validate.Clear()
service.Handlers.Validate.PushBack(validateEndpointHandler)
return &Client{service}
}
func copyConfig(config *Config) *aws.Config {
if config == nil {
config = &Config{}
}
c := &aws.Config{
Credentials: credentials.AnonymousCredentials,
Endpoint: config.Endpoint,
HTTPClient: config.HTTPClient,
Logger: config.Logger,
LogLevel: config.LogLevel,
MaxRetries: config.MaxRetries,
}
if c.HTTPClient == nil {
c.HTTPClient = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
// use a shorter timeout than default because the metadata
// service is local if it is running, and to fail faster
// if not running on an ec2 instance.
Timeout: 5 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
TLSHandshakeTimeout: 10 * time.Second,
},
}
}
if c.Logger == nil {
c.Logger = aws.NewDefaultLogger()
}
if c.LogLevel == nil {
c.LogLevel = aws.LogLevel(aws.LogOff)
}
if c.MaxRetries == nil {
c.MaxRetries = aws.Int(DefaultRetries)
}
return c
}
type metadataOutput struct {
Content string
}
func unmarshalHandler(r *request.Request) {
defer r.HTTPResponse.Body.Close()
b, err := ioutil.ReadAll(r.HTTPResponse.Body)
if err != nil {
r.Error = awserr.New("SerializationError", "unable to unmarshal EC2 metadata respose", err)
}
data := r.Data.(*metadataOutput)
data.Content = string(b)
}
func unmarshalError(r *request.Request) {
defer r.HTTPResponse.Body.Close()
_, err := ioutil.ReadAll(r.HTTPResponse.Body)
if err != nil {
r.Error = awserr.New("SerializationError", "unable to unmarshal EC2 metadata error respose", err)
}
// TODO extract the error...
}
func validateEndpointHandler(r *request.Request) {
if r.Service.Endpoint == "" {
r.Error = aws.ErrMissingEndpoint
}
}

View File

@ -0,0 +1,17 @@
package aws
import "github.com/aws/aws-sdk-go/aws/awserr"
var (
// ErrMissingRegion is an error that is returned if region configuration is
// not found.
//
// @readonly
ErrMissingRegion error = awserr.New("MissingRegion", "could not find region configuration", nil)
// ErrMissingEndpoint is an error that is returned if an endpoint cannot be
// resolved for a service.
//
// @readonly
ErrMissingEndpoint error = awserr.New("MissingEndpoint", "'Endpoint' configuration is required for this service", nil)
)

View File

@ -1,81 +0,0 @@
package aws
import (
"net/http"
"os"
"testing"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/stretchr/testify/assert"
)
func TestValidateEndpointHandler(t *testing.T) {
os.Clearenv()
svc := NewService(NewConfig().WithRegion("us-west-2"))
svc.Handlers.Clear()
svc.Handlers.Validate.PushBack(ValidateEndpointHandler)
req := NewRequest(svc, &Operation{Name: "Operation"}, nil, nil)
err := req.Build()
assert.NoError(t, err)
}
func TestValidateEndpointHandlerErrorRegion(t *testing.T) {
os.Clearenv()
svc := NewService(nil)
svc.Handlers.Clear()
svc.Handlers.Validate.PushBack(ValidateEndpointHandler)
req := NewRequest(svc, &Operation{Name: "Operation"}, nil, nil)
err := req.Build()
assert.Error(t, err)
assert.Equal(t, ErrMissingRegion, err)
}
type mockCredsProvider struct {
expired bool
retrieveCalled bool
}
func (m *mockCredsProvider) Retrieve() (credentials.Value, error) {
m.retrieveCalled = true
return credentials.Value{}, nil
}
func (m *mockCredsProvider) IsExpired() bool {
return m.expired
}
func TestAfterRetryRefreshCreds(t *testing.T) {
os.Clearenv()
credProvider := &mockCredsProvider{}
svc := NewService(&Config{Credentials: credentials.NewCredentials(credProvider), MaxRetries: Int(1)})
svc.Handlers.Clear()
svc.Handlers.ValidateResponse.PushBack(func(r *Request) {
r.Error = awserr.New("UnknownError", "", nil)
r.HTTPResponse = &http.Response{StatusCode: 400}
})
svc.Handlers.UnmarshalError.PushBack(func(r *Request) {
r.Error = awserr.New("ExpiredTokenException", "", nil)
})
svc.Handlers.AfterRetry.PushBack(func(r *Request) {
AfterRetryHandler(r)
})
assert.True(t, svc.Config.Credentials.IsExpired(), "Expect to start out expired")
assert.False(t, credProvider.retrieveCalled)
req := NewRequest(svc, &Operation{Name: "Operation"}, nil, nil)
req.Send()
assert.True(t, svc.Config.Credentials.IsExpired())
assert.False(t, credProvider.retrieveCalled)
_, err := svc.Config.Credentials.Get()
assert.NoError(t, err)
assert.True(t, credProvider.retrieveCalled)
}

View File

@ -1,31 +0,0 @@
package aws
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestHandlerList(t *testing.T) {
s := ""
r := &Request{}
l := HandlerList{}
l.PushBack(func(r *Request) {
s += "a"
r.Data = s
})
l.Run(r)
assert.Equal(t, "a", s)
assert.Equal(t, "a", r.Data)
}
func TestMultipleHandlers(t *testing.T) {
r := &Request{}
l := HandlerList{}
l.PushBack(func(r *Request) { r.Data = nil })
l.PushFront(func(r *Request) { r.Data = Bool(true) })
l.Run(r)
if r.Data != nil {
t.Error("Expected handler to execute")
}
}

View File

@ -62,6 +62,15 @@ const (
// see the body content of requests and responses made while using the SDK
// Will also enable LogDebug.
LogDebugWithHTTPBody
// LogDebugWithRequestRetries states the SDK should log when service requests will
// be retried. This should be used to log when you want to log when service
// requests are being retried. Will also enable LogDebug.
LogDebugWithRequestRetries
// LogDebugWithRequestErrors states the SDK should log when service requests fail
// to build, send, validate, or unmarshal.
LogDebugWithRequestErrors
)
// A Logger is a minimalistic interface for the SDK to log messages to. Should

View File

@ -1,89 +0,0 @@
package aws
import (
"fmt"
"reflect"
"strings"
"github.com/aws/aws-sdk-go/aws/awserr"
)
// ValidateParameters is a request handler to validate the input parameters.
// Validating parameters only has meaning if done prior to the request being sent.
func ValidateParameters(r *Request) {
if r.ParamsFilled() {
v := validator{errors: []string{}}
v.validateAny(reflect.ValueOf(r.Params), "")
if count := len(v.errors); count > 0 {
format := "%d validation errors:\n- %s"
msg := fmt.Sprintf(format, count, strings.Join(v.errors, "\n- "))
r.Error = awserr.New("InvalidParameter", msg, nil)
}
}
}
// A validator validates values. Collects validations errors which occurs.
type validator struct {
errors []string
}
// validateAny will validate any struct, slice or map type. All validations
// are also performed recursively for nested types.
func (v *validator) validateAny(value reflect.Value, path string) {
value = reflect.Indirect(value)
if !value.IsValid() {
return
}
switch value.Kind() {
case reflect.Struct:
v.validateStruct(value, path)
case reflect.Slice:
for i := 0; i < value.Len(); i++ {
v.validateAny(value.Index(i), path+fmt.Sprintf("[%d]", i))
}
case reflect.Map:
for _, n := range value.MapKeys() {
v.validateAny(value.MapIndex(n), path+fmt.Sprintf("[%q]", n.String()))
}
}
}
// validateStruct will validate the struct value's fields. If the structure has
// nested types those types will be validated also.
func (v *validator) validateStruct(value reflect.Value, path string) {
prefix := "."
if path == "" {
prefix = ""
}
for i := 0; i < value.Type().NumField(); i++ {
f := value.Type().Field(i)
if strings.ToLower(f.Name[0:1]) == f.Name[0:1] {
continue
}
fvalue := value.FieldByName(f.Name)
notset := false
if f.Tag.Get("required") != "" {
switch fvalue.Kind() {
case reflect.Ptr, reflect.Slice, reflect.Map:
if fvalue.IsNil() {
notset = true
}
default:
if !fvalue.IsValid() {
notset = true
}
}
}
if notset {
msg := "missing required parameter: " + path + prefix + f.Name
v.errors = append(v.errors, msg)
} else {
v.validateAny(fvalue, path+prefix+f.Name)
}
}
}

View File

@ -1,84 +0,0 @@
package aws_test
import (
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/stretchr/testify/assert"
)
var service = func() *aws.Service {
s := &aws.Service{
Config: &aws.Config{},
ServiceName: "mock-service",
APIVersion: "2015-01-01",
}
return s
}()
type StructShape struct {
RequiredList []*ConditionalStructShape `required:"true"`
RequiredMap map[string]*ConditionalStructShape `required:"true"`
RequiredBool *bool `required:"true"`
OptionalStruct *ConditionalStructShape
hiddenParameter *string
metadataStructureShape
}
type metadataStructureShape struct {
SDKShapeTraits bool
}
type ConditionalStructShape struct {
Name *string `required:"true"`
SDKShapeTraits bool
}
func TestNoErrors(t *testing.T) {
input := &StructShape{
RequiredList: []*ConditionalStructShape{},
RequiredMap: map[string]*ConditionalStructShape{
"key1": {Name: aws.String("Name")},
"key2": {Name: aws.String("Name")},
},
RequiredBool: aws.Bool(true),
OptionalStruct: &ConditionalStructShape{Name: aws.String("Name")},
}
req := aws.NewRequest(service, &aws.Operation{}, input, nil)
aws.ValidateParameters(req)
assert.NoError(t, req.Error)
}
func TestMissingRequiredParameters(t *testing.T) {
input := &StructShape{}
req := aws.NewRequest(service, &aws.Operation{}, input, nil)
aws.ValidateParameters(req)
assert.Error(t, req.Error)
assert.Equal(t, "InvalidParameter", req.Error.(awserr.Error).Code())
assert.Equal(t, "3 validation errors:\n- missing required parameter: RequiredList\n- missing required parameter: RequiredMap\n- missing required parameter: RequiredBool", req.Error.(awserr.Error).Message())
}
func TestNestedMissingRequiredParameters(t *testing.T) {
input := &StructShape{
RequiredList: []*ConditionalStructShape{{}},
RequiredMap: map[string]*ConditionalStructShape{
"key1": {Name: aws.String("Name")},
"key2": {},
},
RequiredBool: aws.Bool(true),
OptionalStruct: &ConditionalStructShape{},
}
req := aws.NewRequest(service, &aws.Operation{}, input, nil)
aws.ValidateParameters(req)
assert.Error(t, req.Error)
assert.Equal(t, "InvalidParameter", req.Error.(awserr.Error).Code())
assert.Equal(t, "3 validation errors:\n- missing required parameter: RequiredList[0].Name\n- missing required parameter: RequiredMap[\"key2\"].Name\n- missing required parameter: OptionalStruct.Name", req.Error.(awserr.Error).Message())
}

View File

@ -1,4 +1,4 @@
package aws
package request
// A Handlers provides a collection of request handlers for various
// stages of handling requests.
@ -15,8 +15,8 @@ type Handlers struct {
AfterRetry HandlerList
}
// copy returns of this handler's lists.
func (h *Handlers) copy() Handlers {
// Copy returns of this handler's lists.
func (h *Handlers) Copy() Handlers {
return Handlers{
Validate: h.Validate.copy(),
Build: h.Build.copy(),
@ -47,19 +47,25 @@ func (h *Handlers) Clear() {
// A HandlerList manages zero or more handlers in a list.
type HandlerList struct {
list []func(*Request)
list []NamedHandler
}
// A NamedHandler is a struct that contains a name and function callback.
type NamedHandler struct {
Name string
Fn func(*Request)
}
// copy creates a copy of the handler list.
func (l *HandlerList) copy() HandlerList {
var n HandlerList
n.list = append([]func(*Request){}, l.list...)
n.list = append([]NamedHandler{}, l.list...)
return n
}
// Clear clears the handler list.
func (l *HandlerList) Clear() {
l.list = []func(*Request){}
l.list = []NamedHandler{}
}
// Len returns the number of handlers in the list.
@ -67,19 +73,40 @@ func (l *HandlerList) Len() int {
return len(l.list)
}
// PushBack pushes handlers f to the back of the handler list.
func (l *HandlerList) PushBack(f ...func(*Request)) {
l.list = append(l.list, f...)
// PushBack pushes handler f to the back of the handler list.
func (l *HandlerList) PushBack(f func(*Request)) {
l.list = append(l.list, NamedHandler{"__anonymous", f})
}
// PushFront pushes handlers f to the front of the handler list.
func (l *HandlerList) PushFront(f ...func(*Request)) {
l.list = append(f, l.list...)
// PushFront pushes handler f to the front of the handler list.
func (l *HandlerList) PushFront(f func(*Request)) {
l.list = append([]NamedHandler{{"__anonymous", f}}, l.list...)
}
// PushBackNamed pushes named handler f to the back of the handler list.
func (l *HandlerList) PushBackNamed(n NamedHandler) {
l.list = append(l.list, n)
}
// PushFrontNamed pushes named handler f to the front of the handler list.
func (l *HandlerList) PushFrontNamed(n NamedHandler) {
l.list = append([]NamedHandler{n}, l.list...)
}
// Remove removes a NamedHandler n
func (l *HandlerList) Remove(n NamedHandler) {
newlist := []NamedHandler{}
for _, m := range l.list {
if m.Name != n.Name {
newlist = append(newlist, m)
}
}
l.list = newlist
}
// Run executes all handlers in the list with a given request object.
func (l *HandlerList) Run(r *Request) {
for _, f := range l.list {
f(r)
f.Fn(r)
}
}

View File

@ -0,0 +1,47 @@
package request_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
)
func TestHandlerList(t *testing.T) {
s := ""
r := &request.Request{}
l := request.HandlerList{}
l.PushBack(func(r *request.Request) {
s += "a"
r.Data = s
})
l.Run(r)
assert.Equal(t, "a", s)
assert.Equal(t, "a", r.Data)
}
func TestMultipleHandlers(t *testing.T) {
r := &request.Request{}
l := request.HandlerList{}
l.PushBack(func(r *request.Request) { r.Data = nil })
l.PushFront(func(r *request.Request) { r.Data = aws.Bool(true) })
l.Run(r)
if r.Data != nil {
t.Error("Expected handler to execute")
}
}
func TestNamedHandlers(t *testing.T) {
l := request.HandlerList{}
named := request.NamedHandler{"Name", func(r *request.Request) {}}
named2 := request.NamedHandler{"NotName", func(r *request.Request) {}}
l.PushBackNamed(named)
l.PushBackNamed(named)
l.PushBackNamed(named2)
l.PushBack(func(r *request.Request) {})
assert.Equal(t, 4, l.Len())
l.Remove(named)
assert.Equal(t, 2, l.Len())
}

View File

@ -1,7 +1,8 @@
package aws
package request
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
@ -10,12 +11,15 @@ import (
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/aws/service/serviceinfo"
)
// A Request is the service request to be made.
type Request struct {
*Service
Retryer
Service serviceinfo.ServiceInfo
Handlers Handlers
Time time.Time
ExpireTime time.Duration
@ -23,7 +27,7 @@ type Request struct {
HTTPRequest *http.Request
HTTPResponse *http.Response
Body io.ReadSeeker
bodyStart int64 // offset from beginning of Body that the request body starts
BodyStart int64 // offset from beginning of Body that the request body starts
Params interface{}
Error error
Data interface{}
@ -51,13 +55,13 @@ type Paginator struct {
TruncationToken string
}
// NewRequest returns a new Request pointer for the service API
// New returns a new Request pointer for the service API
// operation and parameters.
//
// Params is any value of input parameters to be the request payload.
// Data is pointer value to an object which the request's response
// payload will be deserialized to.
func NewRequest(service *Service, operation *Operation, params interface{}, data interface{}) *Request {
func New(service serviceinfo.ServiceInfo, handlers Handlers, retryer Retryer, operation *Operation, params interface{}, data interface{}) *Request {
method := operation.HTTPMethod
if method == "" {
method = "POST"
@ -71,8 +75,9 @@ func NewRequest(service *Service, operation *Operation, params interface{}, data
httpReq.URL, _ = url.Parse(service.Endpoint + p)
r := &Request{
Retryer: retryer,
Service: service,
Handlers: service.Handlers.copy(),
Handlers: handlers.Copy(),
Time: time.Now(),
ExpireTime: 0,
Operation: operation,
@ -89,7 +94,7 @@ func NewRequest(service *Service, operation *Operation, params interface{}, data
// WillRetry returns if the request's can be retried.
func (r *Request) WillRetry() bool {
return r.Error != nil && BoolValue(r.Retryable) && r.RetryCount < r.Service.MaxRetries()
return r.Error != nil && aws.BoolValue(r.Retryable) && r.RetryCount < r.MaxRetries()
}
// ParamsFilled returns if the request's parameters have been populated
@ -134,6 +139,20 @@ func (r *Request) Presign(expireTime time.Duration) (string, error) {
return r.HTTPRequest.URL.String(), nil
}
func debugLogReqError(r *Request, stage string, retrying bool, err error) {
if !r.Service.Config.LogLevel.Matches(aws.LogDebugWithRequestErrors) {
return
}
retryStr := "not retrying"
if retrying {
retryStr = "will retry"
}
r.Service.Config.Logger.Log(fmt.Sprintf("DEBUG: %s %s/%s failed, %s, error %v",
stage, r.Service.ServiceName, r.Operation.Name, retryStr, err))
}
// Build will build the request's object so it can be signed and sent
// to the service. Build will also validate all the request's parameters.
// Anny additional build Handlers set on this request will be run
@ -149,6 +168,7 @@ func (r *Request) Build() error {
r.Error = nil
r.Handlers.Validate.Run(r)
if r.Error != nil {
debugLogReqError(r, "Validate Request", false, r.Error)
return r.Error
}
r.Handlers.Build.Run(r)
@ -165,6 +185,7 @@ func (r *Request) Build() error {
func (r *Request) Sign() error {
r.Build()
if r.Error != nil {
debugLogReqError(r, "Build Request", false, r.Error)
return r.Error
}
@ -183,42 +204,57 @@ func (r *Request) Send() error {
return r.Error
}
if BoolValue(r.Retryable) {
if aws.BoolValue(r.Retryable) {
if r.Service.Config.LogLevel.Matches(aws.LogDebugWithRequestRetries) {
r.Service.Config.Logger.Log(fmt.Sprintf("DEBUG: Retrying Request %s/%s, attempt %d",
r.Service.ServiceName, r.Operation.Name, r.RetryCount))
}
// Re-seek the body back to the original point in for a retry so that
// send will send the body's contents again in the upcoming request.
r.Body.Seek(r.bodyStart, 0)
r.Body.Seek(r.BodyStart, 0)
r.HTTPRequest.Body = ioutil.NopCloser(r.Body)
}
r.Retryable = nil
r.Handlers.Send.Run(r)
if r.Error != nil {
err := r.Error
r.Handlers.Retry.Run(r)
r.Handlers.AfterRetry.Run(r)
if r.Error != nil {
debugLogReqError(r, "Send Request", false, r.Error)
return r.Error
}
debugLogReqError(r, "Send Request", true, err)
continue
}
r.Handlers.UnmarshalMeta.Run(r)
r.Handlers.ValidateResponse.Run(r)
if r.Error != nil {
err := r.Error
r.Handlers.UnmarshalError.Run(r)
r.Handlers.Retry.Run(r)
r.Handlers.AfterRetry.Run(r)
if r.Error != nil {
debugLogReqError(r, "Validate Response", false, r.Error)
return r.Error
}
debugLogReqError(r, "Validate Response", true, err)
continue
}
r.Handlers.Unmarshal.Run(r)
if r.Error != nil {
err := r.Error
r.Handlers.Retry.Run(r)
r.Handlers.AfterRetry.Run(r)
if r.Error != nil {
debugLogReqError(r, "Unmarshal Response", false, r.Error)
return r.Error
}
debugLogReqError(r, "Unmarshal Response", true, err)
continue
}
@ -279,7 +315,7 @@ func (r *Request) NextPage() *Request {
}
data := reflect.New(reflect.TypeOf(r.Data).Elem()).Interface()
nr := NewRequest(r.Service, r.Operation, awsutil.CopyOf(r.Params), data)
nr := New(r.Service, r.Handlers, r.Retryer, r.Operation, awsutil.CopyOf(r.Params), data)
for i, intok := range nr.Operation.InputTokens {
awsutil.SetValueAtAnyPath(nr.Params, intok, tokens[i])
}

View File

@ -1,13 +1,15 @@
package aws_test
package request_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/internal/test/unit"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/stretchr/testify/assert"
)
var _ = unit.Imported
@ -28,7 +30,7 @@ func TestPagination(t *testing.T) {
db.Handlers.Unmarshal.Clear()
db.Handlers.UnmarshalMeta.Clear()
db.Handlers.ValidateResponse.Clear()
db.Handlers.Build.PushBack(func(r *aws.Request) {
db.Handlers.Build.PushBack(func(r *request.Request) {
in := r.Params.(*dynamodb.ListTablesInput)
if in == nil {
tokens = append(tokens, "")
@ -36,7 +38,7 @@ func TestPagination(t *testing.T) {
tokens = append(tokens, *in.ExclusiveStartTableName)
}
})
db.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
db.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = resps[reqNum]
reqNum++
})
@ -80,7 +82,7 @@ func TestPaginationEachPage(t *testing.T) {
db.Handlers.Unmarshal.Clear()
db.Handlers.UnmarshalMeta.Clear()
db.Handlers.ValidateResponse.Clear()
db.Handlers.Build.PushBack(func(r *aws.Request) {
db.Handlers.Build.PushBack(func(r *request.Request) {
in := r.Params.(*dynamodb.ListTablesInput)
if in == nil {
tokens = append(tokens, "")
@ -88,7 +90,7 @@ func TestPaginationEachPage(t *testing.T) {
tokens = append(tokens, *in.ExclusiveStartTableName)
}
})
db.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
db.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = resps[reqNum]
reqNum++
})
@ -133,7 +135,7 @@ func TestPaginationEarlyExit(t *testing.T) {
db.Handlers.Unmarshal.Clear()
db.Handlers.UnmarshalMeta.Clear()
db.Handlers.ValidateResponse.Clear()
db.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
db.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = resps[reqNum]
reqNum++
})
@ -164,7 +166,7 @@ func TestSkipPagination(t *testing.T) {
client.Handlers.Unmarshal.Clear()
client.Handlers.UnmarshalMeta.Clear()
client.Handlers.ValidateResponse.Clear()
client.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
client.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = &s3.HeadBucketOutput{}
})
@ -199,7 +201,7 @@ func TestPaginationTruncation(t *testing.T) {
client.Handlers.Unmarshal.Clear()
client.Handlers.UnmarshalMeta.Clear()
client.Handlers.ValidateResponse.Clear()
client.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
client.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = resps[*reqNum]
*reqNum++
})
@ -260,7 +262,7 @@ var benchDb = func() *dynamodb.DynamoDB {
func BenchmarkCodegenIterator(b *testing.B) {
reqNum := 0
db := benchDb()
db.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
db.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = benchResps[reqNum]
reqNum++
})
@ -289,7 +291,7 @@ func BenchmarkCodegenIterator(b *testing.B) {
func BenchmarkEachPageIterator(b *testing.B) {
reqNum := 0
db := benchDb()
db.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
db.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = benchResps[reqNum]
reqNum++
})

View File

@ -1,4 +1,4 @@
package aws
package request_test
import (
"bytes"
@ -10,8 +10,11 @@ import (
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/service"
"github.com/stretchr/testify/assert"
)
@ -23,7 +26,7 @@ func body(str string) io.ReadCloser {
return ioutil.NopCloser(bytes.NewReader([]byte(str)))
}
func unmarshal(req *Request) {
func unmarshal(req *request.Request) {
defer req.HTTPResponse.Body.Close()
if req.Data != nil {
json.NewDecoder(req.HTTPResponse.Body).Decode(req.Data)
@ -31,7 +34,7 @@ func unmarshal(req *Request) {
return
}
func unmarshalError(req *Request) {
func unmarshalError(req *request.Request) {
bodyBytes, err := ioutil.ReadAll(req.HTTPResponse.Body)
if err != nil {
req.Error = awserr.New("UnmarshaleError", req.HTTPResponse.Status, err)
@ -71,17 +74,17 @@ func TestRequestRecoverRetry5xx(t *testing.T) {
{StatusCode: 200, Body: body(`{"data":"valid"}`)},
}
s := NewService(NewConfig().WithMaxRetries(10))
s := service.New(aws.NewConfig().WithMaxRetries(10))
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *Request) {
s.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &reqs[reqNum]
reqNum++
})
out := &testData{}
r := NewRequest(s, &Operation{Name: "Operation"}, nil, out)
r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out)
err := r.Send()
assert.Nil(t, err)
assert.Equal(t, 2, int(r.RetryCount))
@ -97,17 +100,17 @@ func TestRequestRecoverRetry4xxRetryable(t *testing.T) {
{StatusCode: 200, Body: body(`{"data":"valid"}`)},
}
s := NewService(NewConfig().WithMaxRetries(10))
s := service.New(aws.NewConfig().WithMaxRetries(10))
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *Request) {
s.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &reqs[reqNum]
reqNum++
})
out := &testData{}
r := NewRequest(s, &Operation{Name: "Operation"}, nil, out)
r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out)
err := r.Send()
assert.Nil(t, err)
assert.Equal(t, 2, int(r.RetryCount))
@ -116,16 +119,16 @@ func TestRequestRecoverRetry4xxRetryable(t *testing.T) {
// test that retries don't occur for 4xx status codes with a response type that can't be retried
func TestRequest4xxUnretryable(t *testing.T) {
s := NewService(NewConfig().WithMaxRetries(10))
s := service.New(aws.NewConfig().WithMaxRetries(10))
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *Request) {
s.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &http.Response{StatusCode: 401, Body: body(`{"__type":"SignatureDoesNotMatch","message":"Signature does not match."}`)}
})
out := &testData{}
r := NewRequest(s, &Operation{Name: "Operation"}, nil, out)
r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out)
err := r.Send()
assert.NotNil(t, err)
if e, ok := err.(awserr.RequestFailure); ok {
@ -140,7 +143,7 @@ func TestRequest4xxUnretryable(t *testing.T) {
func TestRequestExhaustRetries(t *testing.T) {
delays := []time.Duration{}
sleepDelay = func(delay time.Duration) {
sleepDelay := func(delay time.Duration) {
delays = append(delays, delay)
}
@ -152,16 +155,16 @@ func TestRequestExhaustRetries(t *testing.T) {
{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
}
s := NewService(NewConfig().WithMaxRetries(DefaultRetries))
s := service.New(aws.NewConfig().WithMaxRetries(aws.DefaultRetries).WithSleepDelay(sleepDelay))
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *Request) {
s.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &reqs[reqNum]
reqNum++
})
r := NewRequest(s, &Operation{Name: "Operation"}, nil, nil)
r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
err := r.Send()
assert.NotNil(t, err)
if e, ok := err.(awserr.RequestFailure); ok {
@ -190,7 +193,7 @@ func TestRequestRecoverExpiredCreds(t *testing.T) {
{StatusCode: 200, Body: body(`{"data":"valid"}`)},
}
s := NewService(&Config{MaxRetries: Int(10), Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "")})
s := service.New(&aws.Config{MaxRetries: aws.Int(10), Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "")})
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
@ -198,21 +201,21 @@ func TestRequestRecoverExpiredCreds(t *testing.T) {
credExpiredBeforeRetry := false
credExpiredAfterRetry := false
s.Handlers.AfterRetry.PushBack(func(r *Request) {
credExpiredAfterRetry = r.Config.Credentials.IsExpired()
s.Handlers.AfterRetry.PushBack(func(r *request.Request) {
credExpiredAfterRetry = r.Service.Config.Credentials.IsExpired()
})
s.Handlers.Sign.Clear()
s.Handlers.Sign.PushBack(func(r *Request) {
r.Config.Credentials.Get()
s.Handlers.Sign.PushBack(func(r *request.Request) {
r.Service.Config.Credentials.Get()
})
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *Request) {
s.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &reqs[reqNum]
reqNum++
})
out := &testData{}
r := NewRequest(s, &Operation{Name: "Operation"}, nil, out)
r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out)
err := r.Send()
assert.Nil(t, err)

View File

@ -0,0 +1,73 @@
package request
import (
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
)
// Retryer is an interface to control retry logic for a given service.
// The default implementation used by most services is the service.DefaultRetryer
// structure, which contains basic retry logic using exponential backoff.
type Retryer interface {
RetryRules(*Request) time.Duration
ShouldRetry(*Request) bool
MaxRetries() uint
}
// retryableCodes is a collection of service response codes which are retry-able
// without any further action.
var retryableCodes = map[string]struct{}{
"RequestError": {},
"ProvisionedThroughputExceededException": {},
"Throttling": {},
"ThrottlingException": {},
"RequestLimitExceeded": {},
"RequestThrottled": {},
"LimitExceededException": {}, // Deleting 10+ DynamoDb tables at once
"TooManyRequestsException": {}, // Lambda functions
}
// credsExpiredCodes is a collection of error codes which signify the credentials
// need to be refreshed. Expired tokens require refreshing of credentials, and
// resigning before the request can be retried.
var credsExpiredCodes = map[string]struct{}{
"ExpiredToken": {},
"ExpiredTokenException": {},
"RequestExpired": {}, // EC2 Only
}
func isCodeRetryable(code string) bool {
if _, ok := retryableCodes[code]; ok {
return true
}
return isCodeExpiredCreds(code)
}
func isCodeExpiredCreds(code string) bool {
_, ok := credsExpiredCodes[code]
return ok
}
// IsErrorRetryable returns whether the error is retryable, based on its Code.
// Returns false if the request has no Error set.
func (r *Request) IsErrorRetryable() bool {
if r.Error != nil {
if err, ok := r.Error.(awserr.Error); ok {
return isCodeRetryable(err.Code())
}
}
return false
}
// IsErrorExpired returns whether the error code is a credential expiry error.
// Returns false if the request has no Error set.
func (r *Request) IsErrorExpired() bool {
if r.Error != nil {
if err, ok := r.Error.(awserr.Error); ok {
return isCodeExpiredCreds(err.Code())
}
}
return false
}

View File

@ -1,194 +0,0 @@
package aws
import (
"fmt"
"math"
"math/rand"
"net/http"
"net/http/httputil"
"regexp"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/internal/endpoints"
)
// A Service implements the base service request and response handling
// used by all services.
type Service struct {
Config *Config
Handlers Handlers
ServiceName string
APIVersion string
Endpoint string
SigningName string
SigningRegion string
JSONVersion string
TargetPrefix string
RetryRules func(*Request) time.Duration
ShouldRetry func(*Request) bool
DefaultMaxRetries uint
}
var schemeRE = regexp.MustCompile("^([^:]+)://")
// NewService will return a pointer to a new Server object initialized.
func NewService(config *Config) *Service {
svc := &Service{Config: config}
svc.Initialize()
return svc
}
// Initialize initializes the service.
func (s *Service) Initialize() {
if s.Config == nil {
s.Config = &Config{}
}
if s.Config.HTTPClient == nil {
s.Config.HTTPClient = http.DefaultClient
}
if s.RetryRules == nil {
s.RetryRules = retryRules
}
if s.ShouldRetry == nil {
s.ShouldRetry = shouldRetry
}
s.DefaultMaxRetries = 3
s.Handlers.Validate.PushBack(ValidateEndpointHandler)
s.Handlers.Build.PushBack(UserAgentHandler)
s.Handlers.Sign.PushBack(BuildContentLength)
s.Handlers.Send.PushBack(SendHandler)
s.Handlers.AfterRetry.PushBack(AfterRetryHandler)
s.Handlers.ValidateResponse.PushBack(ValidateResponseHandler)
s.AddDebugHandlers()
s.buildEndpoint()
if !BoolValue(s.Config.DisableParamValidation) {
s.Handlers.Validate.PushBack(ValidateParameters)
}
}
// buildEndpoint builds the endpoint values the service will use to make requests with.
func (s *Service) buildEndpoint() {
if StringValue(s.Config.Endpoint) != "" {
s.Endpoint = *s.Config.Endpoint
} else {
s.Endpoint, s.SigningRegion =
endpoints.EndpointForRegion(s.ServiceName, StringValue(s.Config.Region))
}
if s.Endpoint != "" && !schemeRE.MatchString(s.Endpoint) {
scheme := "https"
if BoolValue(s.Config.DisableSSL) {
scheme = "http"
}
s.Endpoint = scheme + "://" + s.Endpoint
}
}
// AddDebugHandlers injects debug logging handlers into the service to log request
// debug information.
func (s *Service) AddDebugHandlers() {
if !s.Config.LogLevel.AtLeast(LogDebug) {
return
}
s.Handlers.Send.PushFront(logRequest)
s.Handlers.Send.PushBack(logResponse)
}
const logReqMsg = `DEBUG: Request %s/%s Details:
---[ REQUEST POST-SIGN ]-----------------------------
%s
-----------------------------------------------------`
func logRequest(r *Request) {
logBody := r.Config.LogLevel.Matches(LogDebugWithHTTPBody)
dumpedBody, _ := httputil.DumpRequestOut(r.HTTPRequest, logBody)
r.Config.Logger.Log(fmt.Sprintf(logReqMsg, r.ServiceName, r.Operation.Name, string(dumpedBody)))
}
const logRespMsg = `DEBUG: Response %s/%s Details:
---[ RESPONSE ]--------------------------------------
%s
-----------------------------------------------------`
func logResponse(r *Request) {
var msg = "no reponse data"
if r.HTTPResponse != nil {
logBody := r.Config.LogLevel.Matches(LogDebugWithHTTPBody)
dumpedBody, _ := httputil.DumpResponse(r.HTTPResponse, logBody)
msg = string(dumpedBody)
} else if r.Error != nil {
msg = r.Error.Error()
}
r.Config.Logger.Log(fmt.Sprintf(logRespMsg, r.ServiceName, r.Operation.Name, msg))
}
// MaxRetries returns the number of maximum returns the service will use to make
// an individual API request.
func (s *Service) MaxRetries() uint {
if IntValue(s.Config.MaxRetries) < 0 {
return s.DefaultMaxRetries
}
return uint(IntValue(s.Config.MaxRetries))
}
var seededRand = rand.New(rand.NewSource(time.Now().UnixNano()))
// retryRules returns the delay duration before retrying this request again
func retryRules(r *Request) time.Duration {
delay := int(math.Pow(2, float64(r.RetryCount))) * (seededRand.Intn(30) + 30)
return time.Duration(delay) * time.Millisecond
}
// retryableCodes is a collection of service response codes which are retry-able
// without any further action.
var retryableCodes = map[string]struct{}{
"RequestError": {},
"ProvisionedThroughputExceededException": {},
"Throttling": {},
"ThrottlingException": {},
"RequestLimitExceeded": {},
"RequestThrottled": {},
}
// credsExpiredCodes is a collection of error codes which signify the credentials
// need to be refreshed. Expired tokens require refreshing of credentials, and
// resigning before the request can be retried.
var credsExpiredCodes = map[string]struct{}{
"ExpiredToken": {},
"ExpiredTokenException": {},
"RequestExpired": {}, // EC2 Only
}
func isCodeRetryable(code string) bool {
if _, ok := retryableCodes[code]; ok {
return true
}
return isCodeExpiredCreds(code)
}
func isCodeExpiredCreds(code string) bool {
_, ok := credsExpiredCodes[code]
return ok
}
// shouldRetry returns if the request should be retried.
func shouldRetry(r *Request) bool {
if r.HTTPResponse.StatusCode >= 500 {
return true
}
if r.Error != nil {
if err, ok := r.Error.(awserr.Error); ok {
return isCodeRetryable(err.Code())
}
}
return false
}

View File

@ -0,0 +1,49 @@
package service
import (
"math"
"math/rand"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
)
// DefaultRetryer implements basic retry logic using exponential backoff for
// most services. If you want to implement custom retry logic, implement the
// request.Retryer interface or create a structure type that composes this
// struct and override the specific methods. For example, to override only
// the MaxRetries method:
//
// type retryer struct {
// service.DefaultRetryer
// }
//
// // This implementation always has 100 max retries
// func (d retryer) MaxRetries() uint { return 100 }
type DefaultRetryer struct {
*Service
}
// MaxRetries returns the number of maximum returns the service will use to make
// an individual API request.
func (d DefaultRetryer) MaxRetries() uint {
if aws.IntValue(d.Service.Config.MaxRetries) < 0 {
return d.DefaultMaxRetries
}
return uint(aws.IntValue(d.Service.Config.MaxRetries))
}
// RetryRules returns the delay duration before retrying this request again
func (d DefaultRetryer) RetryRules(r *request.Request) time.Duration {
delay := int(math.Pow(2, float64(r.RetryCount))) * (rand.Intn(30) + 30)
return time.Duration(delay) * time.Millisecond
}
// ShouldRetry returns if the request should be retried.
func (d DefaultRetryer) ShouldRetry(r *request.Request) bool {
if r.HTTPResponse.StatusCode >= 500 {
return true
}
return r.IsErrorRetryable()
}

View File

@ -0,0 +1,133 @@
package service
import (
"fmt"
"io/ioutil"
"net/http"
"net/http/httputil"
"regexp"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/service/serviceinfo"
"github.com/aws/aws-sdk-go/private/endpoints"
)
// A Service implements the base service request and response handling
// used by all services.
type Service struct {
serviceinfo.ServiceInfo
request.Retryer
DefaultMaxRetries uint
Handlers request.Handlers
}
var schemeRE = regexp.MustCompile("^([^:]+)://")
// New will return a pointer to a new Server object initialized.
func New(config *aws.Config) *Service {
svc := &Service{ServiceInfo: serviceinfo.ServiceInfo{Config: config}}
svc.Initialize()
return svc
}
// Initialize initializes the service.
func (s *Service) Initialize() {
if s.Config == nil {
s.Config = &aws.Config{}
}
if s.Config.HTTPClient == nil {
s.Config.HTTPClient = http.DefaultClient
}
if s.Config.SleepDelay == nil {
s.Config.SleepDelay = time.Sleep
}
s.Retryer = DefaultRetryer{s}
s.DefaultMaxRetries = 3
s.Handlers.Validate.PushBackNamed(corehandlers.ValidateEndpointHandler)
s.Handlers.Build.PushBackNamed(corehandlers.UserAgentHandler)
s.Handlers.Sign.PushBackNamed(corehandlers.BuildContentLengthHandler)
s.Handlers.Send.PushBackNamed(corehandlers.SendHandler)
s.Handlers.AfterRetry.PushBackNamed(corehandlers.AfterRetryHandler)
s.Handlers.ValidateResponse.PushBackNamed(corehandlers.ValidateResponseHandler)
if !aws.BoolValue(s.Config.DisableParamValidation) {
s.Handlers.Validate.PushBackNamed(corehandlers.ValidateParametersHandler)
}
s.AddDebugHandlers()
s.buildEndpoint()
}
// NewRequest returns a new Request pointer for the service API
// operation and parameters.
func (s *Service) NewRequest(operation *request.Operation, params interface{}, data interface{}) *request.Request {
return request.New(s.ServiceInfo, s.Handlers, s.Retryer, operation, params, data)
}
// buildEndpoint builds the endpoint values the service will use to make requests with.
func (s *Service) buildEndpoint() {
if aws.StringValue(s.Config.Endpoint) != "" {
s.Endpoint = *s.Config.Endpoint
} else if s.Endpoint == "" {
s.Endpoint, s.SigningRegion =
endpoints.EndpointForRegion(s.ServiceName, aws.StringValue(s.Config.Region))
}
if s.Endpoint != "" && !schemeRE.MatchString(s.Endpoint) {
scheme := "https"
if aws.BoolValue(s.Config.DisableSSL) {
scheme = "http"
}
s.Endpoint = scheme + "://" + s.Endpoint
}
}
// AddDebugHandlers injects debug logging handlers into the service to log request
// debug information.
func (s *Service) AddDebugHandlers() {
if !s.Config.LogLevel.AtLeast(aws.LogDebug) {
return
}
s.Handlers.Send.PushFront(logRequest)
s.Handlers.Send.PushBack(logResponse)
}
const logReqMsg = `DEBUG: Request %s/%s Details:
---[ REQUEST POST-SIGN ]-----------------------------
%s
-----------------------------------------------------`
func logRequest(r *request.Request) {
logBody := r.Service.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody)
dumpedBody, _ := httputil.DumpRequestOut(r.HTTPRequest, logBody)
if logBody {
// Reset the request body because dumpRequest will re-wrap the r.HTTPRequest's
// Body as a NoOpCloser and will not be reset after read by the HTTP
// client reader.
r.Body.Seek(r.BodyStart, 0)
r.HTTPRequest.Body = ioutil.NopCloser(r.Body)
}
r.Service.Config.Logger.Log(fmt.Sprintf(logReqMsg, r.Service.ServiceName, r.Operation.Name, string(dumpedBody)))
}
const logRespMsg = `DEBUG: Response %s/%s Details:
---[ RESPONSE ]--------------------------------------
%s
-----------------------------------------------------`
func logResponse(r *request.Request) {
var msg = "no reponse data"
if r.HTTPResponse != nil {
logBody := r.Service.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody)
dumpedBody, _ := httputil.DumpResponse(r.HTTPResponse, logBody)
msg = string(dumpedBody)
} else if r.Error != nil {
msg = r.Error.Error()
}
r.Service.Config.Logger.Log(fmt.Sprintf(logRespMsg, r.Service.ServiceName, r.Operation.Name, msg))
}

View File

@ -0,0 +1,15 @@
package serviceinfo
import "github.com/aws/aws-sdk-go/aws"
// ServiceInfo wraps immutable data from the service.Service structure.
type ServiceInfo struct {
Config *aws.Config
ServiceName string
APIVersion string
Endpoint string
SigningName string
SigningRegion string
JSONVersion string
TargetPrefix string
}

View File

@ -2,6 +2,7 @@ package aws
import (
"io"
"sync"
)
// ReadSeekCloser wraps a io.Reader returning a ReaderSeakerCloser
@ -53,3 +54,35 @@ func (r ReaderSeekerCloser) Close() error {
}
return nil
}
// A WriteAtBuffer provides a in memory buffer supporting the io.WriterAt interface
// Can be used with the s3manager.Downloader to download content to a buffer
// in memory. Safe to use concurrently.
type WriteAtBuffer struct {
buf []byte
m sync.Mutex
}
// WriteAt writes a slice of bytes to a buffer starting at the position provided
// The number of bytes written will be returned, or error. Can overwrite previous
// written slices if the write ats overlap.
func (b *WriteAtBuffer) WriteAt(p []byte, pos int64) (n int, err error) {
b.m.Lock()
defer b.m.Unlock()
expLen := pos + int64(len(p))
if int64(len(b.buf)) < expLen {
newBuf := make([]byte, expLen)
copy(newBuf, b.buf)
b.buf = newBuf
}
copy(b.buf[pos:], p)
return len(p), nil
}
// Bytes returns a slice of bytes written to the buffer.
func (b *WriteAtBuffer) Bytes() []byte {
b.m.Lock()
defer b.m.Unlock()
return b.buf[:len(b.buf):len(b.buf)]
}

View File

@ -0,0 +1,56 @@
package aws
import (
"math/rand"
"testing"
"github.com/stretchr/testify/assert"
)
func TestWriteAtBuffer(t *testing.T) {
b := &WriteAtBuffer{}
n, err := b.WriteAt([]byte{1}, 0)
assert.NoError(t, err)
assert.Equal(t, 1, n)
n, err = b.WriteAt([]byte{1, 1, 1}, 5)
assert.NoError(t, err)
assert.Equal(t, 3, n)
n, err = b.WriteAt([]byte{2}, 1)
assert.NoError(t, err)
assert.Equal(t, 1, n)
n, err = b.WriteAt([]byte{3}, 2)
assert.NoError(t, err)
assert.Equal(t, 1, n)
assert.Equal(t, []byte{1, 2, 3, 0, 0, 1, 1, 1}, b.Bytes())
}
func BenchmarkWriteAtBuffer(b *testing.B) {
buf := &WriteAtBuffer{}
r := rand.New(rand.NewSource(1))
b.ResetTimer()
for i := 0; i < b.N; i++ {
to := r.Intn(10) * 4096
bs := make([]byte, to)
buf.WriteAt(bs, r.Int63n(10)*4096)
}
}
func BenchmarkWriteAtBufferParallel(b *testing.B) {
buf := &WriteAtBuffer{}
r := rand.New(rand.NewSource(1))
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
to := r.Intn(10) * 4096
bs := make([]byte, to)
buf.WriteAt(bs, r.Int63n(10)*4096)
}
})
}

View File

@ -5,4 +5,4 @@ package aws
const SDKName = "aws-sdk-go"
// SDKVersion is the version of this SDK
const SDKVersion = "0.7.3"
const SDKVersion = "0.9.16"

View File

@ -0,0 +1,31 @@
// Package endpoints validates regional endpoints for services.
package endpoints
//go:generate go run ../model/cli/gen-endpoints/main.go endpoints.json endpoints_map.go
//go:generate gofmt -s -w endpoints_map.go
import "strings"
// EndpointForRegion returns an endpoint and its signing region for a service and region.
// if the service and region pair are not found endpoint and signingRegion will be empty.
func EndpointForRegion(svcName, region string) (endpoint, signingRegion string) {
derivedKeys := []string{
region + "/" + svcName,
region + "/*",
"*/" + svcName,
"*/*",
}
for _, key := range derivedKeys {
if val, ok := endpointsMap.Endpoints[key]; ok {
ep := val.Endpoint
ep = strings.Replace(ep, "{region}", region, -1)
ep = strings.Replace(ep, "{service}", svcName, -1)
endpoint = ep
signingRegion = val.SigningRegion
return
}
}
return
}

View File

@ -0,0 +1,81 @@
{
"version": 2,
"endpoints": {
"*/*": {
"endpoint": "{service}.{region}.amazonaws.com"
},
"cn-north-1/*": {
"endpoint": "{service}.{region}.amazonaws.com.cn",
"signatureVersion": "v4"
},
"us-gov-west-1/iam": {
"endpoint": "iam.us-gov.amazonaws.com"
},
"us-gov-west-1/sts": {
"endpoint": "sts.us-gov-west-1.amazonaws.com"
},
"us-gov-west-1/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"*/cloudfront": {
"endpoint": "cloudfront.amazonaws.com",
"signingRegion": "us-east-1"
},
"*/cloudsearchdomain": {
"endpoint": "",
"signingRegion": "us-east-1"
},
"*/data.iot": {
"endpoint": "",
"signingRegion": "us-east-1"
},
"*/iam": {
"endpoint": "iam.amazonaws.com",
"signingRegion": "us-east-1"
},
"*/importexport": {
"endpoint": "importexport.amazonaws.com",
"signingRegion": "us-east-1"
},
"*/route53": {
"endpoint": "route53.amazonaws.com",
"signingRegion": "us-east-1"
},
"*/sts": {
"endpoint": "sts.amazonaws.com",
"signingRegion": "us-east-1"
},
"us-east-1/sdb": {
"endpoint": "sdb.amazonaws.com",
"signingRegion": "us-east-1"
},
"us-east-1/s3": {
"endpoint": "s3.amazonaws.com"
},
"us-west-1/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"us-west-2/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"eu-west-1/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"ap-southeast-1/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"ap-southeast-2/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"ap-northeast-1/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"sa-east-1/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"eu-central-1/s3": {
"endpoint": "{service}.{region}.amazonaws.com",
"signatureVersion": "v4"
}
}
}

View File

@ -0,0 +1,93 @@
package endpoints
// THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT.
type endpointStruct struct {
Version int
Endpoints map[string]endpointEntry
}
type endpointEntry struct {
Endpoint string
SigningRegion string
}
var endpointsMap = endpointStruct{
Version: 2,
Endpoints: map[string]endpointEntry{
"*/*": {
Endpoint: "{service}.{region}.amazonaws.com",
},
"*/cloudfront": {
Endpoint: "cloudfront.amazonaws.com",
SigningRegion: "us-east-1",
},
"*/cloudsearchdomain": {
Endpoint: "",
SigningRegion: "us-east-1",
},
"*/data.iot": {
Endpoint: "",
SigningRegion: "us-east-1",
},
"*/iam": {
Endpoint: "iam.amazonaws.com",
SigningRegion: "us-east-1",
},
"*/importexport": {
Endpoint: "importexport.amazonaws.com",
SigningRegion: "us-east-1",
},
"*/route53": {
Endpoint: "route53.amazonaws.com",
SigningRegion: "us-east-1",
},
"*/sts": {
Endpoint: "sts.amazonaws.com",
SigningRegion: "us-east-1",
},
"ap-northeast-1/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"ap-southeast-1/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"ap-southeast-2/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"cn-north-1/*": {
Endpoint: "{service}.{region}.amazonaws.com.cn",
},
"eu-central-1/s3": {
Endpoint: "{service}.{region}.amazonaws.com",
},
"eu-west-1/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"sa-east-1/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"us-east-1/s3": {
Endpoint: "s3.amazonaws.com",
},
"us-east-1/sdb": {
Endpoint: "sdb.amazonaws.com",
SigningRegion: "us-east-1",
},
"us-gov-west-1/iam": {
Endpoint: "iam.us-gov.amazonaws.com",
},
"us-gov-west-1/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"us-gov-west-1/sts": {
Endpoint: "sts.us-gov-west-1.amazonaws.com",
},
"us-west-1/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"us-west-2/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
},
}

View File

@ -0,0 +1,28 @@
package endpoints
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestGlobalEndpoints(t *testing.T) {
region := "mock-region-1"
svcs := []string{"cloudfront", "iam", "importexport", "route53", "sts"}
for _, name := range svcs {
ep, sr := EndpointForRegion(name, region)
assert.Equal(t, name+".amazonaws.com", ep)
assert.Equal(t, "us-east-1", sr)
}
}
func TestServicesInCN(t *testing.T) {
region := "cn-north-1"
svcs := []string{"cloudfront", "iam", "importexport", "route53", "sts", "s3"}
for _, name := range svcs {
ep, _ := EndpointForRegion(name, region)
assert.Equal(t, name+"."+region+".amazonaws.com.cn", ep)
}
}

View File

@ -0,0 +1,32 @@
// Package ec2query provides serialisation of AWS EC2 requests and responses.
package ec2query
//go:generate go run ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/input/ec2.json build_test.go
import (
"net/url"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/private/protocol/query/queryutil"
)
// Build builds a request for the EC2 protocol.
func Build(r *request.Request) {
body := url.Values{
"Action": {r.Operation.Name},
"Version": {r.Service.APIVersion},
}
if err := queryutil.Parse(body, r.Params, true); err != nil {
r.Error = awserr.New("SerializationError", "failed encoding EC2 Query request", err)
}
if r.ExpireTime == 0 {
r.HTTPRequest.Method = "POST"
r.HTTPRequest.Header.Set("Content-Type", "application/x-www-form-urlencoded; charset=utf-8")
r.SetBufferBody([]byte(body.Encode()))
} else { // This is a pre-signed request
r.HTTPRequest.Method = "GET"
r.HTTPRequest.URL.RawQuery = body.Encode()
}
}

View File

@ -0,0 +1,85 @@
// +build bench
package ec2query_test
import (
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/service"
"github.com/aws/aws-sdk-go/private/protocol/ec2query"
"github.com/aws/aws-sdk-go/service/ec2"
)
func BenchmarkEC2QueryBuild_Complex_ec2AuthorizeSecurityGroupEgress(b *testing.B) {
params := &ec2.AuthorizeSecurityGroupEgressInput{
GroupId: aws.String("String"), // Required
CidrIp: aws.String("String"),
DryRun: aws.Bool(true),
FromPort: aws.Int64(1),
IpPermissions: []*ec2.IpPermission{
{ // Required
FromPort: aws.Int64(1),
IpProtocol: aws.String("String"),
IpRanges: []*ec2.IpRange{
{ // Required
CidrIp: aws.String("String"),
},
// More values...
},
PrefixListIds: []*ec2.PrefixListId{
{ // Required
PrefixListId: aws.String("String"),
},
// More values...
},
ToPort: aws.Int64(1),
UserIdGroupPairs: []*ec2.UserIdGroupPair{
{ // Required
GroupId: aws.String("String"),
GroupName: aws.String("String"),
UserId: aws.String("String"),
},
// More values...
},
},
// More values...
},
IpProtocol: aws.String("String"),
SourceSecurityGroupName: aws.String("String"),
SourceSecurityGroupOwnerId: aws.String("String"),
ToPort: aws.Int64(1),
}
benchEC2QueryBuild(b, "AuthorizeSecurityGroupEgress", params)
}
func BenchmarkEC2QueryBuild_Simple_ec2AttachNetworkInterface(b *testing.B) {
params := &ec2.AttachNetworkInterfaceInput{
DeviceIndex: aws.Int64(1), // Required
InstanceId: aws.String("String"), // Required
NetworkInterfaceId: aws.String("String"), // Required
DryRun: aws.Bool(true),
}
benchEC2QueryBuild(b, "AttachNetworkInterface", params)
}
func benchEC2QueryBuild(b *testing.B, opName string, params interface{}) {
svc := service.New(nil)
svc.ServiceName = "ec2"
svc.APIVersion = "2015-04-15"
for i := 0; i < b.N; i++ {
r := svc.NewRequest(&request.Operation{
Name: opName,
HTTPMethod: "POST",
HTTPPath: "/",
}, params, nil)
ec2query.Build(r)
if r.Error != nil {
b.Fatal("Unexpected error", r.Error)
}
}
}

View File

@ -0,0 +1,882 @@
package ec2query_test
import (
"bytes"
"encoding/json"
"encoding/xml"
"io"
"io/ioutil"
"net/http"
"net/url"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/service"
"github.com/aws/aws-sdk-go/aws/service/serviceinfo"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/private/protocol/ec2query"
"github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil"
"github.com/aws/aws-sdk-go/private/signer/v4"
"github.com/aws/aws-sdk-go/private/util"
"github.com/stretchr/testify/assert"
)
var _ bytes.Buffer // always import bytes
var _ http.Request
var _ json.Marshaler
var _ time.Time
var _ xmlutil.XMLNode
var _ xml.Attr
var _ = awstesting.GenerateAssertions
var _ = ioutil.Discard
var _ = util.Trim("")
var _ = url.Values{}
var _ = io.EOF
type InputService1ProtocolTest struct {
*service.Service
}
// New returns a new InputService1ProtocolTest client.
func NewInputService1ProtocolTest(config *aws.Config) *InputService1ProtocolTest {
service := &service.Service{
ServiceInfo: serviceinfo.ServiceInfo{
Config: defaults.DefaultConfig.Merge(config),
ServiceName: "inputservice1protocoltest",
APIVersion: "2014-01-01",
},
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(ec2query.Build)
service.Handlers.Unmarshal.PushBack(ec2query.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(ec2query.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(ec2query.UnmarshalError)
return &InputService1ProtocolTest{service}
}
// newRequest creates a new request for a InputService1ProtocolTest operation and runs any
// custom request initialization.
func (c *InputService1ProtocolTest) newRequest(op *request.Operation, params, data interface{}) *request.Request {
req := c.NewRequest(op, params, data)
return req
}
const opInputService1TestCaseOperation1 = "OperationName"
// InputService1TestCaseOperation1Request generates a request for the InputService1TestCaseOperation1 operation.
func (c *InputService1ProtocolTest) InputService1TestCaseOperation1Request(input *InputService1TestShapeInputService1TestCaseOperation1Input) (req *request.Request, output *InputService1TestShapeInputService1TestCaseOperation1Output) {
op := &request.Operation{
Name: opInputService1TestCaseOperation1,
}
if input == nil {
input = &InputService1TestShapeInputService1TestCaseOperation1Input{}
}
req = c.newRequest(op, input, output)
output = &InputService1TestShapeInputService1TestCaseOperation1Output{}
req.Data = output
return
}
func (c *InputService1ProtocolTest) InputService1TestCaseOperation1(input *InputService1TestShapeInputService1TestCaseOperation1Input) (*InputService1TestShapeInputService1TestCaseOperation1Output, error) {
req, out := c.InputService1TestCaseOperation1Request(input)
err := req.Send()
return out, err
}
type InputService1TestShapeInputService1TestCaseOperation1Input struct {
Bar *string `type:"string"`
Foo *string `type:"string"`
metadataInputService1TestShapeInputService1TestCaseOperation1Input `json:"-" xml:"-"`
}
type metadataInputService1TestShapeInputService1TestCaseOperation1Input struct {
SDKShapeTraits bool `type:"structure"`
}
type InputService1TestShapeInputService1TestCaseOperation1Output struct {
metadataInputService1TestShapeInputService1TestCaseOperation1Output `json:"-" xml:"-"`
}
type metadataInputService1TestShapeInputService1TestCaseOperation1Output struct {
SDKShapeTraits bool `type:"structure"`
}
type InputService2ProtocolTest struct {
*service.Service
}
// New returns a new InputService2ProtocolTest client.
func NewInputService2ProtocolTest(config *aws.Config) *InputService2ProtocolTest {
service := &service.Service{
ServiceInfo: serviceinfo.ServiceInfo{
Config: defaults.DefaultConfig.Merge(config),
ServiceName: "inputservice2protocoltest",
APIVersion: "2014-01-01",
},
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(ec2query.Build)
service.Handlers.Unmarshal.PushBack(ec2query.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(ec2query.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(ec2query.UnmarshalError)
return &InputService2ProtocolTest{service}
}
// newRequest creates a new request for a InputService2ProtocolTest operation and runs any
// custom request initialization.
func (c *InputService2ProtocolTest) newRequest(op *request.Operation, params, data interface{}) *request.Request {
req := c.NewRequest(op, params, data)
return req
}
const opInputService2TestCaseOperation1 = "OperationName"
// InputService2TestCaseOperation1Request generates a request for the InputService2TestCaseOperation1 operation.
func (c *InputService2ProtocolTest) InputService2TestCaseOperation1Request(input *InputService2TestShapeInputService2TestCaseOperation1Input) (req *request.Request, output *InputService2TestShapeInputService2TestCaseOperation1Output) {
op := &request.Operation{
Name: opInputService2TestCaseOperation1,
}
if input == nil {
input = &InputService2TestShapeInputService2TestCaseOperation1Input{}
}
req = c.newRequest(op, input, output)
output = &InputService2TestShapeInputService2TestCaseOperation1Output{}
req.Data = output
return
}
func (c *InputService2ProtocolTest) InputService2TestCaseOperation1(input *InputService2TestShapeInputService2TestCaseOperation1Input) (*InputService2TestShapeInputService2TestCaseOperation1Output, error) {
req, out := c.InputService2TestCaseOperation1Request(input)
err := req.Send()
return out, err
}
type InputService2TestShapeInputService2TestCaseOperation1Input struct {
Bar *string `locationName:"barLocationName" type:"string"`
Foo *string `type:"string"`
Yuck *string `locationName:"yuckLocationName" queryName:"yuckQueryName" type:"string"`
metadataInputService2TestShapeInputService2TestCaseOperation1Input `json:"-" xml:"-"`
}
type metadataInputService2TestShapeInputService2TestCaseOperation1Input struct {
SDKShapeTraits bool `type:"structure"`
}
type InputService2TestShapeInputService2TestCaseOperation1Output struct {
metadataInputService2TestShapeInputService2TestCaseOperation1Output `json:"-" xml:"-"`
}
type metadataInputService2TestShapeInputService2TestCaseOperation1Output struct {
SDKShapeTraits bool `type:"structure"`
}
type InputService3ProtocolTest struct {
*service.Service
}
// New returns a new InputService3ProtocolTest client.
func NewInputService3ProtocolTest(config *aws.Config) *InputService3ProtocolTest {
service := &service.Service{
ServiceInfo: serviceinfo.ServiceInfo{
Config: defaults.DefaultConfig.Merge(config),
ServiceName: "inputservice3protocoltest",
APIVersion: "2014-01-01",
},
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(ec2query.Build)
service.Handlers.Unmarshal.PushBack(ec2query.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(ec2query.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(ec2query.UnmarshalError)
return &InputService3ProtocolTest{service}
}
// newRequest creates a new request for a InputService3ProtocolTest operation and runs any
// custom request initialization.
func (c *InputService3ProtocolTest) newRequest(op *request.Operation, params, data interface{}) *request.Request {
req := c.NewRequest(op, params, data)
return req
}
const opInputService3TestCaseOperation1 = "OperationName"
// InputService3TestCaseOperation1Request generates a request for the InputService3TestCaseOperation1 operation.
func (c *InputService3ProtocolTest) InputService3TestCaseOperation1Request(input *InputService3TestShapeInputService3TestCaseOperation1Input) (req *request.Request, output *InputService3TestShapeInputService3TestCaseOperation1Output) {
op := &request.Operation{
Name: opInputService3TestCaseOperation1,
}
if input == nil {
input = &InputService3TestShapeInputService3TestCaseOperation1Input{}
}
req = c.newRequest(op, input, output)
output = &InputService3TestShapeInputService3TestCaseOperation1Output{}
req.Data = output
return
}
func (c *InputService3ProtocolTest) InputService3TestCaseOperation1(input *InputService3TestShapeInputService3TestCaseOperation1Input) (*InputService3TestShapeInputService3TestCaseOperation1Output, error) {
req, out := c.InputService3TestCaseOperation1Request(input)
err := req.Send()
return out, err
}
type InputService3TestShapeInputService3TestCaseOperation1Input struct {
StructArg *InputService3TestShapeStructType `locationName:"Struct" type:"structure"`
metadataInputService3TestShapeInputService3TestCaseOperation1Input `json:"-" xml:"-"`
}
type metadataInputService3TestShapeInputService3TestCaseOperation1Input struct {
SDKShapeTraits bool `type:"structure"`
}
type InputService3TestShapeInputService3TestCaseOperation1Output struct {
metadataInputService3TestShapeInputService3TestCaseOperation1Output `json:"-" xml:"-"`
}
type metadataInputService3TestShapeInputService3TestCaseOperation1Output struct {
SDKShapeTraits bool `type:"structure"`
}
type InputService3TestShapeStructType struct {
ScalarArg *string `locationName:"Scalar" type:"string"`
metadataInputService3TestShapeStructType `json:"-" xml:"-"`
}
type metadataInputService3TestShapeStructType struct {
SDKShapeTraits bool `type:"structure"`
}
type InputService4ProtocolTest struct {
*service.Service
}
// New returns a new InputService4ProtocolTest client.
func NewInputService4ProtocolTest(config *aws.Config) *InputService4ProtocolTest {
service := &service.Service{
ServiceInfo: serviceinfo.ServiceInfo{
Config: defaults.DefaultConfig.Merge(config),
ServiceName: "inputservice4protocoltest",
APIVersion: "2014-01-01",
},
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(ec2query.Build)
service.Handlers.Unmarshal.PushBack(ec2query.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(ec2query.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(ec2query.UnmarshalError)
return &InputService4ProtocolTest{service}
}
// newRequest creates a new request for a InputService4ProtocolTest operation and runs any
// custom request initialization.
func (c *InputService4ProtocolTest) newRequest(op *request.Operation, params, data interface{}) *request.Request {
req := c.NewRequest(op, params, data)
return req
}
const opInputService4TestCaseOperation1 = "OperationName"
// InputService4TestCaseOperation1Request generates a request for the InputService4TestCaseOperation1 operation.
func (c *InputService4ProtocolTest) InputService4TestCaseOperation1Request(input *InputService4TestShapeInputService4TestCaseOperation1Input) (req *request.Request, output *InputService4TestShapeInputService4TestCaseOperation1Output) {
op := &request.Operation{
Name: opInputService4TestCaseOperation1,
}
if input == nil {
input = &InputService4TestShapeInputService4TestCaseOperation1Input{}
}
req = c.newRequest(op, input, output)
output = &InputService4TestShapeInputService4TestCaseOperation1Output{}
req.Data = output
return
}
func (c *InputService4ProtocolTest) InputService4TestCaseOperation1(input *InputService4TestShapeInputService4TestCaseOperation1Input) (*InputService4TestShapeInputService4TestCaseOperation1Output, error) {
req, out := c.InputService4TestCaseOperation1Request(input)
err := req.Send()
return out, err
}
type InputService4TestShapeInputService4TestCaseOperation1Input struct {
ListArg []*string `type:"list"`
metadataInputService4TestShapeInputService4TestCaseOperation1Input `json:"-" xml:"-"`
}
type metadataInputService4TestShapeInputService4TestCaseOperation1Input struct {
SDKShapeTraits bool `type:"structure"`
}
type InputService4TestShapeInputService4TestCaseOperation1Output struct {
metadataInputService4TestShapeInputService4TestCaseOperation1Output `json:"-" xml:"-"`
}
type metadataInputService4TestShapeInputService4TestCaseOperation1Output struct {
SDKShapeTraits bool `type:"structure"`
}
type InputService5ProtocolTest struct {
*service.Service
}
// New returns a new InputService5ProtocolTest client.
func NewInputService5ProtocolTest(config *aws.Config) *InputService5ProtocolTest {
service := &service.Service{
ServiceInfo: serviceinfo.ServiceInfo{
Config: defaults.DefaultConfig.Merge(config),
ServiceName: "inputservice5protocoltest",
APIVersion: "2014-01-01",
},
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(ec2query.Build)
service.Handlers.Unmarshal.PushBack(ec2query.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(ec2query.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(ec2query.UnmarshalError)
return &InputService5ProtocolTest{service}
}
// newRequest creates a new request for a InputService5ProtocolTest operation and runs any
// custom request initialization.
func (c *InputService5ProtocolTest) newRequest(op *request.Operation, params, data interface{}) *request.Request {
req := c.NewRequest(op, params, data)
return req
}
const opInputService5TestCaseOperation1 = "OperationName"
// InputService5TestCaseOperation1Request generates a request for the InputService5TestCaseOperation1 operation.
func (c *InputService5ProtocolTest) InputService5TestCaseOperation1Request(input *InputService5TestShapeInputService5TestCaseOperation1Input) (req *request.Request, output *InputService5TestShapeInputService5TestCaseOperation1Output) {
op := &request.Operation{
Name: opInputService5TestCaseOperation1,
}
if input == nil {
input = &InputService5TestShapeInputService5TestCaseOperation1Input{}
}
req = c.newRequest(op, input, output)
output = &InputService5TestShapeInputService5TestCaseOperation1Output{}
req.Data = output
return
}
func (c *InputService5ProtocolTest) InputService5TestCaseOperation1(input *InputService5TestShapeInputService5TestCaseOperation1Input) (*InputService5TestShapeInputService5TestCaseOperation1Output, error) {
req, out := c.InputService5TestCaseOperation1Request(input)
err := req.Send()
return out, err
}
type InputService5TestShapeInputService5TestCaseOperation1Input struct {
ListArg []*string `locationName:"ListMemberName" locationNameList:"item" type:"list"`
metadataInputService5TestShapeInputService5TestCaseOperation1Input `json:"-" xml:"-"`
}
type metadataInputService5TestShapeInputService5TestCaseOperation1Input struct {
SDKShapeTraits bool `type:"structure"`
}
type InputService5TestShapeInputService5TestCaseOperation1Output struct {
metadataInputService5TestShapeInputService5TestCaseOperation1Output `json:"-" xml:"-"`
}
type metadataInputService5TestShapeInputService5TestCaseOperation1Output struct {
SDKShapeTraits bool `type:"structure"`
}
type InputService6ProtocolTest struct {
*service.Service
}
// New returns a new InputService6ProtocolTest client.
func NewInputService6ProtocolTest(config *aws.Config) *InputService6ProtocolTest {
service := &service.Service{
ServiceInfo: serviceinfo.ServiceInfo{
Config: defaults.DefaultConfig.Merge(config),
ServiceName: "inputservice6protocoltest",
APIVersion: "2014-01-01",
},
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(ec2query.Build)
service.Handlers.Unmarshal.PushBack(ec2query.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(ec2query.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(ec2query.UnmarshalError)
return &InputService6ProtocolTest{service}
}
// newRequest creates a new request for a InputService6ProtocolTest operation and runs any
// custom request initialization.
func (c *InputService6ProtocolTest) newRequest(op *request.Operation, params, data interface{}) *request.Request {
req := c.NewRequest(op, params, data)
return req
}
const opInputService6TestCaseOperation1 = "OperationName"
// InputService6TestCaseOperation1Request generates a request for the InputService6TestCaseOperation1 operation.
func (c *InputService6ProtocolTest) InputService6TestCaseOperation1Request(input *InputService6TestShapeInputService6TestCaseOperation1Input) (req *request.Request, output *InputService6TestShapeInputService6TestCaseOperation1Output) {
op := &request.Operation{
Name: opInputService6TestCaseOperation1,
}
if input == nil {
input = &InputService6TestShapeInputService6TestCaseOperation1Input{}
}
req = c.newRequest(op, input, output)
output = &InputService6TestShapeInputService6TestCaseOperation1Output{}
req.Data = output
return
}
func (c *InputService6ProtocolTest) InputService6TestCaseOperation1(input *InputService6TestShapeInputService6TestCaseOperation1Input) (*InputService6TestShapeInputService6TestCaseOperation1Output, error) {
req, out := c.InputService6TestCaseOperation1Request(input)
err := req.Send()
return out, err
}
type InputService6TestShapeInputService6TestCaseOperation1Input struct {
ListArg []*string `locationName:"ListMemberName" queryName:"ListQueryName" locationNameList:"item" type:"list"`
metadataInputService6TestShapeInputService6TestCaseOperation1Input `json:"-" xml:"-"`
}
type metadataInputService6TestShapeInputService6TestCaseOperation1Input struct {
SDKShapeTraits bool `type:"structure"`
}
type InputService6TestShapeInputService6TestCaseOperation1Output struct {
metadataInputService6TestShapeInputService6TestCaseOperation1Output `json:"-" xml:"-"`
}
type metadataInputService6TestShapeInputService6TestCaseOperation1Output struct {
SDKShapeTraits bool `type:"structure"`
}
type InputService7ProtocolTest struct {
*service.Service
}
// New returns a new InputService7ProtocolTest client.
func NewInputService7ProtocolTest(config *aws.Config) *InputService7ProtocolTest {
service := &service.Service{
ServiceInfo: serviceinfo.ServiceInfo{
Config: defaults.DefaultConfig.Merge(config),
ServiceName: "inputservice7protocoltest",
APIVersion: "2014-01-01",
},
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(ec2query.Build)
service.Handlers.Unmarshal.PushBack(ec2query.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(ec2query.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(ec2query.UnmarshalError)
return &InputService7ProtocolTest{service}
}
// newRequest creates a new request for a InputService7ProtocolTest operation and runs any
// custom request initialization.
func (c *InputService7ProtocolTest) newRequest(op *request.Operation, params, data interface{}) *request.Request {
req := c.NewRequest(op, params, data)
return req
}
const opInputService7TestCaseOperation1 = "OperationName"
// InputService7TestCaseOperation1Request generates a request for the InputService7TestCaseOperation1 operation.
func (c *InputService7ProtocolTest) InputService7TestCaseOperation1Request(input *InputService7TestShapeInputService7TestCaseOperation1Input) (req *request.Request, output *InputService7TestShapeInputService7TestCaseOperation1Output) {
op := &request.Operation{
Name: opInputService7TestCaseOperation1,
}
if input == nil {
input = &InputService7TestShapeInputService7TestCaseOperation1Input{}
}
req = c.newRequest(op, input, output)
output = &InputService7TestShapeInputService7TestCaseOperation1Output{}
req.Data = output
return
}
func (c *InputService7ProtocolTest) InputService7TestCaseOperation1(input *InputService7TestShapeInputService7TestCaseOperation1Input) (*InputService7TestShapeInputService7TestCaseOperation1Output, error) {
req, out := c.InputService7TestCaseOperation1Request(input)
err := req.Send()
return out, err
}
type InputService7TestShapeInputService7TestCaseOperation1Input struct {
BlobArg []byte `type:"blob"`
metadataInputService7TestShapeInputService7TestCaseOperation1Input `json:"-" xml:"-"`
}
type metadataInputService7TestShapeInputService7TestCaseOperation1Input struct {
SDKShapeTraits bool `type:"structure"`
}
type InputService7TestShapeInputService7TestCaseOperation1Output struct {
metadataInputService7TestShapeInputService7TestCaseOperation1Output `json:"-" xml:"-"`
}
type metadataInputService7TestShapeInputService7TestCaseOperation1Output struct {
SDKShapeTraits bool `type:"structure"`
}
type InputService8ProtocolTest struct {
*service.Service
}
// New returns a new InputService8ProtocolTest client.
func NewInputService8ProtocolTest(config *aws.Config) *InputService8ProtocolTest {
service := &service.Service{
ServiceInfo: serviceinfo.ServiceInfo{
Config: defaults.DefaultConfig.Merge(config),
ServiceName: "inputservice8protocoltest",
APIVersion: "2014-01-01",
},
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(ec2query.Build)
service.Handlers.Unmarshal.PushBack(ec2query.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(ec2query.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(ec2query.UnmarshalError)
return &InputService8ProtocolTest{service}
}
// newRequest creates a new request for a InputService8ProtocolTest operation and runs any
// custom request initialization.
func (c *InputService8ProtocolTest) newRequest(op *request.Operation, params, data interface{}) *request.Request {
req := c.NewRequest(op, params, data)
return req
}
const opInputService8TestCaseOperation1 = "OperationName"
// InputService8TestCaseOperation1Request generates a request for the InputService8TestCaseOperation1 operation.
func (c *InputService8ProtocolTest) InputService8TestCaseOperation1Request(input *InputService8TestShapeInputService8TestCaseOperation1Input) (req *request.Request, output *InputService8TestShapeInputService8TestCaseOperation1Output) {
op := &request.Operation{
Name: opInputService8TestCaseOperation1,
}
if input == nil {
input = &InputService8TestShapeInputService8TestCaseOperation1Input{}
}
req = c.newRequest(op, input, output)
output = &InputService8TestShapeInputService8TestCaseOperation1Output{}
req.Data = output
return
}
func (c *InputService8ProtocolTest) InputService8TestCaseOperation1(input *InputService8TestShapeInputService8TestCaseOperation1Input) (*InputService8TestShapeInputService8TestCaseOperation1Output, error) {
req, out := c.InputService8TestCaseOperation1Request(input)
err := req.Send()
return out, err
}
type InputService8TestShapeInputService8TestCaseOperation1Input struct {
TimeArg *time.Time `type:"timestamp" timestampFormat:"iso8601"`
metadataInputService8TestShapeInputService8TestCaseOperation1Input `json:"-" xml:"-"`
}
type metadataInputService8TestShapeInputService8TestCaseOperation1Input struct {
SDKShapeTraits bool `type:"structure"`
}
type InputService8TestShapeInputService8TestCaseOperation1Output struct {
metadataInputService8TestShapeInputService8TestCaseOperation1Output `json:"-" xml:"-"`
}
type metadataInputService8TestShapeInputService8TestCaseOperation1Output struct {
SDKShapeTraits bool `type:"structure"`
}
//
// Tests begin here
//
func TestInputService1ProtocolTestScalarMembersCase1(t *testing.T) {
svc := NewInputService1ProtocolTest(nil)
svc.Endpoint = "https://test"
input := &InputService1TestShapeInputService1TestCaseOperation1Input{
Bar: aws.String("val2"),
Foo: aws.String("val1"),
}
req, _ := svc.InputService1TestCaseOperation1Request(input)
r := req.HTTPRequest
// build request
ec2query.Build(req)
assert.NoError(t, req.Error)
// assert body
assert.NotNil(t, r.Body)
body, _ := ioutil.ReadAll(r.Body)
awstesting.AssertQuery(t, `Action=OperationName&Bar=val2&Foo=val1&Version=2014-01-01`, util.Trim(string(body)))
// assert URL
awstesting.AssertURL(t, "https://test/", r.URL.String())
// assert headers
}
func TestInputService2ProtocolTestStructureWithLocationNameAndQueryNameAppliedToMembersCase1(t *testing.T) {
svc := NewInputService2ProtocolTest(nil)
svc.Endpoint = "https://test"
input := &InputService2TestShapeInputService2TestCaseOperation1Input{
Bar: aws.String("val2"),
Foo: aws.String("val1"),
Yuck: aws.String("val3"),
}
req, _ := svc.InputService2TestCaseOperation1Request(input)
r := req.HTTPRequest
// build request
ec2query.Build(req)
assert.NoError(t, req.Error)
// assert body
assert.NotNil(t, r.Body)
body, _ := ioutil.ReadAll(r.Body)
awstesting.AssertQuery(t, `Action=OperationName&BarLocationName=val2&Foo=val1&Version=2014-01-01&yuckQueryName=val3`, util.Trim(string(body)))
// assert URL
awstesting.AssertURL(t, "https://test/", r.URL.String())
// assert headers
}
func TestInputService3ProtocolTestNestedStructureMembersCase1(t *testing.T) {
svc := NewInputService3ProtocolTest(nil)
svc.Endpoint = "https://test"
input := &InputService3TestShapeInputService3TestCaseOperation1Input{
StructArg: &InputService3TestShapeStructType{
ScalarArg: aws.String("foo"),
},
}
req, _ := svc.InputService3TestCaseOperation1Request(input)
r := req.HTTPRequest
// build request
ec2query.Build(req)
assert.NoError(t, req.Error)
// assert body
assert.NotNil(t, r.Body)
body, _ := ioutil.ReadAll(r.Body)
awstesting.AssertQuery(t, `Action=OperationName&Struct.Scalar=foo&Version=2014-01-01`, util.Trim(string(body)))
// assert URL
awstesting.AssertURL(t, "https://test/", r.URL.String())
// assert headers
}
func TestInputService4ProtocolTestListTypesCase1(t *testing.T) {
svc := NewInputService4ProtocolTest(nil)
svc.Endpoint = "https://test"
input := &InputService4TestShapeInputService4TestCaseOperation1Input{
ListArg: []*string{
aws.String("foo"),
aws.String("bar"),
aws.String("baz"),
},
}
req, _ := svc.InputService4TestCaseOperation1Request(input)
r := req.HTTPRequest
// build request
ec2query.Build(req)
assert.NoError(t, req.Error)
// assert body
assert.NotNil(t, r.Body)
body, _ := ioutil.ReadAll(r.Body)
awstesting.AssertQuery(t, `Action=OperationName&ListArg.1=foo&ListArg.2=bar&ListArg.3=baz&Version=2014-01-01`, util.Trim(string(body)))
// assert URL
awstesting.AssertURL(t, "https://test/", r.URL.String())
// assert headers
}
func TestInputService5ProtocolTestListWithLocationNameAppliedToMemberCase1(t *testing.T) {
svc := NewInputService5ProtocolTest(nil)
svc.Endpoint = "https://test"
input := &InputService5TestShapeInputService5TestCaseOperation1Input{
ListArg: []*string{
aws.String("a"),
aws.String("b"),
aws.String("c"),
},
}
req, _ := svc.InputService5TestCaseOperation1Request(input)
r := req.HTTPRequest
// build request
ec2query.Build(req)
assert.NoError(t, req.Error)
// assert body
assert.NotNil(t, r.Body)
body, _ := ioutil.ReadAll(r.Body)
awstesting.AssertQuery(t, `Action=OperationName&ListMemberName.1=a&ListMemberName.2=b&ListMemberName.3=c&Version=2014-01-01`, util.Trim(string(body)))
// assert URL
awstesting.AssertURL(t, "https://test/", r.URL.String())
// assert headers
}
func TestInputService6ProtocolTestListWithLocationNameAndQueryNameCase1(t *testing.T) {
svc := NewInputService6ProtocolTest(nil)
svc.Endpoint = "https://test"
input := &InputService6TestShapeInputService6TestCaseOperation1Input{
ListArg: []*string{
aws.String("a"),
aws.String("b"),
aws.String("c"),
},
}
req, _ := svc.InputService6TestCaseOperation1Request(input)
r := req.HTTPRequest
// build request
ec2query.Build(req)
assert.NoError(t, req.Error)
// assert body
assert.NotNil(t, r.Body)
body, _ := ioutil.ReadAll(r.Body)
awstesting.AssertQuery(t, `Action=OperationName&ListQueryName.1=a&ListQueryName.2=b&ListQueryName.3=c&Version=2014-01-01`, util.Trim(string(body)))
// assert URL
awstesting.AssertURL(t, "https://test/", r.URL.String())
// assert headers
}
func TestInputService7ProtocolTestBase64EncodedBlobsCase1(t *testing.T) {
svc := NewInputService7ProtocolTest(nil)
svc.Endpoint = "https://test"
input := &InputService7TestShapeInputService7TestCaseOperation1Input{
BlobArg: []byte("foo"),
}
req, _ := svc.InputService7TestCaseOperation1Request(input)
r := req.HTTPRequest
// build request
ec2query.Build(req)
assert.NoError(t, req.Error)
// assert body
assert.NotNil(t, r.Body)
body, _ := ioutil.ReadAll(r.Body)
awstesting.AssertQuery(t, `Action=OperationName&BlobArg=Zm9v&Version=2014-01-01`, util.Trim(string(body)))
// assert URL
awstesting.AssertURL(t, "https://test/", r.URL.String())
// assert headers
}
func TestInputService8ProtocolTestTimestampValuesCase1(t *testing.T) {
svc := NewInputService8ProtocolTest(nil)
svc.Endpoint = "https://test"
input := &InputService8TestShapeInputService8TestCaseOperation1Input{
TimeArg: aws.Time(time.Unix(1422172800, 0)),
}
req, _ := svc.InputService8TestCaseOperation1Request(input)
r := req.HTTPRequest
// build request
ec2query.Build(req)
assert.NoError(t, req.Error)
// assert body
assert.NotNil(t, r.Body)
body, _ := ioutil.ReadAll(r.Body)
awstesting.AssertQuery(t, `Action=OperationName&TimeArg=2015-01-25T08%3A00%3A00Z&Version=2014-01-01`, util.Trim(string(body)))
// assert URL
awstesting.AssertURL(t, "https://test/", r.URL.String())
// assert headers
}

View File

@ -0,0 +1,54 @@
package ec2query
//go:generate go run ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/output/ec2.json unmarshal_test.go
import (
"encoding/xml"
"io"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil"
)
// Unmarshal unmarshals a response body for the EC2 protocol.
func Unmarshal(r *request.Request) {
defer r.HTTPResponse.Body.Close()
if r.DataFilled() {
decoder := xml.NewDecoder(r.HTTPResponse.Body)
err := xmlutil.UnmarshalXML(r.Data, decoder, "")
if err != nil {
r.Error = awserr.New("SerializationError", "failed decoding EC2 Query response", err)
return
}
}
}
// UnmarshalMeta unmarshals response headers for the EC2 protocol.
func UnmarshalMeta(r *request.Request) {
// TODO implement unmarshaling of request IDs
}
type xmlErrorResponse struct {
XMLName xml.Name `xml:"Response"`
Code string `xml:"Errors>Error>Code"`
Message string `xml:"Errors>Error>Message"`
RequestID string `xml:"RequestId"`
}
// UnmarshalError unmarshals a response error for the EC2 protocol.
func UnmarshalError(r *request.Request) {
defer r.HTTPResponse.Body.Close()
resp := &xmlErrorResponse{}
err := xml.NewDecoder(r.HTTPResponse.Body).Decode(resp)
if err != nil && err != io.EOF {
r.Error = awserr.New("SerializationError", "failed decoding EC2 Query error response", err)
} else {
r.Error = awserr.NewRequestFailure(
awserr.New(resp.Code, resp.Message, nil),
r.HTTPResponse.StatusCode,
resp.RequestID,
)
}
}

View File

@ -0,0 +1,838 @@
package ec2query_test
import (
"bytes"
"encoding/json"
"encoding/xml"
"io"
"io/ioutil"
"net/http"
"net/url"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/service"
"github.com/aws/aws-sdk-go/aws/service/serviceinfo"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/private/protocol/ec2query"
"github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil"
"github.com/aws/aws-sdk-go/private/signer/v4"
"github.com/aws/aws-sdk-go/private/util"
"github.com/stretchr/testify/assert"
)
var _ bytes.Buffer // always import bytes
var _ http.Request
var _ json.Marshaler
var _ time.Time
var _ xmlutil.XMLNode
var _ xml.Attr
var _ = awstesting.GenerateAssertions
var _ = ioutil.Discard
var _ = util.Trim("")
var _ = url.Values{}
var _ = io.EOF
type OutputService1ProtocolTest struct {
*service.Service
}
// New returns a new OutputService1ProtocolTest client.
func NewOutputService1ProtocolTest(config *aws.Config) *OutputService1ProtocolTest {
service := &service.Service{
ServiceInfo: serviceinfo.ServiceInfo{
Config: defaults.DefaultConfig.Merge(config),
ServiceName: "outputservice1protocoltest",
APIVersion: "",
},
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(ec2query.Build)
service.Handlers.Unmarshal.PushBack(ec2query.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(ec2query.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(ec2query.UnmarshalError)
return &OutputService1ProtocolTest{service}
}
// newRequest creates a new request for a OutputService1ProtocolTest operation and runs any
// custom request initialization.
func (c *OutputService1ProtocolTest) newRequest(op *request.Operation, params, data interface{}) *request.Request {
req := c.NewRequest(op, params, data)
return req
}
const opOutputService1TestCaseOperation1 = "OperationName"
// OutputService1TestCaseOperation1Request generates a request for the OutputService1TestCaseOperation1 operation.
func (c *OutputService1ProtocolTest) OutputService1TestCaseOperation1Request(input *OutputService1TestShapeOutputService1TestCaseOperation1Input) (req *request.Request, output *OutputService1TestShapeOutputService1TestCaseOperation1Output) {
op := &request.Operation{
Name: opOutputService1TestCaseOperation1,
}
if input == nil {
input = &OutputService1TestShapeOutputService1TestCaseOperation1Input{}
}
req = c.newRequest(op, input, output)
output = &OutputService1TestShapeOutputService1TestCaseOperation1Output{}
req.Data = output
return
}
func (c *OutputService1ProtocolTest) OutputService1TestCaseOperation1(input *OutputService1TestShapeOutputService1TestCaseOperation1Input) (*OutputService1TestShapeOutputService1TestCaseOperation1Output, error) {
req, out := c.OutputService1TestCaseOperation1Request(input)
err := req.Send()
return out, err
}
type OutputService1TestShapeOutputService1TestCaseOperation1Input struct {
metadataOutputService1TestShapeOutputService1TestCaseOperation1Input `json:"-" xml:"-"`
}
type metadataOutputService1TestShapeOutputService1TestCaseOperation1Input struct {
SDKShapeTraits bool `type:"structure"`
}
type OutputService1TestShapeOutputService1TestCaseOperation1Output struct {
Char *string `type:"character"`
Double *float64 `type:"double"`
FalseBool *bool `type:"boolean"`
Float *float64 `type:"float"`
Long *int64 `type:"long"`
Num *int64 `locationName:"FooNum" type:"integer"`
Str *string `type:"string"`
TrueBool *bool `type:"boolean"`
metadataOutputService1TestShapeOutputService1TestCaseOperation1Output `json:"-" xml:"-"`
}
type metadataOutputService1TestShapeOutputService1TestCaseOperation1Output struct {
SDKShapeTraits bool `type:"structure"`
}
type OutputService2ProtocolTest struct {
*service.Service
}
// New returns a new OutputService2ProtocolTest client.
func NewOutputService2ProtocolTest(config *aws.Config) *OutputService2ProtocolTest {
service := &service.Service{
ServiceInfo: serviceinfo.ServiceInfo{
Config: defaults.DefaultConfig.Merge(config),
ServiceName: "outputservice2protocoltest",
APIVersion: "",
},
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(ec2query.Build)
service.Handlers.Unmarshal.PushBack(ec2query.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(ec2query.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(ec2query.UnmarshalError)
return &OutputService2ProtocolTest{service}
}
// newRequest creates a new request for a OutputService2ProtocolTest operation and runs any
// custom request initialization.
func (c *OutputService2ProtocolTest) newRequest(op *request.Operation, params, data interface{}) *request.Request {
req := c.NewRequest(op, params, data)
return req
}
const opOutputService2TestCaseOperation1 = "OperationName"
// OutputService2TestCaseOperation1Request generates a request for the OutputService2TestCaseOperation1 operation.
func (c *OutputService2ProtocolTest) OutputService2TestCaseOperation1Request(input *OutputService2TestShapeOutputService2TestCaseOperation1Input) (req *request.Request, output *OutputService2TestShapeOutputService2TestCaseOperation1Output) {
op := &request.Operation{
Name: opOutputService2TestCaseOperation1,
}
if input == nil {
input = &OutputService2TestShapeOutputService2TestCaseOperation1Input{}
}
req = c.newRequest(op, input, output)
output = &OutputService2TestShapeOutputService2TestCaseOperation1Output{}
req.Data = output
return
}
func (c *OutputService2ProtocolTest) OutputService2TestCaseOperation1(input *OutputService2TestShapeOutputService2TestCaseOperation1Input) (*OutputService2TestShapeOutputService2TestCaseOperation1Output, error) {
req, out := c.OutputService2TestCaseOperation1Request(input)
err := req.Send()
return out, err
}
type OutputService2TestShapeOutputService2TestCaseOperation1Input struct {
metadataOutputService2TestShapeOutputService2TestCaseOperation1Input `json:"-" xml:"-"`
}
type metadataOutputService2TestShapeOutputService2TestCaseOperation1Input struct {
SDKShapeTraits bool `type:"structure"`
}
type OutputService2TestShapeOutputService2TestCaseOperation1Output struct {
Blob []byte `type:"blob"`
metadataOutputService2TestShapeOutputService2TestCaseOperation1Output `json:"-" xml:"-"`
}
type metadataOutputService2TestShapeOutputService2TestCaseOperation1Output struct {
SDKShapeTraits bool `type:"structure"`
}
type OutputService3ProtocolTest struct {
*service.Service
}
// New returns a new OutputService3ProtocolTest client.
func NewOutputService3ProtocolTest(config *aws.Config) *OutputService3ProtocolTest {
service := &service.Service{
ServiceInfo: serviceinfo.ServiceInfo{
Config: defaults.DefaultConfig.Merge(config),
ServiceName: "outputservice3protocoltest",
APIVersion: "",
},
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(ec2query.Build)
service.Handlers.Unmarshal.PushBack(ec2query.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(ec2query.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(ec2query.UnmarshalError)
return &OutputService3ProtocolTest{service}
}
// newRequest creates a new request for a OutputService3ProtocolTest operation and runs any
// custom request initialization.
func (c *OutputService3ProtocolTest) newRequest(op *request.Operation, params, data interface{}) *request.Request {
req := c.NewRequest(op, params, data)
return req
}
const opOutputService3TestCaseOperation1 = "OperationName"
// OutputService3TestCaseOperation1Request generates a request for the OutputService3TestCaseOperation1 operation.
func (c *OutputService3ProtocolTest) OutputService3TestCaseOperation1Request(input *OutputService3TestShapeOutputService3TestCaseOperation1Input) (req *request.Request, output *OutputService3TestShapeOutputService3TestCaseOperation1Output) {
op := &request.Operation{
Name: opOutputService3TestCaseOperation1,
}
if input == nil {
input = &OutputService3TestShapeOutputService3TestCaseOperation1Input{}
}
req = c.newRequest(op, input, output)
output = &OutputService3TestShapeOutputService3TestCaseOperation1Output{}
req.Data = output
return
}
func (c *OutputService3ProtocolTest) OutputService3TestCaseOperation1(input *OutputService3TestShapeOutputService3TestCaseOperation1Input) (*OutputService3TestShapeOutputService3TestCaseOperation1Output, error) {
req, out := c.OutputService3TestCaseOperation1Request(input)
err := req.Send()
return out, err
}
type OutputService3TestShapeOutputService3TestCaseOperation1Input struct {
metadataOutputService3TestShapeOutputService3TestCaseOperation1Input `json:"-" xml:"-"`
}
type metadataOutputService3TestShapeOutputService3TestCaseOperation1Input struct {
SDKShapeTraits bool `type:"structure"`
}
type OutputService3TestShapeOutputService3TestCaseOperation1Output struct {
ListMember []*string `type:"list"`
metadataOutputService3TestShapeOutputService3TestCaseOperation1Output `json:"-" xml:"-"`
}
type metadataOutputService3TestShapeOutputService3TestCaseOperation1Output struct {
SDKShapeTraits bool `type:"structure"`
}
type OutputService4ProtocolTest struct {
*service.Service
}
// New returns a new OutputService4ProtocolTest client.
func NewOutputService4ProtocolTest(config *aws.Config) *OutputService4ProtocolTest {
service := &service.Service{
ServiceInfo: serviceinfo.ServiceInfo{
Config: defaults.DefaultConfig.Merge(config),
ServiceName: "outputservice4protocoltest",
APIVersion: "",
},
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(ec2query.Build)
service.Handlers.Unmarshal.PushBack(ec2query.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(ec2query.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(ec2query.UnmarshalError)
return &OutputService4ProtocolTest{service}
}
// newRequest creates a new request for a OutputService4ProtocolTest operation and runs any
// custom request initialization.
func (c *OutputService4ProtocolTest) newRequest(op *request.Operation, params, data interface{}) *request.Request {
req := c.NewRequest(op, params, data)
return req
}
const opOutputService4TestCaseOperation1 = "OperationName"
// OutputService4TestCaseOperation1Request generates a request for the OutputService4TestCaseOperation1 operation.
func (c *OutputService4ProtocolTest) OutputService4TestCaseOperation1Request(input *OutputService4TestShapeOutputService4TestCaseOperation1Input) (req *request.Request, output *OutputService4TestShapeOutputService4TestCaseOperation1Output) {
op := &request.Operation{
Name: opOutputService4TestCaseOperation1,
}
if input == nil {
input = &OutputService4TestShapeOutputService4TestCaseOperation1Input{}
}
req = c.newRequest(op, input, output)
output = &OutputService4TestShapeOutputService4TestCaseOperation1Output{}
req.Data = output
return
}
func (c *OutputService4ProtocolTest) OutputService4TestCaseOperation1(input *OutputService4TestShapeOutputService4TestCaseOperation1Input) (*OutputService4TestShapeOutputService4TestCaseOperation1Output, error) {
req, out := c.OutputService4TestCaseOperation1Request(input)
err := req.Send()
return out, err
}
type OutputService4TestShapeOutputService4TestCaseOperation1Input struct {
metadataOutputService4TestShapeOutputService4TestCaseOperation1Input `json:"-" xml:"-"`
}
type metadataOutputService4TestShapeOutputService4TestCaseOperation1Input struct {
SDKShapeTraits bool `type:"structure"`
}
type OutputService4TestShapeOutputService4TestCaseOperation1Output struct {
ListMember []*string `locationNameList:"item" type:"list"`
metadataOutputService4TestShapeOutputService4TestCaseOperation1Output `json:"-" xml:"-"`
}
type metadataOutputService4TestShapeOutputService4TestCaseOperation1Output struct {
SDKShapeTraits bool `type:"structure"`
}
type OutputService5ProtocolTest struct {
*service.Service
}
// New returns a new OutputService5ProtocolTest client.
func NewOutputService5ProtocolTest(config *aws.Config) *OutputService5ProtocolTest {
service := &service.Service{
ServiceInfo: serviceinfo.ServiceInfo{
Config: defaults.DefaultConfig.Merge(config),
ServiceName: "outputservice5protocoltest",
APIVersion: "",
},
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(ec2query.Build)
service.Handlers.Unmarshal.PushBack(ec2query.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(ec2query.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(ec2query.UnmarshalError)
return &OutputService5ProtocolTest{service}
}
// newRequest creates a new request for a OutputService5ProtocolTest operation and runs any
// custom request initialization.
func (c *OutputService5ProtocolTest) newRequest(op *request.Operation, params, data interface{}) *request.Request {
req := c.NewRequest(op, params, data)
return req
}
const opOutputService5TestCaseOperation1 = "OperationName"
// OutputService5TestCaseOperation1Request generates a request for the OutputService5TestCaseOperation1 operation.
func (c *OutputService5ProtocolTest) OutputService5TestCaseOperation1Request(input *OutputService5TestShapeOutputService5TestCaseOperation1Input) (req *request.Request, output *OutputService5TestShapeOutputService5TestCaseOperation1Output) {
op := &request.Operation{
Name: opOutputService5TestCaseOperation1,
}
if input == nil {
input = &OutputService5TestShapeOutputService5TestCaseOperation1Input{}
}
req = c.newRequest(op, input, output)
output = &OutputService5TestShapeOutputService5TestCaseOperation1Output{}
req.Data = output
return
}
func (c *OutputService5ProtocolTest) OutputService5TestCaseOperation1(input *OutputService5TestShapeOutputService5TestCaseOperation1Input) (*OutputService5TestShapeOutputService5TestCaseOperation1Output, error) {
req, out := c.OutputService5TestCaseOperation1Request(input)
err := req.Send()
return out, err
}
type OutputService5TestShapeOutputService5TestCaseOperation1Input struct {
metadataOutputService5TestShapeOutputService5TestCaseOperation1Input `json:"-" xml:"-"`
}
type metadataOutputService5TestShapeOutputService5TestCaseOperation1Input struct {
SDKShapeTraits bool `type:"structure"`
}
type OutputService5TestShapeOutputService5TestCaseOperation1Output struct {
ListMember []*string `type:"list" flattened:"true"`
metadataOutputService5TestShapeOutputService5TestCaseOperation1Output `json:"-" xml:"-"`
}
type metadataOutputService5TestShapeOutputService5TestCaseOperation1Output struct {
SDKShapeTraits bool `type:"structure"`
}
type OutputService6ProtocolTest struct {
*service.Service
}
// New returns a new OutputService6ProtocolTest client.
func NewOutputService6ProtocolTest(config *aws.Config) *OutputService6ProtocolTest {
service := &service.Service{
ServiceInfo: serviceinfo.ServiceInfo{
Config: defaults.DefaultConfig.Merge(config),
ServiceName: "outputservice6protocoltest",
APIVersion: "",
},
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(ec2query.Build)
service.Handlers.Unmarshal.PushBack(ec2query.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(ec2query.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(ec2query.UnmarshalError)
return &OutputService6ProtocolTest{service}
}
// newRequest creates a new request for a OutputService6ProtocolTest operation and runs any
// custom request initialization.
func (c *OutputService6ProtocolTest) newRequest(op *request.Operation, params, data interface{}) *request.Request {
req := c.NewRequest(op, params, data)
return req
}
const opOutputService6TestCaseOperation1 = "OperationName"
// OutputService6TestCaseOperation1Request generates a request for the OutputService6TestCaseOperation1 operation.
func (c *OutputService6ProtocolTest) OutputService6TestCaseOperation1Request(input *OutputService6TestShapeOutputService6TestCaseOperation1Input) (req *request.Request, output *OutputService6TestShapeOutputService6TestCaseOperation1Output) {
op := &request.Operation{
Name: opOutputService6TestCaseOperation1,
}
if input == nil {
input = &OutputService6TestShapeOutputService6TestCaseOperation1Input{}
}
req = c.newRequest(op, input, output)
output = &OutputService6TestShapeOutputService6TestCaseOperation1Output{}
req.Data = output
return
}
func (c *OutputService6ProtocolTest) OutputService6TestCaseOperation1(input *OutputService6TestShapeOutputService6TestCaseOperation1Input) (*OutputService6TestShapeOutputService6TestCaseOperation1Output, error) {
req, out := c.OutputService6TestCaseOperation1Request(input)
err := req.Send()
return out, err
}
type OutputService6TestShapeOutputService6TestCaseOperation1Input struct {
metadataOutputService6TestShapeOutputService6TestCaseOperation1Input `json:"-" xml:"-"`
}
type metadataOutputService6TestShapeOutputService6TestCaseOperation1Input struct {
SDKShapeTraits bool `type:"structure"`
}
type OutputService6TestShapeOutputService6TestCaseOperation1Output struct {
Map map[string]*OutputService6TestShapeStructureType `type:"map"`
metadataOutputService6TestShapeOutputService6TestCaseOperation1Output `json:"-" xml:"-"`
}
type metadataOutputService6TestShapeOutputService6TestCaseOperation1Output struct {
SDKShapeTraits bool `type:"structure"`
}
type OutputService6TestShapeStructureType struct {
Foo *string `locationName:"foo" type:"string"`
metadataOutputService6TestShapeStructureType `json:"-" xml:"-"`
}
type metadataOutputService6TestShapeStructureType struct {
SDKShapeTraits bool `type:"structure"`
}
type OutputService7ProtocolTest struct {
*service.Service
}
// New returns a new OutputService7ProtocolTest client.
func NewOutputService7ProtocolTest(config *aws.Config) *OutputService7ProtocolTest {
service := &service.Service{
ServiceInfo: serviceinfo.ServiceInfo{
Config: defaults.DefaultConfig.Merge(config),
ServiceName: "outputservice7protocoltest",
APIVersion: "",
},
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(ec2query.Build)
service.Handlers.Unmarshal.PushBack(ec2query.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(ec2query.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(ec2query.UnmarshalError)
return &OutputService7ProtocolTest{service}
}
// newRequest creates a new request for a OutputService7ProtocolTest operation and runs any
// custom request initialization.
func (c *OutputService7ProtocolTest) newRequest(op *request.Operation, params, data interface{}) *request.Request {
req := c.NewRequest(op, params, data)
return req
}
const opOutputService7TestCaseOperation1 = "OperationName"
// OutputService7TestCaseOperation1Request generates a request for the OutputService7TestCaseOperation1 operation.
func (c *OutputService7ProtocolTest) OutputService7TestCaseOperation1Request(input *OutputService7TestShapeOutputService7TestCaseOperation1Input) (req *request.Request, output *OutputService7TestShapeOutputService7TestCaseOperation1Output) {
op := &request.Operation{
Name: opOutputService7TestCaseOperation1,
}
if input == nil {
input = &OutputService7TestShapeOutputService7TestCaseOperation1Input{}
}
req = c.newRequest(op, input, output)
output = &OutputService7TestShapeOutputService7TestCaseOperation1Output{}
req.Data = output
return
}
func (c *OutputService7ProtocolTest) OutputService7TestCaseOperation1(input *OutputService7TestShapeOutputService7TestCaseOperation1Input) (*OutputService7TestShapeOutputService7TestCaseOperation1Output, error) {
req, out := c.OutputService7TestCaseOperation1Request(input)
err := req.Send()
return out, err
}
type OutputService7TestShapeOutputService7TestCaseOperation1Input struct {
metadataOutputService7TestShapeOutputService7TestCaseOperation1Input `json:"-" xml:"-"`
}
type metadataOutputService7TestShapeOutputService7TestCaseOperation1Input struct {
SDKShapeTraits bool `type:"structure"`
}
type OutputService7TestShapeOutputService7TestCaseOperation1Output struct {
Map map[string]*string `type:"map" flattened:"true"`
metadataOutputService7TestShapeOutputService7TestCaseOperation1Output `json:"-" xml:"-"`
}
type metadataOutputService7TestShapeOutputService7TestCaseOperation1Output struct {
SDKShapeTraits bool `type:"structure"`
}
type OutputService8ProtocolTest struct {
*service.Service
}
// New returns a new OutputService8ProtocolTest client.
func NewOutputService8ProtocolTest(config *aws.Config) *OutputService8ProtocolTest {
service := &service.Service{
ServiceInfo: serviceinfo.ServiceInfo{
Config: defaults.DefaultConfig.Merge(config),
ServiceName: "outputservice8protocoltest",
APIVersion: "",
},
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(ec2query.Build)
service.Handlers.Unmarshal.PushBack(ec2query.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(ec2query.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(ec2query.UnmarshalError)
return &OutputService8ProtocolTest{service}
}
// newRequest creates a new request for a OutputService8ProtocolTest operation and runs any
// custom request initialization.
func (c *OutputService8ProtocolTest) newRequest(op *request.Operation, params, data interface{}) *request.Request {
req := c.NewRequest(op, params, data)
return req
}
const opOutputService8TestCaseOperation1 = "OperationName"
// OutputService8TestCaseOperation1Request generates a request for the OutputService8TestCaseOperation1 operation.
func (c *OutputService8ProtocolTest) OutputService8TestCaseOperation1Request(input *OutputService8TestShapeOutputService8TestCaseOperation1Input) (req *request.Request, output *OutputService8TestShapeOutputService8TestCaseOperation1Output) {
op := &request.Operation{
Name: opOutputService8TestCaseOperation1,
}
if input == nil {
input = &OutputService8TestShapeOutputService8TestCaseOperation1Input{}
}
req = c.newRequest(op, input, output)
output = &OutputService8TestShapeOutputService8TestCaseOperation1Output{}
req.Data = output
return
}
func (c *OutputService8ProtocolTest) OutputService8TestCaseOperation1(input *OutputService8TestShapeOutputService8TestCaseOperation1Input) (*OutputService8TestShapeOutputService8TestCaseOperation1Output, error) {
req, out := c.OutputService8TestCaseOperation1Request(input)
err := req.Send()
return out, err
}
type OutputService8TestShapeOutputService8TestCaseOperation1Input struct {
metadataOutputService8TestShapeOutputService8TestCaseOperation1Input `json:"-" xml:"-"`
}
type metadataOutputService8TestShapeOutputService8TestCaseOperation1Input struct {
SDKShapeTraits bool `type:"structure"`
}
type OutputService8TestShapeOutputService8TestCaseOperation1Output struct {
Map map[string]*string `locationNameKey:"foo" locationNameValue:"bar" type:"map" flattened:"true"`
metadataOutputService8TestShapeOutputService8TestCaseOperation1Output `json:"-" xml:"-"`
}
type metadataOutputService8TestShapeOutputService8TestCaseOperation1Output struct {
SDKShapeTraits bool `type:"structure"`
}
//
// Tests begin here
//
func TestOutputService1ProtocolTestScalarMembersCase1(t *testing.T) {
svc := NewOutputService1ProtocolTest(nil)
buf := bytes.NewReader([]byte("<OperationNameResponse><Str>myname</Str><FooNum>123</FooNum><FalseBool>false</FalseBool><TrueBool>true</TrueBool><Float>1.2</Float><Double>1.3</Double><Long>200</Long><Char>a</Char><RequestId>request-id</RequestId></OperationNameResponse>"))
req, out := svc.OutputService1TestCaseOperation1Request(nil)
req.HTTPResponse = &http.Response{StatusCode: 200, Body: ioutil.NopCloser(buf), Header: http.Header{}}
// set headers
// unmarshal response
ec2query.UnmarshalMeta(req)
ec2query.Unmarshal(req)
assert.NoError(t, req.Error)
// assert response
assert.NotNil(t, out) // ensure out variable is used
assert.Equal(t, "a", *out.Char)
assert.Equal(t, 1.3, *out.Double)
assert.Equal(t, false, *out.FalseBool)
assert.Equal(t, 1.2, *out.Float)
assert.Equal(t, int64(200), *out.Long)
assert.Equal(t, int64(123), *out.Num)
assert.Equal(t, "myname", *out.Str)
assert.Equal(t, true, *out.TrueBool)
}
func TestOutputService2ProtocolTestBlobCase1(t *testing.T) {
svc := NewOutputService2ProtocolTest(nil)
buf := bytes.NewReader([]byte("<OperationNameResponse><Blob>dmFsdWU=</Blob><RequestId>requestid</RequestId></OperationNameResponse>"))
req, out := svc.OutputService2TestCaseOperation1Request(nil)
req.HTTPResponse = &http.Response{StatusCode: 200, Body: ioutil.NopCloser(buf), Header: http.Header{}}
// set headers
// unmarshal response
ec2query.UnmarshalMeta(req)
ec2query.Unmarshal(req)
assert.NoError(t, req.Error)
// assert response
assert.NotNil(t, out) // ensure out variable is used
assert.Equal(t, "value", string(out.Blob))
}
func TestOutputService3ProtocolTestListsCase1(t *testing.T) {
svc := NewOutputService3ProtocolTest(nil)
buf := bytes.NewReader([]byte("<OperationNameResponse><ListMember><member>abc</member><member>123</member></ListMember><RequestId>requestid</RequestId></OperationNameResponse>"))
req, out := svc.OutputService3TestCaseOperation1Request(nil)
req.HTTPResponse = &http.Response{StatusCode: 200, Body: ioutil.NopCloser(buf), Header: http.Header{}}
// set headers
// unmarshal response
ec2query.UnmarshalMeta(req)
ec2query.Unmarshal(req)
assert.NoError(t, req.Error)
// assert response
assert.NotNil(t, out) // ensure out variable is used
assert.Equal(t, "abc", *out.ListMember[0])
assert.Equal(t, "123", *out.ListMember[1])
}
func TestOutputService4ProtocolTestListWithCustomMemberNameCase1(t *testing.T) {
svc := NewOutputService4ProtocolTest(nil)
buf := bytes.NewReader([]byte("<OperationNameResponse><ListMember><item>abc</item><item>123</item></ListMember><RequestId>requestid</RequestId></OperationNameResponse>"))
req, out := svc.OutputService4TestCaseOperation1Request(nil)
req.HTTPResponse = &http.Response{StatusCode: 200, Body: ioutil.NopCloser(buf), Header: http.Header{}}
// set headers
// unmarshal response
ec2query.UnmarshalMeta(req)
ec2query.Unmarshal(req)
assert.NoError(t, req.Error)
// assert response
assert.NotNil(t, out) // ensure out variable is used
assert.Equal(t, "abc", *out.ListMember[0])
assert.Equal(t, "123", *out.ListMember[1])
}
func TestOutputService5ProtocolTestFlattenedListCase1(t *testing.T) {
svc := NewOutputService5ProtocolTest(nil)
buf := bytes.NewReader([]byte("<OperationNameResponse><ListMember>abc</ListMember><ListMember>123</ListMember><RequestId>requestid</RequestId></OperationNameResponse>"))
req, out := svc.OutputService5TestCaseOperation1Request(nil)
req.HTTPResponse = &http.Response{StatusCode: 200, Body: ioutil.NopCloser(buf), Header: http.Header{}}
// set headers
// unmarshal response
ec2query.UnmarshalMeta(req)
ec2query.Unmarshal(req)
assert.NoError(t, req.Error)
// assert response
assert.NotNil(t, out) // ensure out variable is used
assert.Equal(t, "abc", *out.ListMember[0])
assert.Equal(t, "123", *out.ListMember[1])
}
func TestOutputService6ProtocolTestNormalMapCase1(t *testing.T) {
svc := NewOutputService6ProtocolTest(nil)
buf := bytes.NewReader([]byte("<OperationNameResponse><Map><entry><key>qux</key><value><foo>bar</foo></value></entry><entry><key>baz</key><value><foo>bam</foo></value></entry></Map><RequestId>requestid</RequestId></OperationNameResponse>"))
req, out := svc.OutputService6TestCaseOperation1Request(nil)
req.HTTPResponse = &http.Response{StatusCode: 200, Body: ioutil.NopCloser(buf), Header: http.Header{}}
// set headers
// unmarshal response
ec2query.UnmarshalMeta(req)
ec2query.Unmarshal(req)
assert.NoError(t, req.Error)
// assert response
assert.NotNil(t, out) // ensure out variable is used
assert.Equal(t, "bam", *out.Map["baz"].Foo)
assert.Equal(t, "bar", *out.Map["qux"].Foo)
}
func TestOutputService7ProtocolTestFlattenedMapCase1(t *testing.T) {
svc := NewOutputService7ProtocolTest(nil)
buf := bytes.NewReader([]byte("<OperationNameResponse><Map><key>qux</key><value>bar</value></Map><Map><key>baz</key><value>bam</value></Map><RequestId>requestid</RequestId></OperationNameResponse>"))
req, out := svc.OutputService7TestCaseOperation1Request(nil)
req.HTTPResponse = &http.Response{StatusCode: 200, Body: ioutil.NopCloser(buf), Header: http.Header{}}
// set headers
// unmarshal response
ec2query.UnmarshalMeta(req)
ec2query.Unmarshal(req)
assert.NoError(t, req.Error)
// assert response
assert.NotNil(t, out) // ensure out variable is used
assert.Equal(t, "bam", *out.Map["baz"])
assert.Equal(t, "bar", *out.Map["qux"])
}
func TestOutputService8ProtocolTestNamedMapCase1(t *testing.T) {
svc := NewOutputService8ProtocolTest(nil)
buf := bytes.NewReader([]byte("<OperationNameResponse><Map><foo>qux</foo><bar>bar</bar></Map><Map><foo>baz</foo><bar>bam</bar></Map><RequestId>requestid</RequestId></OperationNameResponse>"))
req, out := svc.OutputService8TestCaseOperation1Request(nil)
req.HTTPResponse = &http.Response{StatusCode: 200, Body: ioutil.NopCloser(buf), Header: http.Header{}}
// set headers
// unmarshal response
ec2query.UnmarshalMeta(req)
ec2query.Unmarshal(req)
assert.NoError(t, req.Error)
// assert response
assert.NotNil(t, out) // ensure out variable is used
assert.Equal(t, "bam", *out.Map["baz"])
assert.Equal(t, "bar", *out.Map["qux"])
}

View File

@ -0,0 +1,33 @@
// Package query provides serialisation of AWS query requests, and responses.
package query
//go:generate go run ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/input/query.json build_test.go
import (
"net/url"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/private/protocol/query/queryutil"
)
// Build builds a request for an AWS Query service.
func Build(r *request.Request) {
body := url.Values{
"Action": {r.Operation.Name},
"Version": {r.Service.APIVersion},
}
if err := queryutil.Parse(body, r.Params, false); err != nil {
r.Error = awserr.New("SerializationError", "failed encoding Query request", err)
return
}
if r.ExpireTime == 0 {
r.HTTPRequest.Method = "POST"
r.HTTPRequest.Header.Set("Content-Type", "application/x-www-form-urlencoded; charset=utf-8")
r.SetBufferBody([]byte(body.Encode()))
} else { // This is a pre-signed request
r.HTTPRequest.Method = "GET"
r.HTTPRequest.URL.RawQuery = body.Encode()
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,223 @@
package queryutil
import (
"encoding/base64"
"fmt"
"net/url"
"reflect"
"sort"
"strconv"
"strings"
"time"
)
// Parse parses an object i and fills a url.Values object. The isEC2 flag
// indicates if this is the EC2 Query sub-protocol.
func Parse(body url.Values, i interface{}, isEC2 bool) error {
q := queryParser{isEC2: isEC2}
return q.parseValue(body, reflect.ValueOf(i), "", "")
}
func elemOf(value reflect.Value) reflect.Value {
for value.Kind() == reflect.Ptr {
value = value.Elem()
}
return value
}
type queryParser struct {
isEC2 bool
}
func (q *queryParser) parseValue(v url.Values, value reflect.Value, prefix string, tag reflect.StructTag) error {
value = elemOf(value)
// no need to handle zero values
if !value.IsValid() {
return nil
}
t := tag.Get("type")
if t == "" {
switch value.Kind() {
case reflect.Struct:
t = "structure"
case reflect.Slice:
t = "list"
case reflect.Map:
t = "map"
}
}
switch t {
case "structure":
return q.parseStruct(v, value, prefix)
case "list":
return q.parseList(v, value, prefix, tag)
case "map":
return q.parseMap(v, value, prefix, tag)
default:
return q.parseScalar(v, value, prefix, tag)
}
}
func (q *queryParser) parseStruct(v url.Values, value reflect.Value, prefix string) error {
if !value.IsValid() {
return nil
}
t := value.Type()
for i := 0; i < value.NumField(); i++ {
if c := t.Field(i).Name[0:1]; strings.ToLower(c) == c {
continue // ignore unexported fields
}
value := elemOf(value.Field(i))
field := t.Field(i)
var name string
if q.isEC2 {
name = field.Tag.Get("queryName")
}
if name == "" {
if field.Tag.Get("flattened") != "" && field.Tag.Get("locationNameList") != "" {
name = field.Tag.Get("locationNameList")
} else if locName := field.Tag.Get("locationName"); locName != "" {
name = locName
}
if name != "" && q.isEC2 {
name = strings.ToUpper(name[0:1]) + name[1:]
}
}
if name == "" {
name = field.Name
}
if prefix != "" {
name = prefix + "." + name
}
if err := q.parseValue(v, value, name, field.Tag); err != nil {
return err
}
}
return nil
}
func (q *queryParser) parseList(v url.Values, value reflect.Value, prefix string, tag reflect.StructTag) error {
// If it's empty, generate an empty value
if !value.IsNil() && value.Len() == 0 {
v.Set(prefix, "")
return nil
}
// check for unflattened list member
if !q.isEC2 && tag.Get("flattened") == "" {
prefix += ".member"
}
for i := 0; i < value.Len(); i++ {
slicePrefix := prefix
if slicePrefix == "" {
slicePrefix = strconv.Itoa(i + 1)
} else {
slicePrefix = slicePrefix + "." + strconv.Itoa(i+1)
}
if err := q.parseValue(v, value.Index(i), slicePrefix, ""); err != nil {
return err
}
}
return nil
}
func (q *queryParser) parseMap(v url.Values, value reflect.Value, prefix string, tag reflect.StructTag) error {
// If it's empty, generate an empty value
if !value.IsNil() && value.Len() == 0 {
v.Set(prefix, "")
return nil
}
// check for unflattened list member
if !q.isEC2 && tag.Get("flattened") == "" {
prefix += ".entry"
}
// sort keys for improved serialization consistency.
// this is not strictly necessary for protocol support.
mapKeyValues := value.MapKeys()
mapKeys := map[string]reflect.Value{}
mapKeyNames := make([]string, len(mapKeyValues))
for i, mapKey := range mapKeyValues {
name := mapKey.String()
mapKeys[name] = mapKey
mapKeyNames[i] = name
}
sort.Strings(mapKeyNames)
for i, mapKeyName := range mapKeyNames {
mapKey := mapKeys[mapKeyName]
mapValue := value.MapIndex(mapKey)
kname := tag.Get("locationNameKey")
if kname == "" {
kname = "key"
}
vname := tag.Get("locationNameValue")
if vname == "" {
vname = "value"
}
// serialize key
var keyName string
if prefix == "" {
keyName = strconv.Itoa(i+1) + "." + kname
} else {
keyName = prefix + "." + strconv.Itoa(i+1) + "." + kname
}
if err := q.parseValue(v, mapKey, keyName, ""); err != nil {
return err
}
// serialize value
var valueName string
if prefix == "" {
valueName = strconv.Itoa(i+1) + "." + vname
} else {
valueName = prefix + "." + strconv.Itoa(i+1) + "." + vname
}
if err := q.parseValue(v, mapValue, valueName, ""); err != nil {
return err
}
}
return nil
}
func (q *queryParser) parseScalar(v url.Values, r reflect.Value, name string, tag reflect.StructTag) error {
switch value := r.Interface().(type) {
case string:
v.Set(name, value)
case []byte:
if !r.IsNil() {
v.Set(name, base64.StdEncoding.EncodeToString(value))
}
case bool:
v.Set(name, strconv.FormatBool(value))
case int64:
v.Set(name, strconv.FormatInt(value, 10))
case int:
v.Set(name, strconv.Itoa(value))
case float64:
v.Set(name, strconv.FormatFloat(value, 'f', -1, 64))
case float32:
v.Set(name, strconv.FormatFloat(float64(value), 'f', -1, 32))
case time.Time:
const ISO8601UTC = "2006-01-02T15:04:05Z"
v.Set(name, value.UTC().Format(ISO8601UTC))
default:
return fmt.Errorf("unsupported value for param %s: %v (%s)", name, r.Interface(), r.Type().Name())
}
return nil
}

View File

@ -0,0 +1,29 @@
package query
//go:generate go run ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/output/query.json unmarshal_test.go
import (
"encoding/xml"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil"
)
// Unmarshal unmarshals a response for an AWS Query service.
func Unmarshal(r *request.Request) {
defer r.HTTPResponse.Body.Close()
if r.DataFilled() {
decoder := xml.NewDecoder(r.HTTPResponse.Body)
err := xmlutil.UnmarshalXML(r.Data, decoder, r.Operation.Name+"Result")
if err != nil {
r.Error = awserr.New("SerializationError", "failed decoding Query response", err)
return
}
}
}
// UnmarshalMeta unmarshals header response values for an AWS Query service.
func UnmarshalMeta(r *request.Request) {
// TODO implement unmarshaling of request IDs
}

View File

@ -0,0 +1,33 @@
package query
import (
"encoding/xml"
"io"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
type xmlErrorResponse struct {
XMLName xml.Name `xml:"ErrorResponse"`
Code string `xml:"Error>Code"`
Message string `xml:"Error>Message"`
RequestID string `xml:"RequestId"`
}
// UnmarshalError unmarshals an error response for an AWS Query service.
func UnmarshalError(r *request.Request) {
defer r.HTTPResponse.Body.Close()
resp := &xmlErrorResponse{}
err := xml.NewDecoder(r.HTTPResponse.Body).Decode(resp)
if err != nil && err != io.EOF {
r.Error = awserr.New("SerializationError", "failed to decode query XML error response", err)
} else {
r.Error = awserr.NewRequestFailure(
awserr.New(resp.Code, resp.Message, nil),
r.HTTPResponse.StatusCode,
resp.RequestID,
)
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,254 @@
// Package rest provides RESTful serialization of AWS requests and responses.
package rest
import (
"bytes"
"encoding/base64"
"fmt"
"io"
"net/http"
"net/url"
"path"
"reflect"
"strconv"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
// RFC822 returns an RFC822 formatted timestamp for AWS protocols
const RFC822 = "Mon, 2 Jan 2006 15:04:05 GMT"
// Whether the byte value can be sent without escaping in AWS URLs
var noEscape [256]bool
var errValueNotSet = fmt.Errorf("value not set")
func init() {
for i := 0; i < len(noEscape); i++ {
// AWS expects every character except these to be escaped
noEscape[i] = (i >= 'A' && i <= 'Z') ||
(i >= 'a' && i <= 'z') ||
(i >= '0' && i <= '9') ||
i == '-' ||
i == '.' ||
i == '_' ||
i == '~'
}
}
// Build builds the REST component of a service request.
func Build(r *request.Request) {
if r.ParamsFilled() {
v := reflect.ValueOf(r.Params).Elem()
buildLocationElements(r, v)
buildBody(r, v)
}
}
func buildLocationElements(r *request.Request, v reflect.Value) {
query := r.HTTPRequest.URL.Query()
for i := 0; i < v.NumField(); i++ {
m := v.Field(i)
if n := v.Type().Field(i).Name; n[0:1] == strings.ToLower(n[0:1]) {
continue
}
if m.IsValid() {
field := v.Type().Field(i)
name := field.Tag.Get("locationName")
if name == "" {
name = field.Name
}
if m.Kind() == reflect.Ptr {
m = m.Elem()
}
if !m.IsValid() {
continue
}
var err error
switch field.Tag.Get("location") {
case "headers": // header maps
err = buildHeaderMap(&r.HTTPRequest.Header, m, field.Tag.Get("locationName"))
case "header":
err = buildHeader(&r.HTTPRequest.Header, m, name)
case "uri":
err = buildURI(r.HTTPRequest.URL, m, name)
case "querystring":
err = buildQueryString(query, m, name)
}
r.Error = err
}
if r.Error != nil {
return
}
}
r.HTTPRequest.URL.RawQuery = query.Encode()
updatePath(r.HTTPRequest.URL, r.HTTPRequest.URL.Path)
}
func buildBody(r *request.Request, v reflect.Value) {
if field, ok := v.Type().FieldByName("SDKShapeTraits"); ok {
if payloadName := field.Tag.Get("payload"); payloadName != "" {
pfield, _ := v.Type().FieldByName(payloadName)
if ptag := pfield.Tag.Get("type"); ptag != "" && ptag != "structure" {
payload := reflect.Indirect(v.FieldByName(payloadName))
if payload.IsValid() && payload.Interface() != nil {
switch reader := payload.Interface().(type) {
case io.ReadSeeker:
r.SetReaderBody(reader)
case []byte:
r.SetBufferBody(reader)
case string:
r.SetStringBody(reader)
default:
r.Error = awserr.New("SerializationError",
"failed to encode REST request",
fmt.Errorf("unknown payload type %s", payload.Type()))
}
}
}
}
}
}
func buildHeader(header *http.Header, v reflect.Value, name string) error {
str, err := convertType(v)
if err == errValueNotSet {
return nil
} else if err != nil {
return awserr.New("SerializationError", "failed to encode REST request", err)
}
header.Add(name, str)
return nil
}
func buildHeaderMap(header *http.Header, v reflect.Value, prefix string) error {
for _, key := range v.MapKeys() {
str, err := convertType(v.MapIndex(key))
if err == errValueNotSet {
continue
} else if err != nil {
return awserr.New("SerializationError", "failed to encode REST request", err)
}
header.Add(prefix+key.String(), str)
}
return nil
}
func buildURI(u *url.URL, v reflect.Value, name string) error {
value, err := convertType(v)
if err == errValueNotSet {
return nil
} else if err != nil {
return awserr.New("SerializationError", "failed to encode REST request", err)
}
uri := u.Path
uri = strings.Replace(uri, "{"+name+"}", EscapePath(value, true), -1)
uri = strings.Replace(uri, "{"+name+"+}", EscapePath(value, false), -1)
u.Path = uri
return nil
}
func buildQueryString(query url.Values, v reflect.Value, name string) error {
switch value := v.Interface().(type) {
case []*string:
for _, item := range value {
query.Add(name, *item)
}
case map[string]*string:
for key, item := range value {
query.Add(key, *item)
}
case map[string][]*string:
for key, items := range value {
for _, item := range items {
query.Add(key, *item)
}
}
default:
str, err := convertType(v)
if err == errValueNotSet {
return nil
} else if err != nil {
return awserr.New("SerializationError", "failed to encode REST request", err)
}
query.Set(name, str)
}
return nil
}
func updatePath(url *url.URL, urlPath string) {
scheme, query := url.Scheme, url.RawQuery
hasSlash := strings.HasSuffix(urlPath, "/")
// clean up path
urlPath = path.Clean(urlPath)
if hasSlash && !strings.HasSuffix(urlPath, "/") {
urlPath += "/"
}
// get formatted URL minus scheme so we can build this into Opaque
url.Scheme, url.Path, url.RawQuery = "", "", ""
s := url.String()
url.Scheme = scheme
url.RawQuery = query
// build opaque URI
url.Opaque = s + urlPath
}
// EscapePath escapes part of a URL path in Amazon style
func EscapePath(path string, encodeSep bool) string {
var buf bytes.Buffer
for i := 0; i < len(path); i++ {
c := path[i]
if noEscape[c] || (c == '/' && !encodeSep) {
buf.WriteByte(c)
} else {
buf.WriteByte('%')
buf.WriteString(strings.ToUpper(strconv.FormatUint(uint64(c), 16)))
}
}
return buf.String()
}
func convertType(v reflect.Value) (string, error) {
v = reflect.Indirect(v)
if !v.IsValid() {
return "", errValueNotSet
}
var str string
switch value := v.Interface().(type) {
case string:
str = value
case []byte:
str = base64.StdEncoding.EncodeToString(value)
case bool:
str = strconv.FormatBool(value)
case int64:
str = strconv.FormatInt(value, 10)
case float64:
str = strconv.FormatFloat(value, 'f', -1, 64)
case time.Time:
str = value.UTC().Format(RFC822)
default:
err := fmt.Errorf("Unsupported value for param %v (%s)", v.Interface(), v.Type())
return "", err
}
return str, nil
}

View File

@ -0,0 +1,45 @@
package rest
import "reflect"
// PayloadMember returns the payload field member of i if there is one, or nil.
func PayloadMember(i interface{}) interface{} {
if i == nil {
return nil
}
v := reflect.ValueOf(i).Elem()
if !v.IsValid() {
return nil
}
if field, ok := v.Type().FieldByName("SDKShapeTraits"); ok {
if payloadName := field.Tag.Get("payload"); payloadName != "" {
field, _ := v.Type().FieldByName(payloadName)
if field.Tag.Get("type") != "structure" {
return nil
}
payload := v.FieldByName(payloadName)
if payload.IsValid() || (payload.Kind() == reflect.Ptr && !payload.IsNil()) {
return payload.Interface()
}
}
}
return nil
}
// PayloadType returns the type of a payload field member of i if there is one, or "".
func PayloadType(i interface{}) string {
v := reflect.Indirect(reflect.ValueOf(i))
if !v.IsValid() {
return ""
}
if field, ok := v.Type().FieldByName("SDKShapeTraits"); ok {
if payloadName := field.Tag.Get("payload"); payloadName != "" {
if member, ok := v.Type().FieldByName(payloadName); ok {
return member.Tag.Get("type")
}
}
}
return ""
}

View File

@ -0,0 +1,183 @@
package rest
import (
"encoding/base64"
"fmt"
"io/ioutil"
"net/http"
"reflect"
"strconv"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
// Unmarshal unmarshals the REST component of a response in a REST service.
func Unmarshal(r *request.Request) {
if r.DataFilled() {
v := reflect.Indirect(reflect.ValueOf(r.Data))
unmarshalBody(r, v)
}
}
// UnmarshalMeta unmarshals the REST metadata of a response in a REST service
func UnmarshalMeta(r *request.Request) {
r.RequestID = r.HTTPResponse.Header.Get("X-Amzn-Requestid")
if r.DataFilled() {
v := reflect.Indirect(reflect.ValueOf(r.Data))
unmarshalLocationElements(r, v)
}
}
func unmarshalBody(r *request.Request, v reflect.Value) {
if field, ok := v.Type().FieldByName("SDKShapeTraits"); ok {
if payloadName := field.Tag.Get("payload"); payloadName != "" {
pfield, _ := v.Type().FieldByName(payloadName)
if ptag := pfield.Tag.Get("type"); ptag != "" && ptag != "structure" {
payload := v.FieldByName(payloadName)
if payload.IsValid() {
switch payload.Interface().(type) {
case []byte:
b, err := ioutil.ReadAll(r.HTTPResponse.Body)
if err != nil {
r.Error = awserr.New("SerializationError", "failed to decode REST response", err)
} else {
payload.Set(reflect.ValueOf(b))
}
case *string:
b, err := ioutil.ReadAll(r.HTTPResponse.Body)
if err != nil {
r.Error = awserr.New("SerializationError", "failed to decode REST response", err)
} else {
str := string(b)
payload.Set(reflect.ValueOf(&str))
}
default:
switch payload.Type().String() {
case "io.ReadSeeker":
payload.Set(reflect.ValueOf(aws.ReadSeekCloser(r.HTTPResponse.Body)))
case "aws.ReadSeekCloser", "io.ReadCloser":
payload.Set(reflect.ValueOf(r.HTTPResponse.Body))
default:
r.Error = awserr.New("SerializationError",
"failed to decode REST response",
fmt.Errorf("unknown payload type %s", payload.Type()))
}
}
}
}
}
}
}
func unmarshalLocationElements(r *request.Request, v reflect.Value) {
for i := 0; i < v.NumField(); i++ {
m, field := v.Field(i), v.Type().Field(i)
if n := field.Name; n[0:1] == strings.ToLower(n[0:1]) {
continue
}
if m.IsValid() {
name := field.Tag.Get("locationName")
if name == "" {
name = field.Name
}
switch field.Tag.Get("location") {
case "statusCode":
unmarshalStatusCode(m, r.HTTPResponse.StatusCode)
case "header":
err := unmarshalHeader(m, r.HTTPResponse.Header.Get(name))
if err != nil {
r.Error = awserr.New("SerializationError", "failed to decode REST response", err)
break
}
case "headers":
prefix := field.Tag.Get("locationName")
err := unmarshalHeaderMap(m, r.HTTPResponse.Header, prefix)
if err != nil {
r.Error = awserr.New("SerializationError", "failed to decode REST response", err)
break
}
}
}
if r.Error != nil {
return
}
}
}
func unmarshalStatusCode(v reflect.Value, statusCode int) {
if !v.IsValid() {
return
}
switch v.Interface().(type) {
case *int64:
s := int64(statusCode)
v.Set(reflect.ValueOf(&s))
}
}
func unmarshalHeaderMap(r reflect.Value, headers http.Header, prefix string) error {
switch r.Interface().(type) {
case map[string]*string: // we only support string map value types
out := map[string]*string{}
for k, v := range headers {
k = http.CanonicalHeaderKey(k)
if strings.HasPrefix(strings.ToLower(k), strings.ToLower(prefix)) {
out[k[len(prefix):]] = &v[0]
}
}
r.Set(reflect.ValueOf(out))
}
return nil
}
func unmarshalHeader(v reflect.Value, header string) error {
if !v.IsValid() || (header == "" && v.Elem().Kind() != reflect.String) {
return nil
}
switch v.Interface().(type) {
case *string:
v.Set(reflect.ValueOf(&header))
case []byte:
b, err := base64.StdEncoding.DecodeString(header)
if err != nil {
return err
}
v.Set(reflect.ValueOf(&b))
case *bool:
b, err := strconv.ParseBool(header)
if err != nil {
return err
}
v.Set(reflect.ValueOf(&b))
case *int64:
i, err := strconv.ParseInt(header, 10, 64)
if err != nil {
return err
}
v.Set(reflect.ValueOf(&i))
case *float64:
f, err := strconv.ParseFloat(header, 64)
if err != nil {
return err
}
v.Set(reflect.ValueOf(&f))
case *time.Time:
t, err := time.Parse(RFC822, header)
if err != nil {
return err
}
v.Set(reflect.ValueOf(&t))
default:
err := fmt.Errorf("Unsupported value for param %v (%s)", v.Interface(), v.Type())
return err
}
return nil
}

View File

@ -0,0 +1,287 @@
// Package xmlutil provides XML serialisation of AWS requests and responses.
package xmlutil
import (
"encoding/base64"
"encoding/xml"
"fmt"
"reflect"
"sort"
"strconv"
"strings"
"time"
)
// BuildXML will serialize params into an xml.Encoder.
// Error will be returned if the serialization of any of the params or nested values fails.
func BuildXML(params interface{}, e *xml.Encoder) error {
b := xmlBuilder{encoder: e, namespaces: map[string]string{}}
root := NewXMLElement(xml.Name{})
if err := b.buildValue(reflect.ValueOf(params), root, ""); err != nil {
return err
}
for _, c := range root.Children {
for _, v := range c {
return StructToXML(e, v, false)
}
}
return nil
}
// Returns the reflection element of a value, if it is a pointer.
func elemOf(value reflect.Value) reflect.Value {
for value.Kind() == reflect.Ptr {
value = value.Elem()
}
return value
}
// A xmlBuilder serializes values from Go code to XML
type xmlBuilder struct {
encoder *xml.Encoder
namespaces map[string]string
}
// buildValue generic XMLNode builder for any type. Will build value for their specific type
// struct, list, map, scalar.
//
// Also takes a "type" tag value to set what type a value should be converted to XMLNode as. If
// type is not provided reflect will be used to determine the value's type.
func (b *xmlBuilder) buildValue(value reflect.Value, current *XMLNode, tag reflect.StructTag) error {
value = elemOf(value)
if !value.IsValid() { // no need to handle zero values
return nil
} else if tag.Get("location") != "" { // don't handle non-body location values
return nil
}
t := tag.Get("type")
if t == "" {
switch value.Kind() {
case reflect.Struct:
t = "structure"
case reflect.Slice:
t = "list"
case reflect.Map:
t = "map"
}
}
switch t {
case "structure":
if field, ok := value.Type().FieldByName("SDKShapeTraits"); ok {
tag = tag + reflect.StructTag(" ") + field.Tag
}
return b.buildStruct(value, current, tag)
case "list":
return b.buildList(value, current, tag)
case "map":
return b.buildMap(value, current, tag)
default:
return b.buildScalar(value, current, tag)
}
}
// buildStruct adds a struct and its fields to the current XMLNode. All fields any any nested
// types are converted to XMLNodes also.
func (b *xmlBuilder) buildStruct(value reflect.Value, current *XMLNode, tag reflect.StructTag) error {
if !value.IsValid() {
return nil
}
fieldAdded := false
// unwrap payloads
if payload := tag.Get("payload"); payload != "" {
field, _ := value.Type().FieldByName(payload)
tag = field.Tag
value = elemOf(value.FieldByName(payload))
if !value.IsValid() {
return nil
}
}
child := NewXMLElement(xml.Name{Local: tag.Get("locationName")})
// there is an xmlNamespace associated with this struct
if prefix, uri := tag.Get("xmlPrefix"), tag.Get("xmlURI"); uri != "" {
ns := xml.Attr{
Name: xml.Name{Local: "xmlns"},
Value: uri,
}
if prefix != "" {
b.namespaces[prefix] = uri // register the namespace
ns.Name.Local = "xmlns:" + prefix
}
child.Attr = append(child.Attr, ns)
}
t := value.Type()
for i := 0; i < value.NumField(); i++ {
if c := t.Field(i).Name[0:1]; strings.ToLower(c) == c {
continue // ignore unexported fields
}
member := elemOf(value.Field(i))
field := t.Field(i)
mTag := field.Tag
if mTag.Get("location") != "" { // skip non-body members
continue
}
memberName := mTag.Get("locationName")
if memberName == "" {
memberName = field.Name
mTag = reflect.StructTag(string(mTag) + ` locationName:"` + memberName + `"`)
}
if err := b.buildValue(member, child, mTag); err != nil {
return err
}
fieldAdded = true
}
if fieldAdded { // only append this child if we have one ore more valid members
current.AddChild(child)
}
return nil
}
// buildList adds the value's list items to the current XMLNode as children nodes. All
// nested values in the list are converted to XMLNodes also.
func (b *xmlBuilder) buildList(value reflect.Value, current *XMLNode, tag reflect.StructTag) error {
if value.IsNil() { // don't build omitted lists
return nil
}
// check for unflattened list member
flattened := tag.Get("flattened") != ""
xname := xml.Name{Local: tag.Get("locationName")}
if flattened {
for i := 0; i < value.Len(); i++ {
child := NewXMLElement(xname)
current.AddChild(child)
if err := b.buildValue(value.Index(i), child, ""); err != nil {
return err
}
}
} else {
list := NewXMLElement(xname)
current.AddChild(list)
for i := 0; i < value.Len(); i++ {
iname := tag.Get("locationNameList")
if iname == "" {
iname = "member"
}
child := NewXMLElement(xml.Name{Local: iname})
list.AddChild(child)
if err := b.buildValue(value.Index(i), child, ""); err != nil {
return err
}
}
}
return nil
}
// buildMap adds the value's key/value pairs to the current XMLNode as children nodes. All
// nested values in the map are converted to XMLNodes also.
//
// Error will be returned if it is unable to build the map's values into XMLNodes
func (b *xmlBuilder) buildMap(value reflect.Value, current *XMLNode, tag reflect.StructTag) error {
if value.IsNil() { // don't build omitted maps
return nil
}
maproot := NewXMLElement(xml.Name{Local: tag.Get("locationName")})
current.AddChild(maproot)
current = maproot
kname, vname := "key", "value"
if n := tag.Get("locationNameKey"); n != "" {
kname = n
}
if n := tag.Get("locationNameValue"); n != "" {
vname = n
}
// sorting is not required for compliance, but it makes testing easier
keys := make([]string, value.Len())
for i, k := range value.MapKeys() {
keys[i] = k.String()
}
sort.Strings(keys)
for _, k := range keys {
v := value.MapIndex(reflect.ValueOf(k))
mapcur := current
if tag.Get("flattened") == "" { // add "entry" tag to non-flat maps
child := NewXMLElement(xml.Name{Local: "entry"})
mapcur.AddChild(child)
mapcur = child
}
kchild := NewXMLElement(xml.Name{Local: kname})
kchild.Text = k
vchild := NewXMLElement(xml.Name{Local: vname})
mapcur.AddChild(kchild)
mapcur.AddChild(vchild)
if err := b.buildValue(v, vchild, ""); err != nil {
return err
}
}
return nil
}
// buildScalar will convert the value into a string and append it as a attribute or child
// of the current XMLNode.
//
// The value will be added as an attribute if tag contains a "xmlAttribute" attribute value.
//
// Error will be returned if the value type is unsupported.
func (b *xmlBuilder) buildScalar(value reflect.Value, current *XMLNode, tag reflect.StructTag) error {
var str string
switch converted := value.Interface().(type) {
case string:
str = converted
case []byte:
if !value.IsNil() {
str = base64.StdEncoding.EncodeToString(converted)
}
case bool:
str = strconv.FormatBool(converted)
case int64:
str = strconv.FormatInt(converted, 10)
case int:
str = strconv.Itoa(converted)
case float64:
str = strconv.FormatFloat(converted, 'f', -1, 64)
case float32:
str = strconv.FormatFloat(float64(converted), 'f', -1, 32)
case time.Time:
const ISO8601UTC = "2006-01-02T15:04:05Z"
str = converted.UTC().Format(ISO8601UTC)
default:
return fmt.Errorf("unsupported value for param %s: %v (%s)",
tag.Get("locationName"), value.Interface(), value.Type().Name())
}
xname := xml.Name{Local: tag.Get("locationName")}
if tag.Get("xmlAttribute") != "" { // put into current node's attribute list
attr := xml.Attr{Name: xname, Value: str}
current.Attr = append(current.Attr, attr)
} else { // regular text node
current.AddChild(&XMLNode{Name: xname, Text: str})
}
return nil
}

View File

@ -0,0 +1,260 @@
package xmlutil
import (
"encoding/base64"
"encoding/xml"
"fmt"
"io"
"reflect"
"strconv"
"strings"
"time"
)
// UnmarshalXML deserializes an xml.Decoder into the container v. V
// needs to match the shape of the XML expected to be decoded.
// If the shape doesn't match unmarshaling will fail.
func UnmarshalXML(v interface{}, d *xml.Decoder, wrapper string) error {
n, _ := XMLToStruct(d, nil)
if n.Children != nil {
for _, root := range n.Children {
for _, c := range root {
if wrappedChild, ok := c.Children[wrapper]; ok {
c = wrappedChild[0] // pull out wrapped element
}
err := parse(reflect.ValueOf(v), c, "")
if err != nil {
if err == io.EOF {
return nil
}
return err
}
}
}
return nil
}
return nil
}
// parse deserializes any value from the XMLNode. The type tag is used to infer the type, or reflect
// will be used to determine the type from r.
func parse(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
rtype := r.Type()
if rtype.Kind() == reflect.Ptr {
rtype = rtype.Elem() // check kind of actual element type
}
t := tag.Get("type")
if t == "" {
switch rtype.Kind() {
case reflect.Struct:
t = "structure"
case reflect.Slice:
t = "list"
case reflect.Map:
t = "map"
}
}
switch t {
case "structure":
if field, ok := rtype.FieldByName("SDKShapeTraits"); ok {
tag = field.Tag
}
return parseStruct(r, node, tag)
case "list":
return parseList(r, node, tag)
case "map":
return parseMap(r, node, tag)
default:
return parseScalar(r, node, tag)
}
}
// parseStruct deserializes a structure and its fields from an XMLNode. Any nested
// types in the structure will also be deserialized.
func parseStruct(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
t := r.Type()
if r.Kind() == reflect.Ptr {
if r.IsNil() { // create the structure if it's nil
s := reflect.New(r.Type().Elem())
r.Set(s)
r = s
}
r = r.Elem()
t = t.Elem()
}
// unwrap any payloads
if payload := tag.Get("payload"); payload != "" {
field, _ := t.FieldByName(payload)
return parseStruct(r.FieldByName(payload), node, field.Tag)
}
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
if c := field.Name[0:1]; strings.ToLower(c) == c {
continue // ignore unexported fields
}
// figure out what this field is called
name := field.Name
if field.Tag.Get("flattened") != "" && field.Tag.Get("locationNameList") != "" {
name = field.Tag.Get("locationNameList")
} else if locName := field.Tag.Get("locationName"); locName != "" {
name = locName
}
// try to find the field by name in elements
elems := node.Children[name]
if elems == nil { // try to find the field in attributes
for _, a := range node.Attr {
if name == a.Name.Local {
// turn this into a text node for de-serializing
elems = []*XMLNode{{Text: a.Value}}
}
}
}
member := r.FieldByName(field.Name)
for _, elem := range elems {
err := parse(member, elem, field.Tag)
if err != nil {
return err
}
}
}
return nil
}
// parseList deserializes a list of values from an XML node. Each list entry
// will also be deserialized.
func parseList(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
t := r.Type()
if tag.Get("flattened") == "" { // look at all item entries
mname := "member"
if name := tag.Get("locationNameList"); name != "" {
mname = name
}
if Children, ok := node.Children[mname]; ok {
if r.IsNil() {
r.Set(reflect.MakeSlice(t, len(Children), len(Children)))
}
for i, c := range Children {
err := parse(r.Index(i), c, "")
if err != nil {
return err
}
}
}
} else { // flattened list means this is a single element
if r.IsNil() {
r.Set(reflect.MakeSlice(t, 0, 0))
}
childR := reflect.Zero(t.Elem())
r.Set(reflect.Append(r, childR))
err := parse(r.Index(r.Len()-1), node, "")
if err != nil {
return err
}
}
return nil
}
// parseMap deserializes a map from an XMLNode. The direct children of the XMLNode
// will also be deserialized as map entries.
func parseMap(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
if r.IsNil() {
r.Set(reflect.MakeMap(r.Type()))
}
if tag.Get("flattened") == "" { // look at all child entries
for _, entry := range node.Children["entry"] {
parseMapEntry(r, entry, tag)
}
} else { // this element is itself an entry
parseMapEntry(r, node, tag)
}
return nil
}
// parseMapEntry deserializes a map entry from a XML node.
func parseMapEntry(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
kname, vname := "key", "value"
if n := tag.Get("locationNameKey"); n != "" {
kname = n
}
if n := tag.Get("locationNameValue"); n != "" {
vname = n
}
keys, ok := node.Children[kname]
values := node.Children[vname]
if ok {
for i, key := range keys {
keyR := reflect.ValueOf(key.Text)
value := values[i]
valueR := reflect.New(r.Type().Elem()).Elem()
parse(valueR, value, "")
r.SetMapIndex(keyR, valueR)
}
}
return nil
}
// parseScaller deserializes an XMLNode value into a concrete type based on the
// interface type of r.
//
// Error is returned if the deserialization fails due to invalid type conversion,
// or unsupported interface type.
func parseScalar(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
switch r.Interface().(type) {
case *string:
r.Set(reflect.ValueOf(&node.Text))
return nil
case []byte:
b, err := base64.StdEncoding.DecodeString(node.Text)
if err != nil {
return err
}
r.Set(reflect.ValueOf(b))
case *bool:
v, err := strconv.ParseBool(node.Text)
if err != nil {
return err
}
r.Set(reflect.ValueOf(&v))
case *int64:
v, err := strconv.ParseInt(node.Text, 10, 64)
if err != nil {
return err
}
r.Set(reflect.ValueOf(&v))
case *float64:
v, err := strconv.ParseFloat(node.Text, 64)
if err != nil {
return err
}
r.Set(reflect.ValueOf(&v))
case *time.Time:
const ISO8601UTC = "2006-01-02T15:04:05Z"
t, err := time.Parse(ISO8601UTC, node.Text)
if err != nil {
return err
}
r.Set(reflect.ValueOf(&t))
default:
return fmt.Errorf("unsupported value: %v (%s)", r.Interface(), r.Type())
}
return nil
}

View File

@ -0,0 +1,105 @@
package xmlutil
import (
"encoding/xml"
"io"
"sort"
)
// A XMLNode contains the values to be encoded or decoded.
type XMLNode struct {
Name xml.Name `json:",omitempty"`
Children map[string][]*XMLNode `json:",omitempty"`
Text string `json:",omitempty"`
Attr []xml.Attr `json:",omitempty"`
}
// NewXMLElement returns a pointer to a new XMLNode initialized to default values.
func NewXMLElement(name xml.Name) *XMLNode {
return &XMLNode{
Name: name,
Children: map[string][]*XMLNode{},
Attr: []xml.Attr{},
}
}
// AddChild adds child to the XMLNode.
func (n *XMLNode) AddChild(child *XMLNode) {
if _, ok := n.Children[child.Name.Local]; !ok {
n.Children[child.Name.Local] = []*XMLNode{}
}
n.Children[child.Name.Local] = append(n.Children[child.Name.Local], child)
}
// XMLToStruct converts a xml.Decoder stream to XMLNode with nested values.
func XMLToStruct(d *xml.Decoder, s *xml.StartElement) (*XMLNode, error) {
out := &XMLNode{}
for {
tok, err := d.Token()
if tok == nil || err == io.EOF {
break
}
if err != nil {
return out, err
}
switch typed := tok.(type) {
case xml.CharData:
out.Text = string(typed.Copy())
case xml.StartElement:
el := typed.Copy()
out.Attr = el.Attr
if out.Children == nil {
out.Children = map[string][]*XMLNode{}
}
name := typed.Name.Local
slice := out.Children[name]
if slice == nil {
slice = []*XMLNode{}
}
node, e := XMLToStruct(d, &el)
if e != nil {
return out, e
}
node.Name = typed.Name
slice = append(slice, node)
out.Children[name] = slice
case xml.EndElement:
if s != nil && s.Name.Local == typed.Name.Local { // matching end token
return out, nil
}
}
}
return out, nil
}
// StructToXML writes an XMLNode to a xml.Encoder as tokens.
func StructToXML(e *xml.Encoder, node *XMLNode, sorted bool) error {
e.EncodeToken(xml.StartElement{Name: node.Name, Attr: node.Attr})
if node.Text != "" {
e.EncodeToken(xml.CharData([]byte(node.Text)))
} else if sorted {
sortedNames := []string{}
for k := range node.Children {
sortedNames = append(sortedNames, k)
}
sort.Strings(sortedNames)
for _, k := range sortedNames {
for _, v := range node.Children[k] {
StructToXML(e, v, sorted)
}
}
} else {
for _, c := range node.Children {
for _, v := range c {
StructToXML(e, v, sorted)
}
}
}
e.EncodeToken(xml.EndElement{Name: node.Name})
return e.Flush()
}

View File

@ -0,0 +1,43 @@
package v4_test
import (
"net/url"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/stretchr/testify/assert"
)
var _ = unit.Imported
func TestPresignHandler(t *testing.T) {
svc := s3.New(nil)
req, _ := svc.PutObjectRequest(&s3.PutObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
ContentDisposition: aws.String("a+b c$d"),
ACL: aws.String("public-read"),
})
req.Time = time.Unix(0, 0)
urlstr, err := req.Presign(5 * time.Minute)
assert.NoError(t, err)
expectedDate := "19700101T000000Z"
expectedHeaders := "host;x-amz-acl"
expectedSig := "7edcb4e3a1bf12f4989018d75acbe3a7f03df24bd6f3112602d59fc551f0e4e2"
expectedCred := "AKID/19700101/mock-region/s3/aws4_request"
u, _ := url.Parse(urlstr)
urlQ := u.Query()
assert.Equal(t, expectedSig, urlQ.Get("X-Amz-Signature"))
assert.Equal(t, expectedCred, urlQ.Get("X-Amz-Credential"))
assert.Equal(t, expectedHeaders, urlQ.Get("X-Amz-SignedHeaders"))
assert.Equal(t, expectedDate, urlQ.Get("X-Amz-Date"))
assert.Equal(t, "300", urlQ.Get("X-Amz-Expires"))
assert.NotContains(t, urlstr, "+") // + encoded as %20
}

View File

@ -0,0 +1,365 @@
// Package v4 implements signing for AWS V4 signer
package v4
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/private/protocol/rest"
)
const (
authHeaderPrefix = "AWS4-HMAC-SHA256"
timeFormat = "20060102T150405Z"
shortTimeFormat = "20060102"
)
var ignoredHeaders = map[string]bool{
"Authorization": true,
"Content-Type": true,
"Content-Length": true,
"User-Agent": true,
}
type signer struct {
Request *http.Request
Time time.Time
ExpireTime time.Duration
ServiceName string
Region string
CredValues credentials.Value
Credentials *credentials.Credentials
Query url.Values
Body io.ReadSeeker
Debug aws.LogLevelType
Logger aws.Logger
isPresign bool
formattedTime string
formattedShortTime string
signedHeaders string
canonicalHeaders string
canonicalString string
credentialString string
stringToSign string
signature string
authorization string
}
// Sign requests with signature version 4.
//
// Will sign the requests with the service config's Credentials object
// Signing is skipped if the credentials is the credentials.AnonymousCredentials
// object.
func Sign(req *request.Request) {
// If the request does not need to be signed ignore the signing of the
// request if the AnonymousCredentials object is used.
if req.Service.Config.Credentials == credentials.AnonymousCredentials {
return
}
region := req.Service.SigningRegion
if region == "" {
region = aws.StringValue(req.Service.Config.Region)
}
name := req.Service.SigningName
if name == "" {
name = req.Service.ServiceName
}
s := signer{
Request: req.HTTPRequest,
Time: req.Time,
ExpireTime: req.ExpireTime,
Query: req.HTTPRequest.URL.Query(),
Body: req.Body,
ServiceName: name,
Region: region,
Credentials: req.Service.Config.Credentials,
Debug: req.Service.Config.LogLevel.Value(),
Logger: req.Service.Config.Logger,
}
req.Error = s.sign()
}
func (v4 *signer) sign() error {
if v4.ExpireTime != 0 {
v4.isPresign = true
}
if v4.isRequestSigned() {
if !v4.Credentials.IsExpired() {
// If the request is already signed, and the credentials have not
// expired yet ignore the signing request.
return nil
}
// The credentials have expired for this request. The current signing
// is invalid, and needs to be request because the request will fail.
if v4.isPresign {
v4.removePresign()
// Update the request's query string to ensure the values stays in
// sync in the case retrieving the new credentials fails.
v4.Request.URL.RawQuery = v4.Query.Encode()
}
}
var err error
v4.CredValues, err = v4.Credentials.Get()
if err != nil {
return err
}
if v4.isPresign {
v4.Query.Set("X-Amz-Algorithm", authHeaderPrefix)
if v4.CredValues.SessionToken != "" {
v4.Query.Set("X-Amz-Security-Token", v4.CredValues.SessionToken)
} else {
v4.Query.Del("X-Amz-Security-Token")
}
} else if v4.CredValues.SessionToken != "" {
v4.Request.Header.Set("X-Amz-Security-Token", v4.CredValues.SessionToken)
}
v4.build()
if v4.Debug.Matches(aws.LogDebugWithSigning) {
v4.logSigningInfo()
}
return nil
}
const logSignInfoMsg = `DEBUG: Request Signiture:
---[ CANONICAL STRING ]-----------------------------
%s
---[ STRING TO SIGN ]--------------------------------
%s%s
-----------------------------------------------------`
const logSignedURLMsg = `
---[ SIGNED URL ]------------------------------------
%s`
func (v4 *signer) logSigningInfo() {
signedURLMsg := ""
if v4.isPresign {
signedURLMsg = fmt.Sprintf(logSignedURLMsg, v4.Request.URL.String())
}
msg := fmt.Sprintf(logSignInfoMsg, v4.canonicalString, v4.stringToSign, signedURLMsg)
v4.Logger.Log(msg)
}
func (v4 *signer) build() {
v4.buildTime() // no depends
v4.buildCredentialString() // no depends
if v4.isPresign {
v4.buildQuery() // no depends
}
v4.buildCanonicalHeaders() // depends on cred string
v4.buildCanonicalString() // depends on canon headers / signed headers
v4.buildStringToSign() // depends on canon string
v4.buildSignature() // depends on string to sign
if v4.isPresign {
v4.Request.URL.RawQuery += "&X-Amz-Signature=" + v4.signature
} else {
parts := []string{
authHeaderPrefix + " Credential=" + v4.CredValues.AccessKeyID + "/" + v4.credentialString,
"SignedHeaders=" + v4.signedHeaders,
"Signature=" + v4.signature,
}
v4.Request.Header.Set("Authorization", strings.Join(parts, ", "))
}
}
func (v4 *signer) buildTime() {
v4.formattedTime = v4.Time.UTC().Format(timeFormat)
v4.formattedShortTime = v4.Time.UTC().Format(shortTimeFormat)
if v4.isPresign {
duration := int64(v4.ExpireTime / time.Second)
v4.Query.Set("X-Amz-Date", v4.formattedTime)
v4.Query.Set("X-Amz-Expires", strconv.FormatInt(duration, 10))
} else {
v4.Request.Header.Set("X-Amz-Date", v4.formattedTime)
}
}
func (v4 *signer) buildCredentialString() {
v4.credentialString = strings.Join([]string{
v4.formattedShortTime,
v4.Region,
v4.ServiceName,
"aws4_request",
}, "/")
if v4.isPresign {
v4.Query.Set("X-Amz-Credential", v4.CredValues.AccessKeyID+"/"+v4.credentialString)
}
}
func (v4 *signer) buildQuery() {
for k, h := range v4.Request.Header {
if strings.HasPrefix(http.CanonicalHeaderKey(k), "X-Amz-") {
continue // never hoist x-amz-* headers, they must be signed
}
if _, ok := ignoredHeaders[http.CanonicalHeaderKey(k)]; ok {
continue // never hoist ignored headers
}
v4.Request.Header.Del(k)
v4.Query.Del(k)
for _, v := range h {
v4.Query.Add(k, v)
}
}
}
func (v4 *signer) buildCanonicalHeaders() {
var headers []string
headers = append(headers, "host")
for k := range v4.Request.Header {
if _, ok := ignoredHeaders[http.CanonicalHeaderKey(k)]; ok {
continue // ignored header
}
headers = append(headers, strings.ToLower(k))
}
sort.Strings(headers)
v4.signedHeaders = strings.Join(headers, ";")
if v4.isPresign {
v4.Query.Set("X-Amz-SignedHeaders", v4.signedHeaders)
}
headerValues := make([]string, len(headers))
for i, k := range headers {
if k == "host" {
headerValues[i] = "host:" + v4.Request.URL.Host
} else {
headerValues[i] = k + ":" +
strings.Join(v4.Request.Header[http.CanonicalHeaderKey(k)], ",")
}
}
v4.canonicalHeaders = strings.Join(headerValues, "\n")
}
func (v4 *signer) buildCanonicalString() {
v4.Request.URL.RawQuery = strings.Replace(v4.Query.Encode(), "+", "%20", -1)
uri := v4.Request.URL.Opaque
if uri != "" {
uri = "/" + strings.Join(strings.Split(uri, "/")[3:], "/")
} else {
uri = v4.Request.URL.Path
}
if uri == "" {
uri = "/"
}
if v4.ServiceName != "s3" {
uri = rest.EscapePath(uri, false)
}
v4.canonicalString = strings.Join([]string{
v4.Request.Method,
uri,
v4.Request.URL.RawQuery,
v4.canonicalHeaders + "\n",
v4.signedHeaders,
v4.bodyDigest(),
}, "\n")
}
func (v4 *signer) buildStringToSign() {
v4.stringToSign = strings.Join([]string{
authHeaderPrefix,
v4.formattedTime,
v4.credentialString,
hex.EncodeToString(makeSha256([]byte(v4.canonicalString))),
}, "\n")
}
func (v4 *signer) buildSignature() {
secret := v4.CredValues.SecretAccessKey
date := makeHmac([]byte("AWS4"+secret), []byte(v4.formattedShortTime))
region := makeHmac(date, []byte(v4.Region))
service := makeHmac(region, []byte(v4.ServiceName))
credentials := makeHmac(service, []byte("aws4_request"))
signature := makeHmac(credentials, []byte(v4.stringToSign))
v4.signature = hex.EncodeToString(signature)
}
func (v4 *signer) bodyDigest() string {
hash := v4.Request.Header.Get("X-Amz-Content-Sha256")
if hash == "" {
if v4.isPresign && v4.ServiceName == "s3" {
hash = "UNSIGNED-PAYLOAD"
} else if v4.Body == nil {
hash = hex.EncodeToString(makeSha256([]byte{}))
} else {
hash = hex.EncodeToString(makeSha256Reader(v4.Body))
}
v4.Request.Header.Add("X-Amz-Content-Sha256", hash)
}
return hash
}
// isRequestSigned returns if the request is currently signed or presigned
func (v4 *signer) isRequestSigned() bool {
if v4.isPresign && v4.Query.Get("X-Amz-Signature") != "" {
return true
}
if v4.Request.Header.Get("Authorization") != "" {
return true
}
return false
}
// unsign removes signing flags for both signed and presigned requests.
func (v4 *signer) removePresign() {
v4.Query.Del("X-Amz-Algorithm")
v4.Query.Del("X-Amz-Signature")
v4.Query.Del("X-Amz-Security-Token")
v4.Query.Del("X-Amz-Date")
v4.Query.Del("X-Amz-Expires")
v4.Query.Del("X-Amz-Credential")
v4.Query.Del("X-Amz-SignedHeaders")
}
func makeHmac(key []byte, data []byte) []byte {
hash := hmac.New(sha256.New, key)
hash.Write(data)
return hash.Sum(nil)
}
func makeSha256(data []byte) []byte {
hash := sha256.New()
hash.Write(data)
return hash.Sum(nil)
}
func makeSha256Reader(reader io.ReadSeeker) []byte {
hash := sha256.New()
start, _ := reader.Seek(0, 1)
defer reader.Seek(start, 0)
io.Copy(hash, reader)
return hash.Sum(nil)
}

View File

@ -0,0 +1,247 @@
package v4
import (
"net/http"
"strings"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/service"
"github.com/stretchr/testify/assert"
)
func buildSigner(serviceName string, region string, signTime time.Time, expireTime time.Duration, body string) signer {
endpoint := "https://" + serviceName + "." + region + ".amazonaws.com"
reader := strings.NewReader(body)
req, _ := http.NewRequest("POST", endpoint, reader)
req.URL.Opaque = "//example.org/bucket/key-._~,!@#$%^&*()"
req.Header.Add("X-Amz-Target", "prefix.Operation")
req.Header.Add("Content-Type", "application/x-amz-json-1.0")
req.Header.Add("Content-Length", string(len(body)))
req.Header.Add("X-Amz-Meta-Other-Header", "some-value=!@#$%^&* (+)")
return signer{
Request: req,
Time: signTime,
ExpireTime: expireTime,
Query: req.URL.Query(),
Body: reader,
ServiceName: serviceName,
Region: region,
Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"),
}
}
func removeWS(text string) string {
text = strings.Replace(text, " ", "", -1)
text = strings.Replace(text, "\n", "", -1)
text = strings.Replace(text, "\t", "", -1)
return text
}
func assertEqual(t *testing.T, expected, given string) {
if removeWS(expected) != removeWS(given) {
t.Errorf("\nExpected: %s\nGiven: %s", expected, given)
}
}
func TestPresignRequest(t *testing.T) {
signer := buildSigner("dynamodb", "us-east-1", time.Unix(0, 0), 300*time.Second, "{}")
signer.sign()
expectedDate := "19700101T000000Z"
expectedHeaders := "host;x-amz-meta-other-header;x-amz-target"
expectedSig := "5eeedebf6f995145ce56daa02902d10485246d3defb34f97b973c1f40ab82d36"
expectedCred := "AKID/19700101/us-east-1/dynamodb/aws4_request"
q := signer.Request.URL.Query()
assert.Equal(t, expectedSig, q.Get("X-Amz-Signature"))
assert.Equal(t, expectedCred, q.Get("X-Amz-Credential"))
assert.Equal(t, expectedHeaders, q.Get("X-Amz-SignedHeaders"))
assert.Equal(t, expectedDate, q.Get("X-Amz-Date"))
}
func TestSignRequest(t *testing.T) {
signer := buildSigner("dynamodb", "us-east-1", time.Unix(0, 0), 0, "{}")
signer.sign()
expectedDate := "19700101T000000Z"
expectedSig := "AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/dynamodb/aws4_request, SignedHeaders=host;x-amz-date;x-amz-meta-other-header;x-amz-security-token;x-amz-target, Signature=69ada33fec48180dab153576e4dd80c4e04124f80dda3eccfed8a67c2b91ed5e"
q := signer.Request.Header
assert.Equal(t, expectedSig, q.Get("Authorization"))
assert.Equal(t, expectedDate, q.Get("X-Amz-Date"))
}
func TestSignEmptyBody(t *testing.T) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 0, "")
signer.Body = nil
signer.sign()
hash := signer.Request.Header.Get("X-Amz-Content-Sha256")
assert.Equal(t, "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", hash)
}
func TestSignBody(t *testing.T) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 0, "hello")
signer.sign()
hash := signer.Request.Header.Get("X-Amz-Content-Sha256")
assert.Equal(t, "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", hash)
}
func TestSignSeekedBody(t *testing.T) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 0, " hello")
signer.Body.Read(make([]byte, 3)) // consume first 3 bytes so body is now "hello"
signer.sign()
hash := signer.Request.Header.Get("X-Amz-Content-Sha256")
assert.Equal(t, "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", hash)
start, _ := signer.Body.Seek(0, 1)
assert.Equal(t, int64(3), start)
}
func TestPresignEmptyBodyS3(t *testing.T) {
signer := buildSigner("s3", "us-east-1", time.Now(), 5*time.Minute, "hello")
signer.sign()
hash := signer.Request.Header.Get("X-Amz-Content-Sha256")
assert.Equal(t, "UNSIGNED-PAYLOAD", hash)
}
func TestSignPrecomputedBodyChecksum(t *testing.T) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 0, "hello")
signer.Request.Header.Set("X-Amz-Content-Sha256", "PRECOMPUTED")
signer.sign()
hash := signer.Request.Header.Get("X-Amz-Content-Sha256")
assert.Equal(t, "PRECOMPUTED", hash)
}
func TestAnonymousCredentials(t *testing.T) {
svc := service.New(&aws.Config{Credentials: credentials.AnonymousCredentials})
r := svc.NewRequest(
&request.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
Sign(r)
urlQ := r.HTTPRequest.URL.Query()
assert.Empty(t, urlQ.Get("X-Amz-Signature"))
assert.Empty(t, urlQ.Get("X-Amz-Credential"))
assert.Empty(t, urlQ.Get("X-Amz-SignedHeaders"))
assert.Empty(t, urlQ.Get("X-Amz-Date"))
hQ := r.HTTPRequest.Header
assert.Empty(t, hQ.Get("Authorization"))
assert.Empty(t, hQ.Get("X-Amz-Date"))
}
func TestIgnoreResignRequestWithValidCreds(t *testing.T) {
svc := service.New(&aws.Config{
Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"),
Region: aws.String("us-west-2"),
})
r := svc.NewRequest(
&request.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
Sign(r)
sig := r.HTTPRequest.Header.Get("Authorization")
Sign(r)
assert.Equal(t, sig, r.HTTPRequest.Header.Get("Authorization"))
}
func TestIgnorePreResignRequestWithValidCreds(t *testing.T) {
svc := service.New(&aws.Config{
Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"),
Region: aws.String("us-west-2"),
})
r := svc.NewRequest(
&request.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
r.ExpireTime = time.Minute * 10
Sign(r)
sig := r.HTTPRequest.Header.Get("X-Amz-Signature")
Sign(r)
assert.Equal(t, sig, r.HTTPRequest.Header.Get("X-Amz-Signature"))
}
func TestResignRequestExpiredCreds(t *testing.T) {
creds := credentials.NewStaticCredentials("AKID", "SECRET", "SESSION")
svc := service.New(&aws.Config{Credentials: creds})
r := svc.NewRequest(
&request.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
Sign(r)
querySig := r.HTTPRequest.Header.Get("Authorization")
creds.Expire()
Sign(r)
assert.NotEqual(t, querySig, r.HTTPRequest.Header.Get("Authorization"))
}
func TestPreResignRequestExpiredCreds(t *testing.T) {
provider := &credentials.StaticProvider{credentials.Value{"AKID", "SECRET", "SESSION"}}
creds := credentials.NewCredentials(provider)
svc := service.New(&aws.Config{Credentials: creds})
r := svc.NewRequest(
&request.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
r.ExpireTime = time.Minute * 10
Sign(r)
querySig := r.HTTPRequest.URL.Query().Get("X-Amz-Signature")
creds.Expire()
r.Time = time.Now().Add(time.Hour * 48)
Sign(r)
assert.NotEqual(t, querySig, r.HTTPRequest.URL.Query().Get("X-Amz-Signature"))
}
func BenchmarkPresignRequest(b *testing.B) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 300*time.Second, "{}")
for i := 0; i < b.N; i++ {
signer.sign()
}
}
func BenchmarkSignRequest(b *testing.B) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 0, "{}")
for i := 0; i < b.N; i++ {
signer.sign()
}
}

View File

@ -6,15 +6,15 @@ package cloudwatch
import (
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/aws/request"
)
const opDeleteAlarms = "DeleteAlarms"
// DeleteAlarmsRequest generates a request for the DeleteAlarms operation.
func (c *CloudWatch) DeleteAlarmsRequest(input *DeleteAlarmsInput) (req *aws.Request, output *DeleteAlarmsOutput) {
op := &aws.Operation{
func (c *CloudWatch) DeleteAlarmsRequest(input *DeleteAlarmsInput) (req *request.Request, output *DeleteAlarmsOutput) {
op := &request.Operation{
Name: opDeleteAlarms,
HTTPMethod: "POST",
HTTPPath: "/",
@ -40,12 +40,12 @@ func (c *CloudWatch) DeleteAlarms(input *DeleteAlarmsInput) (*DeleteAlarmsOutput
const opDescribeAlarmHistory = "DescribeAlarmHistory"
// DescribeAlarmHistoryRequest generates a request for the DescribeAlarmHistory operation.
func (c *CloudWatch) DescribeAlarmHistoryRequest(input *DescribeAlarmHistoryInput) (req *aws.Request, output *DescribeAlarmHistoryOutput) {
op := &aws.Operation{
func (c *CloudWatch) DescribeAlarmHistoryRequest(input *DescribeAlarmHistoryInput) (req *request.Request, output *DescribeAlarmHistoryOutput) {
op := &request.Operation{
Name: opDescribeAlarmHistory,
HTTPMethod: "POST",
HTTPPath: "/",
Paginator: &aws.Paginator{
Paginator: &request.Paginator{
InputTokens: []string{"NextToken"},
OutputTokens: []string{"NextToken"},
LimitToken: "MaxRecords",
@ -82,12 +82,12 @@ func (c *CloudWatch) DescribeAlarmHistoryPages(input *DescribeAlarmHistoryInput,
const opDescribeAlarms = "DescribeAlarms"
// DescribeAlarmsRequest generates a request for the DescribeAlarms operation.
func (c *CloudWatch) DescribeAlarmsRequest(input *DescribeAlarmsInput) (req *aws.Request, output *DescribeAlarmsOutput) {
op := &aws.Operation{
func (c *CloudWatch) DescribeAlarmsRequest(input *DescribeAlarmsInput) (req *request.Request, output *DescribeAlarmsOutput) {
op := &request.Operation{
Name: opDescribeAlarms,
HTTPMethod: "POST",
HTTPPath: "/",
Paginator: &aws.Paginator{
Paginator: &request.Paginator{
InputTokens: []string{"NextToken"},
OutputTokens: []string{"NextToken"},
LimitToken: "MaxRecords",
@ -124,8 +124,8 @@ func (c *CloudWatch) DescribeAlarmsPages(input *DescribeAlarmsInput, fn func(p *
const opDescribeAlarmsForMetric = "DescribeAlarmsForMetric"
// DescribeAlarmsForMetricRequest generates a request for the DescribeAlarmsForMetric operation.
func (c *CloudWatch) DescribeAlarmsForMetricRequest(input *DescribeAlarmsForMetricInput) (req *aws.Request, output *DescribeAlarmsForMetricOutput) {
op := &aws.Operation{
func (c *CloudWatch) DescribeAlarmsForMetricRequest(input *DescribeAlarmsForMetricInput) (req *request.Request, output *DescribeAlarmsForMetricOutput) {
op := &request.Operation{
Name: opDescribeAlarmsForMetric,
HTTPMethod: "POST",
HTTPPath: "/",
@ -152,8 +152,8 @@ func (c *CloudWatch) DescribeAlarmsForMetric(input *DescribeAlarmsForMetricInput
const opDisableAlarmActions = "DisableAlarmActions"
// DisableAlarmActionsRequest generates a request for the DisableAlarmActions operation.
func (c *CloudWatch) DisableAlarmActionsRequest(input *DisableAlarmActionsInput) (req *aws.Request, output *DisableAlarmActionsOutput) {
op := &aws.Operation{
func (c *CloudWatch) DisableAlarmActionsRequest(input *DisableAlarmActionsInput) (req *request.Request, output *DisableAlarmActionsOutput) {
op := &request.Operation{
Name: opDisableAlarmActions,
HTTPMethod: "POST",
HTTPPath: "/",
@ -180,8 +180,8 @@ func (c *CloudWatch) DisableAlarmActions(input *DisableAlarmActionsInput) (*Disa
const opEnableAlarmActions = "EnableAlarmActions"
// EnableAlarmActionsRequest generates a request for the EnableAlarmActions operation.
func (c *CloudWatch) EnableAlarmActionsRequest(input *EnableAlarmActionsInput) (req *aws.Request, output *EnableAlarmActionsOutput) {
op := &aws.Operation{
func (c *CloudWatch) EnableAlarmActionsRequest(input *EnableAlarmActionsInput) (req *request.Request, output *EnableAlarmActionsOutput) {
op := &request.Operation{
Name: opEnableAlarmActions,
HTTPMethod: "POST",
HTTPPath: "/",
@ -207,8 +207,8 @@ func (c *CloudWatch) EnableAlarmActions(input *EnableAlarmActionsInput) (*Enable
const opGetMetricStatistics = "GetMetricStatistics"
// GetMetricStatisticsRequest generates a request for the GetMetricStatistics operation.
func (c *CloudWatch) GetMetricStatisticsRequest(input *GetMetricStatisticsInput) (req *aws.Request, output *GetMetricStatisticsOutput) {
op := &aws.Operation{
func (c *CloudWatch) GetMetricStatisticsRequest(input *GetMetricStatisticsInput) (req *request.Request, output *GetMetricStatisticsOutput) {
op := &request.Operation{
Name: opGetMetricStatistics,
HTTPMethod: "POST",
HTTPPath: "/",
@ -259,12 +259,12 @@ func (c *CloudWatch) GetMetricStatistics(input *GetMetricStatisticsInput) (*GetM
const opListMetrics = "ListMetrics"
// ListMetricsRequest generates a request for the ListMetrics operation.
func (c *CloudWatch) ListMetricsRequest(input *ListMetricsInput) (req *aws.Request, output *ListMetricsOutput) {
op := &aws.Operation{
func (c *CloudWatch) ListMetricsRequest(input *ListMetricsInput) (req *request.Request, output *ListMetricsOutput) {
op := &request.Operation{
Name: opListMetrics,
HTTPMethod: "POST",
HTTPPath: "/",
Paginator: &aws.Paginator{
Paginator: &request.Paginator{
InputTokens: []string{"NextToken"},
OutputTokens: []string{"NextToken"},
LimitToken: "",
@ -301,8 +301,8 @@ func (c *CloudWatch) ListMetricsPages(input *ListMetricsInput, fn func(p *ListMe
const opPutMetricAlarm = "PutMetricAlarm"
// PutMetricAlarmRequest generates a request for the PutMetricAlarm operation.
func (c *CloudWatch) PutMetricAlarmRequest(input *PutMetricAlarmInput) (req *aws.Request, output *PutMetricAlarmOutput) {
op := &aws.Operation{
func (c *CloudWatch) PutMetricAlarmRequest(input *PutMetricAlarmInput) (req *request.Request, output *PutMetricAlarmOutput) {
op := &request.Operation{
Name: opPutMetricAlarm,
HTTPMethod: "POST",
HTTPPath: "/",
@ -334,8 +334,8 @@ func (c *CloudWatch) PutMetricAlarm(input *PutMetricAlarmInput) (*PutMetricAlarm
const opPutMetricData = "PutMetricData"
// PutMetricDataRequest generates a request for the PutMetricData operation.
func (c *CloudWatch) PutMetricDataRequest(input *PutMetricDataInput) (req *aws.Request, output *PutMetricDataOutput) {
op := &aws.Operation{
func (c *CloudWatch) PutMetricDataRequest(input *PutMetricDataInput) (req *request.Request, output *PutMetricDataOutput) {
op := &request.Operation{
Name: opPutMetricData,
HTTPMethod: "POST",
HTTPPath: "/",
@ -374,8 +374,8 @@ func (c *CloudWatch) PutMetricData(input *PutMetricDataInput) (*PutMetricDataOut
const opSetAlarmState = "SetAlarmState"
// SetAlarmStateRequest generates a request for the SetAlarmState operation.
func (c *CloudWatch) SetAlarmStateRequest(input *SetAlarmStateInput) (req *aws.Request, output *SetAlarmStateOutput) {
op := &aws.Operation{
func (c *CloudWatch) SetAlarmStateRequest(input *SetAlarmStateInput) (req *request.Request, output *SetAlarmStateOutput) {
op := &request.Operation{
Name: opSetAlarmState,
HTTPMethod: "POST",
HTTPPath: "/",
@ -406,16 +406,16 @@ func (c *CloudWatch) SetAlarmState(input *SetAlarmStateInput) (*SetAlarmStateOut
// returns this data type as part of the DescribeAlarmHistoryResult data type.
type AlarmHistoryItem struct {
// The descriptive name for the alarm.
AlarmName *string `type:"string"`
AlarmName *string `min:"1" type:"string"`
// Machine-readable data about the alarm in JSON format.
HistoryData *string `type:"string"`
HistoryData *string `min:"1" type:"string"`
// The type of alarm history item.
HistoryItemType *string `type:"string" enum:"HistoryItemType"`
// A human-readable summary of the alarm history.
HistorySummary *string `type:"string"`
HistorySummary *string `min:"1" type:"string"`
// The time stamp for the alarm history item. Amazon CloudWatch uses Coordinated
// Universal Time (UTC) when returning time stamps, which do not accommodate
@ -528,7 +528,7 @@ func (s DeleteAlarmsOutput) GoString() string {
type DescribeAlarmHistoryInput struct {
// The name of the alarm.
AlarmName *string `type:"string"`
AlarmName *string `min:"1" type:"string"`
// The ending date to retrieve alarm history.
EndDate *time.Time `type:"timestamp" timestampFormat:"iso8601"`
@ -537,7 +537,7 @@ type DescribeAlarmHistoryInput struct {
HistoryItemType *string `type:"string" enum:"HistoryItemType"`
// The maximum number of alarm history records to retrieve.
MaxRecords *int64 `type:"integer"`
MaxRecords *int64 `min:"1" type:"integer"`
// The token returned by a previous call to indicate that there is more data
// available.
@ -593,13 +593,13 @@ type DescribeAlarmsForMetricInput struct {
Dimensions []*Dimension `type:"list"`
// The name of the metric.
MetricName *string `type:"string" required:"true"`
MetricName *string `min:"1" type:"string" required:"true"`
// The namespace of the metric.
Namespace *string `type:"string" required:"true"`
Namespace *string `min:"1" type:"string" required:"true"`
// The period in seconds over which the statistic is applied.
Period *int64 `type:"integer"`
Period *int64 `min:"60" type:"integer"`
// The statistic for the metric.
Statistic *string `type:"string" enum:"Statistic"`
@ -648,17 +648,17 @@ func (s DescribeAlarmsForMetricOutput) GoString() string {
type DescribeAlarmsInput struct {
// The action name prefix.
ActionPrefix *string `type:"string"`
ActionPrefix *string `min:"1" type:"string"`
// The alarm name prefix. AlarmNames cannot be specified if this parameter is
// specified.
AlarmNamePrefix *string `type:"string"`
AlarmNamePrefix *string `min:"1" type:"string"`
// A list of alarm names to retrieve information for.
AlarmNames []*string `type:"list"`
// The maximum number of alarm descriptions to retrieve.
MaxRecords *int64 `type:"integer"`
MaxRecords *int64 `min:"1" type:"integer"`
// The token returned by a previous call to indicate that there is more data
// available.
@ -715,10 +715,10 @@ func (s DescribeAlarmsOutput) GoString() string {
// For examples that use one or more dimensions, see PutMetricData.
type Dimension struct {
// The name of the dimension.
Name *string `type:"string" required:"true"`
Name *string `min:"1" type:"string" required:"true"`
// The value representing the dimension measurement
Value *string `type:"string" required:"true"`
Value *string `min:"1" type:"string" required:"true"`
metadataDimension `json:"-" xml:"-"`
}
@ -740,10 +740,10 @@ func (s Dimension) GoString() string {
// The DimensionFilter data type is used to filter ListMetrics results.
type DimensionFilter struct {
// The dimension name to be matched.
Name *string `type:"string" required:"true"`
Name *string `min:"1" type:"string" required:"true"`
// The value of the dimension to be matched.
Value *string `type:"string"`
Value *string `min:"1" type:"string"`
metadataDimensionFilter `json:"-" xml:"-"`
}
@ -850,14 +850,14 @@ type GetMetricStatisticsInput struct {
EndTime *time.Time `type:"timestamp" timestampFormat:"iso8601" required:"true"`
// The name of the metric, with or without spaces.
MetricName *string `type:"string" required:"true"`
MetricName *string `min:"1" type:"string" required:"true"`
// The namespace of the metric, with or without spaces.
Namespace *string `type:"string" required:"true"`
Namespace *string `min:"1" type:"string" required:"true"`
// The granularity, in seconds, of the returned datapoints. Period must be at
// least 60 seconds and must be a multiple of 60. The default value is 60.
Period *int64 `type:"integer" required:"true"`
Period *int64 `min:"60" type:"integer" required:"true"`
// The time stamp to use for determining the first datapoint to return. The
// value specified is inclusive; results include datapoints with the time stamp
@ -869,7 +869,7 @@ type GetMetricStatisticsInput struct {
// in the Amazon CloudWatch Developer Guide.
//
// Valid Values: Average | Sum | SampleCount | Maximum | Minimum
Statistics []*string `type:"list" required:"true"`
Statistics []*string `min:"1" type:"list" required:"true"`
// The unit for the metric.
Unit *string `type:"string" enum:"StandardUnit"`
@ -921,10 +921,10 @@ type ListMetricsInput struct {
Dimensions []*DimensionFilter `type:"list"`
// The name of the metric to filter against.
MetricName *string `type:"string"`
MetricName *string `min:"1" type:"string"`
// The namespace to filter against.
Namespace *string `type:"string"`
Namespace *string `min:"1" type:"string"`
// The token returned by a previous call to indicate that there is more data
// available.
@ -984,10 +984,10 @@ type Metric struct {
Dimensions []*Dimension `type:"list"`
// The name of the metric.
MetricName *string `type:"string"`
MetricName *string `min:"1" type:"string"`
// The namespace of the metric.
Namespace *string `type:"string"`
Namespace *string `min:"1" type:"string"`
metadataMetric `json:"-" xml:"-"`
}
@ -1013,15 +1013,15 @@ type MetricAlarm struct {
// state.
ActionsEnabled *bool `type:"boolean"`
// The Amazon Resource Name (ARN) of the alarm.
AlarmARN *string `locationName:"AlarmArn" type:"string"`
// The list of actions to execute when this alarm transitions into an ALARM
// state from any other state. Each action is specified as an Amazon Resource
// Number (ARN). Currently the only actions supported are publishing to an Amazon
// SNS topic and triggering an Auto Scaling policy.
AlarmActions []*string `type:"list"`
// The Amazon Resource Name (ARN) of the alarm.
AlarmArn *string `min:"1" type:"string"`
// The time stamp of the last update to the alarm configuration. Amazon CloudWatch
// uses Coordinated Universal Time (UTC) when returning time stamps, which do
// not accommodate seasonal adjustments such as daylight savings time. For more
@ -1033,7 +1033,7 @@ type MetricAlarm struct {
AlarmDescription *string `type:"string"`
// The name of the alarm.
AlarmName *string `type:"string"`
AlarmName *string `min:"1" type:"string"`
// The arithmetic operation to use when comparing the specified Statistic and
// Threshold. The specified Statistic value is used as the first operand.
@ -1043,7 +1043,7 @@ type MetricAlarm struct {
Dimensions []*Dimension `type:"list"`
// The number of periods over which data is compared to the specified threshold.
EvaluationPeriods *int64 `type:"integer"`
EvaluationPeriods *int64 `min:"1" type:"integer"`
// The list of actions to execute when this alarm transitions into an INSUFFICIENT_DATA
// state from any other state. Each action is specified as an Amazon Resource
@ -1054,10 +1054,10 @@ type MetricAlarm struct {
InsufficientDataActions []*string `type:"list"`
// The name of the alarm's metric.
MetricName *string `type:"string"`
MetricName *string `min:"1" type:"string"`
// The namespace of alarm's associated metric.
Namespace *string `type:"string"`
Namespace *string `min:"1" type:"string"`
// The list of actions to execute when this alarm transitions into an OK state
// from any other state. Each action is specified as an Amazon Resource Number
@ -1066,7 +1066,7 @@ type MetricAlarm struct {
OKActions []*string `type:"list"`
// The period in seconds over which the statistic is applied.
Period *int64 `type:"integer"`
Period *int64 `min:"60" type:"integer"`
// A human-readable explanation for the alarm's state.
StateReason *string `type:"string"`
@ -1119,7 +1119,7 @@ type MetricDatum struct {
Dimensions []*Dimension `type:"list"`
// The name of the metric.
MetricName *string `type:"string" required:"true"`
MetricName *string `min:"1" type:"string" required:"true"`
// A set of statistical values describing the metric.
StatisticValues *StatisticSet `type:"structure"`
@ -1176,7 +1176,7 @@ type PutMetricAlarmInput struct {
// The descriptive name for the alarm. This name must be unique within the user's
// AWS account
AlarmName *string `type:"string" required:"true"`
AlarmName *string `min:"1" type:"string" required:"true"`
// The arithmetic operation to use when comparing the specified Statistic and
// Threshold. The specified Statistic value is used as the first operand.
@ -1186,7 +1186,7 @@ type PutMetricAlarmInput struct {
Dimensions []*Dimension `type:"list"`
// The number of periods over which data is compared to the specified threshold.
EvaluationPeriods *int64 `type:"integer" required:"true"`
EvaluationPeriods *int64 `min:"1" type:"integer" required:"true"`
// The list of actions to execute when this alarm transitions into an INSUFFICIENT_DATA
// state from any other state. Each action is specified as an Amazon Resource
@ -1195,10 +1195,10 @@ type PutMetricAlarmInput struct {
InsufficientDataActions []*string `type:"list"`
// The name for the alarm's associated metric.
MetricName *string `type:"string" required:"true"`
MetricName *string `min:"1" type:"string" required:"true"`
// The namespace for the alarm's associated metric.
Namespace *string `type:"string" required:"true"`
Namespace *string `min:"1" type:"string" required:"true"`
// The list of actions to execute when this alarm transitions into an OK state
// from any other state. Each action is specified as an Amazon Resource Number
@ -1207,7 +1207,7 @@ type PutMetricAlarmInput struct {
OKActions []*string `type:"list"`
// The period in seconds over which the specified statistic is applied.
Period *int64 `type:"integer" required:"true"`
Period *int64 `min:"60" type:"integer" required:"true"`
// The statistic to apply to the alarm's associated metric.
Statistic *string `type:"string" required:"true" enum:"Statistic"`
@ -1258,7 +1258,7 @@ type PutMetricDataInput struct {
MetricData []*MetricDatum `type:"list" required:"true"`
// The namespace for the metric data.
Namespace *string `type:"string" required:"true"`
Namespace *string `min:"1" type:"string" required:"true"`
metadataPutMetricDataInput `json:"-" xml:"-"`
}
@ -1298,7 +1298,7 @@ func (s PutMetricDataOutput) GoString() string {
type SetAlarmStateInput struct {
// The descriptive name for the alarm. This name must be unique within the user's
// AWS account. The maximum length is 255 characters.
AlarmName *string `type:"string" required:"true"`
AlarmName *string `min:"1" type:"string" required:"true"`
// The reason that this alarm is set to this specific state (in human-readable
// text format)

View File

@ -4,59 +4,59 @@
package cloudwatchiface
import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/cloudwatch"
)
// CloudWatchAPI is the interface type for cloudwatch.CloudWatch.
type CloudWatchAPI interface {
DeleteAlarmsRequest(*cloudwatch.DeleteAlarmsInput) (*aws.Request, *cloudwatch.DeleteAlarmsOutput)
DeleteAlarmsRequest(*cloudwatch.DeleteAlarmsInput) (*request.Request, *cloudwatch.DeleteAlarmsOutput)
DeleteAlarms(*cloudwatch.DeleteAlarmsInput) (*cloudwatch.DeleteAlarmsOutput, error)
DescribeAlarmHistoryRequest(*cloudwatch.DescribeAlarmHistoryInput) (*aws.Request, *cloudwatch.DescribeAlarmHistoryOutput)
DescribeAlarmHistoryRequest(*cloudwatch.DescribeAlarmHistoryInput) (*request.Request, *cloudwatch.DescribeAlarmHistoryOutput)
DescribeAlarmHistory(*cloudwatch.DescribeAlarmHistoryInput) (*cloudwatch.DescribeAlarmHistoryOutput, error)
DescribeAlarmHistoryPages(*cloudwatch.DescribeAlarmHistoryInput, func(*cloudwatch.DescribeAlarmHistoryOutput, bool) bool) error
DescribeAlarmsRequest(*cloudwatch.DescribeAlarmsInput) (*aws.Request, *cloudwatch.DescribeAlarmsOutput)
DescribeAlarmsRequest(*cloudwatch.DescribeAlarmsInput) (*request.Request, *cloudwatch.DescribeAlarmsOutput)
DescribeAlarms(*cloudwatch.DescribeAlarmsInput) (*cloudwatch.DescribeAlarmsOutput, error)
DescribeAlarmsPages(*cloudwatch.DescribeAlarmsInput, func(*cloudwatch.DescribeAlarmsOutput, bool) bool) error
DescribeAlarmsForMetricRequest(*cloudwatch.DescribeAlarmsForMetricInput) (*aws.Request, *cloudwatch.DescribeAlarmsForMetricOutput)
DescribeAlarmsForMetricRequest(*cloudwatch.DescribeAlarmsForMetricInput) (*request.Request, *cloudwatch.DescribeAlarmsForMetricOutput)
DescribeAlarmsForMetric(*cloudwatch.DescribeAlarmsForMetricInput) (*cloudwatch.DescribeAlarmsForMetricOutput, error)
DisableAlarmActionsRequest(*cloudwatch.DisableAlarmActionsInput) (*aws.Request, *cloudwatch.DisableAlarmActionsOutput)
DisableAlarmActionsRequest(*cloudwatch.DisableAlarmActionsInput) (*request.Request, *cloudwatch.DisableAlarmActionsOutput)
DisableAlarmActions(*cloudwatch.DisableAlarmActionsInput) (*cloudwatch.DisableAlarmActionsOutput, error)
EnableAlarmActionsRequest(*cloudwatch.EnableAlarmActionsInput) (*aws.Request, *cloudwatch.EnableAlarmActionsOutput)
EnableAlarmActionsRequest(*cloudwatch.EnableAlarmActionsInput) (*request.Request, *cloudwatch.EnableAlarmActionsOutput)
EnableAlarmActions(*cloudwatch.EnableAlarmActionsInput) (*cloudwatch.EnableAlarmActionsOutput, error)
GetMetricStatisticsRequest(*cloudwatch.GetMetricStatisticsInput) (*aws.Request, *cloudwatch.GetMetricStatisticsOutput)
GetMetricStatisticsRequest(*cloudwatch.GetMetricStatisticsInput) (*request.Request, *cloudwatch.GetMetricStatisticsOutput)
GetMetricStatistics(*cloudwatch.GetMetricStatisticsInput) (*cloudwatch.GetMetricStatisticsOutput, error)
ListMetricsRequest(*cloudwatch.ListMetricsInput) (*aws.Request, *cloudwatch.ListMetricsOutput)
ListMetricsRequest(*cloudwatch.ListMetricsInput) (*request.Request, *cloudwatch.ListMetricsOutput)
ListMetrics(*cloudwatch.ListMetricsInput) (*cloudwatch.ListMetricsOutput, error)
ListMetricsPages(*cloudwatch.ListMetricsInput, func(*cloudwatch.ListMetricsOutput, bool) bool) error
PutMetricAlarmRequest(*cloudwatch.PutMetricAlarmInput) (*aws.Request, *cloudwatch.PutMetricAlarmOutput)
PutMetricAlarmRequest(*cloudwatch.PutMetricAlarmInput) (*request.Request, *cloudwatch.PutMetricAlarmOutput)
PutMetricAlarm(*cloudwatch.PutMetricAlarmInput) (*cloudwatch.PutMetricAlarmOutput, error)
PutMetricDataRequest(*cloudwatch.PutMetricDataInput) (*aws.Request, *cloudwatch.PutMetricDataOutput)
PutMetricDataRequest(*cloudwatch.PutMetricDataInput) (*request.Request, *cloudwatch.PutMetricDataOutput)
PutMetricData(*cloudwatch.PutMetricDataInput) (*cloudwatch.PutMetricDataOutput, error)
SetAlarmStateRequest(*cloudwatch.SetAlarmStateInput) (*aws.Request, *cloudwatch.SetAlarmStateOutput)
SetAlarmStateRequest(*cloudwatch.SetAlarmStateInput) (*request.Request, *cloudwatch.SetAlarmStateOutput)
SetAlarmState(*cloudwatch.SetAlarmStateInput) (*cloudwatch.SetAlarmStateOutput, error)
}

View File

@ -8,8 +8,6 @@ import (
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/service/cloudwatch"
)
@ -28,22 +26,14 @@ func ExampleCloudWatch_DeleteAlarms() {
resp, err := svc.DeleteAlarms(params)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
// Generic AWS error with Code, Message, and original error (if any)
fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
if reqErr, ok := err.(awserr.RequestFailure); ok {
// A service error occurred
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
return
}
// Pretty-print the response data.
fmt.Println(awsutil.Prettify(resp))
fmt.Println(resp)
}
func ExampleCloudWatch_DescribeAlarmHistory() {
@ -60,22 +50,14 @@ func ExampleCloudWatch_DescribeAlarmHistory() {
resp, err := svc.DescribeAlarmHistory(params)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
// Generic AWS error with Code, Message, and original error (if any)
fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
if reqErr, ok := err.(awserr.RequestFailure); ok {
// A service error occurred
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
return
}
// Pretty-print the response data.
fmt.Println(awsutil.Prettify(resp))
fmt.Println(resp)
}
func ExampleCloudWatch_DescribeAlarms() {
@ -95,22 +77,14 @@ func ExampleCloudWatch_DescribeAlarms() {
resp, err := svc.DescribeAlarms(params)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
// Generic AWS error with Code, Message, and original error (if any)
fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
if reqErr, ok := err.(awserr.RequestFailure); ok {
// A service error occurred
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
return
}
// Pretty-print the response data.
fmt.Println(awsutil.Prettify(resp))
fmt.Println(resp)
}
func ExampleCloudWatch_DescribeAlarmsForMetric() {
@ -133,22 +107,14 @@ func ExampleCloudWatch_DescribeAlarmsForMetric() {
resp, err := svc.DescribeAlarmsForMetric(params)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
// Generic AWS error with Code, Message, and original error (if any)
fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
if reqErr, ok := err.(awserr.RequestFailure); ok {
// A service error occurred
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
return
}
// Pretty-print the response data.
fmt.Println(awsutil.Prettify(resp))
fmt.Println(resp)
}
func ExampleCloudWatch_DisableAlarmActions() {
@ -163,22 +129,14 @@ func ExampleCloudWatch_DisableAlarmActions() {
resp, err := svc.DisableAlarmActions(params)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
// Generic AWS error with Code, Message, and original error (if any)
fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
if reqErr, ok := err.(awserr.RequestFailure); ok {
// A service error occurred
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
return
}
// Pretty-print the response data.
fmt.Println(awsutil.Prettify(resp))
fmt.Println(resp)
}
func ExampleCloudWatch_EnableAlarmActions() {
@ -193,22 +151,14 @@ func ExampleCloudWatch_EnableAlarmActions() {
resp, err := svc.EnableAlarmActions(params)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
// Generic AWS error with Code, Message, and original error (if any)
fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
if reqErr, ok := err.(awserr.RequestFailure); ok {
// A service error occurred
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
return
}
// Pretty-print the response data.
fmt.Println(awsutil.Prettify(resp))
fmt.Println(resp)
}
func ExampleCloudWatch_GetMetricStatistics() {
@ -236,22 +186,14 @@ func ExampleCloudWatch_GetMetricStatistics() {
resp, err := svc.GetMetricStatistics(params)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
// Generic AWS error with Code, Message, and original error (if any)
fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
if reqErr, ok := err.(awserr.RequestFailure); ok {
// A service error occurred
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
return
}
// Pretty-print the response data.
fmt.Println(awsutil.Prettify(resp))
fmt.Println(resp)
}
func ExampleCloudWatch_ListMetrics() {
@ -272,22 +214,14 @@ func ExampleCloudWatch_ListMetrics() {
resp, err := svc.ListMetrics(params)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
// Generic AWS error with Code, Message, and original error (if any)
fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
if reqErr, ok := err.(awserr.RequestFailure); ok {
// A service error occurred
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
return
}
// Pretty-print the response data.
fmt.Println(awsutil.Prettify(resp))
fmt.Println(resp)
}
func ExampleCloudWatch_PutMetricAlarm() {
@ -328,22 +262,14 @@ func ExampleCloudWatch_PutMetricAlarm() {
resp, err := svc.PutMetricAlarm(params)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
// Generic AWS error with Code, Message, and original error (if any)
fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
if reqErr, ok := err.(awserr.RequestFailure); ok {
// A service error occurred
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
return
}
// Pretty-print the response data.
fmt.Println(awsutil.Prettify(resp))
fmt.Println(resp)
}
func ExampleCloudWatch_PutMetricData() {
@ -377,22 +303,14 @@ func ExampleCloudWatch_PutMetricData() {
resp, err := svc.PutMetricData(params)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
// Generic AWS error with Code, Message, and original error (if any)
fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
if reqErr, ok := err.(awserr.RequestFailure); ok {
// A service error occurred
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
return
}
// Pretty-print the response data.
fmt.Println(awsutil.Prettify(resp))
fmt.Println(resp)
}
func ExampleCloudWatch_SetAlarmState() {
@ -407,20 +325,12 @@ func ExampleCloudWatch_SetAlarmState() {
resp, err := svc.SetAlarmState(params)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
// Generic AWS error with Code, Message, and original error (if any)
fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
if reqErr, ok := err.(awserr.RequestFailure); ok {
// A service error occurred
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
return
}
// Pretty-print the response data.
fmt.Println(awsutil.Prettify(resp))
fmt.Println(resp)
}

View File

@ -4,8 +4,12 @@ package cloudwatch
import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/internal/protocol/query"
"github.com/aws/aws-sdk-go/internal/signer/v4"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/service"
"github.com/aws/aws-sdk-go/aws/service/serviceinfo"
"github.com/aws/aws-sdk-go/private/protocol/query"
"github.com/aws/aws-sdk-go/private/signer/v4"
)
// This is the Amazon CloudWatch API Reference. This guide provides detailed
@ -49,21 +53,23 @@ import (
// AWS Ruby Developer Center (http://aws.amazon.com/ruby/) AWS Windows and .NET
// Developer Center (http://aws.amazon.com/net/)
type CloudWatch struct {
*aws.Service
*service.Service
}
// Used for custom service initialization logic
var initService func(*aws.Service)
var initService func(*service.Service)
// Used for custom request initialization logic
var initRequest func(*aws.Request)
var initRequest func(*request.Request)
// New returns a new CloudWatch client.
func New(config *aws.Config) *CloudWatch {
service := &aws.Service{
Config: aws.DefaultConfig.Merge(config),
ServiceName: "monitoring",
APIVersion: "2010-08-01",
service := &service.Service{
ServiceInfo: serviceinfo.ServiceInfo{
Config: defaults.DefaultConfig.Merge(config),
ServiceName: "monitoring",
APIVersion: "2010-08-01",
},
}
service.Initialize()
@ -84,8 +90,8 @@ func New(config *aws.Config) *CloudWatch {
// newRequest creates a new request for a CloudWatch operation and runs any
// custom request initialization.
func (c *CloudWatch) newRequest(op *aws.Operation, params, data interface{}) *aws.Request {
req := aws.NewRequest(c.Service, op, params, data)
func (c *CloudWatch) newRequest(op *request.Operation, params, data interface{}) *request.Request {
req := c.NewRequest(op, params, data)
// Run custom request initialization if present
if initRequest != nil {

File diff suppressed because it is too large Load Diff

View File

@ -3,19 +3,19 @@ package ec2
import (
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/aws/request"
)
func init() {
initRequest = func(r *aws.Request) {
initRequest = func(r *request.Request) {
if r.Operation.Name == opCopySnapshot { // fill the PresignedURL parameter
r.Handlers.Build.PushFront(fillPresignedURL)
}
}
}
func fillPresignedURL(r *aws.Request) {
func fillPresignedURL(r *request.Request) {
if !r.ParamsFilled() {
return
}
@ -23,7 +23,7 @@ func fillPresignedURL(r *aws.Request) {
params := r.Params.(*CopySnapshotInput)
// Stop if PresignedURL/DestinationRegion is set
if params.PresignedURL != nil || params.DestinationRegion != nil {
if params.PresignedUrl != nil || params.DestinationRegion != nil {
return
}
@ -53,5 +53,5 @@ func fillPresignedURL(r *aws.Request) {
}
// We have our URL, set it on params
params.PresignedURL = &url
params.PresignedUrl = &url
}

View File

@ -6,7 +6,7 @@ import (
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/internal/test/unit"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/stretchr/testify/assert"
)
@ -24,7 +24,7 @@ func TestCopySnapshotPresignedURL(t *testing.T) {
req, _ := svc.CopySnapshotRequest(&ec2.CopySnapshotInput{
SourceRegion: aws.String("us-west-1"),
SourceSnapshotID: aws.String("snap-id"),
SourceSnapshotId: aws.String("snap-id"),
})
req.Sign()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -4,8 +4,12 @@ package ec2
import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/internal/protocol/ec2query"
"github.com/aws/aws-sdk-go/internal/signer/v4"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/service"
"github.com/aws/aws-sdk-go/aws/service/serviceinfo"
"github.com/aws/aws-sdk-go/private/protocol/ec2query"
"github.com/aws/aws-sdk-go/private/signer/v4"
)
// Amazon Elastic Compute Cloud (Amazon EC2) provides resizable computing capacity
@ -13,21 +17,23 @@ import (
// need to invest in hardware up front, so you can develop and deploy applications
// faster.
type EC2 struct {
*aws.Service
*service.Service
}
// Used for custom service initialization logic
var initService func(*aws.Service)
var initService func(*service.Service)
// Used for custom request initialization logic
var initRequest func(*aws.Request)
var initRequest func(*request.Request)
// New returns a new EC2 client.
func New(config *aws.Config) *EC2 {
service := &aws.Service{
Config: aws.DefaultConfig.Merge(config),
ServiceName: "ec2",
APIVersion: "2015-04-15",
service := &service.Service{
ServiceInfo: serviceinfo.ServiceInfo{
Config: defaults.DefaultConfig.Merge(config),
ServiceName: "ec2",
APIVersion: "2015-10-01",
},
}
service.Initialize()
@ -48,8 +54,8 @@ func New(config *aws.Config) *EC2 {
// newRequest creates a new request for a EC2 operation and runs any
// custom request initialization.
func (c *EC2) newRequest(op *aws.Operation, params, data interface{}) *aws.Request {
req := aws.NewRequest(c.Service, op, params, data)
func (c *EC2) newRequest(op *request.Operation, params, data interface{}) *request.Request {
req := c.NewRequest(op, params, data)
// Run custom request initialization if present
if initRequest != nil {

View File

@ -328,9 +328,9 @@ func build(pkg string, tags []string) {
func ldflags() string {
var b bytes.Buffer
b.WriteString("-w")
b.WriteString(fmt.Sprintf(" -X main.version '%s'", version))
b.WriteString(fmt.Sprintf(" -X main.commit '%s'", getGitSha()))
b.WriteString(fmt.Sprintf(" -X main.buildstamp %d", buildStamp()))
b.WriteString(fmt.Sprintf(" -X main.version=%s", version))
b.WriteString(fmt.Sprintf(" -X main.commit=%s", getGitSha()))
b.WriteString(fmt.Sprintf(" -X main.buildstamp=%d", buildStamp()))
return b.String()
}

View File

@ -112,7 +112,7 @@ func handleDescribeInstances(req *cwRequest, c *middleware.Context) {
params.Filters = reqParam.Parameters.Filters
}
if len(reqParam.Parameters.InstanceIds) > 0 {
params.InstanceIDs = reqParam.Parameters.InstanceIds
params.InstanceIds = reqParam.Parameters.InstanceIds
}
resp, err := svc.DescribeInstances(params)