merge with master, conflicts fixed

This commit is contained in:
Nick Christus
2015-12-03 20:34:05 -06:00
1080 changed files with 40208 additions and 13999 deletions

1
.gitignore vendored
View File

@@ -1,4 +1,5 @@
node_modules
npm-debug.log
coverage/
.aws-config.json
awsconfig

View File

@@ -1,4 +1,24 @@
# 2.5 (unreleased)
# 2.6.0 (unreleased)
### New Table Panel
* **table**: New powerful and flexible table panel, closes [#215](https://github.com/grafana/grafana/issues/215)
### Enhancements
* **CloudWatch**: Support for multiple AWS Credentials, closes [#3053](https://github.com/grafana/grafana/issues/3053), [#3080](https://github.com/grafana/grafana/issues/3080)
* **Elasticsearch**: Support for dynamic daily indices for annotations, closes [#3061](https://github.com/grafana/grafana/issues/3061)
* **Graph Panel**: Option to hide series with all zeroes from legend and tooltip, closes [#1381](https://github.com/grafana/grafana/issues/1381), [#3336](https://github.com/grafana/grafana/issues/3336)
### Bug Fixes
* **cloudwatch**: fix for handling of period for long time ranges, fixes [#3086](https://github.com/grafana/grafana/issues/3086)
* **dashboard**: fix for collapse row by clicking on row title, fixes [#3065](https://github.com/grafana/grafana/issues/3065)
* **influxdb**: fix for relative time ranges `last x months` and `last x years`, fixes [#3067](https://github.com/grafana/grafana/issues/3067)
* **graph**: layout fix for color picker when right side legend was enabled, fixes [#3093](https://github.com/grafana/grafana/issues/3093)
* **elasticsearch**: disabling elastic query (via eye) caused error, fixes [#3300](https://github.com/grafana/grafana/issues/3300)
### Breaking changes
* **elasticsearch**: Manual json edited queries are not supported any more (They very barely worked in 2.5)
# 2.5 (2015-10-28)
**New Feature: Mix data sources**
- A built in data source is now available named `-- Mixed --`, When picked in the metrics tab,
@@ -11,12 +31,12 @@ it allows you to add queries of differnet data source types & instances to the s
**New Feature: New and much improved time picker**
- Support for quick ranges like `Today`, `This day last week`, `This week`, `The day so far`, etc.
- Muck improved UI and improved support for UTC, [Issue #2761](https://github.com/grafana/grafana/issues/2761) for more info.
- Improved UI and improved support for UTC, [Issue #2761](https://github.com/grafana/grafana/issues/2761) for more info.
**User Onboarding**
- Org admin can now send email invites (or invite links) to people who are not yet Grafana users
- Sign up flow now supports email verification (if enabled)
- See [Issue #2353](https://github.com/grafana/grafana/issues/2354) for more info.
- See [Issue #2353](https://github.com/grafana/grafana/issues/2353) for more info.
**Other new Features && Enhancements**
- [Pull #2720](https://github.com/grafana/grafana/pull/2720). Admin: Initial basic quota support (per Org)
@@ -28,6 +48,8 @@ it allows you to add queries of differnet data source types & instances to the s
- [Issue #2708](https://github.com/grafana/grafana/issues/2708). InfluxDB: You can now set math expression for select clauses.
- [Issue #1575](https://github.com/grafana/grafana/issues/1575). Drilldown link: now you can click on the external link icon in the panel header to access drilldown links!
- [Issue #1646](https://github.com/grafana/grafana/issues/1646). OpenTSDB: Fetch list of aggregators from OpenTSDB
- [Issue #2955](https://github.com/grafana/grafana/issues/2955). Graph: More axis units (Length, Volume, Temperature, Pressure, etc), thanks @greglook
- [Issue #2928](https://github.com/grafana/grafana/issues/2928). LDAP: Support for searching for groups memberships, i.e. POSIX (no memberOf) schemas, also multiple ldap servers, and root ca cert, thanks @abligh
**Fixes**
- [Issue #2413](https://github.com/grafana/grafana/issues/2413). InfluxDB 0.9: Fix for handling empty series object in response from influxdb

69
Godeps/Godeps.json generated
View File

@@ -1,6 +1,6 @@
{
"ImportPath": "github.com/grafana/grafana",
"GoVersion": "go1.4.2",
"GoVersion": "go1.5",
"Packages": [
"./pkg/..."
],
@@ -20,53 +20,63 @@
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws",
"Comment": "v0.7.3",
"Rev": "bed164a424e75154a40550c04c313ef51a7bb275"
"Comment": "v0.10.4-18-gce51895",
"Rev": "ce51895e994693d65ab997ae48032bf13a9290b7"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/internal/endpoints",
"Comment": "v0.7.3",
"Rev": "bed164a424e75154a40550c04c313ef51a7bb275"
"ImportPath": "github.com/aws/aws-sdk-go/private/endpoints",
"Comment": "v0.10.4-18-gce51895",
"Rev": "ce51895e994693d65ab997ae48032bf13a9290b7"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/internal/protocol/ec2query",
"Comment": "v0.7.3",
"Rev": "bed164a424e75154a40550c04c313ef51a7bb275"
"ImportPath": "github.com/aws/aws-sdk-go/private/protocol/ec2query",
"Comment": "v0.10.4-18-gce51895",
"Rev": "ce51895e994693d65ab997ae48032bf13a9290b7"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/internal/protocol/query",
"Comment": "v0.7.3",
"Rev": "bed164a424e75154a40550c04c313ef51a7bb275"
"ImportPath": "github.com/aws/aws-sdk-go/private/protocol/query",
"Comment": "v0.10.4-18-gce51895",
"Rev": "ce51895e994693d65ab997ae48032bf13a9290b7"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/internal/protocol/rest",
"Comment": "v0.7.3",
"Rev": "bed164a424e75154a40550c04c313ef51a7bb275"
"ImportPath": "github.com/aws/aws-sdk-go/private/protocol/rest",
"Comment": "v0.10.4-18-gce51895",
"Rev": "ce51895e994693d65ab997ae48032bf13a9290b7"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/internal/protocol/xml/xmlutil",
"Comment": "v0.7.3",
"Rev": "bed164a424e75154a40550c04c313ef51a7bb275"
"ImportPath": "github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil",
"Comment": "v0.10.4-18-gce51895",
"Rev": "ce51895e994693d65ab997ae48032bf13a9290b7"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/internal/signer/v4",
"Comment": "v0.7.3",
"Rev": "bed164a424e75154a40550c04c313ef51a7bb275"
"ImportPath": "github.com/aws/aws-sdk-go/private/signer/v4",
"Comment": "v0.10.4-18-gce51895",
"Rev": "ce51895e994693d65ab997ae48032bf13a9290b7"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/private/waiter",
"Comment": "v0.10.4-18-gce51895",
"Rev": "ce51895e994693d65ab997ae48032bf13a9290b7"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/service/cloudwatch",
"Comment": "v0.7.3",
"Rev": "bed164a424e75154a40550c04c313ef51a7bb275"
"Comment": "v0.10.4-18-gce51895",
"Rev": "ce51895e994693d65ab997ae48032bf13a9290b7"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/service/ec2",
"Comment": "v0.7.3",
"Rev": "bed164a424e75154a40550c04c313ef51a7bb275"
"Comment": "v0.10.4-18-gce51895",
"Rev": "ce51895e994693d65ab997ae48032bf13a9290b7"
},
{
"ImportPath": "github.com/davecgh/go-spew/spew",
"Rev": "2df174808ee097f90d259e432cc04442cf60be21"
},
{
"ImportPath": "github.com/go-ini/ini",
"Comment": "v0-48-g060d7da",
"Rev": "060d7da055ba6ec5ea7a31f116332fe5efa04ce0"
},
{
"ImportPath": "github.com/go-ldap/ldap",
"Comment": "v1-19-g83e6542",
@@ -90,6 +100,11 @@
"ImportPath": "github.com/gosimple/slug",
"Rev": "8d258463b4459f161f51d6a357edacd3eef9d663"
},
{
"ImportPath": "github.com/jmespath/go-jmespath",
"Comment": "0.2.2",
"Rev": "3433f3ea46d9f8019119e7dd41274e112a2359a9"
},
{
"ImportPath": "github.com/jtolds/gls",
"Rev": "f1ac7f4f24f50328e6bc838ca4437d1612a0243c"
@@ -124,10 +139,6 @@
"ImportPath": "github.com/streadway/amqp",
"Rev": "150b7f24d6ad507e6026c13d85ce1f1391ac7400"
},
{
"ImportPath": "github.com/vaughan0/go-ini",
"Rev": "a98ad7ee00ec53921f08832bc06ecf7fd600e6a1"
},
{
"ImportPath": "golang.org/x/net/context",
"Rev": "972f0c5fbe4ae29e666c3f78c3ed42ae7a448b0a"

View File

@@ -113,7 +113,7 @@ func newRequestError(err Error, statusCode int, requestID string) *requestError
// Error returns the string representation of the error.
// Satisfies the error interface.
func (r requestError) Error() string {
extra := fmt.Sprintf("status code: %d, request id: [%s]",
extra := fmt.Sprintf("status code: %d, request id: %s",
r.statusCode, r.requestID)
return SprintError(r.Code(), r.Message(), extra, r.OrigErr())
}

View File

@@ -57,16 +57,13 @@ func rcopy(dst, src reflect.Value, root bool) {
}
}
case reflect.Struct:
if !root {
dst.Set(reflect.New(src.Type()).Elem())
}
t := dst.Type()
for i := 0; i < t.NumField(); i++ {
name := t.Field(i).Name
srcval := src.FieldByName(name)
if srcval.IsValid() {
rcopy(dst.FieldByName(name), srcval, false)
srcVal := src.FieldByName(name)
dstVal := dst.FieldByName(name)
if srcVal.IsValid() && dstVal.CanSet() {
rcopy(dstVal, srcVal, false)
}
}
case reflect.Slice:

View File

@@ -77,6 +77,28 @@ func TestCopy(t *testing.T) {
assert.NotEqual(t, f2.C, f1.C)
}
func TestCopyNestedWithUnexported(t *testing.T) {
type Bar struct {
a int
B int
}
type Foo struct {
A string
B Bar
}
f1 := &Foo{A: "string", B: Bar{a: 1, B: 2}}
var f2 Foo
awsutil.Copy(&f2, f1)
// Values match
assert.Equal(t, f2.A, f1.A)
assert.NotEqual(t, f2.B, f1.B)
assert.NotEqual(t, f2.B.a, f1.B.a)
assert.Equal(t, f2.B.B, f2.B.B)
}
func TestCopyIgnoreNilMembers(t *testing.T) {
type Foo struct {
A *string
@@ -136,6 +158,8 @@ func TestCopyDifferentStructs(t *testing.T) {
C map[string]*int
SrcUnique string
SameNameDiffType int
unexportedPtr *int
ExportedPtr *int
}
type DstFoo struct {
A int
@@ -143,6 +167,8 @@ func TestCopyDifferentStructs(t *testing.T) {
C map[string]*int
DstUnique int
SameNameDiffType string
unexportedPtr *int
ExportedPtr *int
}
// Create the initial value
@@ -159,6 +185,8 @@ func TestCopyDifferentStructs(t *testing.T) {
},
SrcUnique: "unique",
SameNameDiffType: 1,
unexportedPtr: &int1,
ExportedPtr: &int2,
}
// Do the copy
@@ -173,6 +201,10 @@ func TestCopyDifferentStructs(t *testing.T) {
assert.Equal(t, 1, f1.SameNameDiffType)
assert.Equal(t, 0, f2.DstUnique)
assert.Equal(t, "", f2.SameNameDiffType)
assert.Equal(t, int1, *f1.unexportedPtr)
assert.Nil(t, f2.unexportedPtr)
assert.Equal(t, int2, *f1.ExportedPtr)
assert.Equal(t, int2, *f2.ExportedPtr)
}
func ExampleCopyOf() {

View File

@@ -0,0 +1,27 @@
package awsutil
import (
"reflect"
)
// DeepEqual returns if the two values are deeply equal like reflect.DeepEqual.
// In addition to this, this method will also dereference the input values if
// possible so the DeepEqual performed will not fail if one parameter is a
// pointer and the other is not.
//
// DeepEqual will not perform indirection of nested values of the input parameters.
func DeepEqual(a, b interface{}) bool {
ra := reflect.Indirect(reflect.ValueOf(a))
rb := reflect.Indirect(reflect.ValueOf(b))
if raValid, rbValid := ra.IsValid(), rb.IsValid(); !raValid && !rbValid {
// If the elements are both nil, and of the same type the are equal
// If they are of different types they are not equal
return reflect.TypeOf(a) == reflect.TypeOf(b)
} else if raValid != rbValid {
// Both values must be valid to be equal
return false
}
return reflect.DeepEqual(ra.Interface(), rb.Interface())
}

View File

@@ -0,0 +1,29 @@
package awsutil_test
import (
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/stretchr/testify/assert"
)
func TestDeepEqual(t *testing.T) {
cases := []struct {
a, b interface{}
equal bool
}{
{"a", "a", true},
{"a", "b", false},
{"a", aws.String(""), false},
{"a", nil, false},
{"a", aws.String("a"), true},
{(*bool)(nil), (*bool)(nil), true},
{(*bool)(nil), (*string)(nil), false},
{nil, nil, true},
}
for i, c := range cases {
assert.Equal(t, c.equal, awsutil.DeepEqual(c.a, c.b), "%d, a:%v b:%v, %t", i, c.a, c.b, c.equal)
}
}

View File

@@ -5,6 +5,8 @@ import (
"regexp"
"strconv"
"strings"
"github.com/jmespath/go-jmespath"
)
var indexRe = regexp.MustCompile(`(.+)\[(-?\d+)?\]$`)
@@ -16,7 +18,7 @@ func rValuesAtPath(v interface{}, path string, create bool, caseSensitive bool)
if len(pathparts) > 1 {
for _, pathpart := range pathparts {
vals := rValuesAtPath(v, pathpart, create, caseSensitive)
if vals != nil && len(vals) > 0 {
if len(vals) > 0 {
return vals
}
}
@@ -142,46 +144,67 @@ func rValuesAtPath(v interface{}, path string, create bool, caseSensitive bool)
return values
}
// ValuesAtPath returns a list of objects at the lexical path inside of a structure
func ValuesAtPath(i interface{}, path string) []interface{} {
if rvals := rValuesAtPath(i, path, false, true); rvals != nil {
vals := make([]interface{}, len(rvals))
for i, rval := range rvals {
vals[i] = rval.Interface()
}
return vals
// ValuesAtPath returns a list of values at the case insensitive lexical
// path inside of a structure.
func ValuesAtPath(i interface{}, path string) ([]interface{}, error) {
result, err := jmespath.Search(path, i)
if err != nil {
return nil, err
}
return nil
v := reflect.ValueOf(result)
if !v.IsValid() || (v.Kind() == reflect.Ptr && v.IsNil()) {
return nil, nil
}
if s, ok := result.([]interface{}); ok {
return s, err
}
if v.Kind() == reflect.Map && v.Len() == 0 {
return nil, nil
}
if v.Kind() == reflect.Slice {
out := make([]interface{}, v.Len())
for i := 0; i < v.Len(); i++ {
out[i] = v.Index(i).Interface()
}
return out, nil
}
return []interface{}{result}, nil
}
// ValuesAtAnyPath returns a list of objects at the case-insensitive lexical
// path inside of a structure
func ValuesAtAnyPath(i interface{}, path string) []interface{} {
if rvals := rValuesAtPath(i, path, false, false); rvals != nil {
vals := make([]interface{}, len(rvals))
for i, rval := range rvals {
vals[i] = rval.Interface()
}
return vals
}
return nil
}
// SetValueAtPath sets an object at the lexical path inside of a structure
// SetValueAtPath sets a value at the case insensitive lexical path inside
// of a structure.
func SetValueAtPath(i interface{}, path string, v interface{}) {
if rvals := rValuesAtPath(i, path, true, true); rvals != nil {
for _, rval := range rvals {
rval.Set(reflect.ValueOf(v))
}
}
}
// SetValueAtAnyPath sets an object at the case insensitive lexical path inside
// of a structure
func SetValueAtAnyPath(i interface{}, path string, v interface{}) {
if rvals := rValuesAtPath(i, path, true, false); rvals != nil {
for _, rval := range rvals {
rval.Set(reflect.ValueOf(v))
setValue(rval, v)
}
}
}
func setValue(dstVal reflect.Value, src interface{}) {
if dstVal.Kind() == reflect.Ptr {
dstVal = reflect.Indirect(dstVal)
}
srcVal := reflect.ValueOf(src)
if !srcVal.IsValid() { // src is literal nil
if dstVal.CanAddr() {
// Convert to pointer so that pointer's value can be nil'ed
// dstVal = dstVal.Addr()
}
dstVal.Set(reflect.Zero(dstVal.Type()))
} else if srcVal.Kind() == reflect.Ptr {
if srcVal.IsNil() {
srcVal = reflect.Zero(dstVal.Type())
} else {
srcVal = reflect.ValueOf(src).Elem()
}
dstVal.Set(srcVal)
} else {
dstVal.Set(srcVal)
}
}

View File

@@ -13,6 +13,7 @@ type Struct struct {
B *Struct
D *Struct
C string
E map[string]string
}
var data = Struct{
@@ -21,30 +22,69 @@ var data = Struct{
B: &Struct{B: &Struct{C: "terminal"}, D: &Struct{C: "terminal2"}},
C: "initial",
}
var data2 = Struct{A: []Struct{
{A: []Struct{{C: "1"}, {C: "1"}, {C: "1"}, {C: "1"}, {C: "1"}}},
{A: []Struct{{C: "2"}, {C: "2"}, {C: "2"}, {C: "2"}, {C: "2"}}},
}}
func TestValueAtPathSuccess(t *testing.T) {
assert.Equal(t, []interface{}{"initial"}, awsutil.ValuesAtPath(data, "C"))
assert.Equal(t, []interface{}{"value1"}, awsutil.ValuesAtPath(data, "A[0].C"))
assert.Equal(t, []interface{}{"value2"}, awsutil.ValuesAtPath(data, "A[1].C"))
assert.Equal(t, []interface{}{"value3"}, awsutil.ValuesAtPath(data, "A[2].C"))
assert.Equal(t, []interface{}{"value3"}, awsutil.ValuesAtAnyPath(data, "a[2].c"))
assert.Equal(t, []interface{}{"value3"}, awsutil.ValuesAtPath(data, "A[-1].C"))
assert.Equal(t, []interface{}{"value1", "value2", "value3"}, awsutil.ValuesAtPath(data, "A[].C"))
assert.Equal(t, []interface{}{"terminal"}, awsutil.ValuesAtPath(data, "B . B . C"))
assert.Equal(t, []interface{}{"terminal", "terminal2"}, awsutil.ValuesAtPath(data, "B.*.C"))
assert.Equal(t, []interface{}{"initial"}, awsutil.ValuesAtPath(data, "A.D.X || C"))
var testCases = []struct {
expect []interface{}
data interface{}
path string
}{
{[]interface{}{"initial"}, data, "C"},
{[]interface{}{"value1"}, data, "A[0].C"},
{[]interface{}{"value2"}, data, "A[1].C"},
{[]interface{}{"value3"}, data, "A[2].C"},
{[]interface{}{"value3"}, data, "a[2].c"},
{[]interface{}{"value3"}, data, "A[-1].C"},
{[]interface{}{"value1", "value2", "value3"}, data, "A[].C"},
{[]interface{}{"terminal"}, data, "B . B . C"},
{[]interface{}{"initial"}, data, "A.D.X || C"},
{[]interface{}{"initial"}, data, "A[0].B || C"},
{[]interface{}{
Struct{A: []Struct{{C: "1"}, {C: "1"}, {C: "1"}, {C: "1"}, {C: "1"}}},
Struct{A: []Struct{{C: "2"}, {C: "2"}, {C: "2"}, {C: "2"}, {C: "2"}}},
}, data2, "A"},
}
for i, c := range testCases {
v, err := awsutil.ValuesAtPath(c.data, c.path)
assert.NoError(t, err, "case %d, expected no error, %s", i, c.path)
assert.Equal(t, c.expect, v, "case %d, %s", i, c.path)
}
}
func TestValueAtPathFailure(t *testing.T) {
assert.Equal(t, []interface{}(nil), awsutil.ValuesAtPath(data, "C.x"))
assert.Equal(t, []interface{}(nil), awsutil.ValuesAtPath(data, ".x"))
assert.Equal(t, []interface{}{}, awsutil.ValuesAtPath(data, "X.Y.Z"))
assert.Equal(t, []interface{}{}, awsutil.ValuesAtPath(data, "A[100].C"))
assert.Equal(t, []interface{}{}, awsutil.ValuesAtPath(data, "A[3].C"))
assert.Equal(t, []interface{}{}, awsutil.ValuesAtPath(data, "B.B.C.Z"))
assert.Equal(t, []interface{}(nil), awsutil.ValuesAtPath(data, "z[-1].C"))
assert.Equal(t, []interface{}{}, awsutil.ValuesAtPath(nil, "A.B.C"))
assert.Equal(t, []interface{}{}, awsutil.ValuesAtPath(Struct{}, "A"))
var testCases = []struct {
expect []interface{}
errContains string
data interface{}
path string
}{
{nil, "", data, "C.x"},
{nil, "SyntaxError: Invalid token: tDot", data, ".x"},
{nil, "", data, "X.Y.Z"},
{nil, "", data, "A[100].C"},
{nil, "", data, "A[3].C"},
{nil, "", data, "B.B.C.Z"},
{nil, "", data, "z[-1].C"},
{nil, "", nil, "A.B.C"},
{[]interface{}{}, "", Struct{}, "A"},
{nil, "", data, "A[0].B.C"},
{nil, "", data, "D"},
}
for i, c := range testCases {
v, err := awsutil.ValuesAtPath(c.data, c.path)
if c.errContains != "" {
assert.Contains(t, err.Error(), c.errContains, "case %d, expected error, %s", i, c.path)
continue
} else {
assert.NoError(t, err, "case %d, expected no error, %s", i, c.path)
}
assert.Equal(t, c.expect, v, "case %d, %s", i, c.path)
}
}
func TestSetValueAtPathSuccess(t *testing.T) {
@@ -61,8 +101,8 @@ func TestSetValueAtPathSuccess(t *testing.T) {
assert.Equal(t, "test0", s.B.D.C)
var s2 Struct
awsutil.SetValueAtAnyPath(&s2, "b.b.c", "test0")
awsutil.SetValueAtPath(&s2, "b.b.c", "test0")
assert.Equal(t, "test0", s2.B.B.C)
awsutil.SetValueAtAnyPath(&s2, "A", []Struct{{}})
awsutil.SetValueAtPath(&s2, "A", []Struct{{}})
assert.Equal(t, []Struct{{}}, s2.A)
}

View File

@@ -0,0 +1,89 @@
package awsutil
import (
"bytes"
"fmt"
"reflect"
"strings"
)
// StringValue returns the string representation of a value.
func StringValue(i interface{}) string {
var buf bytes.Buffer
stringValue(reflect.ValueOf(i), 0, &buf)
return buf.String()
}
func stringValue(v reflect.Value, indent int, buf *bytes.Buffer) {
for v.Kind() == reflect.Ptr {
v = v.Elem()
}
switch v.Kind() {
case reflect.Struct:
buf.WriteString("{\n")
names := []string{}
for i := 0; i < v.Type().NumField(); i++ {
name := v.Type().Field(i).Name
f := v.Field(i)
if name[0:1] == strings.ToLower(name[0:1]) {
continue // ignore unexported fields
}
if (f.Kind() == reflect.Ptr || f.Kind() == reflect.Slice) && f.IsNil() {
continue // ignore unset fields
}
names = append(names, name)
}
for i, n := range names {
val := v.FieldByName(n)
buf.WriteString(strings.Repeat(" ", indent+2))
buf.WriteString(n + ": ")
stringValue(val, indent+2, buf)
if i < len(names)-1 {
buf.WriteString(",\n")
}
}
buf.WriteString("\n" + strings.Repeat(" ", indent) + "}")
case reflect.Slice:
nl, id, id2 := "", "", ""
if v.Len() > 3 {
nl, id, id2 = "\n", strings.Repeat(" ", indent), strings.Repeat(" ", indent+2)
}
buf.WriteString("[" + nl)
for i := 0; i < v.Len(); i++ {
buf.WriteString(id2)
stringValue(v.Index(i), indent+2, buf)
if i < v.Len()-1 {
buf.WriteString("," + nl)
}
}
buf.WriteString(nl + id + "]")
case reflect.Map:
buf.WriteString("{\n")
for i, k := range v.MapKeys() {
buf.WriteString(strings.Repeat(" ", indent+2))
buf.WriteString(k.String() + ": ")
stringValue(v.MapIndex(k), indent+2, buf)
if i < v.Len()-1 {
buf.WriteString(",\n")
}
}
buf.WriteString("\n" + strings.Repeat(" ", indent) + "}")
default:
format := "%v"
switch v.Interface().(type) {
case string:
format = "%q"
}
fmt.Fprintf(buf, format, v.Interface())
}
}

View File

@@ -0,0 +1,111 @@
package client
import (
"fmt"
"io/ioutil"
"net/http/httputil"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/request"
)
// A Config provides configuration to a service client instance.
type Config struct {
Config *aws.Config
Handlers request.Handlers
Endpoint, SigningRegion string
}
// ConfigProvider provides a generic way for a service client to receive
// the ClientConfig without circular dependencies.
type ConfigProvider interface {
ClientConfig(serviceName string, cfgs ...*aws.Config) Config
}
// A Client implements the base client request and response handling
// used by all service clients.
type Client struct {
request.Retryer
metadata.ClientInfo
Config aws.Config
Handlers request.Handlers
}
// New will return a pointer to a new initialized service client.
func New(cfg aws.Config, info metadata.ClientInfo, handlers request.Handlers, options ...func(*Client)) *Client {
svc := &Client{
Config: cfg,
ClientInfo: info,
Handlers: handlers,
}
maxRetries := aws.IntValue(cfg.MaxRetries)
if cfg.MaxRetries == nil || maxRetries == aws.UseServiceDefaultRetries {
maxRetries = 3
}
svc.Retryer = DefaultRetryer{NumMaxRetries: maxRetries}
svc.AddDebugHandlers()
for _, option := range options {
option(svc)
}
return svc
}
// NewRequest returns a new Request pointer for the service API
// operation and parameters.
func (c *Client) NewRequest(operation *request.Operation, params interface{}, data interface{}) *request.Request {
return request.New(c.Config, c.ClientInfo, c.Handlers, c.Retryer, operation, params, data)
}
// AddDebugHandlers injects debug logging handlers into the service to log request
// debug information.
func (c *Client) AddDebugHandlers() {
if !c.Config.LogLevel.AtLeast(aws.LogDebug) {
return
}
c.Handlers.Send.PushFront(logRequest)
c.Handlers.Send.PushBack(logResponse)
}
const logReqMsg = `DEBUG: Request %s/%s Details:
---[ REQUEST POST-SIGN ]-----------------------------
%s
-----------------------------------------------------`
func logRequest(r *request.Request) {
logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody)
dumpedBody, _ := httputil.DumpRequestOut(r.HTTPRequest, logBody)
if logBody {
// Reset the request body because dumpRequest will re-wrap the r.HTTPRequest's
// Body as a NoOpCloser and will not be reset after read by the HTTP
// client reader.
r.Body.Seek(r.BodyStart, 0)
r.HTTPRequest.Body = ioutil.NopCloser(r.Body)
}
r.Config.Logger.Log(fmt.Sprintf(logReqMsg, r.ClientInfo.ServiceName, r.Operation.Name, string(dumpedBody)))
}
const logRespMsg = `DEBUG: Response %s/%s Details:
---[ RESPONSE ]--------------------------------------
%s
-----------------------------------------------------`
func logResponse(r *request.Request) {
var msg = "no reponse data"
if r.HTTPResponse != nil {
logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody)
dumpedBody, _ := httputil.DumpResponse(r.HTTPResponse, logBody)
msg = string(dumpedBody)
} else if r.Error != nil {
msg = r.Error.Error()
}
r.Config.Logger.Log(fmt.Sprintf(logRespMsg, r.ClientInfo.ServiceName, r.Operation.Name, msg))
}

View File

@@ -0,0 +1,45 @@
package client
import (
"math"
"math/rand"
"time"
"github.com/aws/aws-sdk-go/aws/request"
)
// DefaultRetryer implements basic retry logic using exponential backoff for
// most services. If you want to implement custom retry logic, implement the
// request.Retryer interface or create a structure type that composes this
// struct and override the specific methods. For example, to override only
// the MaxRetries method:
//
// type retryer struct {
// service.DefaultRetryer
// }
//
// // This implementation always has 100 max retries
// func (d retryer) MaxRetries() uint { return 100 }
type DefaultRetryer struct {
NumMaxRetries int
}
// MaxRetries returns the number of maximum returns the service will use to make
// an individual API request.
func (d DefaultRetryer) MaxRetries() int {
return d.NumMaxRetries
}
// RetryRules returns the delay duration before retrying this request again
func (d DefaultRetryer) RetryRules(r *request.Request) time.Duration {
delay := int(math.Pow(2, float64(r.RetryCount))) * (rand.Intn(30) + 30)
return time.Duration(delay) * time.Millisecond
}
// ShouldRetry returns if the request should be retried.
func (d DefaultRetryer) ShouldRetry(r *request.Request) bool {
if r.HTTPResponse.StatusCode >= 500 {
return true
}
return r.IsErrorRetryable()
}

View File

@@ -0,0 +1,12 @@
package metadata
// ClientInfo wraps immutable data from the client.Client structure.
type ClientInfo struct {
ServiceName string
APIVersion string
Endpoint string
SigningName string
SigningRegion string
JSONVersion string
TargetPrefix string
}

View File

@@ -2,48 +2,22 @@ package aws
import (
"net/http"
"os"
"time"
"github.com/aws/aws-sdk-go/aws/credentials"
)
// DefaultChainCredentials is a Credentials which will find the first available
// credentials Value from the list of Providers.
//
// This should be used in the default case. Once the type of credentials are
// known switching to the specific Credentials will be more efficient.
var DefaultChainCredentials = credentials.NewChainCredentials(
[]credentials.Provider{
&credentials.EnvProvider{},
&credentials.SharedCredentialsProvider{Filename: "", Profile: ""},
&credentials.EC2RoleProvider{ExpiryWindow: 5 * time.Minute},
})
// The default number of retries for a service. The value of -1 indicates that
// the service specific retry default will be used.
const DefaultRetries = -1
// DefaultConfig is the default all service configuration will be based off of.
// By default, all clients use this structure for initialization options unless
// a custom configuration object is passed in.
//
// You may modify this global structure to change all default configuration
// in the SDK. Note that configuration options are copied by value, so any
// modifications must happen before constructing a client.
var DefaultConfig = NewConfig().
WithCredentials(DefaultChainCredentials).
WithRegion(os.Getenv("AWS_REGION")).
WithHTTPClient(http.DefaultClient).
WithMaxRetries(DefaultRetries).
WithLogger(NewDefaultLogger()).
WithLogLevel(LogOff)
// UseServiceDefaultRetries instructs the config to use the service's own default
// number of retries. This will be the default action if Config.MaxRetries
// is nil also.
const UseServiceDefaultRetries = -1
// A Config provides service configuration for service clients. By default,
// all clients will use the {DefaultConfig} structure.
// all clients will use the {defaults.DefaultConfig} structure.
type Config struct {
// The credentials object to use when signing requests. Defaults to
// {DefaultChainCredentials}.
// a chain of credential providers to search for credentials in environment
// variables, shared credential file, and EC2 Instance Roles.
Credentials *credentials.Credentials
// An optional endpoint URL (hostname only or fully qualified URI)
@@ -102,6 +76,8 @@ type Config struct {
// @see http://docs.aws.amazon.com/AmazonS3/latest/dev/VirtualHosting.html
// Amazon S3: Virtual Hosting of Buckets
S3ForcePathStyle *bool
SleepDelay func(time.Duration)
}
// NewConfig returns a new Config pointer that can be chained with builder methods to
@@ -190,15 +166,24 @@ func (c *Config) WithS3ForcePathStyle(force bool) *Config {
return c
}
// Merge returns a new Config with the other Config's attribute values merged into
// this Config. If the other Config's attribute is nil it will not be merged into
// the new Config to be returned.
func (c Config) Merge(other *Config) *Config {
if other == nil {
return &c
}
// WithSleepDelay overrides the function used to sleep while waiting for the
// next retry. Defaults to time.Sleep.
func (c *Config) WithSleepDelay(fn func(time.Duration)) *Config {
c.SleepDelay = fn
return c
}
dst := c
// MergeIn merges the passed in configs into the existing config object.
func (c *Config) MergeIn(cfgs ...*Config) {
for _, other := range cfgs {
mergeInConfig(c, other)
}
}
func mergeInConfig(dst *Config, other *Config) {
if other == nil {
return
}
if other.Credentials != nil {
dst.Credentials = other.Credentials
@@ -244,11 +229,20 @@ func (c Config) Merge(other *Config) *Config {
dst.S3ForcePathStyle = other.S3ForcePathStyle
}
return &dst
if other.SleepDelay != nil {
dst.SleepDelay = other.SleepDelay
}
}
// Copy will return a shallow copy of the Config object.
func (c Config) Copy() *Config {
dst := c
return &dst
// Copy will return a shallow copy of the Config object. If any additional
// configurations are provided they will be merged into the new config returned.
func (c *Config) Copy(cfgs ...*Config) *Config {
dst := &Config{}
dst.MergeIn(c)
for _, cfg := range cfgs {
dst.MergeIn(cfg)
}
return dst
}

View File

@@ -4,18 +4,11 @@ import (
"net/http"
"reflect"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws/credentials"
)
var testCredentials = credentials.NewChainCredentials([]credentials.Provider{
&credentials.EnvProvider{},
&credentials.SharedCredentialsProvider{
Filename: "TestFilename",
Profile: "TestProfile"},
&credentials.EC2RoleProvider{ExpiryWindow: 5 * time.Minute},
})
var testCredentials = credentials.NewStaticCredentials("AKID", "SECRET", "SESSION")
var copyTestConfig = Config{
Credentials: testCredentials,
@@ -25,7 +18,7 @@ var copyTestConfig = Config{
HTTPClient: http.DefaultClient,
LogLevel: LogLevel(LogDebug),
Logger: NewDefaultLogger(),
MaxRetries: Int(DefaultRetries),
MaxRetries: Int(3),
DisableParamValidation: Bool(true),
DisableComputeChecksums: Bool(true),
S3ForcePathStyle: Bool(true),
@@ -38,6 +31,11 @@ func TestCopy(t *testing.T) {
t.Errorf("Copy() = %+v", got)
t.Errorf(" want %+v", want)
}
got.Region = String("other")
if got.Region == want.Region {
t.Errorf("Expect setting copy values not not reflect in source")
}
}
func TestCopyReturnsNewInstance(t *testing.T) {
@@ -76,7 +74,8 @@ var mergeTests = []struct {
func TestMerge(t *testing.T) {
for i, tt := range mergeTests {
got := tt.cfg.Merge(tt.in)
got := tt.cfg.Copy()
got.MergeIn(tt.in)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Config %d %+v", i, tt.cfg)
t.Errorf(" Merge(%+v)", tt.in)

View File

@@ -1,10 +1,9 @@
package aws_test
package aws
import (
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/stretchr/testify/assert"
)
@@ -18,20 +17,20 @@ func TestStringSlice(t *testing.T) {
if in == nil {
continue
}
out := aws.StringSlice(in)
out := StringSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.StringValueSlice(out)
out2 := StringValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}
var testCasesStringValueSlice = [][]*string{
{aws.String("a"), aws.String("b"), nil, aws.String("c")},
{String("a"), String("b"), nil, String("c")},
}
func TestStringValueSlice(t *testing.T) {
@@ -39,7 +38,7 @@ func TestStringValueSlice(t *testing.T) {
if in == nil {
continue
}
out := aws.StringValueSlice(in)
out := StringValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
if in[i] == nil {
@@ -49,7 +48,7 @@ func TestStringValueSlice(t *testing.T) {
}
}
out2 := aws.StringSlice(out)
out2 := StringSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
for i := range out2 {
if in[i] == nil {
@@ -70,13 +69,13 @@ func TestStringMap(t *testing.T) {
if in == nil {
continue
}
out := aws.StringMap(in)
out := StringMap(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.StringValueMap(out)
out2 := StringValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
@@ -91,13 +90,13 @@ func TestBoolSlice(t *testing.T) {
if in == nil {
continue
}
out := aws.BoolSlice(in)
out := BoolSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.BoolValueSlice(out)
out2 := BoolValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
@@ -110,7 +109,7 @@ func TestBoolValueSlice(t *testing.T) {
if in == nil {
continue
}
out := aws.BoolValueSlice(in)
out := BoolValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
if in[i] == nil {
@@ -120,7 +119,7 @@ func TestBoolValueSlice(t *testing.T) {
}
}
out2 := aws.BoolSlice(out)
out2 := BoolSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
for i := range out2 {
if in[i] == nil {
@@ -141,13 +140,13 @@ func TestBoolMap(t *testing.T) {
if in == nil {
continue
}
out := aws.BoolMap(in)
out := BoolMap(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.BoolValueMap(out)
out2 := BoolValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
@@ -162,13 +161,13 @@ func TestIntSlice(t *testing.T) {
if in == nil {
continue
}
out := aws.IntSlice(in)
out := IntSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.IntValueSlice(out)
out2 := IntValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
@@ -181,7 +180,7 @@ func TestIntValueSlice(t *testing.T) {
if in == nil {
continue
}
out := aws.IntValueSlice(in)
out := IntValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
if in[i] == nil {
@@ -191,7 +190,7 @@ func TestIntValueSlice(t *testing.T) {
}
}
out2 := aws.IntSlice(out)
out2 := IntSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
for i := range out2 {
if in[i] == nil {
@@ -212,13 +211,13 @@ func TestIntMap(t *testing.T) {
if in == nil {
continue
}
out := aws.IntMap(in)
out := IntMap(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.IntValueMap(out)
out2 := IntValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
@@ -233,13 +232,13 @@ func TestInt64Slice(t *testing.T) {
if in == nil {
continue
}
out := aws.Int64Slice(in)
out := Int64Slice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.Int64ValueSlice(out)
out2 := Int64ValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
@@ -252,7 +251,7 @@ func TestInt64ValueSlice(t *testing.T) {
if in == nil {
continue
}
out := aws.Int64ValueSlice(in)
out := Int64ValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
if in[i] == nil {
@@ -262,7 +261,7 @@ func TestInt64ValueSlice(t *testing.T) {
}
}
out2 := aws.Int64Slice(out)
out2 := Int64Slice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
for i := range out2 {
if in[i] == nil {
@@ -283,13 +282,13 @@ func TestInt64Map(t *testing.T) {
if in == nil {
continue
}
out := aws.Int64Map(in)
out := Int64Map(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.Int64ValueMap(out)
out2 := Int64ValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
@@ -304,13 +303,13 @@ func TestFloat64Slice(t *testing.T) {
if in == nil {
continue
}
out := aws.Float64Slice(in)
out := Float64Slice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.Float64ValueSlice(out)
out2 := Float64ValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
@@ -323,7 +322,7 @@ func TestFloat64ValueSlice(t *testing.T) {
if in == nil {
continue
}
out := aws.Float64ValueSlice(in)
out := Float64ValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
if in[i] == nil {
@@ -333,7 +332,7 @@ func TestFloat64ValueSlice(t *testing.T) {
}
}
out2 := aws.Float64Slice(out)
out2 := Float64Slice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
for i := range out2 {
if in[i] == nil {
@@ -354,13 +353,13 @@ func TestFloat64Map(t *testing.T) {
if in == nil {
continue
}
out := aws.Float64Map(in)
out := Float64Map(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.Float64ValueMap(out)
out2 := Float64ValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
@@ -375,13 +374,13 @@ func TestTimeSlice(t *testing.T) {
if in == nil {
continue
}
out := aws.TimeSlice(in)
out := TimeSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.TimeValueSlice(out)
out2 := TimeValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
@@ -394,7 +393,7 @@ func TestTimeValueSlice(t *testing.T) {
if in == nil {
continue
}
out := aws.TimeValueSlice(in)
out := TimeValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
if in[i] == nil {
@@ -404,7 +403,7 @@ func TestTimeValueSlice(t *testing.T) {
}
}
out2 := aws.TimeSlice(out)
out2 := TimeSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
for i := range out2 {
if in[i] == nil {
@@ -425,13 +424,13 @@ func TestTimeMap(t *testing.T) {
if in == nil {
continue
}
out := aws.TimeMap(in)
out := TimeMap(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.TimeValueMap(out)
out2 := TimeValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}

View File

@@ -1,4 +1,4 @@
package aws
package corehandlers
import (
"bytes"
@@ -8,25 +8,23 @@ import (
"net/http"
"net/url"
"regexp"
"runtime"
"strconv"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
var sleepDelay = func(delay time.Duration) {
time.Sleep(delay)
}
// Interface for matching types which also have a Len method.
type lener interface {
Len() int
}
// BuildContentLength builds the content length of a request based on the body,
// BuildContentLengthHandler builds the content length of a request based on the body,
// or will use the HTTPRequest.Header's "Content-Length" if defined. If unable
// to determine request body length and no "Content-Length" was specified it will panic.
func BuildContentLength(r *Request) {
var BuildContentLengthHandler = request.NamedHandler{Name: "core.BuildContentLengthHandler", Fn: func(r *request.Request) {
if slength := r.HTTPRequest.Header.Get("Content-Length"); slength != "" {
length, _ := strconv.ParseInt(slength, 10, 64)
r.HTTPRequest.ContentLength = length
@@ -40,36 +38,38 @@ func BuildContentLength(r *Request) {
case lener:
length = int64(body.Len())
case io.Seeker:
r.bodyStart, _ = body.Seek(0, 1)
r.BodyStart, _ = body.Seek(0, 1)
end, _ := body.Seek(0, 2)
body.Seek(r.bodyStart, 0) // make sure to seek back to original location
length = end - r.bodyStart
body.Seek(r.BodyStart, 0) // make sure to seek back to original location
length = end - r.BodyStart
default:
panic("Cannot get length of body, must provide `ContentLength`")
}
r.HTTPRequest.ContentLength = length
r.HTTPRequest.Header.Set("Content-Length", fmt.Sprintf("%d", length))
}}
// SDKVersionUserAgentHandler is a request handler for adding the SDK Version to the user agent.
var SDKVersionUserAgentHandler = request.NamedHandler{
Name: "core.SDKVersionUserAgentHandler",
Fn: request.MakeAddToUserAgentHandler(aws.SDKName, aws.SDKVersion,
runtime.Version(), runtime.GOOS, runtime.GOARCH),
}
// UserAgentHandler is a request handler for injecting User agent into requests.
func UserAgentHandler(r *Request) {
r.HTTPRequest.Header.Set("User-Agent", SDKName+"/"+SDKVersion)
}
var reStatusCode = regexp.MustCompile(`^(\d+)`)
var reStatusCode = regexp.MustCompile(`^(\d{3})`)
// SendHandler is a request handler to send service request using HTTP client.
func SendHandler(r *Request) {
var SendHandler = request.NamedHandler{Name: "core.SendHandler", Fn: func(r *request.Request) {
var err error
r.HTTPResponse, err = r.Service.Config.HTTPClient.Do(r.HTTPRequest)
r.HTTPResponse, err = r.Config.HTTPClient.Do(r.HTTPRequest)
if err != nil {
// Capture the case where url.Error is returned for error processing
// response. e.g. 301 without location header comes back as string
// error and r.HTTPResponse is nil. Other url redirect errors will
// comeback in a similar method.
if e, ok := err.(*url.Error); ok {
if s := reStatusCode.FindStringSubmatch(e.Error()); s != nil {
if e, ok := err.(*url.Error); ok && e.Err != nil {
if s := reStatusCode.FindStringSubmatch(e.Err.Error()); s != nil {
code, _ := strconv.ParseInt(s[1], 10, 64)
r.HTTPResponse = &http.Response{
StatusCode: int(code),
@@ -79,7 +79,7 @@ func SendHandler(r *Request) {
return
}
}
if r.HTTPRequest == nil {
if r.HTTPResponse == nil {
// Add a dummy request response object to ensure the HTTPResponse
// value is consistent.
r.HTTPResponse = &http.Response{
@@ -90,68 +90,50 @@ func SendHandler(r *Request) {
}
// Catch all other request errors.
r.Error = awserr.New("RequestError", "send request failed", err)
r.Retryable = Bool(true) // network errors are retryable
r.Retryable = aws.Bool(true) // network errors are retryable
}
}
}}
// ValidateResponseHandler is a request handler to validate service response.
func ValidateResponseHandler(r *Request) {
var ValidateResponseHandler = request.NamedHandler{Name: "core.ValidateResponseHandler", Fn: func(r *request.Request) {
if r.HTTPResponse.StatusCode == 0 || r.HTTPResponse.StatusCode >= 300 {
// this may be replaced by an UnmarshalError handler
r.Error = awserr.New("UnknownError", "unknown error", nil)
}
}
}}
// AfterRetryHandler performs final checks to determine if the request should
// be retried and how long to delay.
func AfterRetryHandler(r *Request) {
var AfterRetryHandler = request.NamedHandler{Name: "core.AfterRetryHandler", Fn: func(r *request.Request) {
// If one of the other handlers already set the retry state
// we don't want to override it based on the service's state
if r.Retryable == nil {
r.Retryable = Bool(r.Service.ShouldRetry(r))
r.Retryable = aws.Bool(r.ShouldRetry(r))
}
if r.WillRetry() {
r.RetryDelay = r.Service.RetryRules(r)
sleepDelay(r.RetryDelay)
r.RetryDelay = r.RetryRules(r)
r.Config.SleepDelay(r.RetryDelay)
// when the expired token exception occurs the credentials
// need to be expired locally so that the next request to
// get credentials will trigger a credentials refresh.
if r.Error != nil {
if err, ok := r.Error.(awserr.Error); ok {
if isCodeExpiredCreds(err.Code()) {
r.Config.Credentials.Expire()
}
}
if r.IsErrorExpired() {
r.Config.Credentials.Expire()
}
r.RetryCount++
r.Error = nil
}
}
var (
// ErrMissingRegion is an error that is returned if region configuration is
// not found.
//
// @readonly
ErrMissingRegion error = awserr.New("MissingRegion", "could not find region configuration", nil)
// ErrMissingEndpoint is an error that is returned if an endpoint cannot be
// resolved for a service.
//
// @readonly
ErrMissingEndpoint error = awserr.New("MissingEndpoint", "'Endpoint' configuration is required for this service", nil)
)
}}
// ValidateEndpointHandler is a request handler to validate a request had the
// appropriate Region and Endpoint set. Will set r.Error if the endpoint or
// region is not valid.
func ValidateEndpointHandler(r *Request) {
if r.Service.SigningRegion == "" && StringValue(r.Service.Config.Region) == "" {
r.Error = ErrMissingRegion
} else if r.Service.Endpoint == "" {
r.Error = ErrMissingEndpoint
var ValidateEndpointHandler = request.NamedHandler{Name: "core.ValidateEndpointHandler", Fn: func(r *request.Request) {
if r.ClientInfo.SigningRegion == "" && aws.StringValue(r.Config.Region) == "" {
r.Error = aws.ErrMissingRegion
} else if r.ClientInfo.Endpoint == "" {
r.Error = aws.ErrMissingEndpoint
}
}
}}

View File

@@ -0,0 +1,113 @@
package corehandlers_test
import (
"fmt"
"net/http"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
)
func TestValidateEndpointHandler(t *testing.T) {
os.Clearenv()
svc := awstesting.NewClient(aws.NewConfig().WithRegion("us-west-2"))
svc.Handlers.Clear()
svc.Handlers.Validate.PushBackNamed(corehandlers.ValidateEndpointHandler)
req := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
err := req.Build()
assert.NoError(t, err)
}
func TestValidateEndpointHandlerErrorRegion(t *testing.T) {
os.Clearenv()
svc := awstesting.NewClient()
svc.Handlers.Clear()
svc.Handlers.Validate.PushBackNamed(corehandlers.ValidateEndpointHandler)
req := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
err := req.Build()
assert.Error(t, err)
assert.Equal(t, aws.ErrMissingRegion, err)
}
type mockCredsProvider struct {
expired bool
retrieveCalled bool
}
func (m *mockCredsProvider) Retrieve() (credentials.Value, error) {
m.retrieveCalled = true
return credentials.Value{}, nil
}
func (m *mockCredsProvider) IsExpired() bool {
return m.expired
}
func TestAfterRetryRefreshCreds(t *testing.T) {
os.Clearenv()
credProvider := &mockCredsProvider{}
svc := awstesting.NewClient(&aws.Config{
Credentials: credentials.NewCredentials(credProvider),
MaxRetries: aws.Int(1),
})
svc.Handlers.Clear()
svc.Handlers.ValidateResponse.PushBack(func(r *request.Request) {
r.Error = awserr.New("UnknownError", "", nil)
r.HTTPResponse = &http.Response{StatusCode: 400}
})
svc.Handlers.UnmarshalError.PushBack(func(r *request.Request) {
r.Error = awserr.New("ExpiredTokenException", "", nil)
})
svc.Handlers.AfterRetry.PushBackNamed(corehandlers.AfterRetryHandler)
assert.True(t, svc.Config.Credentials.IsExpired(), "Expect to start out expired")
assert.False(t, credProvider.retrieveCalled)
req := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
req.Send()
assert.True(t, svc.Config.Credentials.IsExpired())
assert.False(t, credProvider.retrieveCalled)
_, err := svc.Config.Credentials.Get()
assert.NoError(t, err)
assert.True(t, credProvider.retrieveCalled)
}
type testSendHandlerTransport struct{}
func (t *testSendHandlerTransport) RoundTrip(r *http.Request) (*http.Response, error) {
return nil, fmt.Errorf("mock error")
}
func TestSendHandlerError(t *testing.T) {
svc := awstesting.NewClient(&aws.Config{
HTTPClient: &http.Client{
Transport: &testSendHandlerTransport{},
},
})
svc.Handlers.Clear()
svc.Handlers.Send.PushBackNamed(corehandlers.SendHandler)
r := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
r.Send()
assert.Error(t, r.Error)
assert.NotNil(t, r.HTTPResponse)
}

View File

@@ -0,0 +1,144 @@
package corehandlers
import (
"fmt"
"reflect"
"strconv"
"strings"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
// ValidateParametersHandler is a request handler to validate the input parameters.
// Validating parameters only has meaning if done prior to the request being sent.
var ValidateParametersHandler = request.NamedHandler{Name: "core.ValidateParametersHandler", Fn: func(r *request.Request) {
if r.ParamsFilled() {
v := validator{errors: []string{}}
v.validateAny(reflect.ValueOf(r.Params), "")
if count := len(v.errors); count > 0 {
format := "%d validation errors:\n- %s"
msg := fmt.Sprintf(format, count, strings.Join(v.errors, "\n- "))
r.Error = awserr.New("InvalidParameter", msg, nil)
}
}
}}
// A validator validates values. Collects validations errors which occurs.
type validator struct {
errors []string
}
// validateAny will validate any struct, slice or map type. All validations
// are also performed recursively for nested types.
func (v *validator) validateAny(value reflect.Value, path string) {
value = reflect.Indirect(value)
if !value.IsValid() {
return
}
switch value.Kind() {
case reflect.Struct:
v.validateStruct(value, path)
case reflect.Slice:
for i := 0; i < value.Len(); i++ {
v.validateAny(value.Index(i), path+fmt.Sprintf("[%d]", i))
}
case reflect.Map:
for _, n := range value.MapKeys() {
v.validateAny(value.MapIndex(n), path+fmt.Sprintf("[%q]", n.String()))
}
}
}
// validateStruct will validate the struct value's fields. If the structure has
// nested types those types will be validated also.
func (v *validator) validateStruct(value reflect.Value, path string) {
prefix := "."
if path == "" {
prefix = ""
}
for i := 0; i < value.Type().NumField(); i++ {
f := value.Type().Field(i)
if strings.ToLower(f.Name[0:1]) == f.Name[0:1] {
continue
}
fvalue := value.FieldByName(f.Name)
err := validateField(f, fvalue, validateFieldRequired, validateFieldMin)
if err != nil {
v.errors = append(v.errors, fmt.Sprintf("%s: %s", err.Error(), path+prefix+f.Name))
continue
}
v.validateAny(fvalue, path+prefix+f.Name)
}
}
type validatorFunc func(f reflect.StructField, fvalue reflect.Value) error
func validateField(f reflect.StructField, fvalue reflect.Value, funcs ...validatorFunc) error {
for _, fn := range funcs {
if err := fn(f, fvalue); err != nil {
return err
}
}
return nil
}
// Validates that a field has a valid value provided for required fields.
func validateFieldRequired(f reflect.StructField, fvalue reflect.Value) error {
if f.Tag.Get("required") == "" {
return nil
}
switch fvalue.Kind() {
case reflect.Ptr, reflect.Slice, reflect.Map:
if fvalue.IsNil() {
return fmt.Errorf("missing required parameter")
}
default:
if !fvalue.IsValid() {
return fmt.Errorf("missing required parameter")
}
}
return nil
}
// Validates that if a value is provided for a field, that value must be at
// least a minimum length.
func validateFieldMin(f reflect.StructField, fvalue reflect.Value) error {
minStr := f.Tag.Get("min")
if minStr == "" {
return nil
}
min, _ := strconv.ParseInt(minStr, 10, 64)
kind := fvalue.Kind()
if kind == reflect.Ptr {
if fvalue.IsNil() {
return nil
}
fvalue = fvalue.Elem()
}
switch fvalue.Kind() {
case reflect.String:
if int64(fvalue.Len()) < min {
return fmt.Errorf("field too short, minimum length %d", min)
}
case reflect.Slice, reflect.Map:
if fvalue.IsNil() {
return nil
}
if int64(fvalue.Len()) < min {
return fmt.Errorf("field too short, minimum length %d", min)
}
// TODO min can also apply to number minimum value.
}
return nil
}

View File

@@ -0,0 +1,134 @@
package corehandlers_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/stretchr/testify/require"
)
var testSvc = func() *client.Client {
s := &client.Client{
Config: aws.Config{},
ClientInfo: metadata.ClientInfo{
ServiceName: "mock-service",
APIVersion: "2015-01-01",
},
}
return s
}()
type StructShape struct {
RequiredList []*ConditionalStructShape `required:"true"`
RequiredMap map[string]*ConditionalStructShape `required:"true"`
RequiredBool *bool `required:"true"`
OptionalStruct *ConditionalStructShape
hiddenParameter *string
metadataStructureShape
}
type metadataStructureShape struct {
SDKShapeTraits bool
}
type ConditionalStructShape struct {
Name *string `required:"true"`
SDKShapeTraits bool
}
func TestNoErrors(t *testing.T) {
input := &StructShape{
RequiredList: []*ConditionalStructShape{},
RequiredMap: map[string]*ConditionalStructShape{
"key1": {Name: aws.String("Name")},
"key2": {Name: aws.String("Name")},
},
RequiredBool: aws.Bool(true),
OptionalStruct: &ConditionalStructShape{Name: aws.String("Name")},
}
req := testSvc.NewRequest(&request.Operation{}, input, nil)
corehandlers.ValidateParametersHandler.Fn(req)
require.NoError(t, req.Error)
}
func TestMissingRequiredParameters(t *testing.T) {
input := &StructShape{}
req := testSvc.NewRequest(&request.Operation{}, input, nil)
corehandlers.ValidateParametersHandler.Fn(req)
require.Error(t, req.Error)
assert.Equal(t, "InvalidParameter", req.Error.(awserr.Error).Code())
assert.Equal(t, "3 validation errors:\n- missing required parameter: RequiredList\n- missing required parameter: RequiredMap\n- missing required parameter: RequiredBool", req.Error.(awserr.Error).Message())
}
func TestNestedMissingRequiredParameters(t *testing.T) {
input := &StructShape{
RequiredList: []*ConditionalStructShape{{}},
RequiredMap: map[string]*ConditionalStructShape{
"key1": {Name: aws.String("Name")},
"key2": {},
},
RequiredBool: aws.Bool(true),
OptionalStruct: &ConditionalStructShape{},
}
req := testSvc.NewRequest(&request.Operation{}, input, nil)
corehandlers.ValidateParametersHandler.Fn(req)
require.Error(t, req.Error)
assert.Equal(t, "InvalidParameter", req.Error.(awserr.Error).Code())
assert.Equal(t, "3 validation errors:\n- missing required parameter: RequiredList[0].Name\n- missing required parameter: RequiredMap[\"key2\"].Name\n- missing required parameter: OptionalStruct.Name", req.Error.(awserr.Error).Message())
}
type testInput struct {
StringField string `min:"5"`
PtrStrField *string `min:"2"`
ListField []string `min:"3"`
MapField map[string]string `min:"4"`
}
var testsFieldMin = []struct {
err awserr.Error
in testInput
}{
{
err: awserr.New("InvalidParameter", "1 validation errors:\n- field too short, minimum length 5: StringField", nil),
in: testInput{StringField: "abcd"},
},
{
err: awserr.New("InvalidParameter", "2 validation errors:\n- field too short, minimum length 5: StringField\n- field too short, minimum length 3: ListField", nil),
in: testInput{StringField: "abcd", ListField: []string{"a", "b"}},
},
{
err: awserr.New("InvalidParameter", "3 validation errors:\n- field too short, minimum length 5: StringField\n- field too short, minimum length 3: ListField\n- field too short, minimum length 4: MapField", nil),
in: testInput{StringField: "abcd", ListField: []string{"a", "b"}, MapField: map[string]string{"a": "a", "b": "b"}},
},
{
err: awserr.New("InvalidParameter", "1 validation errors:\n- field too short, minimum length 2: PtrStrField", nil),
in: testInput{StringField: "abcde", PtrStrField: aws.String("v")},
},
{
err: nil,
in: testInput{StringField: "abcde", PtrStrField: aws.String("value"),
ListField: []string{"a", "b", "c"}, MapField: map[string]string{"a": "a", "b": "b", "c": "c", "d": "d"}},
},
}
func TestValidateFieldMinParameter(t *testing.T) {
for i, c := range testsFieldMin {
req := testSvc.NewRequest(&request.Operation{}, &c.in, nil)
corehandlers.ValidateParametersHandler.Fn(req)
require.Equal(t, c.err, req.Error, "%d case failed", i)
}
}

View File

@@ -53,8 +53,8 @@ import (
"time"
)
// Create an empty Credential object that can be used as dummy placeholder
// credentials for requests that do not need signed.
// AnonymousCredentials is an empty Credential object that can be used as
// dummy placeholder credentials for requests that do not need signed.
//
// This Credentials can be used to configure a service to not sign requests
// when making service API calls. For example, when accessing public

View File

@@ -1,162 +0,0 @@
package credentials
import (
"bufio"
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
)
const metadataCredentialsEndpoint = "http://169.254.169.254/latest/meta-data/iam/security-credentials/"
// A EC2RoleProvider retrieves credentials from the EC2 service, and keeps track if
// those credentials are expired.
//
// Example how to configure the EC2RoleProvider with custom http Client, Endpoint
// or ExpiryWindow
//
// p := &credentials.EC2RoleProvider{
// // Pass in a custom timeout to be used when requesting
// // IAM EC2 Role credentials.
// Client: &http.Client{
// Timeout: 10 * time.Second,
// },
// // Use default EC2 Role metadata endpoint, Alternate endpoints can be
// // specified setting Endpoint to something else.
// Endpoint: "",
// // Do not use early expiry of credentials. If a non zero value is
// // specified the credentials will be expired early
// ExpiryWindow: 0,
// }
type EC2RoleProvider struct {
Expiry
// Endpoint must be fully quantified URL
Endpoint string
// HTTP client to use when connecting to EC2 service
Client *http.Client
// ExpiryWindow will allow the credentials to trigger refreshing prior to
// the credentials actually expiring. This is beneficial so race conditions
// with expiring credentials do not cause request to fail unexpectedly
// due to ExpiredTokenException exceptions.
//
// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
// 10 seconds before the credentials are actually expired.
//
// If ExpiryWindow is 0 or less it will be ignored.
ExpiryWindow time.Duration
}
// NewEC2RoleCredentials returns a pointer to a new Credentials object
// wrapping the EC2RoleProvider.
//
// Takes a custom http.Client which can be configured for custom handling of
// things such as timeout.
//
// Endpoint is the URL that the EC2RoleProvider will connect to when retrieving
// role and credentials.
//
// Window is the expiry window that will be subtracted from the expiry returned
// by the role credential request. This is done so that the credentials will
// expire sooner than their actual lifespan.
func NewEC2RoleCredentials(client *http.Client, endpoint string, window time.Duration) *Credentials {
return NewCredentials(&EC2RoleProvider{
Endpoint: endpoint,
Client: client,
ExpiryWindow: window,
})
}
// Retrieve retrieves credentials from the EC2 service.
// Error will be returned if the request fails, or unable to extract
// the desired credentials.
func (m *EC2RoleProvider) Retrieve() (Value, error) {
if m.Client == nil {
m.Client = http.DefaultClient
}
if m.Endpoint == "" {
m.Endpoint = metadataCredentialsEndpoint
}
credsList, err := requestCredList(m.Client, m.Endpoint)
if err != nil {
return Value{}, err
}
if len(credsList) == 0 {
return Value{}, awserr.New("EmptyEC2RoleList", "empty EC2 Role list", nil)
}
credsName := credsList[0]
roleCreds, err := requestCred(m.Client, m.Endpoint, credsName)
if err != nil {
return Value{}, err
}
m.SetExpiration(roleCreds.Expiration, m.ExpiryWindow)
return Value{
AccessKeyID: roleCreds.AccessKeyID,
SecretAccessKey: roleCreds.SecretAccessKey,
SessionToken: roleCreds.Token,
}, nil
}
// A ec2RoleCredRespBody provides the shape for deserializing credential
// request responses.
type ec2RoleCredRespBody struct {
Expiration time.Time
AccessKeyID string
SecretAccessKey string
Token string
}
// requestCredList requests a list of credentials from the EC2 service.
// If there are no credentials, or there is an error making or receiving the request
func requestCredList(client *http.Client, endpoint string) ([]string, error) {
resp, err := client.Get(endpoint)
if err != nil {
return nil, awserr.New("ListEC2Role", "failed to list EC2 Roles", err)
}
defer resp.Body.Close()
credsList := []string{}
s := bufio.NewScanner(resp.Body)
for s.Scan() {
credsList = append(credsList, s.Text())
}
if err := s.Err(); err != nil {
return nil, awserr.New("ReadEC2Role", "failed to read list of EC2 Roles", err)
}
return credsList, nil
}
// requestCred requests the credentials for a specific credentials from the EC2 service.
//
// If the credentials cannot be found, or there is an error reading the response
// and error will be returned.
func requestCred(client *http.Client, endpoint, credsName string) (*ec2RoleCredRespBody, error) {
resp, err := client.Get(endpoint + credsName)
if err != nil {
return nil, awserr.New("GetEC2RoleCredentials",
fmt.Sprintf("failed to get %s EC2 Role credentials", credsName),
err)
}
defer resp.Body.Close()
respCreds := &ec2RoleCredRespBody{}
if err := json.NewDecoder(resp.Body).Decode(respCreds); err != nil {
return nil, awserr.New("DecodeEC2RoleCredentials",
fmt.Sprintf("failed to decode %s EC2 Role credentials", credsName),
err)
}
return respCreds, nil
}

View File

@@ -1,108 +0,0 @@
package credentials
import (
"fmt"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func initTestServer(expireOn string) *httptest.Server {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.RequestURI == "/" {
fmt.Fprintln(w, "/creds")
} else {
fmt.Fprintf(w, `{
"AccessKeyId" : "accessKey",
"SecretAccessKey" : "secret",
"Token" : "token",
"Expiration" : "%s"
}`, expireOn)
}
}))
return server
}
func TestEC2RoleProvider(t *testing.T) {
server := initTestServer("2014-12-16T01:51:37Z")
defer server.Close()
p := &EC2RoleProvider{Client: http.DefaultClient, Endpoint: server.URL}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
}
func TestEC2RoleProviderIsExpired(t *testing.T) {
server := initTestServer("2014-12-16T01:51:37Z")
defer server.Close()
p := &EC2RoleProvider{Client: http.DefaultClient, Endpoint: server.URL}
p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 15, 21, 26, 0, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired before retrieve.")
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve.")
p.CurrentTime = func() time.Time {
return time.Date(3014, 12, 15, 21, 26, 0, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired.")
}
func TestEC2RoleProviderExpiryWindowIsExpired(t *testing.T) {
server := initTestServer("2014-12-16T01:51:37Z")
defer server.Close()
p := &EC2RoleProvider{Client: http.DefaultClient, Endpoint: server.URL, ExpiryWindow: time.Hour * 1}
p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 15, 0, 51, 37, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired before retrieve.")
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve.")
p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 16, 0, 55, 37, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired.")
}
func BenchmarkEC2RoleProvider(b *testing.B) {
server := initTestServer("2014-12-16T01:51:37Z")
defer server.Close()
p := &EC2RoleProvider{Client: http.DefaultClient, Endpoint: server.URL}
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
}
})
}

View File

@@ -0,0 +1,173 @@
package ec2rolecreds
import (
"bufio"
"encoding/json"
"fmt"
"path"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
)
// A EC2RoleProvider retrieves credentials from the EC2 service, and keeps track if
// those credentials are expired.
//
// Example how to configure the EC2RoleProvider with custom http Client, Endpoint
// or ExpiryWindow
//
// p := &ec2rolecreds.EC2RoleProvider{
// // Pass in a custom timeout to be used when requesting
// // IAM EC2 Role credentials.
// Client: &http.Client{
// Timeout: 10 * time.Second,
// },
// // Do not use early expiry of credentials. If a non zero value is
// // specified the credentials will be expired early
// ExpiryWindow: 0,
// }
type EC2RoleProvider struct {
credentials.Expiry
// Required EC2Metadata client to use when connecting to EC2 metadata service.
Client *ec2metadata.EC2Metadata
// ExpiryWindow will allow the credentials to trigger refreshing prior to
// the credentials actually expiring. This is beneficial so race conditions
// with expiring credentials do not cause request to fail unexpectedly
// due to ExpiredTokenException exceptions.
//
// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
// 10 seconds before the credentials are actually expired.
//
// If ExpiryWindow is 0 or less it will be ignored.
ExpiryWindow time.Duration
}
// NewCredentials returns a pointer to a new Credentials object wrapping
// the EC2RoleProvider. Takes a ConfigProvider to create a EC2Metadata client.
// The ConfigProvider is satisfied by the session.Session type.
func NewCredentials(c client.ConfigProvider, options ...func(*EC2RoleProvider)) *credentials.Credentials {
p := &EC2RoleProvider{
Client: ec2metadata.New(c),
}
for _, option := range options {
option(p)
}
return credentials.NewCredentials(p)
}
// NewCredentialsWithClient returns a pointer to a new Credentials object wrapping
// the EC2RoleProvider. Takes a EC2Metadata client to use when connecting to EC2
// metadata service.
func NewCredentialsWithClient(client *ec2metadata.EC2Metadata, options ...func(*EC2RoleProvider)) *credentials.Credentials {
p := &EC2RoleProvider{
Client: client,
}
for _, option := range options {
option(p)
}
return credentials.NewCredentials(p)
}
// Retrieve retrieves credentials from the EC2 service.
// Error will be returned if the request fails, or unable to extract
// the desired credentials.
func (m *EC2RoleProvider) Retrieve() (credentials.Value, error) {
credsList, err := requestCredList(m.Client)
if err != nil {
return credentials.Value{}, err
}
if len(credsList) == 0 {
return credentials.Value{}, awserr.New("EmptyEC2RoleList", "empty EC2 Role list", nil)
}
credsName := credsList[0]
roleCreds, err := requestCred(m.Client, credsName)
if err != nil {
return credentials.Value{}, err
}
m.SetExpiration(roleCreds.Expiration, m.ExpiryWindow)
return credentials.Value{
AccessKeyID: roleCreds.AccessKeyID,
SecretAccessKey: roleCreds.SecretAccessKey,
SessionToken: roleCreds.Token,
}, nil
}
// A ec2RoleCredRespBody provides the shape for unmarshalling credential
// request responses.
type ec2RoleCredRespBody struct {
// Success State
Expiration time.Time
AccessKeyID string
SecretAccessKey string
Token string
// Error state
Code string
Message string
}
const iamSecurityCredsPath = "/iam/security-credentials"
// requestCredList requests a list of credentials from the EC2 service.
// If there are no credentials, or there is an error making or receiving the request
func requestCredList(client *ec2metadata.EC2Metadata) ([]string, error) {
resp, err := client.GetMetadata(iamSecurityCredsPath)
if err != nil {
return nil, awserr.New("EC2RoleRequestError", "failed to list EC2 Roles", err)
}
credsList := []string{}
s := bufio.NewScanner(strings.NewReader(resp))
for s.Scan() {
credsList = append(credsList, s.Text())
}
if err := s.Err(); err != nil {
return nil, awserr.New("SerializationError", "failed to read list of EC2 Roles", err)
}
return credsList, nil
}
// requestCred requests the credentials for a specific credentials from the EC2 service.
//
// If the credentials cannot be found, or there is an error reading the response
// and error will be returned.
func requestCred(client *ec2metadata.EC2Metadata, credsName string) (ec2RoleCredRespBody, error) {
resp, err := client.GetMetadata(path.Join(iamSecurityCredsPath, credsName))
if err != nil {
return ec2RoleCredRespBody{},
awserr.New("EC2RoleRequestError",
fmt.Sprintf("failed to get %s EC2 Role credentials", credsName),
err)
}
respCreds := ec2RoleCredRespBody{}
if err := json.NewDecoder(strings.NewReader(resp)).Decode(&respCreds); err != nil {
return ec2RoleCredRespBody{},
awserr.New("SerializationError",
fmt.Sprintf("failed to decode %s EC2 Role credentials", credsName),
err)
}
if respCreds.Code != "Success" {
// If an error code was returned something failed requesting the role.
return ec2RoleCredRespBody{}, awserr.New(respCreds.Code, respCreds.Message, nil)
}
return respCreds, nil
}

View File

@@ -0,0 +1,159 @@
package ec2rolecreds_test
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/session"
)
const credsRespTmpl = `{
"Code": "Success",
"Type": "AWS-HMAC",
"AccessKeyId" : "accessKey",
"SecretAccessKey" : "secret",
"Token" : "token",
"Expiration" : "%s",
"LastUpdated" : "2009-11-23T0:00:00Z"
}`
const credsFailRespTmpl = `{
"Code": "ErrorCode",
"Message": "ErrorMsg",
"LastUpdated": "2009-11-23T0:00:00Z"
}`
func initTestServer(expireOn string, failAssume bool) *httptest.Server {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/latest/meta-data/iam/security-credentials" {
fmt.Fprintln(w, "RoleName")
} else if r.URL.Path == "/latest/meta-data/iam/security-credentials/RoleName" {
if failAssume {
fmt.Fprintf(w, credsFailRespTmpl)
} else {
fmt.Fprintf(w, credsRespTmpl, expireOn)
}
} else {
http.Error(w, "bad request", http.StatusBadRequest)
}
}))
return server
}
func TestEC2RoleProvider(t *testing.T) {
server := initTestServer("2014-12-16T01:51:37Z", false)
defer server.Close()
p := &ec2rolecreds.EC2RoleProvider{
Client: ec2metadata.New(session.New(), &aws.Config{Endpoint: aws.String(server.URL + "/latest")}),
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error, %v", err)
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
}
func TestEC2RoleProviderFailAssume(t *testing.T) {
server := initTestServer("2014-12-16T01:51:37Z", true)
defer server.Close()
p := &ec2rolecreds.EC2RoleProvider{
Client: ec2metadata.New(session.New(), &aws.Config{Endpoint: aws.String(server.URL + "/latest")}),
}
creds, err := p.Retrieve()
assert.Error(t, err, "Expect error")
e := err.(awserr.Error)
assert.Equal(t, "ErrorCode", e.Code())
assert.Equal(t, "ErrorMsg", e.Message())
assert.Nil(t, e.OrigErr())
assert.Equal(t, "", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "", creds.SessionToken, "Expect session token to match")
}
func TestEC2RoleProviderIsExpired(t *testing.T) {
server := initTestServer("2014-12-16T01:51:37Z", false)
defer server.Close()
p := &ec2rolecreds.EC2RoleProvider{
Client: ec2metadata.New(session.New(), &aws.Config{Endpoint: aws.String(server.URL + "/latest")}),
}
p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 15, 21, 26, 0, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired before retrieve.")
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error, %v", err)
assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve.")
p.CurrentTime = func() time.Time {
return time.Date(3014, 12, 15, 21, 26, 0, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired.")
}
func TestEC2RoleProviderExpiryWindowIsExpired(t *testing.T) {
server := initTestServer("2014-12-16T01:51:37Z", false)
defer server.Close()
p := &ec2rolecreds.EC2RoleProvider{
Client: ec2metadata.New(session.New(), &aws.Config{Endpoint: aws.String(server.URL + "/latest")}),
ExpiryWindow: time.Hour * 1,
}
p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 15, 0, 51, 37, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired before retrieve.")
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error, %v", err)
assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve.")
p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 16, 0, 55, 37, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired.")
}
func BenchmarkEC3RoleProvider(b *testing.B) {
server := initTestServer("2014-12-16T01:51:37Z", false)
defer server.Close()
p := &ec2rolecreds.EC2RoleProvider{
Client: ec2metadata.New(session.New(), &aws.Config{Endpoint: aws.String(server.URL + "/latest")}),
}
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
if _, err := p.Retrieve(); err != nil {
b.Fatal(err)
}
}
}

View File

@@ -6,3 +6,7 @@ aws_session_token = token
[no_token]
aws_access_key_id = accessKey
aws_secret_access_key = secret
[with_colon]
aws_access_key_id: accessKey
aws_secret_access_key: secret

View File

@@ -5,7 +5,7 @@ import (
"os"
"path/filepath"
"github.com/vaughan0/go-ini"
"github.com/go-ini/ini"
"github.com/aws/aws-sdk-go/aws/awserr"
)
@@ -22,8 +22,12 @@ var (
//
// Profile ini file example: $HOME/.aws/credentials
type SharedCredentialsProvider struct {
// Path to the shared credentials file. If empty will default to current user's
// home directory.
// Path to the shared credentials file.
//
// If empty will look for "AWS_SHARED_CREDENTIALS_FILE" env variable. If the
// env value is empty will default to current user's home directory.
// Linux/OSX: "$HOME/.aws/credentials"
// Windows: "%USERPROFILE%\.aws\credentials"
Filename string
// AWS Profile to extract credentials from the shared credentials file. If empty
@@ -72,32 +76,36 @@ func (p *SharedCredentialsProvider) IsExpired() bool {
// The credentials retrieved from the profile will be returned or error. Error will be
// returned if it fails to read from the file, or the data is invalid.
func loadProfile(filename, profile string) (Value, error) {
config, err := ini.LoadFile(filename)
config, err := ini.Load(filename)
if err != nil {
return Value{}, awserr.New("SharedCredsLoad", "failed to load shared credentials file", err)
}
iniProfile := config.Section(profile)
id, ok := iniProfile["aws_access_key_id"]
if !ok {
return Value{}, awserr.New("SharedCredsAccessKey",
fmt.Sprintf("shared credentials %s in %s did not contain aws_access_key_id", profile, filename),
nil)
iniProfile, err := config.GetSection(profile)
if err != nil {
return Value{}, awserr.New("SharedCredsLoad", "failed to get profile", err)
}
secret, ok := iniProfile["aws_secret_access_key"]
if !ok {
id, err := iniProfile.GetKey("aws_access_key_id")
if err != nil {
return Value{}, awserr.New("SharedCredsAccessKey",
fmt.Sprintf("shared credentials %s in %s did not contain aws_access_key_id", profile, filename),
err)
}
secret, err := iniProfile.GetKey("aws_secret_access_key")
if err != nil {
return Value{}, awserr.New("SharedCredsSecret",
fmt.Sprintf("shared credentials %s in %s did not contain aws_secret_access_key", profile, filename),
nil)
}
token := iniProfile["aws_session_token"]
// Default to empty string if not found
token := iniProfile.Key("aws_session_token")
return Value{
AccessKeyID: id,
SecretAccessKey: secret,
SessionToken: token,
AccessKeyID: id.String(),
SecretAccessKey: secret.String(),
SessionToken: token.String(),
}, nil
}
@@ -106,6 +114,10 @@ func loadProfile(filename, profile string) (Value, error) {
// Will return an error if the user's home directory path cannot be found.
func (p *SharedCredentialsProvider) filename() (string, error) {
if p.Filename == "" {
if p.Filename = os.Getenv("AWS_SHARED_CREDENTIALS_FILE"); p.Filename != "" {
return p.Filename, nil
}
homeDir := os.Getenv("HOME") // *nix
if homeDir == "" { // Windows
homeDir = os.Getenv("USERPROFILE")

View File

@@ -31,6 +31,19 @@ func TestSharedCredentialsProviderIsExpired(t *testing.T) {
assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve")
}
func TestSharedCredentialsProviderWithAWS_SHARED_CREDENTIALS_FILE(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "example.ini")
p := SharedCredentialsProvider{}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
}
func TestSharedCredentialsProviderWithAWS_PROFILE(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_PROFILE", "no_token")
@@ -56,6 +69,18 @@ func TestSharedCredentialsProviderWithoutTokenFromProfile(t *testing.T) {
assert.Empty(t, creds.SessionToken, "Expect no token")
}
func TestSharedCredentialsProviderColonInCredFile(t *testing.T) {
os.Clearenv()
p := SharedCredentialsProvider{Filename: "example.ini", Profile: "with_colon"}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect no token")
}
func BenchmarkSharedCredentialsProvider(b *testing.B) {
os.Clearenv()
@@ -66,12 +91,10 @@ func BenchmarkSharedCredentialsProvider(b *testing.B) {
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
for i := 0; i < b.N; i++ {
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
})
}
}

View File

@@ -6,10 +6,12 @@ package stscreds
import (
"fmt"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/service/sts"
"time"
)
// AssumeRoler represents the minimal subset of the STS client API used by this provider.
@@ -17,31 +19,17 @@ type AssumeRoler interface {
AssumeRole(input *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error)
}
// DefaultDuration is the default amount of time in minutes that the credentials
// will be valid for.
var DefaultDuration = time.Duration(15) * time.Minute
// AssumeRoleProvider retrieves temporary credentials from the STS service, and
// keeps track of their expiration time. This provider must be used explicitly,
// as it is not included in the credentials chain.
//
// Example how to configure a service to use this provider:
//
// config := &aws.Config{
// Credentials: stscreds.NewCredentials(nil, "arn-of-the-role-to-assume", 10*time.Second),
// })
// // Use config for creating your AWS service.
//
// Example how to obtain customised credentials:
//
// provider := &stscreds.Provider{
// // Extend the duration to 1 hour.
// Duration: time.Hour,
// // Custom role name.
// RoleSessionName: "custom-session-name",
// }
// creds := credentials.NewCredentials(provider)
//
type AssumeRoleProvider struct {
credentials.Expiry
// Custom STS client. If not set the default STS client will be used.
// STS client to make assume role request with.
Client AssumeRoler
// Role to be assumed.
@@ -53,6 +41,9 @@ type AssumeRoleProvider struct {
// Expiry duration of the STS credentials. Defaults to 15 minutes if not set.
Duration time.Duration
// Optional ExternalID to pass along, defaults to nil if not set.
ExternalID *string
// ExpiryWindow will allow the credentials to trigger refreshing prior to
// the credentials actually expiring. This is beneficial so race conditions
// with expiring credentials do not cause request to fail unexpectedly
@@ -66,43 +57,62 @@ type AssumeRoleProvider struct {
}
// NewCredentials returns a pointer to a new Credentials object wrapping the
// AssumeRoleProvider. The credentials will expire every 15 minutes and the
// AssumeRoleProvider. The credentials will expire every 15 minutes and the
// role will be named after a nanosecond timestamp of this operation.
//
// The sts and roleARN parameters are used for building the "AssumeRole" call.
// Pass nil as sts to use the default client.
// Takes a Config provider to create the STS client. The ConfigProvider is
// satisfied by the session.Session type.
func NewCredentials(c client.ConfigProvider, roleARN string, options ...func(*AssumeRoleProvider)) *credentials.Credentials {
p := &AssumeRoleProvider{
Client: sts.New(c),
RoleARN: roleARN,
Duration: DefaultDuration,
}
for _, option := range options {
option(p)
}
return credentials.NewCredentials(p)
}
// NewCredentialsWithClient returns a pointer to a new Credentials object wrapping the
// AssumeRoleProvider. The credentials will expire every 15 minutes and the
// role will be named after a nanosecond timestamp of this operation.
//
// Window is the expiry window that will be subtracted from the expiry returned
// by the role credential request. This is done so that the credentials will
// expire sooner than their actual lifespan.
func NewCredentials(client AssumeRoler, roleARN string, window time.Duration) *credentials.Credentials {
return credentials.NewCredentials(&AssumeRoleProvider{
Client: client,
RoleARN: roleARN,
ExpiryWindow: window,
})
// Takes an AssumeRoler which can be satisfiede by the STS client.
func NewCredentialsWithClient(svc AssumeRoler, roleARN string, options ...func(*AssumeRoleProvider)) *credentials.Credentials {
p := &AssumeRoleProvider{
Client: svc,
RoleARN: roleARN,
Duration: DefaultDuration,
}
for _, option := range options {
option(p)
}
return credentials.NewCredentials(p)
}
// Retrieve generates a new set of temporary credentials using STS.
func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
// Apply defaults where parameters are not set.
if p.Client == nil {
p.Client = sts.New(nil)
}
if p.RoleSessionName == "" {
// Try to work out a role name that will hopefully end up unique.
p.RoleSessionName = fmt.Sprintf("%d", time.Now().UTC().UnixNano())
}
if p.Duration == 0 {
// Expire as often as AWS permits.
p.Duration = 15 * time.Minute
p.Duration = DefaultDuration
}
roleOutput, err := p.Client.AssumeRole(&sts.AssumeRoleInput{
DurationSeconds: aws.Int64(int64(p.Duration / time.Second)),
RoleARN: aws.String(p.RoleARN),
RoleArn: aws.String(p.RoleARN),
RoleSessionName: aws.String(p.RoleSessionName),
ExternalId: p.ExternalID,
})
if err != nil {
@@ -113,7 +123,7 @@ func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
p.SetExpiration(*roleOutput.Credentials.Expiration, p.ExpiryWindow)
return credentials.Value{
AccessKeyID: *roleOutput.Credentials.AccessKeyID,
AccessKeyID: *roleOutput.Credentials.AccessKeyId,
SecretAccessKey: *roleOutput.Credentials.SecretAccessKey,
SessionToken: *roleOutput.Credentials.SessionToken,
}, nil

View File

@@ -1,11 +1,12 @@
package stscreds
import (
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/stretchr/testify/assert"
"testing"
"time"
)
type stubSTS struct {
@@ -16,7 +17,7 @@ func (s *stubSTS) AssumeRole(input *sts.AssumeRoleInput) (*sts.AssumeRoleOutput,
return &sts.AssumeRoleOutput{
Credentials: &sts.Credentials{
// Just reflect the role arn to the provider.
AccessKeyID: input.RoleARN,
AccessKeyId: input.RoleArn,
SecretAccessKey: aws.String("assumedSecretAccessKey"),
SessionToken: aws.String("assumedSessionToken"),
Expiration: &expiry,
@@ -47,12 +48,9 @@ func BenchmarkAssumeRoleProvider(b *testing.B) {
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
for i := 0; i < b.N; i++ {
if _, err := p.Retrieve(); err != nil {
b.Fatal(err)
}
})
}
}

View File

@@ -0,0 +1,76 @@
// Package defaults is a collection of helpers to retrieve the SDK's default
// configuration and handlers.
package defaults
import (
"net/http"
"os"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/private/endpoints"
)
// A Defaults provides a collection of default values for SDK clients.
type Defaults struct {
Config *aws.Config
Handlers request.Handlers
}
// Get returns the SDK's default values with Config and handlers pre-configured.
func Get() Defaults {
cfg := Config()
handlers := Handlers()
cfg.Credentials = CredChain(cfg, handlers)
return Defaults{
Config: cfg,
Handlers: handlers,
}
}
// Config returns the default configuration.
func Config() *aws.Config {
return aws.NewConfig().
WithCredentials(credentials.AnonymousCredentials).
WithRegion(os.Getenv("AWS_REGION")).
WithHTTPClient(http.DefaultClient).
WithMaxRetries(aws.UseServiceDefaultRetries).
WithLogger(aws.NewDefaultLogger()).
WithLogLevel(aws.LogOff).
WithSleepDelay(time.Sleep)
}
// Handlers returns the default request handlers.
func Handlers() request.Handlers {
var handlers request.Handlers
handlers.Validate.PushBackNamed(corehandlers.ValidateEndpointHandler)
handlers.Build.PushBackNamed(corehandlers.SDKVersionUserAgentHandler)
handlers.Sign.PushBackNamed(corehandlers.BuildContentLengthHandler)
handlers.Send.PushBackNamed(corehandlers.SendHandler)
handlers.AfterRetry.PushBackNamed(corehandlers.AfterRetryHandler)
handlers.ValidateResponse.PushBackNamed(corehandlers.ValidateResponseHandler)
return handlers
}
// CredChain returns the default credential chain.
func CredChain(cfg *aws.Config, handlers request.Handlers) *credentials.Credentials {
endpoint, signingRegion := endpoints.EndpointForRegion(ec2metadata.ServiceName, *cfg.Region, true)
return credentials.NewChainCredentials(
[]credentials.Provider{
&credentials.EnvProvider{},
&credentials.SharedCredentialsProvider{Filename: "", Profile: ""},
&ec2rolecreds.EC2RoleProvider{
Client: ec2metadata.NewClient(*cfg, handlers, endpoint, signingRegion),
ExpiryWindow: 5 * time.Minute,
},
})
}

View File

@@ -0,0 +1,43 @@
package ec2metadata
import (
"path"
"github.com/aws/aws-sdk-go/aws/request"
)
// GetMetadata uses the path provided to request
func (c *EC2Metadata) GetMetadata(p string) (string, error) {
op := &request.Operation{
Name: "GetMetadata",
HTTPMethod: "GET",
HTTPPath: path.Join("/", "meta-data", p),
}
output := &metadataOutput{}
req := c.NewRequest(op, nil, output)
return output.Content, req.Send()
}
// Region returns the region the instance is running in.
func (c *EC2Metadata) Region() (string, error) {
resp, err := c.GetMetadata("placement/availability-zone")
if err != nil {
return "", err
}
// returns region without the suffix. Eg: us-west-2a becomes us-west-2
return resp[:len(resp)-1], nil
}
// Available returns if the application has access to the EC2 Metadata service.
// Can be used to determine if application is running within an EC2 Instance and
// the metadata service is available.
func (c *EC2Metadata) Available() bool {
if _, err := c.GetMetadata("instance-id"); err != nil {
return false
}
return true
}

View File

@@ -0,0 +1,101 @@
package ec2metadata_test
import (
"bytes"
"io/ioutil"
"net/http"
"net/http/httptest"
"path"
"testing"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
)
func initTestServer(path string, resp string) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.RequestURI != path {
http.Error(w, "not found", http.StatusNotFound)
return
}
w.Write([]byte(resp))
}))
}
func TestEndpoint(t *testing.T) {
c := ec2metadata.New(session.New())
op := &request.Operation{
Name: "GetMetadata",
HTTPMethod: "GET",
HTTPPath: path.Join("/", "meta-data", "testpath"),
}
req := c.NewRequest(op, nil, nil)
assert.Equal(t, "http://169.254.169.254/latest", req.ClientInfo.Endpoint)
assert.Equal(t, "http://169.254.169.254/latest/meta-data/testpath", req.HTTPRequest.URL.String())
}
func TestGetMetadata(t *testing.T) {
server := initTestServer(
"/latest/meta-data/some/path",
"success", // real response includes suffix
)
defer server.Close()
c := ec2metadata.New(session.New(), &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
resp, err := c.GetMetadata("some/path")
assert.NoError(t, err)
assert.Equal(t, "success", resp)
}
func TestGetRegion(t *testing.T) {
server := initTestServer(
"/latest/meta-data/placement/availability-zone",
"us-west-2a", // real response includes suffix
)
defer server.Close()
c := ec2metadata.New(session.New(), &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
region, err := c.Region()
assert.NoError(t, err)
assert.Equal(t, "us-west-2", region)
}
func TestMetadataAvailable(t *testing.T) {
server := initTestServer(
"/latest/meta-data/instance-id",
"instance-id",
)
defer server.Close()
c := ec2metadata.New(session.New(), &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
available := c.Available()
assert.True(t, available)
}
func TestMetadataNotAvailable(t *testing.T) {
c := ec2metadata.New(session.New())
c.Handlers.Send.Clear()
c.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &http.Response{
StatusCode: int(0),
Status: http.StatusText(int(0)),
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
}
r.Error = awserr.New("RequestError", "send request failed", nil)
r.Retryable = aws.Bool(true) // network errors are retryable
})
available := c.Available()
assert.False(t, available)
}

View File

@@ -0,0 +1,116 @@
// Package ec2metadata provides the client for making API calls to the
// EC2 Metadata service.
package ec2metadata
import (
"io/ioutil"
"net"
"net/http"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/request"
)
// ServiceName is the name of the service.
const ServiceName = "ec2metadata"
// A EC2Metadata is an EC2 Metadata service Client.
type EC2Metadata struct {
*client.Client
}
// New creates a new instance of the EC2Metadata client with a session.
// This client is safe to use across multiple goroutines.
//
// Example:
// // Create a EC2Metadata client from just a session.
// svc := ec2metadata.New(mySession)
//
// // Create a EC2Metadata client with additional configuration
// svc := ec2metadata.New(mySession, aws.NewConfig().WithLogLevel(aws.LogDebugHTTPBody))
func New(p client.ConfigProvider, cfgs ...*aws.Config) *EC2Metadata {
c := p.ClientConfig(ServiceName, cfgs...)
return NewClient(*c.Config, c.Handlers, c.Endpoint, c.SigningRegion)
}
// NewClient returns a new EC2Metadata client. Should be used to create
// a client when not using a session. Generally using just New with a session
// is preferred.
func NewClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegion string, opts ...func(*client.Client)) *EC2Metadata {
// If the default http client is provided, replace it with a custom
// client using default timeouts.
if cfg.HTTPClient == http.DefaultClient {
cfg.HTTPClient = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
// use a shorter timeout than default because the metadata
// service is local if it is running, and to fail faster
// if not running on an ec2 instance.
Timeout: 5 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
TLSHandshakeTimeout: 10 * time.Second,
},
}
}
svc := &EC2Metadata{
Client: client.New(
cfg,
metadata.ClientInfo{
ServiceName: ServiceName,
Endpoint: endpoint,
APIVersion: "latest",
},
handlers,
),
}
svc.Handlers.Unmarshal.PushBack(unmarshalHandler)
svc.Handlers.UnmarshalError.PushBack(unmarshalError)
svc.Handlers.Validate.Clear()
svc.Handlers.Validate.PushBack(validateEndpointHandler)
// Add additional options to the service config
for _, option := range opts {
option(svc.Client)
}
return svc
}
type metadataOutput struct {
Content string
}
func unmarshalHandler(r *request.Request) {
defer r.HTTPResponse.Body.Close()
b, err := ioutil.ReadAll(r.HTTPResponse.Body)
if err != nil {
r.Error = awserr.New("SerializationError", "unable to unmarshal EC2 metadata respose", err)
}
data := r.Data.(*metadataOutput)
data.Content = string(b)
}
func unmarshalError(r *request.Request) {
defer r.HTTPResponse.Body.Close()
_, err := ioutil.ReadAll(r.HTTPResponse.Body)
if err != nil {
r.Error = awserr.New("SerializationError", "unable to unmarshal EC2 metadata error respose", err)
}
// TODO extract the error...
}
func validateEndpointHandler(r *request.Request) {
if r.ClientInfo.Endpoint == "" {
r.Error = aws.ErrMissingEndpoint
}
}

View File

@@ -0,0 +1,17 @@
package aws
import "github.com/aws/aws-sdk-go/aws/awserr"
var (
// ErrMissingRegion is an error that is returned if region configuration is
// not found.
//
// @readonly
ErrMissingRegion = awserr.New("MissingRegion", "could not find region configuration", nil)
// ErrMissingEndpoint is an error that is returned if an endpoint cannot be
// resolved for a service.
//
// @readonly
ErrMissingEndpoint = awserr.New("MissingEndpoint", "'Endpoint' configuration is required for this service", nil)
)

View File

@@ -1,81 +0,0 @@
package aws
import (
"net/http"
"os"
"testing"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/stretchr/testify/assert"
)
func TestValidateEndpointHandler(t *testing.T) {
os.Clearenv()
svc := NewService(NewConfig().WithRegion("us-west-2"))
svc.Handlers.Clear()
svc.Handlers.Validate.PushBack(ValidateEndpointHandler)
req := NewRequest(svc, &Operation{Name: "Operation"}, nil, nil)
err := req.Build()
assert.NoError(t, err)
}
func TestValidateEndpointHandlerErrorRegion(t *testing.T) {
os.Clearenv()
svc := NewService(nil)
svc.Handlers.Clear()
svc.Handlers.Validate.PushBack(ValidateEndpointHandler)
req := NewRequest(svc, &Operation{Name: "Operation"}, nil, nil)
err := req.Build()
assert.Error(t, err)
assert.Equal(t, ErrMissingRegion, err)
}
type mockCredsProvider struct {
expired bool
retrieveCalled bool
}
func (m *mockCredsProvider) Retrieve() (credentials.Value, error) {
m.retrieveCalled = true
return credentials.Value{}, nil
}
func (m *mockCredsProvider) IsExpired() bool {
return m.expired
}
func TestAfterRetryRefreshCreds(t *testing.T) {
os.Clearenv()
credProvider := &mockCredsProvider{}
svc := NewService(&Config{Credentials: credentials.NewCredentials(credProvider), MaxRetries: Int(1)})
svc.Handlers.Clear()
svc.Handlers.ValidateResponse.PushBack(func(r *Request) {
r.Error = awserr.New("UnknownError", "", nil)
r.HTTPResponse = &http.Response{StatusCode: 400}
})
svc.Handlers.UnmarshalError.PushBack(func(r *Request) {
r.Error = awserr.New("ExpiredTokenException", "", nil)
})
svc.Handlers.AfterRetry.PushBack(func(r *Request) {
AfterRetryHandler(r)
})
assert.True(t, svc.Config.Credentials.IsExpired(), "Expect to start out expired")
assert.False(t, credProvider.retrieveCalled)
req := NewRequest(svc, &Operation{Name: "Operation"}, nil, nil)
req.Send()
assert.True(t, svc.Config.Credentials.IsExpired())
assert.False(t, credProvider.retrieveCalled)
_, err := svc.Config.Credentials.Get()
assert.NoError(t, err)
assert.True(t, credProvider.retrieveCalled)
}

View File

@@ -1,85 +0,0 @@
package aws
// A Handlers provides a collection of request handlers for various
// stages of handling requests.
type Handlers struct {
Validate HandlerList
Build HandlerList
Sign HandlerList
Send HandlerList
ValidateResponse HandlerList
Unmarshal HandlerList
UnmarshalMeta HandlerList
UnmarshalError HandlerList
Retry HandlerList
AfterRetry HandlerList
}
// copy returns of this handler's lists.
func (h *Handlers) copy() Handlers {
return Handlers{
Validate: h.Validate.copy(),
Build: h.Build.copy(),
Sign: h.Sign.copy(),
Send: h.Send.copy(),
ValidateResponse: h.ValidateResponse.copy(),
Unmarshal: h.Unmarshal.copy(),
UnmarshalError: h.UnmarshalError.copy(),
UnmarshalMeta: h.UnmarshalMeta.copy(),
Retry: h.Retry.copy(),
AfterRetry: h.AfterRetry.copy(),
}
}
// Clear removes callback functions for all handlers
func (h *Handlers) Clear() {
h.Validate.Clear()
h.Build.Clear()
h.Send.Clear()
h.Sign.Clear()
h.Unmarshal.Clear()
h.UnmarshalMeta.Clear()
h.UnmarshalError.Clear()
h.ValidateResponse.Clear()
h.Retry.Clear()
h.AfterRetry.Clear()
}
// A HandlerList manages zero or more handlers in a list.
type HandlerList struct {
list []func(*Request)
}
// copy creates a copy of the handler list.
func (l *HandlerList) copy() HandlerList {
var n HandlerList
n.list = append([]func(*Request){}, l.list...)
return n
}
// Clear clears the handler list.
func (l *HandlerList) Clear() {
l.list = []func(*Request){}
}
// Len returns the number of handlers in the list.
func (l *HandlerList) Len() int {
return len(l.list)
}
// PushBack pushes handlers f to the back of the handler list.
func (l *HandlerList) PushBack(f ...func(*Request)) {
l.list = append(l.list, f...)
}
// PushFront pushes handlers f to the front of the handler list.
func (l *HandlerList) PushFront(f ...func(*Request)) {
l.list = append(f, l.list...)
}
// Run executes all handlers in the list with a given request object.
func (l *HandlerList) Run(r *Request) {
for _, f := range l.list {
f(r)
}
}

View File

@@ -1,31 +0,0 @@
package aws
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestHandlerList(t *testing.T) {
s := ""
r := &Request{}
l := HandlerList{}
l.PushBack(func(r *Request) {
s += "a"
r.Data = s
})
l.Run(r)
assert.Equal(t, "a", s)
assert.Equal(t, "a", r.Data)
}
func TestMultipleHandlers(t *testing.T) {
r := &Request{}
l := HandlerList{}
l.PushBack(func(r *Request) { r.Data = nil })
l.PushFront(func(r *Request) { r.Data = Bool(true) })
l.Run(r)
if r.Data != nil {
t.Error("Expected handler to execute")
}
}

View File

@@ -62,6 +62,15 @@ const (
// see the body content of requests and responses made while using the SDK
// Will also enable LogDebug.
LogDebugWithHTTPBody
// LogDebugWithRequestRetries states the SDK should log when service requests will
// be retried. This should be used to log when you want to log when service
// requests are being retried. Will also enable LogDebug.
LogDebugWithRequestRetries
// LogDebugWithRequestErrors states the SDK should log when service requests fail
// to build, send, validate, or unmarshal.
LogDebugWithRequestErrors
)
// A Logger is a minimalistic interface for the SDK to log messages to. Should

View File

@@ -1,89 +0,0 @@
package aws
import (
"fmt"
"reflect"
"strings"
"github.com/aws/aws-sdk-go/aws/awserr"
)
// ValidateParameters is a request handler to validate the input parameters.
// Validating parameters only has meaning if done prior to the request being sent.
func ValidateParameters(r *Request) {
if r.ParamsFilled() {
v := validator{errors: []string{}}
v.validateAny(reflect.ValueOf(r.Params), "")
if count := len(v.errors); count > 0 {
format := "%d validation errors:\n- %s"
msg := fmt.Sprintf(format, count, strings.Join(v.errors, "\n- "))
r.Error = awserr.New("InvalidParameter", msg, nil)
}
}
}
// A validator validates values. Collects validations errors which occurs.
type validator struct {
errors []string
}
// validateAny will validate any struct, slice or map type. All validations
// are also performed recursively for nested types.
func (v *validator) validateAny(value reflect.Value, path string) {
value = reflect.Indirect(value)
if !value.IsValid() {
return
}
switch value.Kind() {
case reflect.Struct:
v.validateStruct(value, path)
case reflect.Slice:
for i := 0; i < value.Len(); i++ {
v.validateAny(value.Index(i), path+fmt.Sprintf("[%d]", i))
}
case reflect.Map:
for _, n := range value.MapKeys() {
v.validateAny(value.MapIndex(n), path+fmt.Sprintf("[%q]", n.String()))
}
}
}
// validateStruct will validate the struct value's fields. If the structure has
// nested types those types will be validated also.
func (v *validator) validateStruct(value reflect.Value, path string) {
prefix := "."
if path == "" {
prefix = ""
}
for i := 0; i < value.Type().NumField(); i++ {
f := value.Type().Field(i)
if strings.ToLower(f.Name[0:1]) == f.Name[0:1] {
continue
}
fvalue := value.FieldByName(f.Name)
notset := false
if f.Tag.Get("required") != "" {
switch fvalue.Kind() {
case reflect.Ptr, reflect.Slice, reflect.Map:
if fvalue.IsNil() {
notset = true
}
default:
if !fvalue.IsValid() {
notset = true
}
}
}
if notset {
msg := "missing required parameter: " + path + prefix + f.Name
v.errors = append(v.errors, msg)
} else {
v.validateAny(fvalue, path+prefix+f.Name)
}
}
}

View File

@@ -1,84 +0,0 @@
package aws_test
import (
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/stretchr/testify/assert"
)
var service = func() *aws.Service {
s := &aws.Service{
Config: &aws.Config{},
ServiceName: "mock-service",
APIVersion: "2015-01-01",
}
return s
}()
type StructShape struct {
RequiredList []*ConditionalStructShape `required:"true"`
RequiredMap map[string]*ConditionalStructShape `required:"true"`
RequiredBool *bool `required:"true"`
OptionalStruct *ConditionalStructShape
hiddenParameter *string
metadataStructureShape
}
type metadataStructureShape struct {
SDKShapeTraits bool
}
type ConditionalStructShape struct {
Name *string `required:"true"`
SDKShapeTraits bool
}
func TestNoErrors(t *testing.T) {
input := &StructShape{
RequiredList: []*ConditionalStructShape{},
RequiredMap: map[string]*ConditionalStructShape{
"key1": {Name: aws.String("Name")},
"key2": {Name: aws.String("Name")},
},
RequiredBool: aws.Bool(true),
OptionalStruct: &ConditionalStructShape{Name: aws.String("Name")},
}
req := aws.NewRequest(service, &aws.Operation{}, input, nil)
aws.ValidateParameters(req)
assert.NoError(t, req.Error)
}
func TestMissingRequiredParameters(t *testing.T) {
input := &StructShape{}
req := aws.NewRequest(service, &aws.Operation{}, input, nil)
aws.ValidateParameters(req)
assert.Error(t, req.Error)
assert.Equal(t, "InvalidParameter", req.Error.(awserr.Error).Code())
assert.Equal(t, "3 validation errors:\n- missing required parameter: RequiredList\n- missing required parameter: RequiredMap\n- missing required parameter: RequiredBool", req.Error.(awserr.Error).Message())
}
func TestNestedMissingRequiredParameters(t *testing.T) {
input := &StructShape{
RequiredList: []*ConditionalStructShape{{}},
RequiredMap: map[string]*ConditionalStructShape{
"key1": {Name: aws.String("Name")},
"key2": {},
},
RequiredBool: aws.Bool(true),
OptionalStruct: &ConditionalStructShape{},
}
req := aws.NewRequest(service, &aws.Operation{}, input, nil)
aws.ValidateParameters(req)
assert.Error(t, req.Error)
assert.Equal(t, "InvalidParameter", req.Error.(awserr.Error).Code())
assert.Equal(t, "3 validation errors:\n- missing required parameter: RequiredList[0].Name\n- missing required parameter: RequiredMap[\"key2\"].Name\n- missing required parameter: OptionalStruct.Name", req.Error.(awserr.Error).Message())
}

View File

@@ -0,0 +1,140 @@
package request
import (
"fmt"
"strings"
)
// A Handlers provides a collection of request handlers for various
// stages of handling requests.
type Handlers struct {
Validate HandlerList
Build HandlerList
Sign HandlerList
Send HandlerList
ValidateResponse HandlerList
Unmarshal HandlerList
UnmarshalMeta HandlerList
UnmarshalError HandlerList
Retry HandlerList
AfterRetry HandlerList
}
// Copy returns of this handler's lists.
func (h *Handlers) Copy() Handlers {
return Handlers{
Validate: h.Validate.copy(),
Build: h.Build.copy(),
Sign: h.Sign.copy(),
Send: h.Send.copy(),
ValidateResponse: h.ValidateResponse.copy(),
Unmarshal: h.Unmarshal.copy(),
UnmarshalError: h.UnmarshalError.copy(),
UnmarshalMeta: h.UnmarshalMeta.copy(),
Retry: h.Retry.copy(),
AfterRetry: h.AfterRetry.copy(),
}
}
// Clear removes callback functions for all handlers
func (h *Handlers) Clear() {
h.Validate.Clear()
h.Build.Clear()
h.Send.Clear()
h.Sign.Clear()
h.Unmarshal.Clear()
h.UnmarshalMeta.Clear()
h.UnmarshalError.Clear()
h.ValidateResponse.Clear()
h.Retry.Clear()
h.AfterRetry.Clear()
}
// A HandlerList manages zero or more handlers in a list.
type HandlerList struct {
list []NamedHandler
}
// A NamedHandler is a struct that contains a name and function callback.
type NamedHandler struct {
Name string
Fn func(*Request)
}
// copy creates a copy of the handler list.
func (l *HandlerList) copy() HandlerList {
var n HandlerList
n.list = append([]NamedHandler{}, l.list...)
return n
}
// Clear clears the handler list.
func (l *HandlerList) Clear() {
l.list = []NamedHandler{}
}
// Len returns the number of handlers in the list.
func (l *HandlerList) Len() int {
return len(l.list)
}
// PushBack pushes handler f to the back of the handler list.
func (l *HandlerList) PushBack(f func(*Request)) {
l.list = append(l.list, NamedHandler{"__anonymous", f})
}
// PushFront pushes handler f to the front of the handler list.
func (l *HandlerList) PushFront(f func(*Request)) {
l.list = append([]NamedHandler{{"__anonymous", f}}, l.list...)
}
// PushBackNamed pushes named handler f to the back of the handler list.
func (l *HandlerList) PushBackNamed(n NamedHandler) {
l.list = append(l.list, n)
}
// PushFrontNamed pushes named handler f to the front of the handler list.
func (l *HandlerList) PushFrontNamed(n NamedHandler) {
l.list = append([]NamedHandler{n}, l.list...)
}
// Remove removes a NamedHandler n
func (l *HandlerList) Remove(n NamedHandler) {
newlist := []NamedHandler{}
for _, m := range l.list {
if m.Name != n.Name {
newlist = append(newlist, m)
}
}
l.list = newlist
}
// Run executes all handlers in the list with a given request object.
func (l *HandlerList) Run(r *Request) {
for _, f := range l.list {
f.Fn(r)
}
}
// MakeAddToUserAgentHandler will add the name/version pair to the User-Agent request
// header. If the extra parameters are provided they will be added as metadata to the
// name/version pair resulting in the following format.
// "name/version (extra0; extra1; ...)"
// The user agent part will be concatenated with this current request's user agent string.
func MakeAddToUserAgentHandler(name, version string, extra ...string) func(*Request) {
ua := fmt.Sprintf("%s/%s", name, version)
if len(extra) > 0 {
ua += fmt.Sprintf(" (%s)", strings.Join(extra, "; "))
}
return func(r *Request) {
AddToUserAgent(r, ua)
}
}
// MakeAddToUserAgentFreeFormHandler adds the input to the User-Agent request header.
// The input string will be concatenated with the current request's user agent string.
func MakeAddToUserAgentFreeFormHandler(s string) func(*Request) {
return func(r *Request) {
AddToUserAgent(r, s)
}
}

View File

@@ -0,0 +1,47 @@
package request_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
)
func TestHandlerList(t *testing.T) {
s := ""
r := &request.Request{}
l := request.HandlerList{}
l.PushBack(func(r *request.Request) {
s += "a"
r.Data = s
})
l.Run(r)
assert.Equal(t, "a", s)
assert.Equal(t, "a", r.Data)
}
func TestMultipleHandlers(t *testing.T) {
r := &request.Request{}
l := request.HandlerList{}
l.PushBack(func(r *request.Request) { r.Data = nil })
l.PushFront(func(r *request.Request) { r.Data = aws.Bool(true) })
l.Run(r)
if r.Data != nil {
t.Error("Expected handler to execute")
}
}
func TestNamedHandlers(t *testing.T) {
l := request.HandlerList{}
named := request.NamedHandler{Name: "Name", Fn: func(r *request.Request) {}}
named2 := request.NamedHandler{Name: "NotName", Fn: func(r *request.Request) {}}
l.PushBackNamed(named)
l.PushBackNamed(named)
l.PushBackNamed(named2)
l.PushBack(func(r *request.Request) {})
assert.Equal(t, 4, l.Len())
l.Remove(named)
assert.Equal(t, 2, l.Len())
}

View File

@@ -1,7 +1,8 @@
package aws
package request
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
@@ -10,25 +11,29 @@ import (
"strings"
"time"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client/metadata"
)
// A Request is the service request to be made.
type Request struct {
*Service
Handlers Handlers
Config aws.Config
ClientInfo metadata.ClientInfo
Handlers Handlers
Retryer
Time time.Time
ExpireTime time.Duration
Operation *Operation
HTTPRequest *http.Request
HTTPResponse *http.Response
Body io.ReadSeeker
bodyStart int64 // offset from beginning of Body that the request body starts
BodyStart int64 // offset from beginning of Body that the request body starts
Params interface{}
Error error
Data interface{}
RequestID string
RetryCount uint
RetryCount int
Retryable *bool
RetryDelay time.Duration
@@ -51,13 +56,15 @@ type Paginator struct {
TruncationToken string
}
// NewRequest returns a new Request pointer for the service API
// New returns a new Request pointer for the service API
// operation and parameters.
//
// Params is any value of input parameters to be the request payload.
// Data is pointer value to an object which the request's response
// payload will be deserialized to.
func NewRequest(service *Service, operation *Operation, params interface{}, data interface{}) *Request {
func New(cfg aws.Config, clientInfo metadata.ClientInfo, handlers Handlers,
retryer Retryer, operation *Operation, params interface{}, data interface{}) *Request {
method := operation.HTTPMethod
if method == "" {
method = "POST"
@@ -68,11 +75,14 @@ func NewRequest(service *Service, operation *Operation, params interface{}, data
}
httpReq, _ := http.NewRequest(method, "", nil)
httpReq.URL, _ = url.Parse(service.Endpoint + p)
httpReq.URL, _ = url.Parse(clientInfo.Endpoint + p)
r := &Request{
Service: service,
Handlers: service.Handlers.copy(),
Config: cfg,
ClientInfo: clientInfo,
Handlers: handlers.Copy(),
Retryer: retryer,
Time: time.Now(),
ExpireTime: 0,
Operation: operation,
@@ -89,7 +99,7 @@ func NewRequest(service *Service, operation *Operation, params interface{}, data
// WillRetry returns if the request's can be retried.
func (r *Request) WillRetry() bool {
return r.Error != nil && BoolValue(r.Retryable) && r.RetryCount < r.Service.MaxRetries()
return r.Error != nil && aws.BoolValue(r.Retryable) && r.RetryCount < r.MaxRetries()
}
// ParamsFilled returns if the request's parameters have been populated
@@ -134,6 +144,20 @@ func (r *Request) Presign(expireTime time.Duration) (string, error) {
return r.HTTPRequest.URL.String(), nil
}
func debugLogReqError(r *Request, stage string, retrying bool, err error) {
if !r.Config.LogLevel.Matches(aws.LogDebugWithRequestErrors) {
return
}
retryStr := "not retrying"
if retrying {
retryStr = "will retry"
}
r.Config.Logger.Log(fmt.Sprintf("DEBUG: %s %s/%s failed, %s, error %v",
stage, r.ClientInfo.ServiceName, r.Operation.Name, retryStr, err))
}
// Build will build the request's object so it can be signed and sent
// to the service. Build will also validate all the request's parameters.
// Anny additional build Handlers set on this request will be run
@@ -149,6 +173,7 @@ func (r *Request) Build() error {
r.Error = nil
r.Handlers.Validate.Run(r)
if r.Error != nil {
debugLogReqError(r, "Validate Request", false, r.Error)
return r.Error
}
r.Handlers.Build.Run(r)
@@ -165,6 +190,7 @@ func (r *Request) Build() error {
func (r *Request) Sign() error {
r.Build()
if r.Error != nil {
debugLogReqError(r, "Build Request", false, r.Error)
return r.Error
}
@@ -183,42 +209,57 @@ func (r *Request) Send() error {
return r.Error
}
if BoolValue(r.Retryable) {
if aws.BoolValue(r.Retryable) {
if r.Config.LogLevel.Matches(aws.LogDebugWithRequestRetries) {
r.Config.Logger.Log(fmt.Sprintf("DEBUG: Retrying Request %s/%s, attempt %d",
r.ClientInfo.ServiceName, r.Operation.Name, r.RetryCount))
}
// Re-seek the body back to the original point in for a retry so that
// send will send the body's contents again in the upcoming request.
r.Body.Seek(r.bodyStart, 0)
r.Body.Seek(r.BodyStart, 0)
r.HTTPRequest.Body = ioutil.NopCloser(r.Body)
}
r.Retryable = nil
r.Handlers.Send.Run(r)
if r.Error != nil {
err := r.Error
r.Handlers.Retry.Run(r)
r.Handlers.AfterRetry.Run(r)
if r.Error != nil {
debugLogReqError(r, "Send Request", false, r.Error)
return r.Error
}
debugLogReqError(r, "Send Request", true, err)
continue
}
r.Handlers.UnmarshalMeta.Run(r)
r.Handlers.ValidateResponse.Run(r)
if r.Error != nil {
err := r.Error
r.Handlers.UnmarshalError.Run(r)
r.Handlers.Retry.Run(r)
r.Handlers.AfterRetry.Run(r)
if r.Error != nil {
debugLogReqError(r, "Validate Response", false, r.Error)
return r.Error
}
debugLogReqError(r, "Validate Response", true, err)
continue
}
r.Handlers.Unmarshal.Run(r)
if r.Error != nil {
err := r.Error
r.Handlers.Retry.Run(r)
r.Handlers.AfterRetry.Run(r)
if r.Error != nil {
debugLogReqError(r, "Unmarshal Response", false, r.Error)
return r.Error
}
debugLogReqError(r, "Unmarshal Response", true, err)
continue
}
@@ -228,85 +269,11 @@ func (r *Request) Send() error {
return nil
}
// HasNextPage returns true if this request has more pages of data available.
func (r *Request) HasNextPage() bool {
return r.nextPageTokens() != nil
}
// nextPageTokens returns the tokens to use when asking for the next page of
// data.
func (r *Request) nextPageTokens() []interface{} {
if r.Operation.Paginator == nil {
return nil
}
if r.Operation.TruncationToken != "" {
tr := awsutil.ValuesAtAnyPath(r.Data, r.Operation.TruncationToken)
if tr == nil || len(tr) == 0 {
return nil
}
switch v := tr[0].(type) {
case bool:
if v == false {
return nil
}
}
}
found := false
tokens := make([]interface{}, len(r.Operation.OutputTokens))
for i, outtok := range r.Operation.OutputTokens {
v := awsutil.ValuesAtAnyPath(r.Data, outtok)
if v != nil && len(v) > 0 {
found = true
tokens[i] = v[0]
}
}
if found {
return tokens
}
return nil
}
// NextPage returns a new Request that can be executed to return the next
// page of result data. Call .Send() on this request to execute it.
func (r *Request) NextPage() *Request {
tokens := r.nextPageTokens()
if tokens == nil {
return nil
}
data := reflect.New(reflect.TypeOf(r.Data).Elem()).Interface()
nr := NewRequest(r.Service, r.Operation, awsutil.CopyOf(r.Params), data)
for i, intok := range nr.Operation.InputTokens {
awsutil.SetValueAtAnyPath(nr.Params, intok, tokens[i])
}
return nr
}
// EachPage iterates over each page of a paginated request object. The fn
// parameter should be a function with the following sample signature:
//
// func(page *T, lastPage bool) bool {
// return true // return false to stop iterating
// }
//
// Where "T" is the structure type matching the output structure of the given
// operation. For example, a request object generated by
// DynamoDB.ListTablesRequest() would expect to see dynamodb.ListTablesOutput
// as the structure "T". The lastPage value represents whether the page is
// the last page of data or not. The return value of this function should
// return true to keep iterating or false to stop.
func (r *Request) EachPage(fn func(data interface{}, isLastPage bool) (shouldContinue bool)) error {
for page := r; page != nil; page = page.NextPage() {
page.Send()
shouldContinue := fn(page.Data, !page.HasNextPage())
if page.Error != nil || !shouldContinue {
return page.Error
}
}
return nil
// AddToUserAgent adds the string to the end of the request's current user agent.
func AddToUserAgent(r *Request, s string) {
curUA := r.HTTPRequest.Header.Get("User-Agent")
if len(curUA) > 0 {
s = curUA + " " + s
}
r.HTTPRequest.Header.Set("User-Agent", s)
}

View File

@@ -0,0 +1,96 @@
package request
import (
"reflect"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awsutil"
)
//type Paginater interface {
// HasNextPage() bool
// NextPage() *Request
// EachPage(fn func(data interface{}, isLastPage bool) (shouldContinue bool)) error
//}
// HasNextPage returns true if this request has more pages of data available.
func (r *Request) HasNextPage() bool {
return len(r.nextPageTokens()) > 0
}
// nextPageTokens returns the tokens to use when asking for the next page of
// data.
func (r *Request) nextPageTokens() []interface{} {
if r.Operation.Paginator == nil {
return nil
}
if r.Operation.TruncationToken != "" {
tr, _ := awsutil.ValuesAtPath(r.Data, r.Operation.TruncationToken)
if len(tr) == 0 {
return nil
}
switch v := tr[0].(type) {
case *bool:
if !aws.BoolValue(v) {
return nil
}
case bool:
if v == false {
return nil
}
}
}
tokens := []interface{}{}
for _, outToken := range r.Operation.OutputTokens {
v, _ := awsutil.ValuesAtPath(r.Data, outToken)
if len(v) > 0 {
tokens = append(tokens, v[0])
}
}
return tokens
}
// NextPage returns a new Request that can be executed to return the next
// page of result data. Call .Send() on this request to execute it.
func (r *Request) NextPage() *Request {
tokens := r.nextPageTokens()
if len(tokens) == 0 {
return nil
}
data := reflect.New(reflect.TypeOf(r.Data).Elem()).Interface()
nr := New(r.Config, r.ClientInfo, r.Handlers, r.Retryer, r.Operation, awsutil.CopyOf(r.Params), data)
for i, intok := range nr.Operation.InputTokens {
awsutil.SetValueAtPath(nr.Params, intok, tokens[i])
}
return nr
}
// EachPage iterates over each page of a paginated request object. The fn
// parameter should be a function with the following sample signature:
//
// func(page *T, lastPage bool) bool {
// return true // return false to stop iterating
// }
//
// Where "T" is the structure type matching the output structure of the given
// operation. For example, a request object generated by
// DynamoDB.ListTablesRequest() would expect to see dynamodb.ListTablesOutput
// as the structure "T". The lastPage value represents whether the page is
// the last page of data or not. The return value of this function should
// return true to keep iterating or false to stop.
func (r *Request) EachPage(fn func(data interface{}, isLastPage bool) (shouldContinue bool)) error {
for page := r; page != nil; page = page.NextPage() {
page.Send()
shouldContinue := fn(page.Data, !page.HasNextPage())
if page.Error != nil || !shouldContinue {
return page.Error
}
}
return nil
}

View File

@@ -1,20 +1,108 @@
package aws_test
package request_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/internal/test/unit"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/stretchr/testify/assert"
)
var _ = unit.Imported
// Use DynamoDB methods for simplicity
func TestPaginationQueryPage(t *testing.T) {
db := dynamodb.New(unit.Session)
tokens, pages, numPages, gotToEnd := []map[string]*dynamodb.AttributeValue{}, []map[string]*dynamodb.AttributeValue{}, 0, false
reqNum := 0
resps := []*dynamodb.QueryOutput{
{
LastEvaluatedKey: map[string]*dynamodb.AttributeValue{"key": {S: aws.String("key1")}},
Count: aws.Int64(1),
Items: []map[string]*dynamodb.AttributeValue{
map[string]*dynamodb.AttributeValue{
"key": {S: aws.String("key1")},
},
},
},
{
LastEvaluatedKey: map[string]*dynamodb.AttributeValue{"key": {S: aws.String("key2")}},
Count: aws.Int64(1),
Items: []map[string]*dynamodb.AttributeValue{
map[string]*dynamodb.AttributeValue{
"key": {S: aws.String("key2")},
},
},
},
{
LastEvaluatedKey: map[string]*dynamodb.AttributeValue{},
Count: aws.Int64(1),
Items: []map[string]*dynamodb.AttributeValue{
map[string]*dynamodb.AttributeValue{
"key": {S: aws.String("key3")},
},
},
},
}
db.Handlers.Send.Clear() // mock sending
db.Handlers.Unmarshal.Clear()
db.Handlers.UnmarshalMeta.Clear()
db.Handlers.ValidateResponse.Clear()
db.Handlers.Build.PushBack(func(r *request.Request) {
in := r.Params.(*dynamodb.QueryInput)
if in == nil {
tokens = append(tokens, nil)
} else if len(in.ExclusiveStartKey) != 0 {
tokens = append(tokens, in.ExclusiveStartKey)
}
})
db.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = resps[reqNum]
reqNum++
})
params := &dynamodb.QueryInput{
Limit: aws.Int64(2),
TableName: aws.String("tablename"),
}
err := db.QueryPages(params, func(p *dynamodb.QueryOutput, last bool) bool {
numPages++
for _, item := range p.Items {
pages = append(pages, item)
}
if last {
if gotToEnd {
assert.Fail(t, "last=true happened twice")
}
gotToEnd = true
}
return true
})
assert.Nil(t, err)
assert.Equal(t,
[]map[string]*dynamodb.AttributeValue{
map[string]*dynamodb.AttributeValue{"key": {S: aws.String("key1")}},
map[string]*dynamodb.AttributeValue{"key": {S: aws.String("key2")}},
}, tokens)
assert.Equal(t,
[]map[string]*dynamodb.AttributeValue{
map[string]*dynamodb.AttributeValue{"key": {S: aws.String("key1")}},
map[string]*dynamodb.AttributeValue{"key": {S: aws.String("key2")}},
map[string]*dynamodb.AttributeValue{"key": {S: aws.String("key3")}},
}, pages)
assert.Equal(t, 3, numPages)
assert.True(t, gotToEnd)
assert.Nil(t, params.ExclusiveStartKey)
}
// Use DynamoDB methods for simplicity
func TestPagination(t *testing.T) {
db := dynamodb.New(nil)
db := dynamodb.New(unit.Session)
tokens, pages, numPages, gotToEnd := []string{}, []string{}, 0, false
reqNum := 0
@@ -28,7 +116,7 @@ func TestPagination(t *testing.T) {
db.Handlers.Unmarshal.Clear()
db.Handlers.UnmarshalMeta.Clear()
db.Handlers.ValidateResponse.Clear()
db.Handlers.Build.PushBack(func(r *aws.Request) {
db.Handlers.Build.PushBack(func(r *request.Request) {
in := r.Params.(*dynamodb.ListTablesInput)
if in == nil {
tokens = append(tokens, "")
@@ -36,7 +124,7 @@ func TestPagination(t *testing.T) {
tokens = append(tokens, *in.ExclusiveStartTableName)
}
})
db.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
db.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = resps[reqNum]
reqNum++
})
@@ -66,7 +154,7 @@ func TestPagination(t *testing.T) {
// Use DynamoDB methods for simplicity
func TestPaginationEachPage(t *testing.T) {
db := dynamodb.New(nil)
db := dynamodb.New(unit.Session)
tokens, pages, numPages, gotToEnd := []string{}, []string{}, 0, false
reqNum := 0
@@ -80,7 +168,7 @@ func TestPaginationEachPage(t *testing.T) {
db.Handlers.Unmarshal.Clear()
db.Handlers.UnmarshalMeta.Clear()
db.Handlers.ValidateResponse.Clear()
db.Handlers.Build.PushBack(func(r *aws.Request) {
db.Handlers.Build.PushBack(func(r *request.Request) {
in := r.Params.(*dynamodb.ListTablesInput)
if in == nil {
tokens = append(tokens, "")
@@ -88,7 +176,7 @@ func TestPaginationEachPage(t *testing.T) {
tokens = append(tokens, *in.ExclusiveStartTableName)
}
})
db.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
db.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = resps[reqNum]
reqNum++
})
@@ -119,7 +207,7 @@ func TestPaginationEachPage(t *testing.T) {
// Use DynamoDB methods for simplicity
func TestPaginationEarlyExit(t *testing.T) {
db := dynamodb.New(nil)
db := dynamodb.New(unit.Session)
numPages, gotToEnd := 0, false
reqNum := 0
@@ -133,7 +221,7 @@ func TestPaginationEarlyExit(t *testing.T) {
db.Handlers.Unmarshal.Clear()
db.Handlers.UnmarshalMeta.Clear()
db.Handlers.ValidateResponse.Clear()
db.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
db.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = resps[reqNum]
reqNum++
})
@@ -159,12 +247,12 @@ func TestPaginationEarlyExit(t *testing.T) {
}
func TestSkipPagination(t *testing.T) {
client := s3.New(nil)
client := s3.New(unit.Session)
client.Handlers.Send.Clear() // mock sending
client.Handlers.Unmarshal.Clear()
client.Handlers.UnmarshalMeta.Clear()
client.Handlers.ValidateResponse.Clear()
client.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
client.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = &s3.HeadBucketOutput{}
})
@@ -184,10 +272,9 @@ func TestSkipPagination(t *testing.T) {
// Use S3 for simplicity
func TestPaginationTruncation(t *testing.T) {
count := 0
client := s3.New(nil)
client := s3.New(unit.Session)
reqNum := &count
reqNum := 0
resps := []*s3.ListObjectsOutput{
{IsTruncated: aws.Bool(true), Contents: []*s3.Object{{Key: aws.String("Key1")}}},
{IsTruncated: aws.Bool(true), Contents: []*s3.Object{{Key: aws.String("Key2")}}},
@@ -199,9 +286,9 @@ func TestPaginationTruncation(t *testing.T) {
client.Handlers.Unmarshal.Clear()
client.Handlers.UnmarshalMeta.Clear()
client.Handlers.ValidateResponse.Clear()
client.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
r.Data = resps[*reqNum]
*reqNum++
client.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = resps[reqNum]
reqNum++
})
params := &s3.ListObjectsInput{Bucket: aws.String("bucket")}
@@ -216,7 +303,7 @@ func TestPaginationTruncation(t *testing.T) {
assert.Nil(t, err)
// Try again without truncation token at all
count = 0
reqNum = 0
resps[1].IsTruncated = nil
resps[2].IsTruncated = aws.Bool(true)
results = []string{}
@@ -249,7 +336,7 @@ var benchResps = []*dynamodb.ListTablesOutput{
}
var benchDb = func() *dynamodb.DynamoDB {
db := dynamodb.New(nil)
db := dynamodb.New(unit.Session)
db.Handlers.Send.Clear() // mock sending
db.Handlers.Unmarshal.Clear()
db.Handlers.UnmarshalMeta.Clear()
@@ -260,7 +347,7 @@ var benchDb = func() *dynamodb.DynamoDB {
func BenchmarkCodegenIterator(b *testing.B) {
reqNum := 0
db := benchDb()
db.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
db.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = benchResps[reqNum]
reqNum++
})
@@ -289,7 +376,7 @@ func BenchmarkCodegenIterator(b *testing.B) {
func BenchmarkEachPageIterator(b *testing.B) {
reqNum := 0
db := benchDb()
db.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
db.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = benchResps[reqNum]
reqNum++
})

View File

@@ -1,4 +1,4 @@
package aws
package request_test
import (
"bytes"
@@ -7,12 +7,17 @@ import (
"io"
"io/ioutil"
"net/http"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
)
type testData struct {
@@ -23,7 +28,7 @@ func body(str string) io.ReadCloser {
return ioutil.NopCloser(bytes.NewReader([]byte(str)))
}
func unmarshal(req *Request) {
func unmarshal(req *request.Request) {
defer req.HTTPResponse.Body.Close()
if req.Data != nil {
json.NewDecoder(req.HTTPResponse.Body).Decode(req.Data)
@@ -31,7 +36,7 @@ func unmarshal(req *Request) {
return
}
func unmarshalError(req *Request) {
func unmarshalError(req *request.Request) {
bodyBytes, err := ioutil.ReadAll(req.HTTPResponse.Body)
if err != nil {
req.Error = awserr.New("UnmarshaleError", req.HTTPResponse.Status, err)
@@ -71,17 +76,17 @@ func TestRequestRecoverRetry5xx(t *testing.T) {
{StatusCode: 200, Body: body(`{"data":"valid"}`)},
}
s := NewService(NewConfig().WithMaxRetries(10))
s := awstesting.NewClient(aws.NewConfig().WithMaxRetries(10))
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *Request) {
s.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &reqs[reqNum]
reqNum++
})
out := &testData{}
r := NewRequest(s, &Operation{Name: "Operation"}, nil, out)
r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out)
err := r.Send()
assert.Nil(t, err)
assert.Equal(t, 2, int(r.RetryCount))
@@ -97,17 +102,17 @@ func TestRequestRecoverRetry4xxRetryable(t *testing.T) {
{StatusCode: 200, Body: body(`{"data":"valid"}`)},
}
s := NewService(NewConfig().WithMaxRetries(10))
s := awstesting.NewClient(aws.NewConfig().WithMaxRetries(10))
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *Request) {
s.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &reqs[reqNum]
reqNum++
})
out := &testData{}
r := NewRequest(s, &Operation{Name: "Operation"}, nil, out)
r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out)
err := r.Send()
assert.Nil(t, err)
assert.Equal(t, 2, int(r.RetryCount))
@@ -116,16 +121,16 @@ func TestRequestRecoverRetry4xxRetryable(t *testing.T) {
// test that retries don't occur for 4xx status codes with a response type that can't be retried
func TestRequest4xxUnretryable(t *testing.T) {
s := NewService(NewConfig().WithMaxRetries(10))
s := awstesting.NewClient(aws.NewConfig().WithMaxRetries(10))
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *Request) {
s.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &http.Response{StatusCode: 401, Body: body(`{"__type":"SignatureDoesNotMatch","message":"Signature does not match."}`)}
})
out := &testData{}
r := NewRequest(s, &Operation{Name: "Operation"}, nil, out)
r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out)
err := r.Send()
assert.NotNil(t, err)
if e, ok := err.(awserr.RequestFailure); ok {
@@ -140,7 +145,7 @@ func TestRequest4xxUnretryable(t *testing.T) {
func TestRequestExhaustRetries(t *testing.T) {
delays := []time.Duration{}
sleepDelay = func(delay time.Duration) {
sleepDelay := func(delay time.Duration) {
delays = append(delays, delay)
}
@@ -152,16 +157,16 @@ func TestRequestExhaustRetries(t *testing.T) {
{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
}
s := NewService(NewConfig().WithMaxRetries(DefaultRetries))
s := awstesting.NewClient(aws.NewConfig().WithSleepDelay(sleepDelay))
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *Request) {
s.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &reqs[reqNum]
reqNum++
})
r := NewRequest(s, &Operation{Name: "Operation"}, nil, nil)
r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
err := r.Send()
assert.NotNil(t, err)
if e, ok := err.(awserr.RequestFailure); ok {
@@ -190,7 +195,7 @@ func TestRequestRecoverExpiredCreds(t *testing.T) {
{StatusCode: 200, Body: body(`{"data":"valid"}`)},
}
s := NewService(&Config{MaxRetries: Int(10), Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "")})
s := awstesting.NewClient(&aws.Config{MaxRetries: aws.Int(10), Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "")})
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
@@ -198,21 +203,21 @@ func TestRequestRecoverExpiredCreds(t *testing.T) {
credExpiredBeforeRetry := false
credExpiredAfterRetry := false
s.Handlers.AfterRetry.PushBack(func(r *Request) {
s.Handlers.AfterRetry.PushBack(func(r *request.Request) {
credExpiredAfterRetry = r.Config.Credentials.IsExpired()
})
s.Handlers.Sign.Clear()
s.Handlers.Sign.PushBack(func(r *Request) {
s.Handlers.Sign.PushBack(func(r *request.Request) {
r.Config.Credentials.Get()
})
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *Request) {
s.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &reqs[reqNum]
reqNum++
})
out := &testData{}
r := NewRequest(s, &Operation{Name: "Operation"}, nil, out)
r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out)
err := r.Send()
assert.Nil(t, err)
@@ -223,3 +228,34 @@ func TestRequestRecoverExpiredCreds(t *testing.T) {
assert.Equal(t, 1, int(r.RetryCount))
assert.Equal(t, "valid", out.Data)
}
func TestMakeAddtoUserAgentHandler(t *testing.T) {
fn := request.MakeAddToUserAgentHandler("name", "version", "extra1", "extra2")
r := &request.Request{HTTPRequest: &http.Request{Header: http.Header{}}}
r.HTTPRequest.Header.Set("User-Agent", "foo/bar")
fn(r)
assert.Equal(t, "foo/bar name/version (extra1; extra2)", r.HTTPRequest.Header.Get("User-Agent"))
}
func TestMakeAddtoUserAgentFreeFormHandler(t *testing.T) {
fn := request.MakeAddToUserAgentFreeFormHandler("name/version (extra1; extra2)")
r := &request.Request{HTTPRequest: &http.Request{Header: http.Header{}}}
r.HTTPRequest.Header.Set("User-Agent", "foo/bar")
fn(r)
assert.Equal(t, "foo/bar name/version (extra1; extra2)", r.HTTPRequest.Header.Get("User-Agent"))
}
func TestRequestUserAgent(t *testing.T) {
s := awstesting.NewClient(&aws.Config{Region: aws.String("us-east-1")})
// s.Handlers.Validate.Clear()
req := s.NewRequest(&request.Operation{Name: "Operation"}, nil, &testData{})
req.HTTPRequest.Header.Set("User-Agent", "foo/bar")
assert.NoError(t, req.Build())
expectUA := fmt.Sprintf("foo/bar %s/%s (%s; %s; %s)",
aws.SDKName, aws.SDKVersion, runtime.Version(), runtime.GOOS, runtime.GOARCH)
assert.Equal(t, expectUA, req.HTTPRequest.Header.Get("User-Agent"))
}

View File

@@ -0,0 +1,74 @@
package request
import (
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
)
// Retryer is an interface to control retry logic for a given service.
// The default implementation used by most services is the service.DefaultRetryer
// structure, which contains basic retry logic using exponential backoff.
type Retryer interface {
RetryRules(*Request) time.Duration
ShouldRetry(*Request) bool
MaxRetries() int
}
// retryableCodes is a collection of service response codes which are retry-able
// without any further action.
var retryableCodes = map[string]struct{}{
"RequestError": {},
"RequestTimeout": {},
"ProvisionedThroughputExceededException": {},
"Throttling": {},
"ThrottlingException": {},
"RequestLimitExceeded": {},
"RequestThrottled": {},
"LimitExceededException": {}, // Deleting 10+ DynamoDb tables at once
"TooManyRequestsException": {}, // Lambda functions
}
// credsExpiredCodes is a collection of error codes which signify the credentials
// need to be refreshed. Expired tokens require refreshing of credentials, and
// resigning before the request can be retried.
var credsExpiredCodes = map[string]struct{}{
"ExpiredToken": {},
"ExpiredTokenException": {},
"RequestExpired": {}, // EC2 Only
}
func isCodeRetryable(code string) bool {
if _, ok := retryableCodes[code]; ok {
return true
}
return isCodeExpiredCreds(code)
}
func isCodeExpiredCreds(code string) bool {
_, ok := credsExpiredCodes[code]
return ok
}
// IsErrorRetryable returns whether the error is retryable, based on its Code.
// Returns false if the request has no Error set.
func (r *Request) IsErrorRetryable() bool {
if r.Error != nil {
if err, ok := r.Error.(awserr.Error); ok {
return isCodeRetryable(err.Code())
}
}
return false
}
// IsErrorExpired returns whether the error code is a credential expiry error.
// Returns false if the request has no Error set.
func (r *Request) IsErrorExpired() bool {
if r.Error != nil {
if err, ok := r.Error.(awserr.Error); ok {
return isCodeExpiredCreds(err.Code())
}
}
return false
}

View File

@@ -1,194 +0,0 @@
package aws
import (
"fmt"
"math"
"math/rand"
"net/http"
"net/http/httputil"
"regexp"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/internal/endpoints"
)
// A Service implements the base service request and response handling
// used by all services.
type Service struct {
Config *Config
Handlers Handlers
ServiceName string
APIVersion string
Endpoint string
SigningName string
SigningRegion string
JSONVersion string
TargetPrefix string
RetryRules func(*Request) time.Duration
ShouldRetry func(*Request) bool
DefaultMaxRetries uint
}
var schemeRE = regexp.MustCompile("^([^:]+)://")
// NewService will return a pointer to a new Server object initialized.
func NewService(config *Config) *Service {
svc := &Service{Config: config}
svc.Initialize()
return svc
}
// Initialize initializes the service.
func (s *Service) Initialize() {
if s.Config == nil {
s.Config = &Config{}
}
if s.Config.HTTPClient == nil {
s.Config.HTTPClient = http.DefaultClient
}
if s.RetryRules == nil {
s.RetryRules = retryRules
}
if s.ShouldRetry == nil {
s.ShouldRetry = shouldRetry
}
s.DefaultMaxRetries = 3
s.Handlers.Validate.PushBack(ValidateEndpointHandler)
s.Handlers.Build.PushBack(UserAgentHandler)
s.Handlers.Sign.PushBack(BuildContentLength)
s.Handlers.Send.PushBack(SendHandler)
s.Handlers.AfterRetry.PushBack(AfterRetryHandler)
s.Handlers.ValidateResponse.PushBack(ValidateResponseHandler)
s.AddDebugHandlers()
s.buildEndpoint()
if !BoolValue(s.Config.DisableParamValidation) {
s.Handlers.Validate.PushBack(ValidateParameters)
}
}
// buildEndpoint builds the endpoint values the service will use to make requests with.
func (s *Service) buildEndpoint() {
if StringValue(s.Config.Endpoint) != "" {
s.Endpoint = *s.Config.Endpoint
} else {
s.Endpoint, s.SigningRegion =
endpoints.EndpointForRegion(s.ServiceName, StringValue(s.Config.Region))
}
if s.Endpoint != "" && !schemeRE.MatchString(s.Endpoint) {
scheme := "https"
if BoolValue(s.Config.DisableSSL) {
scheme = "http"
}
s.Endpoint = scheme + "://" + s.Endpoint
}
}
// AddDebugHandlers injects debug logging handlers into the service to log request
// debug information.
func (s *Service) AddDebugHandlers() {
if !s.Config.LogLevel.AtLeast(LogDebug) {
return
}
s.Handlers.Send.PushFront(logRequest)
s.Handlers.Send.PushBack(logResponse)
}
const logReqMsg = `DEBUG: Request %s/%s Details:
---[ REQUEST POST-SIGN ]-----------------------------
%s
-----------------------------------------------------`
func logRequest(r *Request) {
logBody := r.Config.LogLevel.Matches(LogDebugWithHTTPBody)
dumpedBody, _ := httputil.DumpRequestOut(r.HTTPRequest, logBody)
r.Config.Logger.Log(fmt.Sprintf(logReqMsg, r.ServiceName, r.Operation.Name, string(dumpedBody)))
}
const logRespMsg = `DEBUG: Response %s/%s Details:
---[ RESPONSE ]--------------------------------------
%s
-----------------------------------------------------`
func logResponse(r *Request) {
var msg = "no reponse data"
if r.HTTPResponse != nil {
logBody := r.Config.LogLevel.Matches(LogDebugWithHTTPBody)
dumpedBody, _ := httputil.DumpResponse(r.HTTPResponse, logBody)
msg = string(dumpedBody)
} else if r.Error != nil {
msg = r.Error.Error()
}
r.Config.Logger.Log(fmt.Sprintf(logRespMsg, r.ServiceName, r.Operation.Name, msg))
}
// MaxRetries returns the number of maximum returns the service will use to make
// an individual API request.
func (s *Service) MaxRetries() uint {
if IntValue(s.Config.MaxRetries) < 0 {
return s.DefaultMaxRetries
}
return uint(IntValue(s.Config.MaxRetries))
}
var seededRand = rand.New(rand.NewSource(time.Now().UnixNano()))
// retryRules returns the delay duration before retrying this request again
func retryRules(r *Request) time.Duration {
delay := int(math.Pow(2, float64(r.RetryCount))) * (seededRand.Intn(30) + 30)
return time.Duration(delay) * time.Millisecond
}
// retryableCodes is a collection of service response codes which are retry-able
// without any further action.
var retryableCodes = map[string]struct{}{
"RequestError": {},
"ProvisionedThroughputExceededException": {},
"Throttling": {},
"ThrottlingException": {},
"RequestLimitExceeded": {},
"RequestThrottled": {},
}
// credsExpiredCodes is a collection of error codes which signify the credentials
// need to be refreshed. Expired tokens require refreshing of credentials, and
// resigning before the request can be retried.
var credsExpiredCodes = map[string]struct{}{
"ExpiredToken": {},
"ExpiredTokenException": {},
"RequestExpired": {}, // EC2 Only
}
func isCodeRetryable(code string) bool {
if _, ok := retryableCodes[code]; ok {
return true
}
return isCodeExpiredCreds(code)
}
func isCodeExpiredCreds(code string) bool {
_, ok := credsExpiredCodes[code]
return ok
}
// shouldRetry returns if the request should be retried.
func shouldRetry(r *Request) bool {
if r.HTTPResponse.StatusCode >= 500 {
return true
}
if r.Error != nil {
if err, ok := r.Error.(awserr.Error); ok {
return isCodeRetryable(err.Code())
}
}
return false
}

View File

@@ -0,0 +1,105 @@
// Package session provides a way to create service clients with shared configuration
// and handlers.
package session
import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/private/endpoints"
)
// A Session provides a central location to create service clients from and
// store configurations and request handlers for those services.
//
// Sessions are safe to create service clients concurrently, but it is not safe
// to mutate the session concurrently.
type Session struct {
Config *aws.Config
Handlers request.Handlers
}
// New creates a new instance of the handlers merging in the provided Configs
// on top of the SDK's default configurations. Once the session is created it
// can be mutated to modify Configs or Handlers. The session is safe to be read
// concurrently, but it should not be written to concurrently.
//
// Example:
// // Create a session with the default config and request handlers.
// sess := session.New()
//
// // Create a session with a custom region
// sess := session.New(&aws.Config{Region: aws.String("us-east-1")})
//
// // Create a session, and add additional handlers for all service
// // clients created with the session to inherit. Adds logging handler.
// sess := session.New()
// sess.Handlers.Send.PushFront(func(r *request.Request) {
// // Log every request made and its payload
// logger.Println("Request: %s/%s, Payload: %s", r.ClientInfo.ServiceName, r.Operation, r.Params)
// })
//
// // Create a S3 client instance from a session
// sess := session.New()
// svc := s3.New(sess)
func New(cfgs ...*aws.Config) *Session {
def := defaults.Get()
s := &Session{
Config: def.Config,
Handlers: def.Handlers,
}
s.Config.MergeIn(cfgs...)
initHandlers(s)
return s
}
func initHandlers(s *Session) {
// Add the Validate parameter handler if it is not disabled.
s.Handlers.Validate.Remove(corehandlers.ValidateParametersHandler)
if !aws.BoolValue(s.Config.DisableParamValidation) {
s.Handlers.Validate.PushBackNamed(corehandlers.ValidateParametersHandler)
}
}
// Copy creates and returns a copy of the current session, coping the config
// and handlers. If any additional configs are provided they will be merged
// on top of the session's copied config.
//
// Example:
// // Create a copy of the current session, configured for the us-west-2 region.
// sess.Copy(&aws.Config{Region: aws.String("us-west-2"})
func (s *Session) Copy(cfgs ...*aws.Config) *Session {
newSession := &Session{
Config: s.Config.Copy(cfgs...),
Handlers: s.Handlers.Copy(),
}
initHandlers(newSession)
return newSession
}
// ClientConfig satisfies the client.ConfigProvider interface and is used to
// configure the service client instances. Passing the Session to the service
// client's constructor (New) will use this method to configure the client.
//
// Example:
// sess := session.New()
// s3.New(sess)
func (s *Session) ClientConfig(serviceName string, cfgs ...*aws.Config) client.Config {
s = s.Copy(cfgs...)
endpoint, signingRegion := endpoints.NormalizeEndpoint(
aws.StringValue(s.Config.Endpoint), serviceName,
aws.StringValue(s.Config.Region), aws.BoolValue(s.Config.DisableSSL))
return client.Config{
Config: s.Config,
Handlers: s.Handlers,
Endpoint: endpoint,
SigningRegion: signingRegion,
}
}

View File

@@ -0,0 +1,20 @@
package session_test
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
)
func TestNewDefaultSession(t *testing.T) {
s := session.New(&aws.Config{Region: aws.String("region")})
assert.Equal(t, "region", *s.Config.Region)
assert.Equal(t, http.DefaultClient, s.Config.HTTPClient)
assert.NotNil(t, s.Config.Logger)
assert.Equal(t, aws.LogOff, *s.Config.LogLevel)
}

View File

@@ -2,9 +2,10 @@ package aws
import (
"io"
"sync"
)
// ReadSeekCloser wraps a io.Reader returning a ReaderSeakerCloser
// ReadSeekCloser wraps a io.Reader returning a ReaderSeekerCloser
func ReadSeekCloser(r io.Reader) ReaderSeekerCloser {
return ReaderSeekerCloser{r}
}
@@ -53,3 +54,35 @@ func (r ReaderSeekerCloser) Close() error {
}
return nil
}
// A WriteAtBuffer provides a in memory buffer supporting the io.WriterAt interface
// Can be used with the s3manager.Downloader to download content to a buffer
// in memory. Safe to use concurrently.
type WriteAtBuffer struct {
buf []byte
m sync.Mutex
}
// WriteAt writes a slice of bytes to a buffer starting at the position provided
// The number of bytes written will be returned, or error. Can overwrite previous
// written slices if the write ats overlap.
func (b *WriteAtBuffer) WriteAt(p []byte, pos int64) (n int, err error) {
b.m.Lock()
defer b.m.Unlock()
expLen := pos + int64(len(p))
if int64(len(b.buf)) < expLen {
newBuf := make([]byte, expLen)
copy(newBuf, b.buf)
b.buf = newBuf
}
copy(b.buf[pos:], p)
return len(p), nil
}
// Bytes returns a slice of bytes written to the buffer.
func (b *WriteAtBuffer) Bytes() []byte {
b.m.Lock()
defer b.m.Unlock()
return b.buf[:len(b.buf):len(b.buf)]
}

View File

@@ -0,0 +1,56 @@
package aws
import (
"math/rand"
"testing"
"github.com/stretchr/testify/assert"
)
func TestWriteAtBuffer(t *testing.T) {
b := &WriteAtBuffer{}
n, err := b.WriteAt([]byte{1}, 0)
assert.NoError(t, err)
assert.Equal(t, 1, n)
n, err = b.WriteAt([]byte{1, 1, 1}, 5)
assert.NoError(t, err)
assert.Equal(t, 3, n)
n, err = b.WriteAt([]byte{2}, 1)
assert.NoError(t, err)
assert.Equal(t, 1, n)
n, err = b.WriteAt([]byte{3}, 2)
assert.NoError(t, err)
assert.Equal(t, 1, n)
assert.Equal(t, []byte{1, 2, 3, 0, 0, 1, 1, 1}, b.Bytes())
}
func BenchmarkWriteAtBuffer(b *testing.B) {
buf := &WriteAtBuffer{}
r := rand.New(rand.NewSource(1))
b.ResetTimer()
for i := 0; i < b.N; i++ {
to := r.Intn(10) * 4096
bs := make([]byte, to)
buf.WriteAt(bs, r.Int63n(10)*4096)
}
}
func BenchmarkWriteAtBufferParallel(b *testing.B) {
buf := &WriteAtBuffer{}
r := rand.New(rand.NewSource(1))
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
to := r.Intn(10) * 4096
bs := make([]byte, to)
buf.WriteAt(bs, r.Int63n(10)*4096)
}
})
}

View File

@@ -5,4 +5,4 @@ package aws
const SDKName = "aws-sdk-go"
// SDKVersion is the version of this SDK
const SDKVersion = "0.7.3"
const SDKVersion = "0.10.4"

View File

@@ -0,0 +1,65 @@
// Package endpoints validates regional endpoints for services.
package endpoints
//go:generate go run ../model/cli/gen-endpoints/main.go endpoints.json endpoints_map.go
//go:generate gofmt -s -w endpoints_map.go
import (
"fmt"
"regexp"
"strings"
)
// NormalizeEndpoint takes and endpoint and service API information to return a
// normalized endpoint and signing region. If the endpoint is not an empty string
// the service name and region will be used to look up the service's API endpoint.
// If the endpoint is provided the scheme will be added if it is not present.
func NormalizeEndpoint(endpoint, serviceName, region string, disableSSL bool) (normEndpoint, signingRegion string) {
if endpoint == "" {
return EndpointForRegion(serviceName, region, disableSSL)
}
return AddScheme(endpoint, disableSSL), ""
}
// EndpointForRegion returns an endpoint and its signing region for a service and region.
// if the service and region pair are not found endpoint and signingRegion will be empty.
func EndpointForRegion(svcName, region string, disableSSL bool) (endpoint, signingRegion string) {
derivedKeys := []string{
region + "/" + svcName,
region + "/*",
"*/" + svcName,
"*/*",
}
for _, key := range derivedKeys {
if val, ok := endpointsMap.Endpoints[key]; ok {
ep := val.Endpoint
ep = strings.Replace(ep, "{region}", region, -1)
ep = strings.Replace(ep, "{service}", svcName, -1)
endpoint = ep
signingRegion = val.SigningRegion
break
}
}
return AddScheme(endpoint, disableSSL), signingRegion
}
// Regular expression to determine if the endpoint string is prefixed with a scheme.
var schemeRE = regexp.MustCompile("^([^:]+)://")
// AddScheme adds the HTTP or HTTPS schemes to a endpoint URL if there is no
// scheme. If disableSSL is true HTTP will be added instead of the default HTTPS.
func AddScheme(endpoint string, disableSSL bool) string {
if endpoint != "" && !schemeRE.MatchString(endpoint) {
scheme := "https"
if disableSSL {
scheme = "http"
}
endpoint = fmt.Sprintf("%s://%s", scheme, endpoint)
}
return endpoint
}

View File

@@ -0,0 +1,89 @@
{
"version": 2,
"endpoints": {
"*/*": {
"endpoint": "{service}.{region}.amazonaws.com"
},
"cn-north-1/*": {
"endpoint": "{service}.{region}.amazonaws.com.cn",
"signatureVersion": "v4"
},
"us-gov-west-1/iam": {
"endpoint": "iam.us-gov.amazonaws.com"
},
"us-gov-west-1/sts": {
"endpoint": "sts.us-gov-west-1.amazonaws.com"
},
"us-gov-west-1/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"*/cloudfront": {
"endpoint": "cloudfront.amazonaws.com",
"signingRegion": "us-east-1"
},
"*/cloudsearchdomain": {
"endpoint": "",
"signingRegion": "us-east-1"
},
"*/data.iot": {
"endpoint": "",
"signingRegion": "us-east-1"
},
"*/ec2metadata": {
"endpoint": "http://169.254.169.254/latest",
"signingRegion": "us-east-1"
},
"*/iam": {
"endpoint": "iam.amazonaws.com",
"signingRegion": "us-east-1"
},
"*/importexport": {
"endpoint": "importexport.amazonaws.com",
"signingRegion": "us-east-1"
},
"*/route53": {
"endpoint": "route53.amazonaws.com",
"signingRegion": "us-east-1"
},
"*/sts": {
"endpoint": "sts.amazonaws.com",
"signingRegion": "us-east-1"
},
"*/waf": {
"endpoint": "waf.amazonaws.com",
"signingRegion": "us-east-1"
},
"us-east-1/sdb": {
"endpoint": "sdb.amazonaws.com",
"signingRegion": "us-east-1"
},
"us-east-1/s3": {
"endpoint": "s3.amazonaws.com"
},
"us-west-1/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"us-west-2/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"eu-west-1/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"ap-southeast-1/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"ap-southeast-2/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"ap-northeast-1/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"sa-east-1/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"eu-central-1/s3": {
"endpoint": "{service}.{region}.amazonaws.com",
"signatureVersion": "v4"
}
}
}

View File

@@ -0,0 +1,101 @@
package endpoints
// THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT.
type endpointStruct struct {
Version int
Endpoints map[string]endpointEntry
}
type endpointEntry struct {
Endpoint string
SigningRegion string
}
var endpointsMap = endpointStruct{
Version: 2,
Endpoints: map[string]endpointEntry{
"*/*": {
Endpoint: "{service}.{region}.amazonaws.com",
},
"*/cloudfront": {
Endpoint: "cloudfront.amazonaws.com",
SigningRegion: "us-east-1",
},
"*/cloudsearchdomain": {
Endpoint: "",
SigningRegion: "us-east-1",
},
"*/data.iot": {
Endpoint: "",
SigningRegion: "us-east-1",
},
"*/ec2metadata": {
Endpoint: "http://169.254.169.254/latest",
SigningRegion: "us-east-1",
},
"*/iam": {
Endpoint: "iam.amazonaws.com",
SigningRegion: "us-east-1",
},
"*/importexport": {
Endpoint: "importexport.amazonaws.com",
SigningRegion: "us-east-1",
},
"*/route53": {
Endpoint: "route53.amazonaws.com",
SigningRegion: "us-east-1",
},
"*/sts": {
Endpoint: "sts.amazonaws.com",
SigningRegion: "us-east-1",
},
"*/waf": {
Endpoint: "waf.amazonaws.com",
SigningRegion: "us-east-1",
},
"ap-northeast-1/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"ap-southeast-1/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"ap-southeast-2/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"cn-north-1/*": {
Endpoint: "{service}.{region}.amazonaws.com.cn",
},
"eu-central-1/s3": {
Endpoint: "{service}.{region}.amazonaws.com",
},
"eu-west-1/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"sa-east-1/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"us-east-1/s3": {
Endpoint: "s3.amazonaws.com",
},
"us-east-1/sdb": {
Endpoint: "sdb.amazonaws.com",
SigningRegion: "us-east-1",
},
"us-gov-west-1/iam": {
Endpoint: "iam.us-gov.amazonaws.com",
},
"us-gov-west-1/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"us-gov-west-1/sts": {
Endpoint: "sts.us-gov-west-1.amazonaws.com",
},
"us-west-1/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"us-west-2/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
},
}

View File

@@ -0,0 +1,41 @@
package endpoints_test
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/private/endpoints"
)
func TestGenericEndpoint(t *testing.T) {
name := "service"
region := "mock-region-1"
ep, sr := endpoints.EndpointForRegion(name, region, false)
assert.Equal(t, fmt.Sprintf("https://%s.%s.amazonaws.com", name, region), ep)
assert.Empty(t, sr)
}
func TestGlobalEndpoints(t *testing.T) {
region := "mock-region-1"
svcs := []string{"cloudfront", "iam", "importexport", "route53", "sts", "waf"}
for _, name := range svcs {
ep, sr := endpoints.EndpointForRegion(name, region, false)
assert.Equal(t, fmt.Sprintf("https://%s.amazonaws.com", name), ep)
assert.Equal(t, "us-east-1", sr)
}
}
func TestServicesInCN(t *testing.T) {
region := "cn-north-1"
svcs := []string{"cloudfront", "iam", "importexport", "route53", "sts", "s3", "waf"}
for _, name := range svcs {
ep, sr := endpoints.EndpointForRegion(name, region, false)
assert.Equal(t, fmt.Sprintf("https://%s.%s.amazonaws.com.cn", name, region), ep)
assert.Empty(t, sr)
}
}

View File

@@ -0,0 +1,32 @@
// Package ec2query provides serialisation of AWS EC2 requests and responses.
package ec2query
//go:generate go run ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/input/ec2.json build_test.go
import (
"net/url"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/private/protocol/query/queryutil"
)
// Build builds a request for the EC2 protocol.
func Build(r *request.Request) {
body := url.Values{
"Action": {r.Operation.Name},
"Version": {r.ClientInfo.APIVersion},
}
if err := queryutil.Parse(body, r.Params, true); err != nil {
r.Error = awserr.New("SerializationError", "failed encoding EC2 Query request", err)
}
if r.ExpireTime == 0 {
r.HTTPRequest.Method = "POST"
r.HTTPRequest.Header.Set("Content-Type", "application/x-www-form-urlencoded; charset=utf-8")
r.SetBufferBody([]byte(body.Encode()))
} else { // This is a pre-signed request
r.HTTPRequest.Method = "GET"
r.HTTPRequest.URL.RawQuery = body.Encode()
}
}

View File

@@ -0,0 +1,85 @@
// +build bench
package ec2query_test
import (
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/private/protocol/ec2query"
"github.com/aws/aws-sdk-go/service/ec2"
)
func BenchmarkEC2QueryBuild_Complex_ec2AuthorizeSecurityGroupEgress(b *testing.B) {
params := &ec2.AuthorizeSecurityGroupEgressInput{
GroupId: aws.String("String"), // Required
CidrIp: aws.String("String"),
DryRun: aws.Bool(true),
FromPort: aws.Int64(1),
IpPermissions: []*ec2.IpPermission{
{ // Required
FromPort: aws.Int64(1),
IpProtocol: aws.String("String"),
IpRanges: []*ec2.IpRange{
{ // Required
CidrIp: aws.String("String"),
},
// More values...
},
PrefixListIds: []*ec2.PrefixListId{
{ // Required
PrefixListId: aws.String("String"),
},
// More values...
},
ToPort: aws.Int64(1),
UserIdGroupPairs: []*ec2.UserIdGroupPair{
{ // Required
GroupId: aws.String("String"),
GroupName: aws.String("String"),
UserId: aws.String("String"),
},
// More values...
},
},
// More values...
},
IpProtocol: aws.String("String"),
SourceSecurityGroupName: aws.String("String"),
SourceSecurityGroupOwnerId: aws.String("String"),
ToPort: aws.Int64(1),
}
benchEC2QueryBuild(b, "AuthorizeSecurityGroupEgress", params)
}
func BenchmarkEC2QueryBuild_Simple_ec2AttachNetworkInterface(b *testing.B) {
params := &ec2.AttachNetworkInterfaceInput{
DeviceIndex: aws.Int64(1), // Required
InstanceId: aws.String("String"), // Required
NetworkInterfaceId: aws.String("String"), // Required
DryRun: aws.Bool(true),
}
benchEC2QueryBuild(b, "AttachNetworkInterface", params)
}
func benchEC2QueryBuild(b *testing.B, opName string, params interface{}) {
svc := awstesting.NewClient()
svc.ServiceName = "ec2"
svc.APIVersion = "2015-04-15"
for i := 0; i < b.N; i++ {
r := svc.NewRequest(&request.Operation{
Name: opName,
HTTPMethod: "POST",
HTTPPath: "/",
}, params, nil)
ec2query.Build(r)
if r.Error != nil {
b.Fatal("Unexpected error", r.Error)
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,54 @@
package ec2query
//go:generate go run ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/output/ec2.json unmarshal_test.go
import (
"encoding/xml"
"io"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil"
)
// Unmarshal unmarshals a response body for the EC2 protocol.
func Unmarshal(r *request.Request) {
defer r.HTTPResponse.Body.Close()
if r.DataFilled() {
decoder := xml.NewDecoder(r.HTTPResponse.Body)
err := xmlutil.UnmarshalXML(r.Data, decoder, "")
if err != nil {
r.Error = awserr.New("SerializationError", "failed decoding EC2 Query response", err)
return
}
}
}
// UnmarshalMeta unmarshals response headers for the EC2 protocol.
func UnmarshalMeta(r *request.Request) {
// TODO implement unmarshaling of request IDs
}
type xmlErrorResponse struct {
XMLName xml.Name `xml:"Response"`
Code string `xml:"Errors>Error>Code"`
Message string `xml:"Errors>Error>Message"`
RequestID string `xml:"RequestId"`
}
// UnmarshalError unmarshals a response error for the EC2 protocol.
func UnmarshalError(r *request.Request) {
defer r.HTTPResponse.Body.Close()
resp := &xmlErrorResponse{}
err := xml.NewDecoder(r.HTTPResponse.Body).Decode(resp)
if err != nil && err != io.EOF {
r.Error = awserr.New("SerializationError", "failed decoding EC2 Query error response", err)
} else {
r.Error = awserr.NewRequestFailure(
awserr.New(resp.Code, resp.Message, nil),
r.HTTPResponse.StatusCode,
resp.RequestID,
)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,33 @@
// Package query provides serialisation of AWS query requests, and responses.
package query
//go:generate go run ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/input/query.json build_test.go
import (
"net/url"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/private/protocol/query/queryutil"
)
// Build builds a request for an AWS Query service.
func Build(r *request.Request) {
body := url.Values{
"Action": {r.Operation.Name},
"Version": {r.ClientInfo.APIVersion},
}
if err := queryutil.Parse(body, r.Params, false); err != nil {
r.Error = awserr.New("SerializationError", "failed encoding Query request", err)
return
}
if r.ExpireTime == 0 {
r.HTTPRequest.Method = "POST"
r.HTTPRequest.Header.Set("Content-Type", "application/x-www-form-urlencoded; charset=utf-8")
r.SetBufferBody([]byte(body.Encode()))
} else { // This is a pre-signed request
r.HTTPRequest.Method = "GET"
r.HTTPRequest.URL.RawQuery = body.Encode()
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,223 @@
package queryutil
import (
"encoding/base64"
"fmt"
"net/url"
"reflect"
"sort"
"strconv"
"strings"
"time"
)
// Parse parses an object i and fills a url.Values object. The isEC2 flag
// indicates if this is the EC2 Query sub-protocol.
func Parse(body url.Values, i interface{}, isEC2 bool) error {
q := queryParser{isEC2: isEC2}
return q.parseValue(body, reflect.ValueOf(i), "", "")
}
func elemOf(value reflect.Value) reflect.Value {
for value.Kind() == reflect.Ptr {
value = value.Elem()
}
return value
}
type queryParser struct {
isEC2 bool
}
func (q *queryParser) parseValue(v url.Values, value reflect.Value, prefix string, tag reflect.StructTag) error {
value = elemOf(value)
// no need to handle zero values
if !value.IsValid() {
return nil
}
t := tag.Get("type")
if t == "" {
switch value.Kind() {
case reflect.Struct:
t = "structure"
case reflect.Slice:
t = "list"
case reflect.Map:
t = "map"
}
}
switch t {
case "structure":
return q.parseStruct(v, value, prefix)
case "list":
return q.parseList(v, value, prefix, tag)
case "map":
return q.parseMap(v, value, prefix, tag)
default:
return q.parseScalar(v, value, prefix, tag)
}
}
func (q *queryParser) parseStruct(v url.Values, value reflect.Value, prefix string) error {
if !value.IsValid() {
return nil
}
t := value.Type()
for i := 0; i < value.NumField(); i++ {
if c := t.Field(i).Name[0:1]; strings.ToLower(c) == c {
continue // ignore unexported fields
}
elemValue := elemOf(value.Field(i))
field := t.Field(i)
var name string
if q.isEC2 {
name = field.Tag.Get("queryName")
}
if name == "" {
if field.Tag.Get("flattened") != "" && field.Tag.Get("locationNameList") != "" {
name = field.Tag.Get("locationNameList")
} else if locName := field.Tag.Get("locationName"); locName != "" {
name = locName
}
if name != "" && q.isEC2 {
name = strings.ToUpper(name[0:1]) + name[1:]
}
}
if name == "" {
name = field.Name
}
if prefix != "" {
name = prefix + "." + name
}
if err := q.parseValue(v, elemValue, name, field.Tag); err != nil {
return err
}
}
return nil
}
func (q *queryParser) parseList(v url.Values, value reflect.Value, prefix string, tag reflect.StructTag) error {
// If it's empty, generate an empty value
if !value.IsNil() && value.Len() == 0 {
v.Set(prefix, "")
return nil
}
// check for unflattened list member
if !q.isEC2 && tag.Get("flattened") == "" {
prefix += ".member"
}
for i := 0; i < value.Len(); i++ {
slicePrefix := prefix
if slicePrefix == "" {
slicePrefix = strconv.Itoa(i + 1)
} else {
slicePrefix = slicePrefix + "." + strconv.Itoa(i+1)
}
if err := q.parseValue(v, value.Index(i), slicePrefix, ""); err != nil {
return err
}
}
return nil
}
func (q *queryParser) parseMap(v url.Values, value reflect.Value, prefix string, tag reflect.StructTag) error {
// If it's empty, generate an empty value
if !value.IsNil() && value.Len() == 0 {
v.Set(prefix, "")
return nil
}
// check for unflattened list member
if !q.isEC2 && tag.Get("flattened") == "" {
prefix += ".entry"
}
// sort keys for improved serialization consistency.
// this is not strictly necessary for protocol support.
mapKeyValues := value.MapKeys()
mapKeys := map[string]reflect.Value{}
mapKeyNames := make([]string, len(mapKeyValues))
for i, mapKey := range mapKeyValues {
name := mapKey.String()
mapKeys[name] = mapKey
mapKeyNames[i] = name
}
sort.Strings(mapKeyNames)
for i, mapKeyName := range mapKeyNames {
mapKey := mapKeys[mapKeyName]
mapValue := value.MapIndex(mapKey)
kname := tag.Get("locationNameKey")
if kname == "" {
kname = "key"
}
vname := tag.Get("locationNameValue")
if vname == "" {
vname = "value"
}
// serialize key
var keyName string
if prefix == "" {
keyName = strconv.Itoa(i+1) + "." + kname
} else {
keyName = prefix + "." + strconv.Itoa(i+1) + "." + kname
}
if err := q.parseValue(v, mapKey, keyName, ""); err != nil {
return err
}
// serialize value
var valueName string
if prefix == "" {
valueName = strconv.Itoa(i+1) + "." + vname
} else {
valueName = prefix + "." + strconv.Itoa(i+1) + "." + vname
}
if err := q.parseValue(v, mapValue, valueName, ""); err != nil {
return err
}
}
return nil
}
func (q *queryParser) parseScalar(v url.Values, r reflect.Value, name string, tag reflect.StructTag) error {
switch value := r.Interface().(type) {
case string:
v.Set(name, value)
case []byte:
if !r.IsNil() {
v.Set(name, base64.StdEncoding.EncodeToString(value))
}
case bool:
v.Set(name, strconv.FormatBool(value))
case int64:
v.Set(name, strconv.FormatInt(value, 10))
case int:
v.Set(name, strconv.Itoa(value))
case float64:
v.Set(name, strconv.FormatFloat(value, 'f', -1, 64))
case float32:
v.Set(name, strconv.FormatFloat(float64(value), 'f', -1, 32))
case time.Time:
const ISO8601UTC = "2006-01-02T15:04:05Z"
v.Set(name, value.UTC().Format(ISO8601UTC))
default:
return fmt.Errorf("unsupported value for param %s: %v (%s)", name, r.Interface(), r.Type().Name())
}
return nil
}

View File

@@ -0,0 +1,29 @@
package query
//go:generate go run ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/output/query.json unmarshal_test.go
import (
"encoding/xml"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil"
)
// Unmarshal unmarshals a response for an AWS Query service.
func Unmarshal(r *request.Request) {
defer r.HTTPResponse.Body.Close()
if r.DataFilled() {
decoder := xml.NewDecoder(r.HTTPResponse.Body)
err := xmlutil.UnmarshalXML(r.Data, decoder, r.Operation.Name+"Result")
if err != nil {
r.Error = awserr.New("SerializationError", "failed decoding Query response", err)
return
}
}
}
// UnmarshalMeta unmarshals header response values for an AWS Query service.
func UnmarshalMeta(r *request.Request) {
// TODO implement unmarshaling of request IDs
}

View File

@@ -0,0 +1,33 @@
package query
import (
"encoding/xml"
"io"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
type xmlErrorResponse struct {
XMLName xml.Name `xml:"ErrorResponse"`
Code string `xml:"Error>Code"`
Message string `xml:"Error>Message"`
RequestID string `xml:"RequestId"`
}
// UnmarshalError unmarshals an error response for an AWS Query service.
func UnmarshalError(r *request.Request) {
defer r.HTTPResponse.Body.Close()
resp := &xmlErrorResponse{}
err := xml.NewDecoder(r.HTTPResponse.Body).Decode(resp)
if err != nil && err != io.EOF {
r.Error = awserr.New("SerializationError", "failed to decode query XML error response", err)
} else {
r.Error = awserr.NewRequestFailure(
awserr.New(resp.Code, resp.Message, nil),
r.HTTPResponse.StatusCode,
resp.RequestID,
)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,254 @@
// Package rest provides RESTful serialization of AWS requests and responses.
package rest
import (
"bytes"
"encoding/base64"
"fmt"
"io"
"net/http"
"net/url"
"path"
"reflect"
"strconv"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
// RFC822 returns an RFC822 formatted timestamp for AWS protocols
const RFC822 = "Mon, 2 Jan 2006 15:04:05 GMT"
// Whether the byte value can be sent without escaping in AWS URLs
var noEscape [256]bool
var errValueNotSet = fmt.Errorf("value not set")
func init() {
for i := 0; i < len(noEscape); i++ {
// AWS expects every character except these to be escaped
noEscape[i] = (i >= 'A' && i <= 'Z') ||
(i >= 'a' && i <= 'z') ||
(i >= '0' && i <= '9') ||
i == '-' ||
i == '.' ||
i == '_' ||
i == '~'
}
}
// Build builds the REST component of a service request.
func Build(r *request.Request) {
if r.ParamsFilled() {
v := reflect.ValueOf(r.Params).Elem()
buildLocationElements(r, v)
buildBody(r, v)
}
}
func buildLocationElements(r *request.Request, v reflect.Value) {
query := r.HTTPRequest.URL.Query()
for i := 0; i < v.NumField(); i++ {
m := v.Field(i)
if n := v.Type().Field(i).Name; n[0:1] == strings.ToLower(n[0:1]) {
continue
}
if m.IsValid() {
field := v.Type().Field(i)
name := field.Tag.Get("locationName")
if name == "" {
name = field.Name
}
if m.Kind() == reflect.Ptr {
m = m.Elem()
}
if !m.IsValid() {
continue
}
var err error
switch field.Tag.Get("location") {
case "headers": // header maps
err = buildHeaderMap(&r.HTTPRequest.Header, m, field.Tag.Get("locationName"))
case "header":
err = buildHeader(&r.HTTPRequest.Header, m, name)
case "uri":
err = buildURI(r.HTTPRequest.URL, m, name)
case "querystring":
err = buildQueryString(query, m, name)
}
r.Error = err
}
if r.Error != nil {
return
}
}
r.HTTPRequest.URL.RawQuery = query.Encode()
updatePath(r.HTTPRequest.URL, r.HTTPRequest.URL.Path)
}
func buildBody(r *request.Request, v reflect.Value) {
if field, ok := v.Type().FieldByName("SDKShapeTraits"); ok {
if payloadName := field.Tag.Get("payload"); payloadName != "" {
pfield, _ := v.Type().FieldByName(payloadName)
if ptag := pfield.Tag.Get("type"); ptag != "" && ptag != "structure" {
payload := reflect.Indirect(v.FieldByName(payloadName))
if payload.IsValid() && payload.Interface() != nil {
switch reader := payload.Interface().(type) {
case io.ReadSeeker:
r.SetReaderBody(reader)
case []byte:
r.SetBufferBody(reader)
case string:
r.SetStringBody(reader)
default:
r.Error = awserr.New("SerializationError",
"failed to encode REST request",
fmt.Errorf("unknown payload type %s", payload.Type()))
}
}
}
}
}
}
func buildHeader(header *http.Header, v reflect.Value, name string) error {
str, err := convertType(v)
if err == errValueNotSet {
return nil
} else if err != nil {
return awserr.New("SerializationError", "failed to encode REST request", err)
}
header.Add(name, str)
return nil
}
func buildHeaderMap(header *http.Header, v reflect.Value, prefix string) error {
for _, key := range v.MapKeys() {
str, err := convertType(v.MapIndex(key))
if err == errValueNotSet {
continue
} else if err != nil {
return awserr.New("SerializationError", "failed to encode REST request", err)
}
header.Add(prefix+key.String(), str)
}
return nil
}
func buildURI(u *url.URL, v reflect.Value, name string) error {
value, err := convertType(v)
if err == errValueNotSet {
return nil
} else if err != nil {
return awserr.New("SerializationError", "failed to encode REST request", err)
}
uri := u.Path
uri = strings.Replace(uri, "{"+name+"}", EscapePath(value, true), -1)
uri = strings.Replace(uri, "{"+name+"+}", EscapePath(value, false), -1)
u.Path = uri
return nil
}
func buildQueryString(query url.Values, v reflect.Value, name string) error {
switch value := v.Interface().(type) {
case []*string:
for _, item := range value {
query.Add(name, *item)
}
case map[string]*string:
for key, item := range value {
query.Add(key, *item)
}
case map[string][]*string:
for key, items := range value {
for _, item := range items {
query.Add(key, *item)
}
}
default:
str, err := convertType(v)
if err == errValueNotSet {
return nil
} else if err != nil {
return awserr.New("SerializationError", "failed to encode REST request", err)
}
query.Set(name, str)
}
return nil
}
func updatePath(url *url.URL, urlPath string) {
scheme, query := url.Scheme, url.RawQuery
hasSlash := strings.HasSuffix(urlPath, "/")
// clean up path
urlPath = path.Clean(urlPath)
if hasSlash && !strings.HasSuffix(urlPath, "/") {
urlPath += "/"
}
// get formatted URL minus scheme so we can build this into Opaque
url.Scheme, url.Path, url.RawQuery = "", "", ""
s := url.String()
url.Scheme = scheme
url.RawQuery = query
// build opaque URI
url.Opaque = s + urlPath
}
// EscapePath escapes part of a URL path in Amazon style
func EscapePath(path string, encodeSep bool) string {
var buf bytes.Buffer
for i := 0; i < len(path); i++ {
c := path[i]
if noEscape[c] || (c == '/' && !encodeSep) {
buf.WriteByte(c)
} else {
buf.WriteByte('%')
buf.WriteString(strings.ToUpper(strconv.FormatUint(uint64(c), 16)))
}
}
return buf.String()
}
func convertType(v reflect.Value) (string, error) {
v = reflect.Indirect(v)
if !v.IsValid() {
return "", errValueNotSet
}
var str string
switch value := v.Interface().(type) {
case string:
str = value
case []byte:
str = base64.StdEncoding.EncodeToString(value)
case bool:
str = strconv.FormatBool(value)
case int64:
str = strconv.FormatInt(value, 10)
case float64:
str = strconv.FormatFloat(value, 'f', -1, 64)
case time.Time:
str = value.UTC().Format(RFC822)
default:
err := fmt.Errorf("Unsupported value for param %v (%s)", v.Interface(), v.Type())
return "", err
}
return str, nil
}

View File

@@ -0,0 +1,45 @@
package rest
import "reflect"
// PayloadMember returns the payload field member of i if there is one, or nil.
func PayloadMember(i interface{}) interface{} {
if i == nil {
return nil
}
v := reflect.ValueOf(i).Elem()
if !v.IsValid() {
return nil
}
if field, ok := v.Type().FieldByName("SDKShapeTraits"); ok {
if payloadName := field.Tag.Get("payload"); payloadName != "" {
field, _ := v.Type().FieldByName(payloadName)
if field.Tag.Get("type") != "structure" {
return nil
}
payload := v.FieldByName(payloadName)
if payload.IsValid() || (payload.Kind() == reflect.Ptr && !payload.IsNil()) {
return payload.Interface()
}
}
}
return nil
}
// PayloadType returns the type of a payload field member of i if there is one, or "".
func PayloadType(i interface{}) string {
v := reflect.Indirect(reflect.ValueOf(i))
if !v.IsValid() {
return ""
}
if field, ok := v.Type().FieldByName("SDKShapeTraits"); ok {
if payloadName := field.Tag.Get("payload"); payloadName != "" {
if member, ok := v.Type().FieldByName(payloadName); ok {
return member.Tag.Get("type")
}
}
}
return ""
}

View File

@@ -0,0 +1,183 @@
package rest
import (
"encoding/base64"
"fmt"
"io/ioutil"
"net/http"
"reflect"
"strconv"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
// Unmarshal unmarshals the REST component of a response in a REST service.
func Unmarshal(r *request.Request) {
if r.DataFilled() {
v := reflect.Indirect(reflect.ValueOf(r.Data))
unmarshalBody(r, v)
}
}
// UnmarshalMeta unmarshals the REST metadata of a response in a REST service
func UnmarshalMeta(r *request.Request) {
r.RequestID = r.HTTPResponse.Header.Get("X-Amzn-Requestid")
if r.DataFilled() {
v := reflect.Indirect(reflect.ValueOf(r.Data))
unmarshalLocationElements(r, v)
}
}
func unmarshalBody(r *request.Request, v reflect.Value) {
if field, ok := v.Type().FieldByName("SDKShapeTraits"); ok {
if payloadName := field.Tag.Get("payload"); payloadName != "" {
pfield, _ := v.Type().FieldByName(payloadName)
if ptag := pfield.Tag.Get("type"); ptag != "" && ptag != "structure" {
payload := v.FieldByName(payloadName)
if payload.IsValid() {
switch payload.Interface().(type) {
case []byte:
b, err := ioutil.ReadAll(r.HTTPResponse.Body)
if err != nil {
r.Error = awserr.New("SerializationError", "failed to decode REST response", err)
} else {
payload.Set(reflect.ValueOf(b))
}
case *string:
b, err := ioutil.ReadAll(r.HTTPResponse.Body)
if err != nil {
r.Error = awserr.New("SerializationError", "failed to decode REST response", err)
} else {
str := string(b)
payload.Set(reflect.ValueOf(&str))
}
default:
switch payload.Type().String() {
case "io.ReadSeeker":
payload.Set(reflect.ValueOf(aws.ReadSeekCloser(r.HTTPResponse.Body)))
case "aws.ReadSeekCloser", "io.ReadCloser":
payload.Set(reflect.ValueOf(r.HTTPResponse.Body))
default:
r.Error = awserr.New("SerializationError",
"failed to decode REST response",
fmt.Errorf("unknown payload type %s", payload.Type()))
}
}
}
}
}
}
}
func unmarshalLocationElements(r *request.Request, v reflect.Value) {
for i := 0; i < v.NumField(); i++ {
m, field := v.Field(i), v.Type().Field(i)
if n := field.Name; n[0:1] == strings.ToLower(n[0:1]) {
continue
}
if m.IsValid() {
name := field.Tag.Get("locationName")
if name == "" {
name = field.Name
}
switch field.Tag.Get("location") {
case "statusCode":
unmarshalStatusCode(m, r.HTTPResponse.StatusCode)
case "header":
err := unmarshalHeader(m, r.HTTPResponse.Header.Get(name))
if err != nil {
r.Error = awserr.New("SerializationError", "failed to decode REST response", err)
break
}
case "headers":
prefix := field.Tag.Get("locationName")
err := unmarshalHeaderMap(m, r.HTTPResponse.Header, prefix)
if err != nil {
r.Error = awserr.New("SerializationError", "failed to decode REST response", err)
break
}
}
}
if r.Error != nil {
return
}
}
}
func unmarshalStatusCode(v reflect.Value, statusCode int) {
if !v.IsValid() {
return
}
switch v.Interface().(type) {
case *int64:
s := int64(statusCode)
v.Set(reflect.ValueOf(&s))
}
}
func unmarshalHeaderMap(r reflect.Value, headers http.Header, prefix string) error {
switch r.Interface().(type) {
case map[string]*string: // we only support string map value types
out := map[string]*string{}
for k, v := range headers {
k = http.CanonicalHeaderKey(k)
if strings.HasPrefix(strings.ToLower(k), strings.ToLower(prefix)) {
out[k[len(prefix):]] = &v[0]
}
}
r.Set(reflect.ValueOf(out))
}
return nil
}
func unmarshalHeader(v reflect.Value, header string) error {
if !v.IsValid() || (header == "" && v.Elem().Kind() != reflect.String) {
return nil
}
switch v.Interface().(type) {
case *string:
v.Set(reflect.ValueOf(&header))
case []byte:
b, err := base64.StdEncoding.DecodeString(header)
if err != nil {
return err
}
v.Set(reflect.ValueOf(&b))
case *bool:
b, err := strconv.ParseBool(header)
if err != nil {
return err
}
v.Set(reflect.ValueOf(&b))
case *int64:
i, err := strconv.ParseInt(header, 10, 64)
if err != nil {
return err
}
v.Set(reflect.ValueOf(&i))
case *float64:
f, err := strconv.ParseFloat(header, 64)
if err != nil {
return err
}
v.Set(reflect.ValueOf(&f))
case *time.Time:
t, err := time.Parse(RFC822, header)
if err != nil {
return err
}
v.Set(reflect.ValueOf(&t))
default:
err := fmt.Errorf("Unsupported value for param %v (%s)", v.Interface(), v.Type())
return err
}
return nil
}

View File

@@ -0,0 +1,287 @@
// Package xmlutil provides XML serialisation of AWS requests and responses.
package xmlutil
import (
"encoding/base64"
"encoding/xml"
"fmt"
"reflect"
"sort"
"strconv"
"strings"
"time"
)
// BuildXML will serialize params into an xml.Encoder.
// Error will be returned if the serialization of any of the params or nested values fails.
func BuildXML(params interface{}, e *xml.Encoder) error {
b := xmlBuilder{encoder: e, namespaces: map[string]string{}}
root := NewXMLElement(xml.Name{})
if err := b.buildValue(reflect.ValueOf(params), root, ""); err != nil {
return err
}
for _, c := range root.Children {
for _, v := range c {
return StructToXML(e, v, false)
}
}
return nil
}
// Returns the reflection element of a value, if it is a pointer.
func elemOf(value reflect.Value) reflect.Value {
for value.Kind() == reflect.Ptr {
value = value.Elem()
}
return value
}
// A xmlBuilder serializes values from Go code to XML
type xmlBuilder struct {
encoder *xml.Encoder
namespaces map[string]string
}
// buildValue generic XMLNode builder for any type. Will build value for their specific type
// struct, list, map, scalar.
//
// Also takes a "type" tag value to set what type a value should be converted to XMLNode as. If
// type is not provided reflect will be used to determine the value's type.
func (b *xmlBuilder) buildValue(value reflect.Value, current *XMLNode, tag reflect.StructTag) error {
value = elemOf(value)
if !value.IsValid() { // no need to handle zero values
return nil
} else if tag.Get("location") != "" { // don't handle non-body location values
return nil
}
t := tag.Get("type")
if t == "" {
switch value.Kind() {
case reflect.Struct:
t = "structure"
case reflect.Slice:
t = "list"
case reflect.Map:
t = "map"
}
}
switch t {
case "structure":
if field, ok := value.Type().FieldByName("SDKShapeTraits"); ok {
tag = tag + reflect.StructTag(" ") + field.Tag
}
return b.buildStruct(value, current, tag)
case "list":
return b.buildList(value, current, tag)
case "map":
return b.buildMap(value, current, tag)
default:
return b.buildScalar(value, current, tag)
}
}
// buildStruct adds a struct and its fields to the current XMLNode. All fields any any nested
// types are converted to XMLNodes also.
func (b *xmlBuilder) buildStruct(value reflect.Value, current *XMLNode, tag reflect.StructTag) error {
if !value.IsValid() {
return nil
}
fieldAdded := false
// unwrap payloads
if payload := tag.Get("payload"); payload != "" {
field, _ := value.Type().FieldByName(payload)
tag = field.Tag
value = elemOf(value.FieldByName(payload))
if !value.IsValid() {
return nil
}
}
child := NewXMLElement(xml.Name{Local: tag.Get("locationName")})
// there is an xmlNamespace associated with this struct
if prefix, uri := tag.Get("xmlPrefix"), tag.Get("xmlURI"); uri != "" {
ns := xml.Attr{
Name: xml.Name{Local: "xmlns"},
Value: uri,
}
if prefix != "" {
b.namespaces[prefix] = uri // register the namespace
ns.Name.Local = "xmlns:" + prefix
}
child.Attr = append(child.Attr, ns)
}
t := value.Type()
for i := 0; i < value.NumField(); i++ {
if c := t.Field(i).Name[0:1]; strings.ToLower(c) == c {
continue // ignore unexported fields
}
member := elemOf(value.Field(i))
field := t.Field(i)
mTag := field.Tag
if mTag.Get("location") != "" { // skip non-body members
continue
}
memberName := mTag.Get("locationName")
if memberName == "" {
memberName = field.Name
mTag = reflect.StructTag(string(mTag) + ` locationName:"` + memberName + `"`)
}
if err := b.buildValue(member, child, mTag); err != nil {
return err
}
fieldAdded = true
}
if fieldAdded { // only append this child if we have one ore more valid members
current.AddChild(child)
}
return nil
}
// buildList adds the value's list items to the current XMLNode as children nodes. All
// nested values in the list are converted to XMLNodes also.
func (b *xmlBuilder) buildList(value reflect.Value, current *XMLNode, tag reflect.StructTag) error {
if value.IsNil() { // don't build omitted lists
return nil
}
// check for unflattened list member
flattened := tag.Get("flattened") != ""
xname := xml.Name{Local: tag.Get("locationName")}
if flattened {
for i := 0; i < value.Len(); i++ {
child := NewXMLElement(xname)
current.AddChild(child)
if err := b.buildValue(value.Index(i), child, ""); err != nil {
return err
}
}
} else {
list := NewXMLElement(xname)
current.AddChild(list)
for i := 0; i < value.Len(); i++ {
iname := tag.Get("locationNameList")
if iname == "" {
iname = "member"
}
child := NewXMLElement(xml.Name{Local: iname})
list.AddChild(child)
if err := b.buildValue(value.Index(i), child, ""); err != nil {
return err
}
}
}
return nil
}
// buildMap adds the value's key/value pairs to the current XMLNode as children nodes. All
// nested values in the map are converted to XMLNodes also.
//
// Error will be returned if it is unable to build the map's values into XMLNodes
func (b *xmlBuilder) buildMap(value reflect.Value, current *XMLNode, tag reflect.StructTag) error {
if value.IsNil() { // don't build omitted maps
return nil
}
maproot := NewXMLElement(xml.Name{Local: tag.Get("locationName")})
current.AddChild(maproot)
current = maproot
kname, vname := "key", "value"
if n := tag.Get("locationNameKey"); n != "" {
kname = n
}
if n := tag.Get("locationNameValue"); n != "" {
vname = n
}
// sorting is not required for compliance, but it makes testing easier
keys := make([]string, value.Len())
for i, k := range value.MapKeys() {
keys[i] = k.String()
}
sort.Strings(keys)
for _, k := range keys {
v := value.MapIndex(reflect.ValueOf(k))
mapcur := current
if tag.Get("flattened") == "" { // add "entry" tag to non-flat maps
child := NewXMLElement(xml.Name{Local: "entry"})
mapcur.AddChild(child)
mapcur = child
}
kchild := NewXMLElement(xml.Name{Local: kname})
kchild.Text = k
vchild := NewXMLElement(xml.Name{Local: vname})
mapcur.AddChild(kchild)
mapcur.AddChild(vchild)
if err := b.buildValue(v, vchild, ""); err != nil {
return err
}
}
return nil
}
// buildScalar will convert the value into a string and append it as a attribute or child
// of the current XMLNode.
//
// The value will be added as an attribute if tag contains a "xmlAttribute" attribute value.
//
// Error will be returned if the value type is unsupported.
func (b *xmlBuilder) buildScalar(value reflect.Value, current *XMLNode, tag reflect.StructTag) error {
var str string
switch converted := value.Interface().(type) {
case string:
str = converted
case []byte:
if !value.IsNil() {
str = base64.StdEncoding.EncodeToString(converted)
}
case bool:
str = strconv.FormatBool(converted)
case int64:
str = strconv.FormatInt(converted, 10)
case int:
str = strconv.Itoa(converted)
case float64:
str = strconv.FormatFloat(converted, 'f', -1, 64)
case float32:
str = strconv.FormatFloat(float64(converted), 'f', -1, 32)
case time.Time:
const ISO8601UTC = "2006-01-02T15:04:05Z"
str = converted.UTC().Format(ISO8601UTC)
default:
return fmt.Errorf("unsupported value for param %s: %v (%s)",
tag.Get("locationName"), value.Interface(), value.Type().Name())
}
xname := xml.Name{Local: tag.Get("locationName")}
if tag.Get("xmlAttribute") != "" { // put into current node's attribute list
attr := xml.Attr{Name: xname, Value: str}
current.Attr = append(current.Attr, attr)
} else { // regular text node
current.AddChild(&XMLNode{Name: xname, Text: str})
}
return nil
}

View File

@@ -0,0 +1,260 @@
package xmlutil
import (
"encoding/base64"
"encoding/xml"
"fmt"
"io"
"reflect"
"strconv"
"strings"
"time"
)
// UnmarshalXML deserializes an xml.Decoder into the container v. V
// needs to match the shape of the XML expected to be decoded.
// If the shape doesn't match unmarshaling will fail.
func UnmarshalXML(v interface{}, d *xml.Decoder, wrapper string) error {
n, _ := XMLToStruct(d, nil)
if n.Children != nil {
for _, root := range n.Children {
for _, c := range root {
if wrappedChild, ok := c.Children[wrapper]; ok {
c = wrappedChild[0] // pull out wrapped element
}
err := parse(reflect.ValueOf(v), c, "")
if err != nil {
if err == io.EOF {
return nil
}
return err
}
}
}
return nil
}
return nil
}
// parse deserializes any value from the XMLNode. The type tag is used to infer the type, or reflect
// will be used to determine the type from r.
func parse(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
rtype := r.Type()
if rtype.Kind() == reflect.Ptr {
rtype = rtype.Elem() // check kind of actual element type
}
t := tag.Get("type")
if t == "" {
switch rtype.Kind() {
case reflect.Struct:
t = "structure"
case reflect.Slice:
t = "list"
case reflect.Map:
t = "map"
}
}
switch t {
case "structure":
if field, ok := rtype.FieldByName("SDKShapeTraits"); ok {
tag = field.Tag
}
return parseStruct(r, node, tag)
case "list":
return parseList(r, node, tag)
case "map":
return parseMap(r, node, tag)
default:
return parseScalar(r, node, tag)
}
}
// parseStruct deserializes a structure and its fields from an XMLNode. Any nested
// types in the structure will also be deserialized.
func parseStruct(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
t := r.Type()
if r.Kind() == reflect.Ptr {
if r.IsNil() { // create the structure if it's nil
s := reflect.New(r.Type().Elem())
r.Set(s)
r = s
}
r = r.Elem()
t = t.Elem()
}
// unwrap any payloads
if payload := tag.Get("payload"); payload != "" {
field, _ := t.FieldByName(payload)
return parseStruct(r.FieldByName(payload), node, field.Tag)
}
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
if c := field.Name[0:1]; strings.ToLower(c) == c {
continue // ignore unexported fields
}
// figure out what this field is called
name := field.Name
if field.Tag.Get("flattened") != "" && field.Tag.Get("locationNameList") != "" {
name = field.Tag.Get("locationNameList")
} else if locName := field.Tag.Get("locationName"); locName != "" {
name = locName
}
// try to find the field by name in elements
elems := node.Children[name]
if elems == nil { // try to find the field in attributes
for _, a := range node.Attr {
if name == a.Name.Local {
// turn this into a text node for de-serializing
elems = []*XMLNode{{Text: a.Value}}
}
}
}
member := r.FieldByName(field.Name)
for _, elem := range elems {
err := parse(member, elem, field.Tag)
if err != nil {
return err
}
}
}
return nil
}
// parseList deserializes a list of values from an XML node. Each list entry
// will also be deserialized.
func parseList(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
t := r.Type()
if tag.Get("flattened") == "" { // look at all item entries
mname := "member"
if name := tag.Get("locationNameList"); name != "" {
mname = name
}
if Children, ok := node.Children[mname]; ok {
if r.IsNil() {
r.Set(reflect.MakeSlice(t, len(Children), len(Children)))
}
for i, c := range Children {
err := parse(r.Index(i), c, "")
if err != nil {
return err
}
}
}
} else { // flattened list means this is a single element
if r.IsNil() {
r.Set(reflect.MakeSlice(t, 0, 0))
}
childR := reflect.Zero(t.Elem())
r.Set(reflect.Append(r, childR))
err := parse(r.Index(r.Len()-1), node, "")
if err != nil {
return err
}
}
return nil
}
// parseMap deserializes a map from an XMLNode. The direct children of the XMLNode
// will also be deserialized as map entries.
func parseMap(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
if r.IsNil() {
r.Set(reflect.MakeMap(r.Type()))
}
if tag.Get("flattened") == "" { // look at all child entries
for _, entry := range node.Children["entry"] {
parseMapEntry(r, entry, tag)
}
} else { // this element is itself an entry
parseMapEntry(r, node, tag)
}
return nil
}
// parseMapEntry deserializes a map entry from a XML node.
func parseMapEntry(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
kname, vname := "key", "value"
if n := tag.Get("locationNameKey"); n != "" {
kname = n
}
if n := tag.Get("locationNameValue"); n != "" {
vname = n
}
keys, ok := node.Children[kname]
values := node.Children[vname]
if ok {
for i, key := range keys {
keyR := reflect.ValueOf(key.Text)
value := values[i]
valueR := reflect.New(r.Type().Elem()).Elem()
parse(valueR, value, "")
r.SetMapIndex(keyR, valueR)
}
}
return nil
}
// parseScaller deserializes an XMLNode value into a concrete type based on the
// interface type of r.
//
// Error is returned if the deserialization fails due to invalid type conversion,
// or unsupported interface type.
func parseScalar(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
switch r.Interface().(type) {
case *string:
r.Set(reflect.ValueOf(&node.Text))
return nil
case []byte:
b, err := base64.StdEncoding.DecodeString(node.Text)
if err != nil {
return err
}
r.Set(reflect.ValueOf(b))
case *bool:
v, err := strconv.ParseBool(node.Text)
if err != nil {
return err
}
r.Set(reflect.ValueOf(&v))
case *int64:
v, err := strconv.ParseInt(node.Text, 10, 64)
if err != nil {
return err
}
r.Set(reflect.ValueOf(&v))
case *float64:
v, err := strconv.ParseFloat(node.Text, 64)
if err != nil {
return err
}
r.Set(reflect.ValueOf(&v))
case *time.Time:
const ISO8601UTC = "2006-01-02T15:04:05Z"
t, err := time.Parse(ISO8601UTC, node.Text)
if err != nil {
return err
}
r.Set(reflect.ValueOf(&t))
default:
return fmt.Errorf("unsupported value: %v (%s)", r.Interface(), r.Type())
}
return nil
}

View File

@@ -0,0 +1,105 @@
package xmlutil
import (
"encoding/xml"
"io"
"sort"
)
// A XMLNode contains the values to be encoded or decoded.
type XMLNode struct {
Name xml.Name `json:",omitempty"`
Children map[string][]*XMLNode `json:",omitempty"`
Text string `json:",omitempty"`
Attr []xml.Attr `json:",omitempty"`
}
// NewXMLElement returns a pointer to a new XMLNode initialized to default values.
func NewXMLElement(name xml.Name) *XMLNode {
return &XMLNode{
Name: name,
Children: map[string][]*XMLNode{},
Attr: []xml.Attr{},
}
}
// AddChild adds child to the XMLNode.
func (n *XMLNode) AddChild(child *XMLNode) {
if _, ok := n.Children[child.Name.Local]; !ok {
n.Children[child.Name.Local] = []*XMLNode{}
}
n.Children[child.Name.Local] = append(n.Children[child.Name.Local], child)
}
// XMLToStruct converts a xml.Decoder stream to XMLNode with nested values.
func XMLToStruct(d *xml.Decoder, s *xml.StartElement) (*XMLNode, error) {
out := &XMLNode{}
for {
tok, err := d.Token()
if tok == nil || err == io.EOF {
break
}
if err != nil {
return out, err
}
switch typed := tok.(type) {
case xml.CharData:
out.Text = string(typed.Copy())
case xml.StartElement:
el := typed.Copy()
out.Attr = el.Attr
if out.Children == nil {
out.Children = map[string][]*XMLNode{}
}
name := typed.Name.Local
slice := out.Children[name]
if slice == nil {
slice = []*XMLNode{}
}
node, e := XMLToStruct(d, &el)
if e != nil {
return out, e
}
node.Name = typed.Name
slice = append(slice, node)
out.Children[name] = slice
case xml.EndElement:
if s != nil && s.Name.Local == typed.Name.Local { // matching end token
return out, nil
}
}
}
return out, nil
}
// StructToXML writes an XMLNode to a xml.Encoder as tokens.
func StructToXML(e *xml.Encoder, node *XMLNode, sorted bool) error {
e.EncodeToken(xml.StartElement{Name: node.Name, Attr: node.Attr})
if node.Text != "" {
e.EncodeToken(xml.CharData([]byte(node.Text)))
} else if sorted {
sortedNames := []string{}
for k := range node.Children {
sortedNames = append(sortedNames, k)
}
sort.Strings(sortedNames)
for _, k := range sortedNames {
for _, v := range node.Children[k] {
StructToXML(e, v, sorted)
}
}
} else {
for _, c := range node.Children {
for _, v := range c {
StructToXML(e, v, sorted)
}
}
}
e.EncodeToken(xml.EndElement{Name: node.Name})
return e.Flush()
}

View File

@@ -0,0 +1,42 @@
package v4_test
import (
"net/url"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
)
func TestPresignHandler(t *testing.T) {
svc := s3.New(unit.Session)
req, _ := svc.PutObjectRequest(&s3.PutObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
ContentDisposition: aws.String("a+b c$d"),
ACL: aws.String("public-read"),
})
req.Time = time.Unix(0, 0)
urlstr, err := req.Presign(5 * time.Minute)
assert.NoError(t, err)
expectedDate := "19700101T000000Z"
expectedHeaders := "host;x-amz-acl"
expectedSig := "7edcb4e3a1bf12f4989018d75acbe3a7f03df24bd6f3112602d59fc551f0e4e2"
expectedCred := "AKID/19700101/mock-region/s3/aws4_request"
u, _ := url.Parse(urlstr)
urlQ := u.Query()
assert.Equal(t, expectedSig, urlQ.Get("X-Amz-Signature"))
assert.Equal(t, expectedCred, urlQ.Get("X-Amz-Credential"))
assert.Equal(t, expectedHeaders, urlQ.Get("X-Amz-SignedHeaders"))
assert.Equal(t, expectedDate, urlQ.Get("X-Amz-Date"))
assert.Equal(t, "300", urlQ.Get("X-Amz-Expires"))
assert.NotContains(t, urlstr, "+") // + encoded as %20
}

View File

@@ -0,0 +1,365 @@
// Package v4 implements signing for AWS V4 signer
package v4
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/private/protocol/rest"
)
const (
authHeaderPrefix = "AWS4-HMAC-SHA256"
timeFormat = "20060102T150405Z"
shortTimeFormat = "20060102"
)
var ignoredHeaders = map[string]bool{
"Authorization": true,
"Content-Type": true,
"Content-Length": true,
"User-Agent": true,
}
type signer struct {
Request *http.Request
Time time.Time
ExpireTime time.Duration
ServiceName string
Region string
CredValues credentials.Value
Credentials *credentials.Credentials
Query url.Values
Body io.ReadSeeker
Debug aws.LogLevelType
Logger aws.Logger
isPresign bool
formattedTime string
formattedShortTime string
signedHeaders string
canonicalHeaders string
canonicalString string
credentialString string
stringToSign string
signature string
authorization string
}
// Sign requests with signature version 4.
//
// Will sign the requests with the service config's Credentials object
// Signing is skipped if the credentials is the credentials.AnonymousCredentials
// object.
func Sign(req *request.Request) {
// If the request does not need to be signed ignore the signing of the
// request if the AnonymousCredentials object is used.
if req.Config.Credentials == credentials.AnonymousCredentials {
return
}
region := req.ClientInfo.SigningRegion
if region == "" {
region = aws.StringValue(req.Config.Region)
}
name := req.ClientInfo.SigningName
if name == "" {
name = req.ClientInfo.ServiceName
}
s := signer{
Request: req.HTTPRequest,
Time: req.Time,
ExpireTime: req.ExpireTime,
Query: req.HTTPRequest.URL.Query(),
Body: req.Body,
ServiceName: name,
Region: region,
Credentials: req.Config.Credentials,
Debug: req.Config.LogLevel.Value(),
Logger: req.Config.Logger,
}
req.Error = s.sign()
}
func (v4 *signer) sign() error {
if v4.ExpireTime != 0 {
v4.isPresign = true
}
if v4.isRequestSigned() {
if !v4.Credentials.IsExpired() {
// If the request is already signed, and the credentials have not
// expired yet ignore the signing request.
return nil
}
// The credentials have expired for this request. The current signing
// is invalid, and needs to be request because the request will fail.
if v4.isPresign {
v4.removePresign()
// Update the request's query string to ensure the values stays in
// sync in the case retrieving the new credentials fails.
v4.Request.URL.RawQuery = v4.Query.Encode()
}
}
var err error
v4.CredValues, err = v4.Credentials.Get()
if err != nil {
return err
}
if v4.isPresign {
v4.Query.Set("X-Amz-Algorithm", authHeaderPrefix)
if v4.CredValues.SessionToken != "" {
v4.Query.Set("X-Amz-Security-Token", v4.CredValues.SessionToken)
} else {
v4.Query.Del("X-Amz-Security-Token")
}
} else if v4.CredValues.SessionToken != "" {
v4.Request.Header.Set("X-Amz-Security-Token", v4.CredValues.SessionToken)
}
v4.build()
if v4.Debug.Matches(aws.LogDebugWithSigning) {
v4.logSigningInfo()
}
return nil
}
const logSignInfoMsg = `DEBUG: Request Signiture:
---[ CANONICAL STRING ]-----------------------------
%s
---[ STRING TO SIGN ]--------------------------------
%s%s
-----------------------------------------------------`
const logSignedURLMsg = `
---[ SIGNED URL ]------------------------------------
%s`
func (v4 *signer) logSigningInfo() {
signedURLMsg := ""
if v4.isPresign {
signedURLMsg = fmt.Sprintf(logSignedURLMsg, v4.Request.URL.String())
}
msg := fmt.Sprintf(logSignInfoMsg, v4.canonicalString, v4.stringToSign, signedURLMsg)
v4.Logger.Log(msg)
}
func (v4 *signer) build() {
v4.buildTime() // no depends
v4.buildCredentialString() // no depends
if v4.isPresign {
v4.buildQuery() // no depends
}
v4.buildCanonicalHeaders() // depends on cred string
v4.buildCanonicalString() // depends on canon headers / signed headers
v4.buildStringToSign() // depends on canon string
v4.buildSignature() // depends on string to sign
if v4.isPresign {
v4.Request.URL.RawQuery += "&X-Amz-Signature=" + v4.signature
} else {
parts := []string{
authHeaderPrefix + " Credential=" + v4.CredValues.AccessKeyID + "/" + v4.credentialString,
"SignedHeaders=" + v4.signedHeaders,
"Signature=" + v4.signature,
}
v4.Request.Header.Set("Authorization", strings.Join(parts, ", "))
}
}
func (v4 *signer) buildTime() {
v4.formattedTime = v4.Time.UTC().Format(timeFormat)
v4.formattedShortTime = v4.Time.UTC().Format(shortTimeFormat)
if v4.isPresign {
duration := int64(v4.ExpireTime / time.Second)
v4.Query.Set("X-Amz-Date", v4.formattedTime)
v4.Query.Set("X-Amz-Expires", strconv.FormatInt(duration, 10))
} else {
v4.Request.Header.Set("X-Amz-Date", v4.formattedTime)
}
}
func (v4 *signer) buildCredentialString() {
v4.credentialString = strings.Join([]string{
v4.formattedShortTime,
v4.Region,
v4.ServiceName,
"aws4_request",
}, "/")
if v4.isPresign {
v4.Query.Set("X-Amz-Credential", v4.CredValues.AccessKeyID+"/"+v4.credentialString)
}
}
func (v4 *signer) buildQuery() {
for k, h := range v4.Request.Header {
if strings.HasPrefix(http.CanonicalHeaderKey(k), "X-Amz-") {
continue // never hoist x-amz-* headers, they must be signed
}
if _, ok := ignoredHeaders[http.CanonicalHeaderKey(k)]; ok {
continue // never hoist ignored headers
}
v4.Request.Header.Del(k)
v4.Query.Del(k)
for _, v := range h {
v4.Query.Add(k, v)
}
}
}
func (v4 *signer) buildCanonicalHeaders() {
var headers []string
headers = append(headers, "host")
for k := range v4.Request.Header {
if _, ok := ignoredHeaders[http.CanonicalHeaderKey(k)]; ok {
continue // ignored header
}
headers = append(headers, strings.ToLower(k))
}
sort.Strings(headers)
v4.signedHeaders = strings.Join(headers, ";")
if v4.isPresign {
v4.Query.Set("X-Amz-SignedHeaders", v4.signedHeaders)
}
headerValues := make([]string, len(headers))
for i, k := range headers {
if k == "host" {
headerValues[i] = "host:" + v4.Request.URL.Host
} else {
headerValues[i] = k + ":" +
strings.Join(v4.Request.Header[http.CanonicalHeaderKey(k)], ",")
}
}
v4.canonicalHeaders = strings.Join(headerValues, "\n")
}
func (v4 *signer) buildCanonicalString() {
v4.Request.URL.RawQuery = strings.Replace(v4.Query.Encode(), "+", "%20", -1)
uri := v4.Request.URL.Opaque
if uri != "" {
uri = "/" + strings.Join(strings.Split(uri, "/")[3:], "/")
} else {
uri = v4.Request.URL.Path
}
if uri == "" {
uri = "/"
}
if v4.ServiceName != "s3" {
uri = rest.EscapePath(uri, false)
}
v4.canonicalString = strings.Join([]string{
v4.Request.Method,
uri,
v4.Request.URL.RawQuery,
v4.canonicalHeaders + "\n",
v4.signedHeaders,
v4.bodyDigest(),
}, "\n")
}
func (v4 *signer) buildStringToSign() {
v4.stringToSign = strings.Join([]string{
authHeaderPrefix,
v4.formattedTime,
v4.credentialString,
hex.EncodeToString(makeSha256([]byte(v4.canonicalString))),
}, "\n")
}
func (v4 *signer) buildSignature() {
secret := v4.CredValues.SecretAccessKey
date := makeHmac([]byte("AWS4"+secret), []byte(v4.formattedShortTime))
region := makeHmac(date, []byte(v4.Region))
service := makeHmac(region, []byte(v4.ServiceName))
credentials := makeHmac(service, []byte("aws4_request"))
signature := makeHmac(credentials, []byte(v4.stringToSign))
v4.signature = hex.EncodeToString(signature)
}
func (v4 *signer) bodyDigest() string {
hash := v4.Request.Header.Get("X-Amz-Content-Sha256")
if hash == "" {
if v4.isPresign && v4.ServiceName == "s3" {
hash = "UNSIGNED-PAYLOAD"
} else if v4.Body == nil {
hash = hex.EncodeToString(makeSha256([]byte{}))
} else {
hash = hex.EncodeToString(makeSha256Reader(v4.Body))
}
v4.Request.Header.Add("X-Amz-Content-Sha256", hash)
}
return hash
}
// isRequestSigned returns if the request is currently signed or presigned
func (v4 *signer) isRequestSigned() bool {
if v4.isPresign && v4.Query.Get("X-Amz-Signature") != "" {
return true
}
if v4.Request.Header.Get("Authorization") != "" {
return true
}
return false
}
// unsign removes signing flags for both signed and presigned requests.
func (v4 *signer) removePresign() {
v4.Query.Del("X-Amz-Algorithm")
v4.Query.Del("X-Amz-Signature")
v4.Query.Del("X-Amz-Security-Token")
v4.Query.Del("X-Amz-Date")
v4.Query.Del("X-Amz-Expires")
v4.Query.Del("X-Amz-Credential")
v4.Query.Del("X-Amz-SignedHeaders")
}
func makeHmac(key []byte, data []byte) []byte {
hash := hmac.New(sha256.New, key)
hash.Write(data)
return hash.Sum(nil)
}
func makeSha256(data []byte) []byte {
hash := sha256.New()
hash.Write(data)
return hash.Sum(nil)
}
func makeSha256Reader(reader io.ReadSeeker) []byte {
hash := sha256.New()
start, _ := reader.Seek(0, 1)
defer reader.Seek(start, 0)
io.Copy(hash, reader)
return hash.Sum(nil)
}

View File

@@ -0,0 +1,252 @@
package v4
import (
"net/http"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
)
func buildSigner(serviceName string, region string, signTime time.Time, expireTime time.Duration, body string) signer {
endpoint := "https://" + serviceName + "." + region + ".amazonaws.com"
reader := strings.NewReader(body)
req, _ := http.NewRequest("POST", endpoint, reader)
req.URL.Opaque = "//example.org/bucket/key-._~,!@#$%^&*()"
req.Header.Add("X-Amz-Target", "prefix.Operation")
req.Header.Add("Content-Type", "application/x-amz-json-1.0")
req.Header.Add("Content-Length", string(len(body)))
req.Header.Add("X-Amz-Meta-Other-Header", "some-value=!@#$%^&* (+)")
return signer{
Request: req,
Time: signTime,
ExpireTime: expireTime,
Query: req.URL.Query(),
Body: reader,
ServiceName: serviceName,
Region: region,
Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"),
}
}
func removeWS(text string) string {
text = strings.Replace(text, " ", "", -1)
text = strings.Replace(text, "\n", "", -1)
text = strings.Replace(text, "\t", "", -1)
return text
}
func assertEqual(t *testing.T, expected, given string) {
if removeWS(expected) != removeWS(given) {
t.Errorf("\nExpected: %s\nGiven: %s", expected, given)
}
}
func TestPresignRequest(t *testing.T) {
signer := buildSigner("dynamodb", "us-east-1", time.Unix(0, 0), 300*time.Second, "{}")
signer.sign()
expectedDate := "19700101T000000Z"
expectedHeaders := "host;x-amz-meta-other-header;x-amz-target"
expectedSig := "5eeedebf6f995145ce56daa02902d10485246d3defb34f97b973c1f40ab82d36"
expectedCred := "AKID/19700101/us-east-1/dynamodb/aws4_request"
q := signer.Request.URL.Query()
assert.Equal(t, expectedSig, q.Get("X-Amz-Signature"))
assert.Equal(t, expectedCred, q.Get("X-Amz-Credential"))
assert.Equal(t, expectedHeaders, q.Get("X-Amz-SignedHeaders"))
assert.Equal(t, expectedDate, q.Get("X-Amz-Date"))
}
func TestSignRequest(t *testing.T) {
signer := buildSigner("dynamodb", "us-east-1", time.Unix(0, 0), 0, "{}")
signer.sign()
expectedDate := "19700101T000000Z"
expectedSig := "AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/dynamodb/aws4_request, SignedHeaders=host;x-amz-date;x-amz-meta-other-header;x-amz-security-token;x-amz-target, Signature=69ada33fec48180dab153576e4dd80c4e04124f80dda3eccfed8a67c2b91ed5e"
q := signer.Request.Header
assert.Equal(t, expectedSig, q.Get("Authorization"))
assert.Equal(t, expectedDate, q.Get("X-Amz-Date"))
}
func TestSignEmptyBody(t *testing.T) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 0, "")
signer.Body = nil
signer.sign()
hash := signer.Request.Header.Get("X-Amz-Content-Sha256")
assert.Equal(t, "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", hash)
}
func TestSignBody(t *testing.T) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 0, "hello")
signer.sign()
hash := signer.Request.Header.Get("X-Amz-Content-Sha256")
assert.Equal(t, "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", hash)
}
func TestSignSeekedBody(t *testing.T) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 0, " hello")
signer.Body.Read(make([]byte, 3)) // consume first 3 bytes so body is now "hello"
signer.sign()
hash := signer.Request.Header.Get("X-Amz-Content-Sha256")
assert.Equal(t, "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", hash)
start, _ := signer.Body.Seek(0, 1)
assert.Equal(t, int64(3), start)
}
func TestPresignEmptyBodyS3(t *testing.T) {
signer := buildSigner("s3", "us-east-1", time.Now(), 5*time.Minute, "hello")
signer.sign()
hash := signer.Request.Header.Get("X-Amz-Content-Sha256")
assert.Equal(t, "UNSIGNED-PAYLOAD", hash)
}
func TestSignPrecomputedBodyChecksum(t *testing.T) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 0, "hello")
signer.Request.Header.Set("X-Amz-Content-Sha256", "PRECOMPUTED")
signer.sign()
hash := signer.Request.Header.Get("X-Amz-Content-Sha256")
assert.Equal(t, "PRECOMPUTED", hash)
}
func TestAnonymousCredentials(t *testing.T) {
svc := awstesting.NewClient(&aws.Config{Credentials: credentials.AnonymousCredentials})
r := svc.NewRequest(
&request.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
Sign(r)
urlQ := r.HTTPRequest.URL.Query()
assert.Empty(t, urlQ.Get("X-Amz-Signature"))
assert.Empty(t, urlQ.Get("X-Amz-Credential"))
assert.Empty(t, urlQ.Get("X-Amz-SignedHeaders"))
assert.Empty(t, urlQ.Get("X-Amz-Date"))
hQ := r.HTTPRequest.Header
assert.Empty(t, hQ.Get("Authorization"))
assert.Empty(t, hQ.Get("X-Amz-Date"))
}
func TestIgnoreResignRequestWithValidCreds(t *testing.T) {
svc := awstesting.NewClient(&aws.Config{
Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"),
Region: aws.String("us-west-2"),
})
r := svc.NewRequest(
&request.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
Sign(r)
sig := r.HTTPRequest.Header.Get("Authorization")
Sign(r)
assert.Equal(t, sig, r.HTTPRequest.Header.Get("Authorization"))
}
func TestIgnorePreResignRequestWithValidCreds(t *testing.T) {
svc := awstesting.NewClient(&aws.Config{
Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"),
Region: aws.String("us-west-2"),
})
r := svc.NewRequest(
&request.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
r.ExpireTime = time.Minute * 10
Sign(r)
sig := r.HTTPRequest.Header.Get("X-Amz-Signature")
Sign(r)
assert.Equal(t, sig, r.HTTPRequest.Header.Get("X-Amz-Signature"))
}
func TestResignRequestExpiredCreds(t *testing.T) {
creds := credentials.NewStaticCredentials("AKID", "SECRET", "SESSION")
svc := awstesting.NewClient(&aws.Config{Credentials: creds})
r := svc.NewRequest(
&request.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
Sign(r)
querySig := r.HTTPRequest.Header.Get("Authorization")
creds.Expire()
Sign(r)
assert.NotEqual(t, querySig, r.HTTPRequest.Header.Get("Authorization"))
}
func TestPreResignRequestExpiredCreds(t *testing.T) {
provider := &credentials.StaticProvider{Value: credentials.Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "SESSION",
}}
creds := credentials.NewCredentials(provider)
svc := awstesting.NewClient(&aws.Config{Credentials: creds})
r := svc.NewRequest(
&request.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
r.ExpireTime = time.Minute * 10
Sign(r)
querySig := r.HTTPRequest.URL.Query().Get("X-Amz-Signature")
creds.Expire()
r.Time = time.Now().Add(time.Hour * 48)
Sign(r)
assert.NotEqual(t, querySig, r.HTTPRequest.URL.Query().Get("X-Amz-Signature"))
}
func BenchmarkPresignRequest(b *testing.B) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 300*time.Second, "{}")
for i := 0; i < b.N; i++ {
signer.sign()
}
}
func BenchmarkSignRequest(b *testing.B) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 0, "{}")
for i := 0; i < b.N; i++ {
signer.sign()
}
}

View File

@@ -0,0 +1,103 @@
package waiter
import (
"fmt"
"reflect"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/aws/request"
)
// A Config provides a collection of configuration values to setup a generated
// waiter code with.
type Config struct {
Name string
Delay int
MaxAttempts int
Operation string
Acceptors []WaitAcceptor
}
// A WaitAcceptor provides the information needed to wait for an API operation
// to complete.
type WaitAcceptor struct {
Expected interface{}
Matcher string
State string
Argument string
}
// A Waiter provides waiting for an operation to complete.
type Waiter struct {
Config
Client interface{}
Input interface{}
}
// Wait waits for an operation to complete, expire max attempts, or fail. Error
// is returned if the operation fails.
func (w *Waiter) Wait() error {
client := reflect.ValueOf(w.Client)
in := reflect.ValueOf(w.Input)
method := client.MethodByName(w.Config.Operation + "Request")
for i := 0; i < w.MaxAttempts; i++ {
res := method.Call([]reflect.Value{in})
req := res[0].Interface().(*request.Request)
req.Handlers.Build.PushBack(request.MakeAddToUserAgentFreeFormHandler("Waiter"))
if err := req.Send(); err != nil {
return err
}
for _, a := range w.Acceptors {
result := false
switch a.Matcher {
case "pathAll":
if vals, _ := awsutil.ValuesAtPath(req.Data, a.Argument); req.Error == nil && vals != nil {
result = true
for _, val := range vals {
if !awsutil.DeepEqual(val, a.Expected) {
result = false
break
}
}
}
case "pathAny":
if vals, _ := awsutil.ValuesAtPath(req.Data, a.Argument); req.Error == nil && vals != nil {
for _, val := range vals {
if awsutil.DeepEqual(val, a.Expected) {
result = true
break
}
}
}
case "status":
s := a.Expected.(int)
result = s == req.HTTPResponse.StatusCode
}
if result {
switch a.State {
case "success":
return nil // waiter completed
case "failure":
if req.Error == nil {
return awserr.New("ResourceNotReady",
fmt.Sprintf("failed waiting for successful resource state"), nil)
}
return req.Error // waiter failed
case "retry":
// do nothing, just retry
}
break
}
}
time.Sleep(time.Second * time.Duration(w.Delay))
}
return awserr.New("ResourceNotReady",
fmt.Sprintf("exceeded %d wait attempts", w.MaxAttempts), nil)
}

View File

@@ -0,0 +1,183 @@
package waiter_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/private/waiter"
)
type mockClient struct {
*client.Client
}
type MockInput struct{}
type MockOutput struct {
States []*MockState
}
type MockState struct {
State *string
}
func (c *mockClient) MockRequest(input *MockInput) (*request.Request, *MockOutput) {
op := &request.Operation{
Name: "Mock",
HTTPMethod: "POST",
HTTPPath: "/",
}
if input == nil {
input = &MockInput{}
}
output := &MockOutput{}
req := c.NewRequest(op, input, output)
req.Data = output
return req, output
}
var mockAcceptors = []waiter.WaitAcceptor{
{
State: "success",
Matcher: "pathAll",
Argument: "States[].State",
Expected: "running",
},
{
State: "failure",
Matcher: "pathAny",
Argument: "States[].State",
Expected: "stopping",
},
}
func TestWaiter(t *testing.T) {
svc := &mockClient{Client: awstesting.NewClient(&aws.Config{
Region: aws.String("mock-region"),
})}
svc.Handlers.Send.Clear() // mock sending
svc.Handlers.Unmarshal.Clear()
svc.Handlers.UnmarshalMeta.Clear()
svc.Handlers.ValidateResponse.Clear()
reqNum := 0
resps := []*MockOutput{
{ // Request 1
States: []*MockState{
{State: aws.String("pending")},
{State: aws.String("pending")},
},
},
{ // Request 1
States: []*MockState{
{State: aws.String("running")},
{State: aws.String("pending")},
},
},
{ // Request 1
States: []*MockState{
{State: aws.String("running")},
{State: aws.String("running")},
},
},
}
numBuiltReq := 0
svc.Handlers.Build.PushBack(func(r *request.Request) {
numBuiltReq++
})
svc.Handlers.Unmarshal.PushBack(func(r *request.Request) {
if reqNum >= len(resps) {
assert.Fail(t, "too many polling requests made")
return
}
r.Data = resps[reqNum]
reqNum++
})
waiterCfg := waiter.Config{
Operation: "Mock",
Delay: 0,
MaxAttempts: 10,
Acceptors: mockAcceptors,
}
w := waiter.Waiter{
Client: svc,
Input: &MockInput{},
Config: waiterCfg,
}
err := w.Wait()
assert.NoError(t, err)
assert.Equal(t, 3, numBuiltReq)
assert.Equal(t, 3, reqNum)
}
func TestWaiterFailure(t *testing.T) {
svc := &mockClient{Client: awstesting.NewClient(&aws.Config{
Region: aws.String("mock-region"),
})}
svc.Handlers.Send.Clear() // mock sending
svc.Handlers.Unmarshal.Clear()
svc.Handlers.UnmarshalMeta.Clear()
svc.Handlers.ValidateResponse.Clear()
reqNum := 0
resps := []*MockOutput{
{ // Request 1
States: []*MockState{
{State: aws.String("pending")},
{State: aws.String("pending")},
},
},
{ // Request 1
States: []*MockState{
{State: aws.String("running")},
{State: aws.String("pending")},
},
},
{ // Request 1
States: []*MockState{
{State: aws.String("running")},
{State: aws.String("stopping")},
},
},
}
numBuiltReq := 0
svc.Handlers.Build.PushBack(func(r *request.Request) {
numBuiltReq++
})
svc.Handlers.Unmarshal.PushBack(func(r *request.Request) {
if reqNum >= len(resps) {
assert.Fail(t, "too many polling requests made")
return
}
r.Data = resps[reqNum]
reqNum++
})
waiterCfg := waiter.Config{
Operation: "Mock",
Delay: 0,
MaxAttempts: 10,
Acceptors: mockAcceptors,
}
w := waiter.Waiter{
Client: svc,
Input: &MockInput{},
Config: waiterCfg,
}
err := w.Wait().(awserr.Error)
assert.Error(t, err)
assert.Equal(t, "ResourceNotReady", err.Code())
assert.Equal(t, "failed waiting for successful resource state", err.Message())
assert.Equal(t, 3, numBuiltReq)
assert.Equal(t, 3, reqNum)
}

View File

@@ -6,15 +6,15 @@ package cloudwatch
import (
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/aws/request"
)
const opDeleteAlarms = "DeleteAlarms"
// DeleteAlarmsRequest generates a request for the DeleteAlarms operation.
func (c *CloudWatch) DeleteAlarmsRequest(input *DeleteAlarmsInput) (req *aws.Request, output *DeleteAlarmsOutput) {
op := &aws.Operation{
func (c *CloudWatch) DeleteAlarmsRequest(input *DeleteAlarmsInput) (req *request.Request, output *DeleteAlarmsOutput) {
op := &request.Operation{
Name: opDeleteAlarms,
HTTPMethod: "POST",
HTTPPath: "/",
@@ -40,12 +40,12 @@ func (c *CloudWatch) DeleteAlarms(input *DeleteAlarmsInput) (*DeleteAlarmsOutput
const opDescribeAlarmHistory = "DescribeAlarmHistory"
// DescribeAlarmHistoryRequest generates a request for the DescribeAlarmHistory operation.
func (c *CloudWatch) DescribeAlarmHistoryRequest(input *DescribeAlarmHistoryInput) (req *aws.Request, output *DescribeAlarmHistoryOutput) {
op := &aws.Operation{
func (c *CloudWatch) DescribeAlarmHistoryRequest(input *DescribeAlarmHistoryInput) (req *request.Request, output *DescribeAlarmHistoryOutput) {
op := &request.Operation{
Name: opDescribeAlarmHistory,
HTTPMethod: "POST",
HTTPPath: "/",
Paginator: &aws.Paginator{
Paginator: &request.Paginator{
InputTokens: []string{"NextToken"},
OutputTokens: []string{"NextToken"},
LimitToken: "MaxRecords",
@@ -74,6 +74,7 @@ func (c *CloudWatch) DescribeAlarmHistory(input *DescribeAlarmHistoryInput) (*De
func (c *CloudWatch) DescribeAlarmHistoryPages(input *DescribeAlarmHistoryInput, fn func(p *DescribeAlarmHistoryOutput, lastPage bool) (shouldContinue bool)) error {
page, _ := c.DescribeAlarmHistoryRequest(input)
page.Handlers.Build.PushBack(request.MakeAddToUserAgentFreeFormHandler("Paginator"))
return page.EachPage(func(p interface{}, lastPage bool) bool {
return fn(p.(*DescribeAlarmHistoryOutput), lastPage)
})
@@ -82,12 +83,12 @@ func (c *CloudWatch) DescribeAlarmHistoryPages(input *DescribeAlarmHistoryInput,
const opDescribeAlarms = "DescribeAlarms"
// DescribeAlarmsRequest generates a request for the DescribeAlarms operation.
func (c *CloudWatch) DescribeAlarmsRequest(input *DescribeAlarmsInput) (req *aws.Request, output *DescribeAlarmsOutput) {
op := &aws.Operation{
func (c *CloudWatch) DescribeAlarmsRequest(input *DescribeAlarmsInput) (req *request.Request, output *DescribeAlarmsOutput) {
op := &request.Operation{
Name: opDescribeAlarms,
HTTPMethod: "POST",
HTTPPath: "/",
Paginator: &aws.Paginator{
Paginator: &request.Paginator{
InputTokens: []string{"NextToken"},
OutputTokens: []string{"NextToken"},
LimitToken: "MaxRecords",
@@ -116,6 +117,7 @@ func (c *CloudWatch) DescribeAlarms(input *DescribeAlarmsInput) (*DescribeAlarms
func (c *CloudWatch) DescribeAlarmsPages(input *DescribeAlarmsInput, fn func(p *DescribeAlarmsOutput, lastPage bool) (shouldContinue bool)) error {
page, _ := c.DescribeAlarmsRequest(input)
page.Handlers.Build.PushBack(request.MakeAddToUserAgentFreeFormHandler("Paginator"))
return page.EachPage(func(p interface{}, lastPage bool) bool {
return fn(p.(*DescribeAlarmsOutput), lastPage)
})
@@ -124,8 +126,8 @@ func (c *CloudWatch) DescribeAlarmsPages(input *DescribeAlarmsInput, fn func(p *
const opDescribeAlarmsForMetric = "DescribeAlarmsForMetric"
// DescribeAlarmsForMetricRequest generates a request for the DescribeAlarmsForMetric operation.
func (c *CloudWatch) DescribeAlarmsForMetricRequest(input *DescribeAlarmsForMetricInput) (req *aws.Request, output *DescribeAlarmsForMetricOutput) {
op := &aws.Operation{
func (c *CloudWatch) DescribeAlarmsForMetricRequest(input *DescribeAlarmsForMetricInput) (req *request.Request, output *DescribeAlarmsForMetricOutput) {
op := &request.Operation{
Name: opDescribeAlarmsForMetric,
HTTPMethod: "POST",
HTTPPath: "/",
@@ -152,8 +154,8 @@ func (c *CloudWatch) DescribeAlarmsForMetric(input *DescribeAlarmsForMetricInput
const opDisableAlarmActions = "DisableAlarmActions"
// DisableAlarmActionsRequest generates a request for the DisableAlarmActions operation.
func (c *CloudWatch) DisableAlarmActionsRequest(input *DisableAlarmActionsInput) (req *aws.Request, output *DisableAlarmActionsOutput) {
op := &aws.Operation{
func (c *CloudWatch) DisableAlarmActionsRequest(input *DisableAlarmActionsInput) (req *request.Request, output *DisableAlarmActionsOutput) {
op := &request.Operation{
Name: opDisableAlarmActions,
HTTPMethod: "POST",
HTTPPath: "/",
@@ -180,8 +182,8 @@ func (c *CloudWatch) DisableAlarmActions(input *DisableAlarmActionsInput) (*Disa
const opEnableAlarmActions = "EnableAlarmActions"
// EnableAlarmActionsRequest generates a request for the EnableAlarmActions operation.
func (c *CloudWatch) EnableAlarmActionsRequest(input *EnableAlarmActionsInput) (req *aws.Request, output *EnableAlarmActionsOutput) {
op := &aws.Operation{
func (c *CloudWatch) EnableAlarmActionsRequest(input *EnableAlarmActionsInput) (req *request.Request, output *EnableAlarmActionsOutput) {
op := &request.Operation{
Name: opEnableAlarmActions,
HTTPMethod: "POST",
HTTPPath: "/",
@@ -207,8 +209,8 @@ func (c *CloudWatch) EnableAlarmActions(input *EnableAlarmActionsInput) (*Enable
const opGetMetricStatistics = "GetMetricStatistics"
// GetMetricStatisticsRequest generates a request for the GetMetricStatistics operation.
func (c *CloudWatch) GetMetricStatisticsRequest(input *GetMetricStatisticsInput) (req *aws.Request, output *GetMetricStatisticsOutput) {
op := &aws.Operation{
func (c *CloudWatch) GetMetricStatisticsRequest(input *GetMetricStatisticsInput) (req *request.Request, output *GetMetricStatisticsOutput) {
op := &request.Operation{
Name: opGetMetricStatistics,
HTTPMethod: "POST",
HTTPPath: "/",
@@ -259,12 +261,12 @@ func (c *CloudWatch) GetMetricStatistics(input *GetMetricStatisticsInput) (*GetM
const opListMetrics = "ListMetrics"
// ListMetricsRequest generates a request for the ListMetrics operation.
func (c *CloudWatch) ListMetricsRequest(input *ListMetricsInput) (req *aws.Request, output *ListMetricsOutput) {
op := &aws.Operation{
func (c *CloudWatch) ListMetricsRequest(input *ListMetricsInput) (req *request.Request, output *ListMetricsOutput) {
op := &request.Operation{
Name: opListMetrics,
HTTPMethod: "POST",
HTTPPath: "/",
Paginator: &aws.Paginator{
Paginator: &request.Paginator{
InputTokens: []string{"NextToken"},
OutputTokens: []string{"NextToken"},
LimitToken: "",
@@ -293,6 +295,7 @@ func (c *CloudWatch) ListMetrics(input *ListMetricsInput) (*ListMetricsOutput, e
func (c *CloudWatch) ListMetricsPages(input *ListMetricsInput, fn func(p *ListMetricsOutput, lastPage bool) (shouldContinue bool)) error {
page, _ := c.ListMetricsRequest(input)
page.Handlers.Build.PushBack(request.MakeAddToUserAgentFreeFormHandler("Paginator"))
return page.EachPage(func(p interface{}, lastPage bool) bool {
return fn(p.(*ListMetricsOutput), lastPage)
})
@@ -301,8 +304,8 @@ func (c *CloudWatch) ListMetricsPages(input *ListMetricsInput, fn func(p *ListMe
const opPutMetricAlarm = "PutMetricAlarm"
// PutMetricAlarmRequest generates a request for the PutMetricAlarm operation.
func (c *CloudWatch) PutMetricAlarmRequest(input *PutMetricAlarmInput) (req *aws.Request, output *PutMetricAlarmOutput) {
op := &aws.Operation{
func (c *CloudWatch) PutMetricAlarmRequest(input *PutMetricAlarmInput) (req *request.Request, output *PutMetricAlarmOutput) {
op := &request.Operation{
Name: opPutMetricAlarm,
HTTPMethod: "POST",
HTTPPath: "/",
@@ -334,8 +337,8 @@ func (c *CloudWatch) PutMetricAlarm(input *PutMetricAlarmInput) (*PutMetricAlarm
const opPutMetricData = "PutMetricData"
// PutMetricDataRequest generates a request for the PutMetricData operation.
func (c *CloudWatch) PutMetricDataRequest(input *PutMetricDataInput) (req *aws.Request, output *PutMetricDataOutput) {
op := &aws.Operation{
func (c *CloudWatch) PutMetricDataRequest(input *PutMetricDataInput) (req *request.Request, output *PutMetricDataOutput) {
op := &request.Operation{
Name: opPutMetricData,
HTTPMethod: "POST",
HTTPPath: "/",
@@ -374,8 +377,8 @@ func (c *CloudWatch) PutMetricData(input *PutMetricDataInput) (*PutMetricDataOut
const opSetAlarmState = "SetAlarmState"
// SetAlarmStateRequest generates a request for the SetAlarmState operation.
func (c *CloudWatch) SetAlarmStateRequest(input *SetAlarmStateInput) (req *aws.Request, output *SetAlarmStateOutput) {
op := &aws.Operation{
func (c *CloudWatch) SetAlarmStateRequest(input *SetAlarmStateInput) (req *request.Request, output *SetAlarmStateOutput) {
op := &request.Operation{
Name: opSetAlarmState,
HTTPMethod: "POST",
HTTPPath: "/",
@@ -406,16 +409,16 @@ func (c *CloudWatch) SetAlarmState(input *SetAlarmStateInput) (*SetAlarmStateOut
// returns this data type as part of the DescribeAlarmHistoryResult data type.
type AlarmHistoryItem struct {
// The descriptive name for the alarm.
AlarmName *string `type:"string"`
AlarmName *string `min:"1" type:"string"`
// Machine-readable data about the alarm in JSON format.
HistoryData *string `type:"string"`
HistoryData *string `min:"1" type:"string"`
// The type of alarm history item.
HistoryItemType *string `type:"string" enum:"HistoryItemType"`
// A human-readable summary of the alarm history.
HistorySummary *string `type:"string"`
HistorySummary *string `min:"1" type:"string"`
// The time stamp for the alarm history item. Amazon CloudWatch uses Coordinated
// Universal Time (UTC) when returning time stamps, which do not accommodate
@@ -528,7 +531,7 @@ func (s DeleteAlarmsOutput) GoString() string {
type DescribeAlarmHistoryInput struct {
// The name of the alarm.
AlarmName *string `type:"string"`
AlarmName *string `min:"1" type:"string"`
// The ending date to retrieve alarm history.
EndDate *time.Time `type:"timestamp" timestampFormat:"iso8601"`
@@ -537,7 +540,7 @@ type DescribeAlarmHistoryInput struct {
HistoryItemType *string `type:"string" enum:"HistoryItemType"`
// The maximum number of alarm history records to retrieve.
MaxRecords *int64 `type:"integer"`
MaxRecords *int64 `min:"1" type:"integer"`
// The token returned by a previous call to indicate that there is more data
// available.
@@ -593,13 +596,13 @@ type DescribeAlarmsForMetricInput struct {
Dimensions []*Dimension `type:"list"`
// The name of the metric.
MetricName *string `type:"string" required:"true"`
MetricName *string `min:"1" type:"string" required:"true"`
// The namespace of the metric.
Namespace *string `type:"string" required:"true"`
Namespace *string `min:"1" type:"string" required:"true"`
// The period in seconds over which the statistic is applied.
Period *int64 `type:"integer"`
Period *int64 `min:"60" type:"integer"`
// The statistic for the metric.
Statistic *string `type:"string" enum:"Statistic"`
@@ -648,17 +651,17 @@ func (s DescribeAlarmsForMetricOutput) GoString() string {
type DescribeAlarmsInput struct {
// The action name prefix.
ActionPrefix *string `type:"string"`
ActionPrefix *string `min:"1" type:"string"`
// The alarm name prefix. AlarmNames cannot be specified if this parameter is
// specified.
AlarmNamePrefix *string `type:"string"`
AlarmNamePrefix *string `min:"1" type:"string"`
// A list of alarm names to retrieve information for.
AlarmNames []*string `type:"list"`
// The maximum number of alarm descriptions to retrieve.
MaxRecords *int64 `type:"integer"`
MaxRecords *int64 `min:"1" type:"integer"`
// The token returned by a previous call to indicate that there is more data
// available.
@@ -715,10 +718,10 @@ func (s DescribeAlarmsOutput) GoString() string {
// For examples that use one or more dimensions, see PutMetricData.
type Dimension struct {
// The name of the dimension.
Name *string `type:"string" required:"true"`
Name *string `min:"1" type:"string" required:"true"`
// The value representing the dimension measurement
Value *string `type:"string" required:"true"`
Value *string `min:"1" type:"string" required:"true"`
metadataDimension `json:"-" xml:"-"`
}
@@ -740,10 +743,10 @@ func (s Dimension) GoString() string {
// The DimensionFilter data type is used to filter ListMetrics results.
type DimensionFilter struct {
// The dimension name to be matched.
Name *string `type:"string" required:"true"`
Name *string `min:"1" type:"string" required:"true"`
// The value of the dimension to be matched.
Value *string `type:"string"`
Value *string `min:"1" type:"string"`
metadataDimensionFilter `json:"-" xml:"-"`
}
@@ -850,14 +853,14 @@ type GetMetricStatisticsInput struct {
EndTime *time.Time `type:"timestamp" timestampFormat:"iso8601" required:"true"`
// The name of the metric, with or without spaces.
MetricName *string `type:"string" required:"true"`
MetricName *string `min:"1" type:"string" required:"true"`
// The namespace of the metric, with or without spaces.
Namespace *string `type:"string" required:"true"`
Namespace *string `min:"1" type:"string" required:"true"`
// The granularity, in seconds, of the returned datapoints. Period must be at
// least 60 seconds and must be a multiple of 60. The default value is 60.
Period *int64 `type:"integer" required:"true"`
Period *int64 `min:"60" type:"integer" required:"true"`
// The time stamp to use for determining the first datapoint to return. The
// value specified is inclusive; results include datapoints with the time stamp
@@ -869,7 +872,7 @@ type GetMetricStatisticsInput struct {
// in the Amazon CloudWatch Developer Guide.
//
// Valid Values: Average | Sum | SampleCount | Maximum | Minimum
Statistics []*string `type:"list" required:"true"`
Statistics []*string `min:"1" type:"list" required:"true"`
// The unit for the metric.
Unit *string `type:"string" enum:"StandardUnit"`
@@ -921,10 +924,10 @@ type ListMetricsInput struct {
Dimensions []*DimensionFilter `type:"list"`
// The name of the metric to filter against.
MetricName *string `type:"string"`
MetricName *string `min:"1" type:"string"`
// The namespace to filter against.
Namespace *string `type:"string"`
Namespace *string `min:"1" type:"string"`
// The token returned by a previous call to indicate that there is more data
// available.
@@ -984,10 +987,10 @@ type Metric struct {
Dimensions []*Dimension `type:"list"`
// The name of the metric.
MetricName *string `type:"string"`
MetricName *string `min:"1" type:"string"`
// The namespace of the metric.
Namespace *string `type:"string"`
Namespace *string `min:"1" type:"string"`
metadataMetric `json:"-" xml:"-"`
}
@@ -1013,15 +1016,15 @@ type MetricAlarm struct {
// state.
ActionsEnabled *bool `type:"boolean"`
// The Amazon Resource Name (ARN) of the alarm.
AlarmARN *string `locationName:"AlarmArn" type:"string"`
// The list of actions to execute when this alarm transitions into an ALARM
// state from any other state. Each action is specified as an Amazon Resource
// Number (ARN). Currently the only actions supported are publishing to an Amazon
// SNS topic and triggering an Auto Scaling policy.
AlarmActions []*string `type:"list"`
// The Amazon Resource Name (ARN) of the alarm.
AlarmArn *string `min:"1" type:"string"`
// The time stamp of the last update to the alarm configuration. Amazon CloudWatch
// uses Coordinated Universal Time (UTC) when returning time stamps, which do
// not accommodate seasonal adjustments such as daylight savings time. For more
@@ -1033,7 +1036,7 @@ type MetricAlarm struct {
AlarmDescription *string `type:"string"`
// The name of the alarm.
AlarmName *string `type:"string"`
AlarmName *string `min:"1" type:"string"`
// The arithmetic operation to use when comparing the specified Statistic and
// Threshold. The specified Statistic value is used as the first operand.
@@ -1043,7 +1046,7 @@ type MetricAlarm struct {
Dimensions []*Dimension `type:"list"`
// The number of periods over which data is compared to the specified threshold.
EvaluationPeriods *int64 `type:"integer"`
EvaluationPeriods *int64 `min:"1" type:"integer"`
// The list of actions to execute when this alarm transitions into an INSUFFICIENT_DATA
// state from any other state. Each action is specified as an Amazon Resource
@@ -1054,10 +1057,10 @@ type MetricAlarm struct {
InsufficientDataActions []*string `type:"list"`
// The name of the alarm's metric.
MetricName *string `type:"string"`
MetricName *string `min:"1" type:"string"`
// The namespace of alarm's associated metric.
Namespace *string `type:"string"`
Namespace *string `min:"1" type:"string"`
// The list of actions to execute when this alarm transitions into an OK state
// from any other state. Each action is specified as an Amazon Resource Number
@@ -1066,7 +1069,7 @@ type MetricAlarm struct {
OKActions []*string `type:"list"`
// The period in seconds over which the statistic is applied.
Period *int64 `type:"integer"`
Period *int64 `min:"60" type:"integer"`
// A human-readable explanation for the alarm's state.
StateReason *string `type:"string"`
@@ -1119,7 +1122,7 @@ type MetricDatum struct {
Dimensions []*Dimension `type:"list"`
// The name of the metric.
MetricName *string `type:"string" required:"true"`
MetricName *string `min:"1" type:"string" required:"true"`
// A set of statistical values describing the metric.
StatisticValues *StatisticSet `type:"structure"`
@@ -1176,7 +1179,7 @@ type PutMetricAlarmInput struct {
// The descriptive name for the alarm. This name must be unique within the user's
// AWS account
AlarmName *string `type:"string" required:"true"`
AlarmName *string `min:"1" type:"string" required:"true"`
// The arithmetic operation to use when comparing the specified Statistic and
// Threshold. The specified Statistic value is used as the first operand.
@@ -1186,7 +1189,7 @@ type PutMetricAlarmInput struct {
Dimensions []*Dimension `type:"list"`
// The number of periods over which data is compared to the specified threshold.
EvaluationPeriods *int64 `type:"integer" required:"true"`
EvaluationPeriods *int64 `min:"1" type:"integer" required:"true"`
// The list of actions to execute when this alarm transitions into an INSUFFICIENT_DATA
// state from any other state. Each action is specified as an Amazon Resource
@@ -1195,10 +1198,10 @@ type PutMetricAlarmInput struct {
InsufficientDataActions []*string `type:"list"`
// The name for the alarm's associated metric.
MetricName *string `type:"string" required:"true"`
MetricName *string `min:"1" type:"string" required:"true"`
// The namespace for the alarm's associated metric.
Namespace *string `type:"string" required:"true"`
Namespace *string `min:"1" type:"string" required:"true"`
// The list of actions to execute when this alarm transitions into an OK state
// from any other state. Each action is specified as an Amazon Resource Number
@@ -1207,7 +1210,7 @@ type PutMetricAlarmInput struct {
OKActions []*string `type:"list"`
// The period in seconds over which the specified statistic is applied.
Period *int64 `type:"integer" required:"true"`
Period *int64 `min:"60" type:"integer" required:"true"`
// The statistic to apply to the alarm's associated metric.
Statistic *string `type:"string" required:"true" enum:"Statistic"`
@@ -1258,7 +1261,7 @@ type PutMetricDataInput struct {
MetricData []*MetricDatum `type:"list" required:"true"`
// The namespace for the metric data.
Namespace *string `type:"string" required:"true"`
Namespace *string `min:"1" type:"string" required:"true"`
metadataPutMetricDataInput `json:"-" xml:"-"`
}
@@ -1298,7 +1301,7 @@ func (s PutMetricDataOutput) GoString() string {
type SetAlarmStateInput struct {
// The descriptive name for the alarm. This name must be unique within the user's
// AWS account. The maximum length is 255 characters.
AlarmName *string `type:"string" required:"true"`
AlarmName *string `min:"1" type:"string" required:"true"`
// The reason that this alarm is set to this specific state (in human-readable
// text format)

View File

@@ -4,59 +4,61 @@
package cloudwatchiface
import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/cloudwatch"
)
// CloudWatchAPI is the interface type for cloudwatch.CloudWatch.
type CloudWatchAPI interface {
DeleteAlarmsRequest(*cloudwatch.DeleteAlarmsInput) (*aws.Request, *cloudwatch.DeleteAlarmsOutput)
DeleteAlarmsRequest(*cloudwatch.DeleteAlarmsInput) (*request.Request, *cloudwatch.DeleteAlarmsOutput)
DeleteAlarms(*cloudwatch.DeleteAlarmsInput) (*cloudwatch.DeleteAlarmsOutput, error)
DescribeAlarmHistoryRequest(*cloudwatch.DescribeAlarmHistoryInput) (*aws.Request, *cloudwatch.DescribeAlarmHistoryOutput)
DescribeAlarmHistoryRequest(*cloudwatch.DescribeAlarmHistoryInput) (*request.Request, *cloudwatch.DescribeAlarmHistoryOutput)
DescribeAlarmHistory(*cloudwatch.DescribeAlarmHistoryInput) (*cloudwatch.DescribeAlarmHistoryOutput, error)
DescribeAlarmHistoryPages(*cloudwatch.DescribeAlarmHistoryInput, func(*cloudwatch.DescribeAlarmHistoryOutput, bool) bool) error
DescribeAlarmsRequest(*cloudwatch.DescribeAlarmsInput) (*aws.Request, *cloudwatch.DescribeAlarmsOutput)
DescribeAlarmsRequest(*cloudwatch.DescribeAlarmsInput) (*request.Request, *cloudwatch.DescribeAlarmsOutput)
DescribeAlarms(*cloudwatch.DescribeAlarmsInput) (*cloudwatch.DescribeAlarmsOutput, error)
DescribeAlarmsPages(*cloudwatch.DescribeAlarmsInput, func(*cloudwatch.DescribeAlarmsOutput, bool) bool) error
DescribeAlarmsForMetricRequest(*cloudwatch.DescribeAlarmsForMetricInput) (*aws.Request, *cloudwatch.DescribeAlarmsForMetricOutput)
DescribeAlarmsForMetricRequest(*cloudwatch.DescribeAlarmsForMetricInput) (*request.Request, *cloudwatch.DescribeAlarmsForMetricOutput)
DescribeAlarmsForMetric(*cloudwatch.DescribeAlarmsForMetricInput) (*cloudwatch.DescribeAlarmsForMetricOutput, error)
DisableAlarmActionsRequest(*cloudwatch.DisableAlarmActionsInput) (*aws.Request, *cloudwatch.DisableAlarmActionsOutput)
DisableAlarmActionsRequest(*cloudwatch.DisableAlarmActionsInput) (*request.Request, *cloudwatch.DisableAlarmActionsOutput)
DisableAlarmActions(*cloudwatch.DisableAlarmActionsInput) (*cloudwatch.DisableAlarmActionsOutput, error)
EnableAlarmActionsRequest(*cloudwatch.EnableAlarmActionsInput) (*aws.Request, *cloudwatch.EnableAlarmActionsOutput)
EnableAlarmActionsRequest(*cloudwatch.EnableAlarmActionsInput) (*request.Request, *cloudwatch.EnableAlarmActionsOutput)
EnableAlarmActions(*cloudwatch.EnableAlarmActionsInput) (*cloudwatch.EnableAlarmActionsOutput, error)
GetMetricStatisticsRequest(*cloudwatch.GetMetricStatisticsInput) (*aws.Request, *cloudwatch.GetMetricStatisticsOutput)
GetMetricStatisticsRequest(*cloudwatch.GetMetricStatisticsInput) (*request.Request, *cloudwatch.GetMetricStatisticsOutput)
GetMetricStatistics(*cloudwatch.GetMetricStatisticsInput) (*cloudwatch.GetMetricStatisticsOutput, error)
ListMetricsRequest(*cloudwatch.ListMetricsInput) (*aws.Request, *cloudwatch.ListMetricsOutput)
ListMetricsRequest(*cloudwatch.ListMetricsInput) (*request.Request, *cloudwatch.ListMetricsOutput)
ListMetrics(*cloudwatch.ListMetricsInput) (*cloudwatch.ListMetricsOutput, error)
ListMetricsPages(*cloudwatch.ListMetricsInput, func(*cloudwatch.ListMetricsOutput, bool) bool) error
PutMetricAlarmRequest(*cloudwatch.PutMetricAlarmInput) (*aws.Request, *cloudwatch.PutMetricAlarmOutput)
PutMetricAlarmRequest(*cloudwatch.PutMetricAlarmInput) (*request.Request, *cloudwatch.PutMetricAlarmOutput)
PutMetricAlarm(*cloudwatch.PutMetricAlarmInput) (*cloudwatch.PutMetricAlarmOutput, error)
PutMetricDataRequest(*cloudwatch.PutMetricDataInput) (*aws.Request, *cloudwatch.PutMetricDataOutput)
PutMetricDataRequest(*cloudwatch.PutMetricDataInput) (*request.Request, *cloudwatch.PutMetricDataOutput)
PutMetricData(*cloudwatch.PutMetricDataInput) (*cloudwatch.PutMetricDataOutput, error)
SetAlarmStateRequest(*cloudwatch.SetAlarmStateInput) (*aws.Request, *cloudwatch.SetAlarmStateOutput)
SetAlarmStateRequest(*cloudwatch.SetAlarmStateInput) (*request.Request, *cloudwatch.SetAlarmStateOutput)
SetAlarmState(*cloudwatch.SetAlarmStateInput) (*cloudwatch.SetAlarmStateOutput, error)
}
var _ CloudWatchAPI = (*cloudwatch.CloudWatch)(nil)

View File

@@ -1,15 +0,0 @@
// THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT.
package cloudwatchiface_test
import (
"testing"
"github.com/aws/aws-sdk-go/service/cloudwatch"
"github.com/aws/aws-sdk-go/service/cloudwatch/cloudwatchiface"
"github.com/stretchr/testify/assert"
)
func TestInterface(t *testing.T) {
assert.Implements(t, (*cloudwatchiface.CloudWatchAPI)(nil), cloudwatch.New(nil))
}

View File

@@ -8,8 +8,7 @@ import (
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/cloudwatch"
)
@@ -17,7 +16,7 @@ var _ time.Duration
var _ bytes.Buffer
func ExampleCloudWatch_DeleteAlarms() {
svc := cloudwatch.New(nil)
svc := cloudwatch.New(session.New())
params := &cloudwatch.DeleteAlarmsInput{
AlarmNames: []*string{ // Required
@@ -28,26 +27,18 @@ func ExampleCloudWatch_DeleteAlarms() {
resp, err := svc.DeleteAlarms(params)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
// Generic AWS error with Code, Message, and original error (if any)
fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
if reqErr, ok := err.(awserr.RequestFailure); ok {
// A service error occurred
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
return
}
// Pretty-print the response data.
fmt.Println(awsutil.Prettify(resp))
fmt.Println(resp)
}
func ExampleCloudWatch_DescribeAlarmHistory() {
svc := cloudwatch.New(nil)
svc := cloudwatch.New(session.New())
params := &cloudwatch.DescribeAlarmHistoryInput{
AlarmName: aws.String("AlarmName"),
@@ -60,26 +51,18 @@ func ExampleCloudWatch_DescribeAlarmHistory() {
resp, err := svc.DescribeAlarmHistory(params)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
// Generic AWS error with Code, Message, and original error (if any)
fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
if reqErr, ok := err.(awserr.RequestFailure); ok {
// A service error occurred
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
return
}
// Pretty-print the response data.
fmt.Println(awsutil.Prettify(resp))
fmt.Println(resp)
}
func ExampleCloudWatch_DescribeAlarms() {
svc := cloudwatch.New(nil)
svc := cloudwatch.New(session.New())
params := &cloudwatch.DescribeAlarmsInput{
ActionPrefix: aws.String("ActionPrefix"),
@@ -95,26 +78,18 @@ func ExampleCloudWatch_DescribeAlarms() {
resp, err := svc.DescribeAlarms(params)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
// Generic AWS error with Code, Message, and original error (if any)
fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
if reqErr, ok := err.(awserr.RequestFailure); ok {
// A service error occurred
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
return
}
// Pretty-print the response data.
fmt.Println(awsutil.Prettify(resp))
fmt.Println(resp)
}
func ExampleCloudWatch_DescribeAlarmsForMetric() {
svc := cloudwatch.New(nil)
svc := cloudwatch.New(session.New())
params := &cloudwatch.DescribeAlarmsForMetricInput{
MetricName: aws.String("MetricName"), // Required
@@ -133,26 +108,18 @@ func ExampleCloudWatch_DescribeAlarmsForMetric() {
resp, err := svc.DescribeAlarmsForMetric(params)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
// Generic AWS error with Code, Message, and original error (if any)
fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
if reqErr, ok := err.(awserr.RequestFailure); ok {
// A service error occurred
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
return
}
// Pretty-print the response data.
fmt.Println(awsutil.Prettify(resp))
fmt.Println(resp)
}
func ExampleCloudWatch_DisableAlarmActions() {
svc := cloudwatch.New(nil)
svc := cloudwatch.New(session.New())
params := &cloudwatch.DisableAlarmActionsInput{
AlarmNames: []*string{ // Required
@@ -163,26 +130,18 @@ func ExampleCloudWatch_DisableAlarmActions() {
resp, err := svc.DisableAlarmActions(params)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
// Generic AWS error with Code, Message, and original error (if any)
fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
if reqErr, ok := err.(awserr.RequestFailure); ok {
// A service error occurred
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
return
}
// Pretty-print the response data.
fmt.Println(awsutil.Prettify(resp))
fmt.Println(resp)
}
func ExampleCloudWatch_EnableAlarmActions() {
svc := cloudwatch.New(nil)
svc := cloudwatch.New(session.New())
params := &cloudwatch.EnableAlarmActionsInput{
AlarmNames: []*string{ // Required
@@ -193,26 +152,18 @@ func ExampleCloudWatch_EnableAlarmActions() {
resp, err := svc.EnableAlarmActions(params)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
// Generic AWS error with Code, Message, and original error (if any)
fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
if reqErr, ok := err.(awserr.RequestFailure); ok {
// A service error occurred
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
return
}
// Pretty-print the response data.
fmt.Println(awsutil.Prettify(resp))
fmt.Println(resp)
}
func ExampleCloudWatch_GetMetricStatistics() {
svc := cloudwatch.New(nil)
svc := cloudwatch.New(session.New())
params := &cloudwatch.GetMetricStatisticsInput{
EndTime: aws.Time(time.Now()), // Required
@@ -236,26 +187,18 @@ func ExampleCloudWatch_GetMetricStatistics() {
resp, err := svc.GetMetricStatistics(params)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
// Generic AWS error with Code, Message, and original error (if any)
fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
if reqErr, ok := err.(awserr.RequestFailure); ok {
// A service error occurred
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
return
}
// Pretty-print the response data.
fmt.Println(awsutil.Prettify(resp))
fmt.Println(resp)
}
func ExampleCloudWatch_ListMetrics() {
svc := cloudwatch.New(nil)
svc := cloudwatch.New(session.New())
params := &cloudwatch.ListMetricsInput{
Dimensions: []*cloudwatch.DimensionFilter{
@@ -272,26 +215,18 @@ func ExampleCloudWatch_ListMetrics() {
resp, err := svc.ListMetrics(params)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
// Generic AWS error with Code, Message, and original error (if any)
fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
if reqErr, ok := err.(awserr.RequestFailure); ok {
// A service error occurred
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
return
}
// Pretty-print the response data.
fmt.Println(awsutil.Prettify(resp))
fmt.Println(resp)
}
func ExampleCloudWatch_PutMetricAlarm() {
svc := cloudwatch.New(nil)
svc := cloudwatch.New(session.New())
params := &cloudwatch.PutMetricAlarmInput{
AlarmName: aws.String("AlarmName"), // Required
@@ -328,26 +263,18 @@ func ExampleCloudWatch_PutMetricAlarm() {
resp, err := svc.PutMetricAlarm(params)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
// Generic AWS error with Code, Message, and original error (if any)
fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
if reqErr, ok := err.(awserr.RequestFailure); ok {
// A service error occurred
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
return
}
// Pretty-print the response data.
fmt.Println(awsutil.Prettify(resp))
fmt.Println(resp)
}
func ExampleCloudWatch_PutMetricData() {
svc := cloudwatch.New(nil)
svc := cloudwatch.New(session.New())
params := &cloudwatch.PutMetricDataInput{
MetricData: []*cloudwatch.MetricDatum{ // Required
@@ -377,26 +304,18 @@ func ExampleCloudWatch_PutMetricData() {
resp, err := svc.PutMetricData(params)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
// Generic AWS error with Code, Message, and original error (if any)
fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
if reqErr, ok := err.(awserr.RequestFailure); ok {
// A service error occurred
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
return
}
// Pretty-print the response data.
fmt.Println(awsutil.Prettify(resp))
fmt.Println(resp)
}
func ExampleCloudWatch_SetAlarmState() {
svc := cloudwatch.New(nil)
svc := cloudwatch.New(session.New())
params := &cloudwatch.SetAlarmStateInput{
AlarmName: aws.String("AlarmName"), // Required
@@ -407,20 +326,12 @@ func ExampleCloudWatch_SetAlarmState() {
resp, err := svc.SetAlarmState(params)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
// Generic AWS error with Code, Message, and original error (if any)
fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
if reqErr, ok := err.(awserr.RequestFailure); ok {
// A service error occurred
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
fmt.Println(err.Error())
return
}
// Pretty-print the response data.
fmt.Println(awsutil.Prettify(resp))
fmt.Println(resp)
}

View File

@@ -4,8 +4,11 @@ package cloudwatch
import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/internal/protocol/query"
"github.com/aws/aws-sdk-go/internal/signer/v4"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/private/protocol/query"
"github.com/aws/aws-sdk-go/private/signer/v4"
)
// This is the Amazon CloudWatch API Reference. This guide provides detailed
@@ -48,44 +51,70 @@ import (
// Center (http://aws.amazon.com/php/) AWS Python Developer Center (http://aws.amazon.com/python/)
// AWS Ruby Developer Center (http://aws.amazon.com/ruby/) AWS Windows and .NET
// Developer Center (http://aws.amazon.com/net/)
//The service client's operations are safe to be used concurrently.
// It is not safe to mutate any of the client's properties though.
type CloudWatch struct {
*aws.Service
*client.Client
}
// Used for custom service initialization logic
var initService func(*aws.Service)
// Used for custom client initialization logic
var initClient func(*client.Client)
// Used for custom request initialization logic
var initRequest func(*aws.Request)
var initRequest func(*request.Request)
// New returns a new CloudWatch client.
func New(config *aws.Config) *CloudWatch {
service := &aws.Service{
Config: aws.DefaultConfig.Merge(config),
ServiceName: "monitoring",
APIVersion: "2010-08-01",
// A ServiceName is the name of the service the client will make API calls to.
const ServiceName = "monitoring"
// New creates a new instance of the CloudWatch client with a session.
// If additional configuration is needed for the client instance use the optional
// aws.Config parameter to add your extra config.
//
// Example:
// // Create a CloudWatch client from just a session.
// svc := cloudwatch.New(mySession)
//
// // Create a CloudWatch client with additional configuration
// svc := cloudwatch.New(mySession, aws.NewConfig().WithRegion("us-west-2"))
func New(p client.ConfigProvider, cfgs ...*aws.Config) *CloudWatch {
c := p.ClientConfig(ServiceName, cfgs...)
return newClient(*c.Config, c.Handlers, c.Endpoint, c.SigningRegion)
}
// newClient creates, initializes and returns a new service client instance.
func newClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegion string) *CloudWatch {
svc := &CloudWatch{
Client: client.New(
cfg,
metadata.ClientInfo{
ServiceName: ServiceName,
SigningRegion: signingRegion,
Endpoint: endpoint,
APIVersion: "2010-08-01",
},
handlers,
),
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(query.Build)
service.Handlers.Unmarshal.PushBack(query.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(query.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(query.UnmarshalError)
svc.Handlers.Sign.PushBack(v4.Sign)
svc.Handlers.Build.PushBack(query.Build)
svc.Handlers.Unmarshal.PushBack(query.Unmarshal)
svc.Handlers.UnmarshalMeta.PushBack(query.UnmarshalMeta)
svc.Handlers.UnmarshalError.PushBack(query.UnmarshalError)
// Run custom service initialization if present
if initService != nil {
initService(service)
// Run custom client initialization if present
if initClient != nil {
initClient(svc.Client)
}
return &CloudWatch{service}
return svc
}
// newRequest creates a new request for a CloudWatch operation and runs any
// custom request initialization.
func (c *CloudWatch) newRequest(op *aws.Operation, params, data interface{}) *aws.Request {
req := aws.NewRequest(c.Service, op, params, data)
func (c *CloudWatch) newRequest(op *request.Operation, params, data interface{}) *request.Request {
req := c.NewRequest(op, params, data)
// Run custom request initialization if present
if initRequest != nil {

File diff suppressed because it is too large Load Diff

View File

@@ -5,53 +5,51 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/private/endpoints"
)
func init() {
initRequest = func(r *aws.Request) {
initRequest = func(r *request.Request) {
if r.Operation.Name == opCopySnapshot { // fill the PresignedURL parameter
r.Handlers.Build.PushFront(fillPresignedURL)
}
}
}
func fillPresignedURL(r *aws.Request) {
func fillPresignedURL(r *request.Request) {
if !r.ParamsFilled() {
return
}
params := r.Params.(*CopySnapshotInput)
origParams := r.Params.(*CopySnapshotInput)
// Stop if PresignedURL/DestinationRegion is set
if params.PresignedURL != nil || params.DestinationRegion != nil {
if origParams.PresignedUrl != nil || origParams.DestinationRegion != nil {
return
}
// First generate a copy of parameters
r.Params = awsutil.CopyOf(r.Params)
params = r.Params.(*CopySnapshotInput)
origParams.DestinationRegion = r.Config.Region
newParams := awsutil.CopyOf(r.Params).(*CopySnapshotInput)
// Set destination region. Avoids infinite handler loop.
// Also needed to sign sub-request.
params.DestinationRegion = r.Service.Config.Region
// Create a new client pointing at source region.
// We will use this to presign the CopySnapshot request against
// the source region
config := r.Service.Config.Copy().
// Create a new request based on the existing request. We will use this to
// presign the CopySnapshot request against the source region.
cfg := r.Config.Copy(aws.NewConfig().
WithEndpoint("").
WithRegion(*params.SourceRegion)
WithRegion(aws.StringValue(origParams.SourceRegion)))
client := New(config)
clientInfo := r.ClientInfo
clientInfo.Endpoint, clientInfo.SigningRegion = endpoints.EndpointForRegion(
clientInfo.ServiceName, aws.StringValue(cfg.Region), aws.BoolValue(cfg.DisableSSL))
// Presign a CopySnapshot request with modified params
req, _ := client.CopySnapshotRequest(params)
url, err := req.Presign(300 * time.Second) // 5 minutes should be enough.
if err != nil { // bubble error back up to original request
req := request.New(*cfg, clientInfo, r.Handlers, r.Retryer, r.Operation, newParams, r.Data)
url, err := req.Presign(5 * time.Minute) // 5 minutes should be enough.
if err != nil { // bubble error back up to original request
r.Error = err
return
}
// We have our URL, set it on params
params.PresignedURL = &url
origParams.PresignedUrl = &url
}

View File

@@ -6,15 +6,13 @@ import (
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/internal/test/unit"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/stretchr/testify/assert"
)
var _ = unit.Imported
func TestCopySnapshotPresignedURL(t *testing.T) {
svc := ec2.New(&aws.Config{Region: aws.String("us-west-2")})
svc := ec2.New(unit.Session, &aws.Config{Region: aws.String("us-west-2")})
assert.NotPanics(t, func() {
// Doesn't panic on nil input
@@ -24,13 +22,14 @@ func TestCopySnapshotPresignedURL(t *testing.T) {
req, _ := svc.CopySnapshotRequest(&ec2.CopySnapshotInput{
SourceRegion: aws.String("us-west-1"),
SourceSnapshotID: aws.String("snap-id"),
SourceSnapshotId: aws.String("snap-id"),
})
req.Sign()
b, _ := ioutil.ReadAll(req.HTTPRequest.Body)
q, _ := url.ParseQuery(string(b))
url, _ := url.QueryUnescape(q.Get("PresignedUrl"))
u, _ := url.QueryUnescape(q.Get("PresignedUrl"))
assert.Equal(t, "us-west-2", q.Get("DestinationRegion"))
assert.Regexp(t, `^https://ec2\.us-west-1\.amazon.+&DestinationRegion=us-west-2`, url)
assert.Equal(t, "us-west-1", q.Get("SourceRegion"))
assert.Regexp(t, `^https://ec2\.us-west-1\.amazonaws\.com/.+&DestinationRegion=us-west-2`, u)
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,15 +0,0 @@
// THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT.
package ec2iface_test
import (
"testing"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/ec2/ec2iface"
"github.com/stretchr/testify/assert"
)
func TestInterface(t *testing.T) {
assert.Implements(t, (*ec2iface.EC2API)(nil), ec2.New(nil))
}

File diff suppressed because it is too large Load Diff

View File

@@ -4,52 +4,81 @@ package ec2
import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/internal/protocol/ec2query"
"github.com/aws/aws-sdk-go/internal/signer/v4"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/private/protocol/ec2query"
"github.com/aws/aws-sdk-go/private/signer/v4"
)
// Amazon Elastic Compute Cloud (Amazon EC2) provides resizable computing capacity
// in the Amazon Web Services (AWS) cloud. Using Amazon EC2 eliminates your
// need to invest in hardware up front, so you can develop and deploy applications
// faster.
//The service client's operations are safe to be used concurrently.
// It is not safe to mutate any of the client's properties though.
type EC2 struct {
*aws.Service
*client.Client
}
// Used for custom service initialization logic
var initService func(*aws.Service)
// Used for custom client initialization logic
var initClient func(*client.Client)
// Used for custom request initialization logic
var initRequest func(*aws.Request)
var initRequest func(*request.Request)
// New returns a new EC2 client.
func New(config *aws.Config) *EC2 {
service := &aws.Service{
Config: aws.DefaultConfig.Merge(config),
ServiceName: "ec2",
APIVersion: "2015-04-15",
// A ServiceName is the name of the service the client will make API calls to.
const ServiceName = "ec2"
// New creates a new instance of the EC2 client with a session.
// If additional configuration is needed for the client instance use the optional
// aws.Config parameter to add your extra config.
//
// Example:
// // Create a EC2 client from just a session.
// svc := ec2.New(mySession)
//
// // Create a EC2 client with additional configuration
// svc := ec2.New(mySession, aws.NewConfig().WithRegion("us-west-2"))
func New(p client.ConfigProvider, cfgs ...*aws.Config) *EC2 {
c := p.ClientConfig(ServiceName, cfgs...)
return newClient(*c.Config, c.Handlers, c.Endpoint, c.SigningRegion)
}
// newClient creates, initializes and returns a new service client instance.
func newClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegion string) *EC2 {
svc := &EC2{
Client: client.New(
cfg,
metadata.ClientInfo{
ServiceName: ServiceName,
SigningRegion: signingRegion,
Endpoint: endpoint,
APIVersion: "2015-10-01",
},
handlers,
),
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(ec2query.Build)
service.Handlers.Unmarshal.PushBack(ec2query.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(ec2query.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(ec2query.UnmarshalError)
svc.Handlers.Sign.PushBack(v4.Sign)
svc.Handlers.Build.PushBack(ec2query.Build)
svc.Handlers.Unmarshal.PushBack(ec2query.Unmarshal)
svc.Handlers.UnmarshalMeta.PushBack(ec2query.UnmarshalMeta)
svc.Handlers.UnmarshalError.PushBack(ec2query.UnmarshalError)
// Run custom service initialization if present
if initService != nil {
initService(service)
// Run custom client initialization if present
if initClient != nil {
initClient(svc.Client)
}
return &EC2{service}
return svc
}
// newRequest creates a new request for a EC2 operation and runs any
// custom request initialization.
func (c *EC2) newRequest(op *aws.Operation, params, data interface{}) *aws.Request {
req := aws.NewRequest(c.Service, op, params, data)
func (c *EC2) newRequest(op *request.Operation, params, data interface{}) *request.Request {
req := c.NewRequest(op, params, data)
// Run custom request initialization if present
if initRequest != nil {

View File

@@ -0,0 +1,761 @@
// THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT.
package ec2
import (
"github.com/aws/aws-sdk-go/private/waiter"
)
func (c *EC2) WaitUntilBundleTaskComplete(input *DescribeBundleTasksInput) error {
waiterCfg := waiter.Config{
Operation: "DescribeBundleTasks",
Delay: 15,
MaxAttempts: 40,
Acceptors: []waiter.WaitAcceptor{
{
State: "success",
Matcher: "pathAll",
Argument: "BundleTasks[].State",
Expected: "complete",
},
{
State: "failure",
Matcher: "pathAny",
Argument: "BundleTasks[].State",
Expected: "failed",
},
},
}
w := waiter.Waiter{
Client: c,
Input: input,
Config: waiterCfg,
}
return w.Wait()
}
func (c *EC2) WaitUntilConversionTaskCancelled(input *DescribeConversionTasksInput) error {
waiterCfg := waiter.Config{
Operation: "DescribeConversionTasks",
Delay: 15,
MaxAttempts: 40,
Acceptors: []waiter.WaitAcceptor{
{
State: "success",
Matcher: "pathAll",
Argument: "ConversionTasks[].State",
Expected: "cancelled",
},
},
}
w := waiter.Waiter{
Client: c,
Input: input,
Config: waiterCfg,
}
return w.Wait()
}
func (c *EC2) WaitUntilConversionTaskCompleted(input *DescribeConversionTasksInput) error {
waiterCfg := waiter.Config{
Operation: "DescribeConversionTasks",
Delay: 15,
MaxAttempts: 40,
Acceptors: []waiter.WaitAcceptor{
{
State: "success",
Matcher: "pathAll",
Argument: "ConversionTasks[].State",
Expected: "completed",
},
{
State: "failure",
Matcher: "pathAny",
Argument: "ConversionTasks[].State",
Expected: "cancelled",
},
{
State: "failure",
Matcher: "pathAny",
Argument: "ConversionTasks[].State",
Expected: "cancelling",
},
},
}
w := waiter.Waiter{
Client: c,
Input: input,
Config: waiterCfg,
}
return w.Wait()
}
func (c *EC2) WaitUntilConversionTaskDeleted(input *DescribeConversionTasksInput) error {
waiterCfg := waiter.Config{
Operation: "DescribeConversionTasks",
Delay: 15,
MaxAttempts: 40,
Acceptors: []waiter.WaitAcceptor{
{
State: "success",
Matcher: "pathAll",
Argument: "ConversionTasks[].State",
Expected: "deleted",
},
},
}
w := waiter.Waiter{
Client: c,
Input: input,
Config: waiterCfg,
}
return w.Wait()
}
func (c *EC2) WaitUntilCustomerGatewayAvailable(input *DescribeCustomerGatewaysInput) error {
waiterCfg := waiter.Config{
Operation: "DescribeCustomerGateways",
Delay: 15,
MaxAttempts: 40,
Acceptors: []waiter.WaitAcceptor{
{
State: "success",
Matcher: "pathAll",
Argument: "CustomerGateways[].State",
Expected: "available",
},
{
State: "failure",
Matcher: "pathAny",
Argument: "CustomerGateways[].State",
Expected: "deleted",
},
{
State: "failure",
Matcher: "pathAny",
Argument: "CustomerGateways[].State",
Expected: "deleting",
},
},
}
w := waiter.Waiter{
Client: c,
Input: input,
Config: waiterCfg,
}
return w.Wait()
}
func (c *EC2) WaitUntilExportTaskCancelled(input *DescribeExportTasksInput) error {
waiterCfg := waiter.Config{
Operation: "DescribeExportTasks",
Delay: 15,
MaxAttempts: 40,
Acceptors: []waiter.WaitAcceptor{
{
State: "success",
Matcher: "pathAll",
Argument: "ExportTasks[].State",
Expected: "cancelled",
},
},
}
w := waiter.Waiter{
Client: c,
Input: input,
Config: waiterCfg,
}
return w.Wait()
}
func (c *EC2) WaitUntilExportTaskCompleted(input *DescribeExportTasksInput) error {
waiterCfg := waiter.Config{
Operation: "DescribeExportTasks",
Delay: 15,
MaxAttempts: 40,
Acceptors: []waiter.WaitAcceptor{
{
State: "success",
Matcher: "pathAll",
Argument: "ExportTasks[].State",
Expected: "completed",
},
},
}
w := waiter.Waiter{
Client: c,
Input: input,
Config: waiterCfg,
}
return w.Wait()
}
func (c *EC2) WaitUntilImageAvailable(input *DescribeImagesInput) error {
waiterCfg := waiter.Config{
Operation: "DescribeImages",
Delay: 15,
MaxAttempts: 40,
Acceptors: []waiter.WaitAcceptor{
{
State: "success",
Matcher: "pathAll",
Argument: "Images[].State",
Expected: "available",
},
{
State: "failure",
Matcher: "pathAny",
Argument: "Images[].State",
Expected: "failed",
},
},
}
w := waiter.Waiter{
Client: c,
Input: input,
Config: waiterCfg,
}
return w.Wait()
}
func (c *EC2) WaitUntilInstanceExists(input *DescribeInstancesInput) error {
waiterCfg := waiter.Config{
Operation: "DescribeInstances",
Delay: 5,
MaxAttempts: 40,
Acceptors: []waiter.WaitAcceptor{
{
State: "success",
Matcher: "status",
Argument: "",
Expected: 200,
},
{
State: "retry",
Matcher: "error",
Argument: "",
Expected: "InvalidInstanceIDNotFound",
},
},
}
w := waiter.Waiter{
Client: c,
Input: input,
Config: waiterCfg,
}
return w.Wait()
}
func (c *EC2) WaitUntilInstanceRunning(input *DescribeInstancesInput) error {
waiterCfg := waiter.Config{
Operation: "DescribeInstances",
Delay: 15,
MaxAttempts: 40,
Acceptors: []waiter.WaitAcceptor{
{
State: "success",
Matcher: "pathAll",
Argument: "Reservations[].Instances[].State.Name",
Expected: "running",
},
{
State: "failure",
Matcher: "pathAny",
Argument: "Reservations[].Instances[].State.Name",
Expected: "shutting-down",
},
{
State: "failure",
Matcher: "pathAny",
Argument: "Reservations[].Instances[].State.Name",
Expected: "terminated",
},
{
State: "failure",
Matcher: "pathAny",
Argument: "Reservations[].Instances[].State.Name",
Expected: "stopping",
},
},
}
w := waiter.Waiter{
Client: c,
Input: input,
Config: waiterCfg,
}
return w.Wait()
}
func (c *EC2) WaitUntilInstanceStatusOk(input *DescribeInstanceStatusInput) error {
waiterCfg := waiter.Config{
Operation: "DescribeInstanceStatus",
Delay: 15,
MaxAttempts: 40,
Acceptors: []waiter.WaitAcceptor{
{
State: "success",
Matcher: "pathAll",
Argument: "InstanceStatuses[].InstanceStatus.Status",
Expected: "ok",
},
},
}
w := waiter.Waiter{
Client: c,
Input: input,
Config: waiterCfg,
}
return w.Wait()
}
func (c *EC2) WaitUntilInstanceStopped(input *DescribeInstancesInput) error {
waiterCfg := waiter.Config{
Operation: "DescribeInstances",
Delay: 15,
MaxAttempts: 40,
Acceptors: []waiter.WaitAcceptor{
{
State: "success",
Matcher: "pathAll",
Argument: "Reservations[].Instances[].State.Name",
Expected: "stopped",
},
{
State: "failure",
Matcher: "pathAny",
Argument: "Reservations[].Instances[].State.Name",
Expected: "pending",
},
{
State: "failure",
Matcher: "pathAny",
Argument: "Reservations[].Instances[].State.Name",
Expected: "terminated",
},
},
}
w := waiter.Waiter{
Client: c,
Input: input,
Config: waiterCfg,
}
return w.Wait()
}
func (c *EC2) WaitUntilInstanceTerminated(input *DescribeInstancesInput) error {
waiterCfg := waiter.Config{
Operation: "DescribeInstances",
Delay: 15,
MaxAttempts: 40,
Acceptors: []waiter.WaitAcceptor{
{
State: "success",
Matcher: "pathAll",
Argument: "Reservations[].Instances[].State.Name",
Expected: "terminated",
},
{
State: "failure",
Matcher: "pathAny",
Argument: "Reservations[].Instances[].State.Name",
Expected: "pending",
},
{
State: "failure",
Matcher: "pathAny",
Argument: "Reservations[].Instances[].State.Name",
Expected: "stopping",
},
},
}
w := waiter.Waiter{
Client: c,
Input: input,
Config: waiterCfg,
}
return w.Wait()
}
func (c *EC2) WaitUntilKeyPairExists(input *DescribeKeyPairsInput) error {
waiterCfg := waiter.Config{
Operation: "DescribeKeyPairs",
Delay: 5,
MaxAttempts: 6,
Acceptors: []waiter.WaitAcceptor{
{
State: "success",
Matcher: "pathAll",
Argument: "length(KeyPairs[].KeyName) > `0`",
Expected: true,
},
{
State: "retry",
Matcher: "error",
Argument: "",
Expected: "InvalidKeyPairNotFound",
},
},
}
w := waiter.Waiter{
Client: c,
Input: input,
Config: waiterCfg,
}
return w.Wait()
}
func (c *EC2) WaitUntilNetworkInterfaceAvailable(input *DescribeNetworkInterfacesInput) error {
waiterCfg := waiter.Config{
Operation: "DescribeNetworkInterfaces",
Delay: 20,
MaxAttempts: 10,
Acceptors: []waiter.WaitAcceptor{
{
State: "success",
Matcher: "pathAll",
Argument: "NetworkInterfaces[].Status",
Expected: "available",
},
{
State: "failure",
Matcher: "error",
Argument: "",
Expected: "InvalidNetworkInterfaceIDNotFound",
},
},
}
w := waiter.Waiter{
Client: c,
Input: input,
Config: waiterCfg,
}
return w.Wait()
}
func (c *EC2) WaitUntilPasswordDataAvailable(input *GetPasswordDataInput) error {
waiterCfg := waiter.Config{
Operation: "GetPasswordData",
Delay: 15,
MaxAttempts: 40,
Acceptors: []waiter.WaitAcceptor{
{
State: "success",
Matcher: "path",
Argument: "length(PasswordData) > `0`",
Expected: true,
},
},
}
w := waiter.Waiter{
Client: c,
Input: input,
Config: waiterCfg,
}
return w.Wait()
}
func (c *EC2) WaitUntilSnapshotCompleted(input *DescribeSnapshotsInput) error {
waiterCfg := waiter.Config{
Operation: "DescribeSnapshots",
Delay: 15,
MaxAttempts: 40,
Acceptors: []waiter.WaitAcceptor{
{
State: "success",
Matcher: "pathAll",
Argument: "Snapshots[].State",
Expected: "completed",
},
},
}
w := waiter.Waiter{
Client: c,
Input: input,
Config: waiterCfg,
}
return w.Wait()
}
func (c *EC2) WaitUntilSpotInstanceRequestFulfilled(input *DescribeSpotInstanceRequestsInput) error {
waiterCfg := waiter.Config{
Operation: "DescribeSpotInstanceRequests",
Delay: 15,
MaxAttempts: 40,
Acceptors: []waiter.WaitAcceptor{
{
State: "success",
Matcher: "pathAll",
Argument: "SpotInstanceRequests[].Status.Code",
Expected: "fulfilled",
},
{
State: "failure",
Matcher: "pathAny",
Argument: "SpotInstanceRequests[].Status.Code",
Expected: "schedule-expired",
},
{
State: "failure",
Matcher: "pathAny",
Argument: "SpotInstanceRequests[].Status.Code",
Expected: "canceled-before-fulfillment",
},
{
State: "failure",
Matcher: "pathAny",
Argument: "SpotInstanceRequests[].Status.Code",
Expected: "bad-parameters",
},
{
State: "failure",
Matcher: "pathAny",
Argument: "SpotInstanceRequests[].Status.Code",
Expected: "system-error",
},
},
}
w := waiter.Waiter{
Client: c,
Input: input,
Config: waiterCfg,
}
return w.Wait()
}
func (c *EC2) WaitUntilSubnetAvailable(input *DescribeSubnetsInput) error {
waiterCfg := waiter.Config{
Operation: "DescribeSubnets",
Delay: 15,
MaxAttempts: 40,
Acceptors: []waiter.WaitAcceptor{
{
State: "success",
Matcher: "pathAll",
Argument: "Subnets[].State",
Expected: "available",
},
},
}
w := waiter.Waiter{
Client: c,
Input: input,
Config: waiterCfg,
}
return w.Wait()
}
func (c *EC2) WaitUntilSystemStatusOk(input *DescribeInstanceStatusInput) error {
waiterCfg := waiter.Config{
Operation: "DescribeInstanceStatus",
Delay: 15,
MaxAttempts: 40,
Acceptors: []waiter.WaitAcceptor{
{
State: "success",
Matcher: "pathAll",
Argument: "InstanceStatuses[].SystemStatus.Status",
Expected: "ok",
},
},
}
w := waiter.Waiter{
Client: c,
Input: input,
Config: waiterCfg,
}
return w.Wait()
}
func (c *EC2) WaitUntilVolumeAvailable(input *DescribeVolumesInput) error {
waiterCfg := waiter.Config{
Operation: "DescribeVolumes",
Delay: 15,
MaxAttempts: 40,
Acceptors: []waiter.WaitAcceptor{
{
State: "success",
Matcher: "pathAll",
Argument: "Volumes[].State",
Expected: "available",
},
{
State: "failure",
Matcher: "pathAny",
Argument: "Volumes[].State",
Expected: "deleted",
},
},
}
w := waiter.Waiter{
Client: c,
Input: input,
Config: waiterCfg,
}
return w.Wait()
}
func (c *EC2) WaitUntilVolumeDeleted(input *DescribeVolumesInput) error {
waiterCfg := waiter.Config{
Operation: "DescribeVolumes",
Delay: 15,
MaxAttempts: 40,
Acceptors: []waiter.WaitAcceptor{
{
State: "success",
Matcher: "pathAll",
Argument: "Volumes[].State",
Expected: "deleted",
},
{
State: "success",
Matcher: "error",
Argument: "",
Expected: "InvalidVolumeNotFound",
},
},
}
w := waiter.Waiter{
Client: c,
Input: input,
Config: waiterCfg,
}
return w.Wait()
}
func (c *EC2) WaitUntilVolumeInUse(input *DescribeVolumesInput) error {
waiterCfg := waiter.Config{
Operation: "DescribeVolumes",
Delay: 15,
MaxAttempts: 40,
Acceptors: []waiter.WaitAcceptor{
{
State: "success",
Matcher: "pathAll",
Argument: "Volumes[].State",
Expected: "in-use",
},
{
State: "failure",
Matcher: "pathAny",
Argument: "Volumes[].State",
Expected: "deleted",
},
},
}
w := waiter.Waiter{
Client: c,
Input: input,
Config: waiterCfg,
}
return w.Wait()
}
func (c *EC2) WaitUntilVpcAvailable(input *DescribeVpcsInput) error {
waiterCfg := waiter.Config{
Operation: "DescribeVpcs",
Delay: 15,
MaxAttempts: 40,
Acceptors: []waiter.WaitAcceptor{
{
State: "success",
Matcher: "pathAll",
Argument: "Vpcs[].State",
Expected: "available",
},
},
}
w := waiter.Waiter{
Client: c,
Input: input,
Config: waiterCfg,
}
return w.Wait()
}
func (c *EC2) WaitUntilVpnConnectionAvailable(input *DescribeVpnConnectionsInput) error {
waiterCfg := waiter.Config{
Operation: "DescribeVpnConnections",
Delay: 15,
MaxAttempts: 40,
Acceptors: []waiter.WaitAcceptor{
{
State: "success",
Matcher: "pathAll",
Argument: "VpnConnections[].State",
Expected: "available",
},
{
State: "failure",
Matcher: "pathAny",
Argument: "VpnConnections[].State",
Expected: "deleting",
},
{
State: "failure",
Matcher: "pathAny",
Argument: "VpnConnections[].State",
Expected: "deleted",
},
},
}
w := waiter.Waiter{
Client: c,
Input: input,
Config: waiterCfg,
}
return w.Wait()
}
func (c *EC2) WaitUntilVpnConnectionDeleted(input *DescribeVpnConnectionsInput) error {
waiterCfg := waiter.Config{
Operation: "DescribeVpnConnections",
Delay: 15,
MaxAttempts: 40,
Acceptors: []waiter.WaitAcceptor{
{
State: "success",
Matcher: "pathAll",
Argument: "VpnConnections[].State",
Expected: "deleted",
},
{
State: "failure",
Matcher: "pathAny",
Argument: "VpnConnections[].State",
Expected: "pending",
},
},
}
w := waiter.Waiter{
Client: c,
Input: input,
Config: waiterCfg,
}
return w.Wait()
}

View File

@@ -0,0 +1,4 @@
testdata/conf_out.ini
ini.sublime-project
ini.sublime-workspace
testdata/conf_reflect.ini

191
Godeps/_workspace/src/github.com/go-ini/ini/LICENSE generated vendored Normal file
View File

@@ -0,0 +1,191 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction, and
distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by the copyright
owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all other entities
that control, are controlled by, or are under common control with that entity.
For the purposes of this definition, "control" means (i) the power, direct or
indirect, to cause the direction or management of such entity, whether by
contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity exercising
permissions granted by this License.
"Source" form shall mean the preferred form for making modifications, including
but not limited to software source code, documentation source, and configuration
files.
"Object" form shall mean any form resulting from mechanical transformation or
translation of a Source form, including but not limited to compiled object code,
generated documentation, and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or Object form, made
available under the License, as indicated by a copyright notice that is included
in or attached to the work (an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object form, that
is based on (or derived from) the Work and for which the editorial revisions,
annotations, elaborations, or other modifications represent, as a whole, an
original work of authorship. For the purposes of this License, Derivative Works
shall not include works that remain separable from, or merely link (or bind by
name) to the interfaces of, the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including the original version
of the Work and any modifications or additions to that Work or Derivative Works
thereof, that is intentionally submitted to Licensor for inclusion in the Work
by the copyright owner or by an individual or Legal Entity authorized to submit
on behalf of the copyright owner. For the purposes of this definition,
"submitted" means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems, and
issue tracking systems that are managed by, or on behalf of, the Licensor for
the purpose of discussing and improving the Work, but excluding communication
that is conspicuously marked or otherwise designated in writing by the copyright
owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf
of whom a Contribution has been received by Licensor and subsequently
incorporated within the Work.
2. Grant of Copyright License.
Subject to the terms and conditions of this License, each Contributor hereby
grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
irrevocable copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the Work and such
Derivative Works in Source or Object form.
3. Grant of Patent License.
Subject to the terms and conditions of this License, each Contributor hereby
grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
irrevocable (except as stated in this section) patent license to make, have
made, use, offer to sell, sell, import, and otherwise transfer the Work, where
such license applies only to those patent claims licensable by such Contributor
that are necessarily infringed by their Contribution(s) alone or by combination
of their Contribution(s) with the Work to which such Contribution(s) was
submitted. If You institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work or a
Contribution incorporated within the Work constitutes direct or contributory
patent infringement, then any patent licenses granted to You under this License
for that Work shall terminate as of the date such litigation is filed.
4. Redistribution.
You may reproduce and distribute copies of the Work or Derivative Works thereof
in any medium, with or without modifications, and in Source or Object form,
provided that You meet the following conditions:
You must give any other recipients of the Work or Derivative Works a copy of
this License; and
You must cause any modified files to carry prominent notices stating that You
changed the files; and
You must retain, in the Source form of any Derivative Works that You distribute,
all copyright, patent, trademark, and attribution notices from the Source form
of the Work, excluding those notices that do not pertain to any part of the
Derivative Works; and
If the Work includes a "NOTICE" text file as part of its distribution, then any
Derivative Works that You distribute must include a readable copy of the
attribution notices contained within such NOTICE file, excluding those notices
that do not pertain to any part of the Derivative Works, in at least one of the
following places: within a NOTICE text file distributed as part of the
Derivative Works; within the Source form or documentation, if provided along
with the Derivative Works; or, within a display generated by the Derivative
Works, if and wherever such third-party notices normally appear. The contents of
the NOTICE file are for informational purposes only and do not modify the
License. You may add Your own attribution notices within Derivative Works that
You distribute, alongside or as an addendum to the NOTICE text from the Work,
provided that such additional attribution notices cannot be construed as
modifying the License.
You may add Your own copyright statement to Your modifications and may provide
additional or different license terms and conditions for use, reproduction, or
distribution of Your modifications, or for any such Derivative Works as a whole,
provided Your use, reproduction, and distribution of the Work otherwise complies
with the conditions stated in this License.
5. Submission of Contributions.
Unless You explicitly state otherwise, any Contribution intentionally submitted
for inclusion in the Work by You to the Licensor shall be under the terms and
conditions of this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify the terms of
any separate license agreement you may have executed with Licensor regarding
such Contributions.
6. Trademarks.
This License does not grant permission to use the trade names, trademarks,
service marks, or product names of the Licensor, except as required for
reasonable and customary use in describing the origin of the Work and
reproducing the content of the NOTICE file.
7. Disclaimer of Warranty.
Unless required by applicable law or agreed to in writing, Licensor provides the
Work (and each Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied,
including, without limitation, any warranties or conditions of TITLE,
NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are
solely responsible for determining the appropriateness of using or
redistributing the Work and assume any risks associated with Your exercise of
permissions under this License.
8. Limitation of Liability.
In no event and under no legal theory, whether in tort (including negligence),
contract, or otherwise, unless required by applicable law (such as deliberate
and grossly negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special, incidental,
or consequential damages of any character arising as a result of this License or
out of the use or inability to use the Work (including but not limited to
damages for loss of goodwill, work stoppage, computer failure or malfunction, or
any and all other commercial damages or losses), even if such Contributor has
been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability.
While redistributing the Work or Derivative Works thereof, You may choose to
offer, and charge a fee for, acceptance of support, warranty, indemnity, or
other liability obligations and/or rights consistent with this License. However,
in accepting such obligations, You may act only on Your own behalf and on Your
sole responsibility, not on behalf of any other Contributor, and only if You
agree to indemnify, defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason of your
accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work
To apply the Apache License to your work, attach the following boilerplate
notice, with the fields enclosed by brackets "[]" replaced with your own
identifying information. (Don't include the brackets!) The text should be
enclosed in the appropriate comment syntax for the file format. We also
recommend that a file or class name and description of purpose be included on
the same "printed page" as the copyright notice for easier identification within
third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

560
Godeps/_workspace/src/github.com/go-ini/ini/README.md generated vendored Normal file
View File

@@ -0,0 +1,560 @@
ini [![Build Status](https://drone.io/github.com/go-ini/ini/status.png)](https://drone.io/github.com/go-ini/ini/latest) [![](http://gocover.io/_badge/github.com/go-ini/ini)](http://gocover.io/github.com/go-ini/ini)
===
![](https://avatars0.githubusercontent.com/u/10216035?v=3&s=200)
Package ini provides INI file read and write functionality in Go.
[简体中文](README_ZH.md)
## Feature
- Load multiple data sources(`[]byte` or file) with overwrites.
- Read with recursion values.
- Read with parent-child sections.
- Read with auto-increment key names.
- Read with multiple-line values.
- Read with tons of helper methods.
- Read and convert values to Go types.
- Read and **WRITE** comments of sections and keys.
- Manipulate sections, keys and comments with ease.
- Keep sections and keys in order as you parse and save.
## Installation
go get gopkg.in/ini.v1
## Getting Started
### Loading from data sources
A **Data Source** is either raw data in type `[]byte` or a file name with type `string` and you can load **as many as** data sources you want. Passing other types will simply return an error.
```go
cfg, err := ini.Load([]byte("raw data"), "filename")
```
Or start with an empty object:
```go
cfg := ini.Empty()
```
When you cannot decide how many data sources to load at the beginning, you still able to **Append()** them later.
```go
err := cfg.Append("other file", []byte("other raw data"))
```
### Working with sections
To get a section, you would need to:
```go
section, err := cfg.GetSection("section name")
```
For a shortcut for default section, just give an empty string as name:
```go
section, err := cfg.GetSection("")
```
When you're pretty sure the section exists, following code could make your life easier:
```go
section := cfg.Section("")
```
What happens when the section somehow does not exist? Don't panic, it automatically creates and returns a new section to you.
To create a new section:
```go
err := cfg.NewSection("new section")
```
To get a list of sections or section names:
```go
sections := cfg.Sections()
names := cfg.SectionStrings()
```
### Working with keys
To get a key under a section:
```go
key, err := cfg.Section("").GetKey("key name")
```
Same rule applies to key operations:
```go
key := cfg.Section("").Key("key name")
```
To create a new key:
```go
err := cfg.Section("").NewKey("name", "value")
```
To get a list of keys or key names:
```go
keys := cfg.Section("").Keys()
names := cfg.Section("").KeyStrings()
```
To get a clone hash of keys and corresponding values:
```go
hash := cfg.GetSection("").KeysHash()
```
### Working with values
To get a string value:
```go
val := cfg.Section("").Key("key name").String()
```
To validate key value on the fly:
```go
val := cfg.Section("").Key("key name").Validate(func(in string) string {
if len(in) == 0 {
return "default"
}
return in
})
```
To get value with types:
```go
// For boolean values:
// true when value is: 1, t, T, TRUE, true, True, YES, yes, Yes, ON, on, On
// false when value is: 0, f, F, FALSE, false, False, NO, no, No, OFF, off, Off
v, err = cfg.Section("").Key("BOOL").Bool()
v, err = cfg.Section("").Key("FLOAT64").Float64()
v, err = cfg.Section("").Key("INT").Int()
v, err = cfg.Section("").Key("INT64").Int64()
v, err = cfg.Section("").Key("UINT").Uint()
v, err = cfg.Section("").Key("UINT64").Uint64()
v, err = cfg.Section("").Key("TIME").TimeFormat(time.RFC3339)
v, err = cfg.Section("").Key("TIME").Time() // RFC3339
v = cfg.Section("").Key("BOOL").MustBool()
v = cfg.Section("").Key("FLOAT64").MustFloat64()
v = cfg.Section("").Key("INT").MustInt()
v = cfg.Section("").Key("INT64").MustInt64()
v = cfg.Section("").Key("UINT").MustUint()
v = cfg.Section("").Key("UINT64").MustUint64()
v = cfg.Section("").Key("TIME").MustTimeFormat(time.RFC3339)
v = cfg.Section("").Key("TIME").MustTime() // RFC3339
// Methods start with Must also accept one argument for default value
// when key not found or fail to parse value to given type.
// Except method MustString, which you have to pass a default value.
v = cfg.Section("").Key("String").MustString("default")
v = cfg.Section("").Key("BOOL").MustBool(true)
v = cfg.Section("").Key("FLOAT64").MustFloat64(1.25)
v = cfg.Section("").Key("INT").MustInt(10)
v = cfg.Section("").Key("INT64").MustInt64(99)
v = cfg.Section("").Key("UINT").MustUint(3)
v = cfg.Section("").Key("UINT64").MustUint64(6)
v = cfg.Section("").Key("TIME").MustTimeFormat(time.RFC3339, time.Now())
v = cfg.Section("").Key("TIME").MustTime(time.Now()) // RFC3339
```
What if my value is three-line long?
```ini
[advance]
ADDRESS = """404 road,
NotFound, State, 5000
Earth"""
```
Not a problem!
```go
cfg.Section("advance").Key("ADDRESS").String()
/* --- start ---
404 road,
NotFound, State, 5000
Earth
------ end --- */
```
That's cool, how about continuation lines?
```ini
[advance]
two_lines = how about \
continuation lines?
lots_of_lines = 1 \
2 \
3 \
4
```
Piece of cake!
```go
cfg.Section("advance").Key("two_lines").String() // how about continuation lines?
cfg.Section("advance").Key("lots_of_lines").String() // 1 2 3 4
```
Note that single quotes around values will be stripped:
```ini
foo = "some value" // foo: some value
bar = 'some value' // bar: some value
```
That's all? Hmm, no.
#### Helper methods of working with values
To get value with given candidates:
```go
v = cfg.Section("").Key("STRING").In("default", []string{"str", "arr", "types"})
v = cfg.Section("").Key("FLOAT64").InFloat64(1.1, []float64{1.25, 2.5, 3.75})
v = cfg.Section("").Key("INT").InInt(5, []int{10, 20, 30})
v = cfg.Section("").Key("INT64").InInt64(10, []int64{10, 20, 30})
v = cfg.Section("").Key("UINT").InUint(4, []int{3, 6, 9})
v = cfg.Section("").Key("UINT64").InUint64(8, []int64{3, 6, 9})
v = cfg.Section("").Key("TIME").InTimeFormat(time.RFC3339, time.Now(), []time.Time{time1, time2, time3})
v = cfg.Section("").Key("TIME").InTime(time.Now(), []time.Time{time1, time2, time3}) // RFC3339
```
Default value will be presented if value of key is not in candidates you given, and default value does not need be one of candidates.
To validate value in a given range:
```go
vals = cfg.Section("").Key("FLOAT64").RangeFloat64(0.0, 1.1, 2.2)
vals = cfg.Section("").Key("INT").RangeInt(0, 10, 20)
vals = cfg.Section("").Key("INT64").RangeInt64(0, 10, 20)
vals = cfg.Section("").Key("UINT").RangeUint(0, 3, 9)
vals = cfg.Section("").Key("UINT64").RangeUint64(0, 3, 9)
vals = cfg.Section("").Key("TIME").RangeTimeFormat(time.RFC3339, time.Now(), minTime, maxTime)
vals = cfg.Section("").Key("TIME").RangeTime(time.Now(), minTime, maxTime) // RFC3339
```
To auto-split value into slice:
```go
vals = cfg.Section("").Key("STRINGS").Strings(",")
vals = cfg.Section("").Key("FLOAT64S").Float64s(",")
vals = cfg.Section("").Key("INTS").Ints(",")
vals = cfg.Section("").Key("INT64S").Int64s(",")
vals = cfg.Section("").Key("UINTS").Uints(",")
vals = cfg.Section("").Key("UINT64S").Uint64s(",")
vals = cfg.Section("").Key("TIMES").Times(",")
```
### Save your configuration
Finally, it's time to save your configuration to somewhere.
A typical way to save configuration is writing it to a file:
```go
// ...
err = cfg.SaveTo("my.ini")
err = cfg.SaveToIndent("my.ini", "\t")
```
Another way to save is writing to a `io.Writer` interface:
```go
// ...
cfg.WriteTo(writer)
cfg.WriteToIndent(writer, "\t")
```
## Advanced Usage
### Recursive Values
For all value of keys, there is a special syntax `%(<name>)s`, where `<name>` is the key name in same section or default section, and `%(<name>)s` will be replaced by corresponding value(empty string if key not found). You can use this syntax at most 99 level of recursions.
```ini
NAME = ini
[author]
NAME = Unknwon
GITHUB = https://github.com/%(NAME)s
[package]
FULL_NAME = github.com/go-ini/%(NAME)s
```
```go
cfg.Section("author").Key("GITHUB").String() // https://github.com/Unknwon
cfg.Section("package").Key("FULL_NAME").String() // github.com/go-ini/ini
```
### Parent-child Sections
You can use `.` in section name to indicate parent-child relationship between two or more sections. If the key not found in the child section, library will try again on its parent section until there is no parent section.
```ini
NAME = ini
VERSION = v1
IMPORT_PATH = gopkg.in/%(NAME)s.%(VERSION)s
[package]
CLONE_URL = https://%(IMPORT_PATH)s
[package.sub]
```
```go
cfg.Section("package.sub").Key("CLONE_URL").String() // https://gopkg.in/ini.v1
```
### Auto-increment Key Names
If key name is `-` in data source, then it would be seen as special syntax for auto-increment key name start from 1, and every section is independent on counter.
```ini
[features]
-: Support read/write comments of keys and sections
-: Support auto-increment of key names
-: Support load multiple files to overwrite key values
```
```go
cfg.Section("features").KeyStrings() // []{"#1", "#2", "#3"}
```
### Map To Struct
Want more objective way to play with INI? Cool.
```ini
Name = Unknwon
age = 21
Male = true
Born = 1993-01-01T20:17:05Z
[Note]
Content = Hi is a good man!
Cities = HangZhou, Boston
```
```go
type Note struct {
Content string
Cities []string
}
type Person struct {
Name string
Age int `ini:"age"`
Male bool
Born time.Time
Note
Created time.Time `ini:"-"`
}
func main() {
cfg, err := ini.Load("path/to/ini")
// ...
p := new(Person)
err = cfg.MapTo(p)
// ...
// Things can be simpler.
err = ini.MapTo(p, "path/to/ini")
// ...
// Just map a section? Fine.
n := new(Note)
err = cfg.Section("Note").MapTo(n)
// ...
}
```
Can I have default value for field? Absolutely.
Assign it before you map to struct. It will keep the value as it is if the key is not presented or got wrong type.
```go
// ...
p := &Person{
Name: "Joe",
}
// ...
```
It's really cool, but what's the point if you can't give me my file back from struct?
### Reflect From Struct
Why not?
```go
type Embeded struct {
Dates []time.Time `delim:"|"`
Places []string
None []int
}
type Author struct {
Name string `ini:"NAME"`
Male bool
Age int
GPA float64
NeverMind string `ini:"-"`
*Embeded
}
func main() {
a := &Author{"Unknwon", true, 21, 2.8, "",
&Embeded{
[]time.Time{time.Now(), time.Now()},
[]string{"HangZhou", "Boston"},
[]int{},
}}
cfg := ini.Empty()
err = ini.ReflectFrom(cfg, a)
// ...
}
```
So, what do I get?
```ini
NAME = Unknwon
Male = true
Age = 21
GPA = 2.8
[Embeded]
Dates = 2015-08-07T22:14:22+08:00|2015-08-07T22:14:22+08:00
Places = HangZhou,Boston
None =
```
#### Name Mapper
To save your time and make your code cleaner, this library supports [`NameMapper`](https://gowalker.org/gopkg.in/ini.v1#NameMapper) between struct field and actual section and key name.
There are 2 built-in name mappers:
- `AllCapsUnderscore`: it converts to format `ALL_CAPS_UNDERSCORE` then match section or key.
- `TitleUnderscore`: it converts to format `title_underscore` then match section or key.
To use them:
```go
type Info struct {
PackageName string
}
func main() {
err = ini.MapToWithMapper(&Info{}, ini.TitleUnderscore, []byte("packag_name=ini"))
// ...
cfg, err := ini.Load([]byte("PACKAGE_NAME=ini"))
// ...
info := new(Info)
cfg.NameMapper = ini.AllCapsUnderscore
err = cfg.MapTo(info)
// ...
}
```
Same rules of name mapper apply to `ini.ReflectFromWithMapper` function.
#### Other Notes On Map/Reflect
Any embedded struct is treated as a section by default, and there is no automatic parent-child relations in map/reflect feature:
```go
type Child struct {
Age string
}
type Parent struct {
Name string
Child
}
type Config struct {
City string
Parent
}
```
Example configuration:
```ini
City = Boston
[Parent]
Name = Unknwon
[Child]
Age = 21
```
What if, yes, I'm paranoid, I want embedded struct to be in the same section. Well, all roads lead to Rome.
```go
type Child struct {
Age string
}
type Parent struct {
Name string
Child `ini:"Parent"`
}
type Config struct {
City string
Parent
}
```
Example configuration:
```ini
City = Boston
[Parent]
Name = Unknwon
Age = 21
```
## Getting Help
- [API Documentation](https://gowalker.org/gopkg.in/ini.v1)
- [File An Issue](https://github.com/go-ini/ini/issues/new)
## FAQs
### What does `BlockMode` field do?
By default, library lets you read and write values so we need a locker to make sure your data is safe. But in cases that you are very sure about only reading data through the library, you can set `cfg.BlockMode = false` to speed up read operations about **50-70%** faster.
### Why another INI library?
Many people are using my another INI library [goconfig](https://github.com/Unknwon/goconfig), so the reason for this one is I would like to make more Go style code. Also when you set `cfg.BlockMode = false`, this one is about **10-30%** faster.
To make those changes I have to confirm API broken, so it's safer to keep it in another place and start using `gopkg.in` to version my package at this time.(PS: shorter import path)
## License
This project is under Apache v2 License. See the [LICENSE](LICENSE) file for the full license text.

View File

@@ -0,0 +1,547 @@
本包提供了 Go 语言中读写 INI 文件的功能。
## 功能特性
- 支持覆盖加载多个数据源(`[]byte` 或文件)
- 支持递归读取键值
- 支持读取父子分区
- 支持读取自增键名
- 支持读取多行的键值
- 支持大量辅助方法
- 支持在读取时直接转换为 Go 语言类型
- 支持读取和 **写入** 分区和键的注释
- 轻松操作分区、键值和注释
- 在保存文件时分区和键值会保持原有的顺序
## 下载安装
go get gopkg.in/ini.v1
## 开始使用
### 从数据源加载
一个 **数据源** 可以是 `[]byte` 类型的原始数据,或 `string` 类型的文件路径。您可以加载 **任意多个** 数据源。如果您传递其它类型的数据源,则会直接返回错误。
```go
cfg, err := ini.Load([]byte("raw data"), "filename")
```
或者从一个空白的文件开始:
```go
cfg := ini.Empty()
```
当您在一开始无法决定需要加载哪些数据源时,仍可以使用 **Append()** 在需要的时候加载它们。
```go
err := cfg.Append("other file", []byte("other raw data"))
```
### 操作分区Section
获取指定分区:
```go
section, err := cfg.GetSection("section name")
```
如果您想要获取默认分区,则可以用空字符串代替分区名:
```go
section, err := cfg.GetSection("")
```
当您非常确定某个分区是存在的,可以使用以下简便方法:
```go
section := cfg.Section("")
```
如果不小心判断错了,要获取的分区其实是不存在的,那会发生什么呢?没事的,它会自动创建并返回一个对应的分区对象给您。
创建一个分区:
```go
err := cfg.NewSection("new section")
```
获取所有分区对象或名称:
```go
sections := cfg.Sections()
names := cfg.SectionStrings()
```
### 操作键Key
获取某个分区下的键:
```go
key, err := cfg.Section("").GetKey("key name")
```
和分区一样,您也可以直接获取键而忽略错误处理:
```go
key := cfg.Section("").Key("key name")
```
创建一个新的键:
```go
err := cfg.Section("").NewKey("name", "value")
```
获取分区下的所有键或键名:
```go
keys := cfg.Section("").Keys()
names := cfg.Section("").KeyStrings()
```
获取分区下的所有键值对的克隆:
```go
hash := cfg.GetSection("").KeysHash()
```
### 操作键值Value
获取一个类型为字符串string的值
```go
val := cfg.Section("").Key("key name").String()
```
获取值的同时通过自定义函数进行处理验证:
```go
val := cfg.Section("").Key("key name").Validate(func(in string) string {
if len(in) == 0 {
return "default"
}
return in
})
```
获取其它类型的值:
```go
// 布尔值的规则:
// true 当值为1, t, T, TRUE, true, True, YES, yes, Yes, ON, on, On
// false 当值为0, f, F, FALSE, false, False, NO, no, No, OFF, off, Off
v, err = cfg.Section("").Key("BOOL").Bool()
v, err = cfg.Section("").Key("FLOAT64").Float64()
v, err = cfg.Section("").Key("INT").Int()
v, err = cfg.Section("").Key("INT64").Int64()
v, err = cfg.Section("").Key("UINT").Uint()
v, err = cfg.Section("").Key("UINT64").Uint64()
v, err = cfg.Section("").Key("TIME").TimeFormat(time.RFC3339)
v, err = cfg.Section("").Key("TIME").Time() // RFC3339
v = cfg.Section("").Key("BOOL").MustBool()
v = cfg.Section("").Key("FLOAT64").MustFloat64()
v = cfg.Section("").Key("INT").MustInt()
v = cfg.Section("").Key("INT64").MustInt64()
v = cfg.Section("").Key("UINT").MustUint()
v = cfg.Section("").Key("UINT64").MustUint64()
v = cfg.Section("").Key("TIME").MustTimeFormat(time.RFC3339)
v = cfg.Section("").Key("TIME").MustTime() // RFC3339
// 由 Must 开头的方法名允许接收一个相同类型的参数来作为默认值,
// 当键不存在或者转换失败时,则会直接返回该默认值。
// 但是MustString 方法必须传递一个默认值。
v = cfg.Seciont("").Key("String").MustString("default")
v = cfg.Section("").Key("BOOL").MustBool(true)
v = cfg.Section("").Key("FLOAT64").MustFloat64(1.25)
v = cfg.Section("").Key("INT").MustInt(10)
v = cfg.Section("").Key("INT64").MustInt64(99)
v = cfg.Section("").Key("UINT").MustUint(3)
v = cfg.Section("").Key("UINT64").MustUint64(6)
v = cfg.Section("").Key("TIME").MustTimeFormat(time.RFC3339, time.Now())
v = cfg.Section("").Key("TIME").MustTime(time.Now()) // RFC3339
```
如果我的值有好多行怎么办?
```ini
[advance]
ADDRESS = """404 road,
NotFound, State, 5000
Earth"""
```
嗯哼?小 case
```go
cfg.Section("advance").Key("ADDRESS").String()
/* --- start ---
404 road,
NotFound, State, 5000
Earth
------ end --- */
```
赞爆了!那要是我属于一行的内容写不下想要写到第二行怎么办?
```ini
[advance]
two_lines = how about \
continuation lines?
lots_of_lines = 1 \
2 \
3 \
4
```
简直是小菜一碟!
```go
cfg.Section("advance").Key("two_lines").String() // how about continuation lines?
cfg.Section("advance").Key("lots_of_lines").String() // 1 2 3 4
```
需要注意的是,值两侧的单引号会被自动剔除:
```ini
foo = "some value" // foo: some value
bar = 'some value' // bar: some value
```
这就是全部了?哈哈,当然不是。
#### 操作键值的辅助方法
获取键值时设定候选值:
```go
v = cfg.Section("").Key("STRING").In("default", []string{"str", "arr", "types"})
v = cfg.Section("").Key("FLOAT64").InFloat64(1.1, []float64{1.25, 2.5, 3.75})
v = cfg.Section("").Key("INT").InInt(5, []int{10, 20, 30})
v = cfg.Section("").Key("INT64").InInt64(10, []int64{10, 20, 30})
v = cfg.Section("").Key("UINT").InUint(4, []int{3, 6, 9})
v = cfg.Section("").Key("UINT64").InUint64(8, []int64{3, 6, 9})
v = cfg.Section("").Key("TIME").InTimeFormat(time.RFC3339, time.Now(), []time.Time{time1, time2, time3})
v = cfg.Section("").Key("TIME").InTime(time.Now(), []time.Time{time1, time2, time3}) // RFC3339
```
如果获取到的值不是候选值的任意一个,则会返回默认值,而默认值不需要是候选值中的一员。
验证获取的值是否在指定范围内:
```go
vals = cfg.Section("").Key("FLOAT64").RangeFloat64(0.0, 1.1, 2.2)
vals = cfg.Section("").Key("INT").RangeInt(0, 10, 20)
vals = cfg.Section("").Key("INT64").RangeInt64(0, 10, 20)
vals = cfg.Section("").Key("UINT").RangeUint(0, 3, 9)
vals = cfg.Section("").Key("UINT64").RangeUint64(0, 3, 9)
vals = cfg.Section("").Key("TIME").RangeTimeFormat(time.RFC3339, time.Now(), minTime, maxTime)
vals = cfg.Section("").Key("TIME").RangeTime(time.Now(), minTime, maxTime) // RFC3339
```
自动分割键值为切片slice
```go
vals = cfg.Section("").Key("STRINGS").Strings(",")
vals = cfg.Section("").Key("FLOAT64S").Float64s(",")
vals = cfg.Section("").Key("INTS").Ints(",")
vals = cfg.Section("").Key("INT64S").Int64s(",")
vals = cfg.Section("").Key("UINTS").Uints(",")
vals = cfg.Section("").Key("UINT64S").Uint64s(",")
vals = cfg.Section("").Key("TIMES").Times(",")
```
### 保存配置
终于到了这个时刻,是时候保存一下配置了。
比较原始的做法是输出配置到某个文件:
```go
// ...
err = cfg.SaveTo("my.ini")
err = cfg.SaveToIndent("my.ini", "\t")
```
另一个比较高级的做法是写入到任何实现 `io.Writer` 接口的对象中:
```go
// ...
cfg.WriteTo(writer)
cfg.WriteToIndent(writer, "\t")
```
### 高级用法
#### 递归读取键值
在获取所有键值的过程中,特殊语法 `%(<name>)s` 会被应用,其中 `<name>` 可以是相同分区或者默认分区下的键名。字符串 `%(<name>)s` 会被相应的键值所替代,如果指定的键不存在,则会用空字符串替代。您可以最多使用 99 层的递归嵌套。
```ini
NAME = ini
[author]
NAME = Unknwon
GITHUB = https://github.com/%(NAME)s
[package]
FULL_NAME = github.com/go-ini/%(NAME)s
```
```go
cfg.Section("author").Key("GITHUB").String() // https://github.com/Unknwon
cfg.Section("package").Key("FULL_NAME").String() // github.com/go-ini/ini
```
#### 读取父子分区
您可以在分区名称中使用 `.` 来表示两个或多个分区之间的父子关系。如果某个键在子分区中不存在,则会去它的父分区中再次寻找,直到没有父分区为止。
```ini
NAME = ini
VERSION = v1
IMPORT_PATH = gopkg.in/%(NAME)s.%(VERSION)s
[package]
CLONE_URL = https://%(IMPORT_PATH)s
[package.sub]
```
```go
cfg.Section("package.sub").Key("CLONE_URL").String() // https://gopkg.in/ini.v1
```
#### 读取自增键名
如果数据源中的键名为 `-`,则认为该键使用了自增键名的特殊语法。计数器从 1 开始,并且分区之间是相互独立的。
```ini
[features]
-: Support read/write comments of keys and sections
-: Support auto-increment of key names
-: Support load multiple files to overwrite key values
```
```go
cfg.Section("features").KeyStrings() // []{"#1", "#2", "#3"}
```
### 映射到结构
想要使用更加面向对象的方式玩转 INI 吗?好主意。
```ini
Name = Unknwon
age = 21
Male = true
Born = 1993-01-01T20:17:05Z
[Note]
Content = Hi is a good man!
Cities = HangZhou, Boston
```
```go
type Note struct {
Content string
Cities []string
}
type Person struct {
Name string
Age int `ini:"age"`
Male bool
Born time.Time
Note
Created time.Time `ini:"-"`
}
func main() {
cfg, err := ini.Load("path/to/ini")
// ...
p := new(Person)
err = cfg.MapTo(p)
// ...
// 一切竟可以如此的简单。
err = ini.MapTo(p, "path/to/ini")
// ...
// 嗯哼?只需要映射一个分区吗?
n := new(Note)
err = cfg.Section("Note").MapTo(n)
// ...
}
```
结构的字段怎么设置默认值呢?很简单,只要在映射之前对指定字段进行赋值就可以了。如果键未找到或者类型错误,该值不会发生改变。
```go
// ...
p := &Person{
Name: "Joe",
}
// ...
```
这样玩 INI 真的好酷啊!然而,如果不能还给我原来的配置文件,有什么卵用?
### 从结构反射
可是,我有说不能吗?
```go
type Embeded struct {
Dates []time.Time `delim:"|"`
Places []string
None []int
}
type Author struct {
Name string `ini:"NAME"`
Male bool
Age int
GPA float64
NeverMind string `ini:"-"`
*Embeded
}
func main() {
a := &Author{"Unknwon", true, 21, 2.8, "",
&Embeded{
[]time.Time{time.Now(), time.Now()},
[]string{"HangZhou", "Boston"},
[]int{},
}}
cfg := ini.Empty()
err = ini.ReflectFrom(cfg, a)
// ...
}
```
瞧瞧,奇迹发生了。
```ini
NAME = Unknwon
Male = true
Age = 21
GPA = 2.8
[Embeded]
Dates = 2015-08-07T22:14:22+08:00|2015-08-07T22:14:22+08:00
Places = HangZhou,Boston
None =
```
#### 名称映射器Name Mapper
为了节省您的时间并简化代码,本库支持类型为 [`NameMapper`](https://gowalker.org/gopkg.in/ini.v1#NameMapper) 的名称映射器,该映射器负责结构字段名与分区名和键名之间的映射。
目前有 2 款内置的映射器:
- `AllCapsUnderscore`:该映射器将字段名转换至格式 `ALL_CAPS_UNDERSCORE` 后再去匹配分区名和键名。
- `TitleUnderscore`:该映射器将字段名转换至格式 `title_underscore` 后再去匹配分区名和键名。
使用方法:
```go
type Info struct{
PackageName string
}
func main() {
err = ini.MapToWithMapper(&Info{}, ini.TitleUnderscore, []byte("packag_name=ini"))
// ...
cfg, err := ini.Load([]byte("PACKAGE_NAME=ini"))
// ...
info := new(Info)
cfg.NameMapper = ini.AllCapsUnderscore
err = cfg.MapTo(info)
// ...
}
```
使用函数 `ini.ReflectFromWithMapper` 时也可应用相同的规则。
#### 映射/反射的其它说明
任何嵌入的结构都会被默认认作一个不同的分区,并且不会自动产生所谓的父子分区关联:
```go
type Child struct {
Age string
}
type Parent struct {
Name string
Child
}
type Config struct {
City string
Parent
}
```
示例配置文件:
```ini
City = Boston
[Parent]
Name = Unknwon
[Child]
Age = 21
```
很好,但是,我就是要嵌入结构也在同一个分区。好吧,你爹是李刚!
```go
type Child struct {
Age string
}
type Parent struct {
Name string
Child `ini:"Parent"`
}
type Config struct {
City string
Parent
}
```
示例配置文件:
```ini
City = Boston
[Parent]
Name = Unknwon
Age = 21
```
## 获取帮助
- [API 文档](https://gowalker.org/gopkg.in/ini.v1)
- [创建工单](https://github.com/go-ini/ini/issues/new)
## 常见问题
### 字段 `BlockMode` 是什么?
默认情况下,本库会在您进行读写操作时采用锁机制来确保数据时间。但在某些情况下,您非常确定只进行读操作。此时,您可以通过设置 `cfg.BlockMode = false` 来将读操作提升大约 **50-70%** 的性能。
### 为什么要写另一个 INI 解析库?
许多人都在使用我的 [goconfig](https://github.com/Unknwon/goconfig) 来完成对 INI 文件的操作,但我希望使用更加 Go 风格的代码。并且当您设置 `cfg.BlockMode = false` 时,会有大约 **10-30%** 的性能提升。
为了做出这些改变,我必须对 API 进行破坏,所以新开一个仓库是最安全的做法。除此之外,本库直接使用 `gopkg.in` 来进行版本化发布。(其实真相是导入路径更短了)

1226
Godeps/_workspace/src/github.com/go-ini/ini/ini.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More