Merge remote-tracking branch 'upstream/master'

This commit is contained in:
Mike Kobyakov
2015-09-17 23:07:52 -07:00
607 changed files with 71935 additions and 6257 deletions

View File

@@ -9,7 +9,7 @@ watch_dirs = [
"$WORKDIR/public/views",
"$WORKDIR/conf",
]
watch_exts = [".go", ".ini"]
watch_exts = [".go", ".ini", ".toml", ".html"]
build_delay = 1500
cmds = [
["go", "build", "-o", "./bin/grafana-server"],

12
.editorconfig Normal file
View File

@@ -0,0 +1,12 @@
# http://editorconfig.org
root = true
[*]
indent_style = space
indent_size = 2
charset = utf-8
trim_trailing_whitespace = true
insert_final_newline = true
[*.md]
trim_trailing_whitespace = false

5
.gitignore vendored
View File

@@ -3,6 +3,8 @@ coverage/
.aws-config.json
awsconfig
/dist
/emails/dist
/public_gen
/tmp
docs/AWS_S3_BUCKET
@@ -27,4 +29,5 @@ public/css/*.min.css
conf/custom.ini
fig.yml
profile.cov
grafana
tsconfig.json

View File

@@ -23,7 +23,7 @@
"laxcomma": true,
"sub": true,
"unused": true,
"maxdepth": 5,
"maxdepth": 6,
"maxlen": 140,
"globals": {
@@ -32,4 +32,4 @@
"Chromath": false,
"setImmediate": true
}
}
}

View File

@@ -1,4 +1,86 @@
# 2.1.0 (unreleased - master branch)
# 2.5 (unreleased)
** New Feature: Mix data sources **
A built in data source is now available named `-- Mixed --`, When picked in the metrics tab,
it allows you to add queries of differnet data source types & instances to the same graph/panel!
[Issue #436](https://github.com/grafana/grafana/issues/436)
** New Feature: Elasticsearch Metrics Query Editor and Viz Support **
- Feature rich query editor and processing features enables you to issues all kind of metric queries to Elasticsearch
- See [Issue #1034](https://github.com/grafana/grafana/issues/1034) for more info.
** User Onboarding **
- Org admin can now send email invites (or invite links) to people who are not yet Grafana users
- Sign up flow now supports email verification (if enabled)
- See [Issue #2353](https://github.com/grafana/grafana/issues/2354) for more info.
** Other new Features && Enhancements**
- [Pull #2720](https://github.com/grafana/grafana/pull/2720). Admin: Initial basic quota support (per Org)
- [Issue #2577](https://github.com/grafana/grafana/issues/2577). Panel: Resize handles in panel bottom right corners for easy width and height change
- [Issue #2457](https://github.com/grafana/grafana/issues/2457). Admin: admin page for all grafana organizations (list / edit view)
- [Issue #1186](https://github.com/grafana/grafana/issues/1186). Time Picker: New option `today`, will set time range from midnight to now
- [Issue #1186](https://github.com/grafana/grafana/issues/1186). Time Picker: New option `today`, will set time range from midnight to now
- [Issue #2647](https://github.com/grafana/grafana/issues/2647). InfluxDB: You can now set group by time interval on each query
- [Issue #2599](https://github.com/grafana/grafana/issues/2599). InfluxDB: Improved alias support, you can now use the `AS` clause for each select statement
- [Issue #2708](https://github.com/grafana/grafana/issues/2708). InfluxDB: You can now set math expression for select clauses.
**Fixes**
- [Issue #2413](https://github.com/grafana/grafana/issues/2413). InfluxDB 0.9: Fix for handling empty series object in response from influxdb
- [Issue #2574](https://github.com/grafana/grafana/issues/2574). Snapshot: Fix for snapshot with expire 7 days option, 7 days option not correct, was 7 hours
- [Issue #2568](https://github.com/grafana/grafana/issues/2568). AuthProxy: Fix for server side rendering of panel when using auth proxy
- [Issue #2490](https://github.com/grafana/grafana/issues/2490). Graphite: Dashboard import was broken in 2.1 and 2.1.1, working now
- [Issue #2565](https://github.com/grafana/grafana/issues/2565). TimePicker: Fix for when you applied custom time range it did not refreh dashboard
- [Issue #2563](https://github.com/grafana/grafana/issues/2563). Annotations: Fixed issue when html sanitizer failes for title to annotation body, now fallbacks to html escaping title and text
- [Issue #2564](https://github.com/grafana/grafana/issues/2564). Templating: Another atempt at fixing #2534 (Init multi value template var used in repeat panel from url)
- [Issue #2620](https://github.com/grafana/grafana/issues/2620). Graph: multi series tooltip did no highlight correct point when stacking was enabled and series were of different resolution
- [Issue #2636](https://github.com/grafana/grafana/issues/2636). InfluxDB: Do no show template vars in dropdown for tag keys and group by keys
- [Issue #2604](https://github.com/grafana/grafana/issues/2604). InfluxDB: More alias options, can now use `$[0-9]` syntax to reference part of a measurement name (seperated by dots)
**Breaking Changes**
- Notice to makers/users of custom data sources, there is a minor breaking change in 2.2 that
require an update to custom data sources for them to work in 2.2. [Read this doc](https://github.com/grafana/grafana/tree/master/docs/sources/datasources/plugin_api.md) for more on the
data source api change.
- The duplicate query function used in data source editors is changed, and moveMetricQuery function was renamed
**Tech (Note for devs)**
Started using Typescript (transpiled to ES5), uncompiled typescript files and less files are in public folder (in source tree)
This folder is never modified by build steps. Compiled css and javascript files are put in public_gen, all other files
that do not undergo transformation are just copied from public to public_gen, it is public_gen that is used by grafana-server
if it is found.
Grunt & Watch tasks:
- `grunt` : default task, will remove public_gen, copy over all files from public, do less & typescript compilation
- `grunt watch`: will watch for changes to less, and typescript files and compile them to public_gen, and for other files it will just copy them to public_gen
2.1.3 (2015-08-24)
**Fixes**
- [Issue #2580](https://github.com/grafana/grafana/issues/2580). Packaging: ldap.toml was not marked as config file and could be overwritten in upgrade
- [Issue #2564](https://github.com/grafana/grafana/issues/2564). Templating: Another atempt at fixing #2534 (Init multi value template var used in repeat panel from url)
# 2.1.2 (2015-08-20)
**Fixes**
- [Issue #2558](https://github.com/grafana/grafana/issues/2558). DragDrop: Fix for broken drag drop behavior
- [Issue #2534](https://github.com/grafana/grafana/issues/2534). Templating: fix for setting template variable value via url and having repeated panels or rows
# 2.1.1 (2015-08-11)
**Fixes**
- [Issue #2443](https://github.com/grafana/grafana/issues/2443). Templating: Fix for buggy repeat row behavior when combined with with repeat panel due to recent change before 2.1 release
- [Issue #2442](https://github.com/grafana/grafana/issues/2442). Templating: Fix text panel when using template variables in text in in repeated panel
- [Issue #2446](https://github.com/grafana/grafana/issues/2446). InfluxDB: Fix for using template vars inside alias field (InfluxDB 0.9)
- [Issue #2460](https://github.com/grafana/grafana/issues/2460). SinglestatPanel: Fix to handle series with no data points
- [Issue #2461](https://github.com/grafana/grafana/issues/2461). LDAP: Fix for ldap users with empty email address
- [Issue #2484](https://github.com/grafana/grafana/issues/2484). Graphite: Fix bug when using series ref (#A-Z) and referenced series is hidden in query editor.
- [Issue #1896](https://github.com/grafana/grafana/issues/1896). Postgres: Dashboard search is now case insensitive when using Postgres
**Enhancements**
- [Issue #2477](https://github.com/grafana/grafana/issues/2477). InfluxDB(0.9): Added more condition operators (`<`, `>`, `<>`, `!~`), thx @thuck
- [Issue #2483](https://github.com/grafana/grafana/issues/2484). InfluxDB(0.9): Use $col as option in alias patterns, thx @thuck
# 2.1.0 (2015-08-04)
**Data sources**
- [Issue #1525](https://github.com/grafana/grafana/issues/1525). InfluxDB: Full support for InfluxDB 0.9 with new adapted query editor
@@ -103,6 +185,10 @@
# 2.0.0-Beta1 (2015-03-30)
**Important Note**
Grafana 2.x is fundamentally different from 1.x; it now ships with an integrated backend server. Please read the [Documentation](http://docs.grafana.org) for more detailed about this SIGNIFCANT change to Grafana
**New features**
- [Issue #1623](https://github.com/grafana/grafana/issues/1623). Share Dashboard: Dashboard snapshot sharing (dash and data snapshot), save to local or save to public snapshot dashboard snapshots.raintank.io site
- [Issue #1622](https://github.com/grafana/grafana/issues/1622). Share Panel: The share modal now has an embed option, gives you an iframe that you can use to embedd a single graph on another web site

58
Godeps/Godeps.json generated
View File

@@ -5,6 +5,11 @@
"./pkg/..."
],
"Deps": [
{
"ImportPath": "github.com/BurntSushi/toml",
"Comment": "v0.1.0-21-g056c9bc",
"Rev": "056c9bc7be7190eaa7715723883caffa5f8fa3e4"
},
{
"ImportPath": "github.com/Unknwon/com",
"Rev": "d9bcf409c8a368d06c9b347705c381e7c12d54df"
@@ -13,6 +18,50 @@
"ImportPath": "github.com/Unknwon/macaron",
"Rev": "93de4f3fad97bf246b838f828e2348f46f21f20a"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws",
"Comment": "v0.7.3",
"Rev": "bed164a424e75154a40550c04c313ef51a7bb275"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/internal/endpoints",
"Comment": "v0.7.3",
"Rev": "bed164a424e75154a40550c04c313ef51a7bb275"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/internal/protocol/query",
"Comment": "v0.7.3",
"Rev": "bed164a424e75154a40550c04c313ef51a7bb275"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/internal/protocol/rest",
"Comment": "v0.7.3",
"Rev": "bed164a424e75154a40550c04c313ef51a7bb275"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/internal/protocol/xml/xmlutil",
"Comment": "v0.7.3",
"Rev": "bed164a424e75154a40550c04c313ef51a7bb275"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/internal/signer/v4",
"Comment": "v0.7.3",
"Rev": "bed164a424e75154a40550c04c313ef51a7bb275"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/service/cloudwatch",
"Comment": "v0.7.3",
"Rev": "bed164a424e75154a40550c04c313ef51a7bb275"
},
{
"ImportPath": "github.com/davecgh/go-spew/spew",
"Rev": "2df174808ee097f90d259e432cc04442cf60be21"
},
{
"ImportPath": "github.com/go-ldap/ldap",
"Comment": "v1-19-g83e6542",
"Rev": "83e65426fd1c06626e88aa8a085e5bfed0208e29"
},
{
"ImportPath": "github.com/go-sql-driver/mysql",
"Comment": "v1.2-26-g9543750",
@@ -65,6 +114,10 @@
"ImportPath": "github.com/streadway/amqp",
"Rev": "150b7f24d6ad507e6026c13d85ce1f1391ac7400"
},
{
"ImportPath": "github.com/vaughan0/go-ini",
"Rev": "a98ad7ee00ec53921f08832bc06ecf7fd600e6a1"
},
{
"ImportPath": "golang.org/x/net/context",
"Rev": "972f0c5fbe4ae29e666c3f78c3ed42ae7a448b0a"
@@ -73,6 +126,11 @@
"ImportPath": "golang.org/x/oauth2",
"Rev": "c58fcf0ffc1c772aa2e1ee4894bc19f2649263b2"
},
{
"ImportPath": "gopkg.in/asn1-ber.v1",
"Comment": "v1",
"Rev": "9eae18c3681ae3d3c677ac2b80a8fe57de45fc09"
},
{
"ImportPath": "gopkg.in/bufio.v1",
"Comment": "v1",

View File

@@ -0,0 +1,5 @@
TAGS
tags
.*.swp
tomlcheck/tomlcheck
toml.test

View File

@@ -0,0 +1,12 @@
language: go
go:
- 1.1
- 1.2
- tip
install:
- go install ./...
- go get github.com/BurntSushi/toml-test
script:
- export PATH="$PATH:$HOME/gopath/bin"
- make test

View File

@@ -0,0 +1,3 @@
Compatible with TOML version
[v0.2.0](https://github.com/mojombo/toml/blob/master/versions/toml-v0.2.0.md)

View File

@@ -0,0 +1,14 @@
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
Version 2, December 2004
Copyright (C) 2004 Sam Hocevar <sam@hocevar.net>
Everyone is permitted to copy and distribute verbatim or modified
copies of this license document, and changing it is allowed as long
as the name is changed.
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
0. You just DO WHAT THE FUCK YOU WANT TO.

View File

@@ -0,0 +1,19 @@
install:
go install ./...
test: install
go test -v
toml-test toml-test-decoder
toml-test -encoder toml-test-encoder
fmt:
gofmt -w *.go */*.go
colcheck *.go */*.go
tags:
find ./ -name '*.go' -print0 | xargs -0 gotags > TAGS
push:
git push origin master
git push github master

View File

@@ -0,0 +1,220 @@
## TOML parser and encoder for Go with reflection
TOML stands for Tom's Obvious, Minimal Language. This Go package provides a
reflection interface similar to Go's standard library `json` and `xml`
packages. This package also supports the `encoding.TextUnmarshaler` and
`encoding.TextMarshaler` interfaces so that you can define custom data
representations. (There is an example of this below.)
Spec: https://github.com/mojombo/toml
Compatible with TOML version
[v0.2.0](https://github.com/toml-lang/toml/blob/master/versions/en/toml-v0.2.0.md)
Documentation: http://godoc.org/github.com/BurntSushi/toml
Installation:
```bash
go get github.com/BurntSushi/toml
```
Try the toml validator:
```bash
go get github.com/BurntSushi/toml/cmd/tomlv
tomlv some-toml-file.toml
```
[![Build status](https://api.travis-ci.org/BurntSushi/toml.png)](https://travis-ci.org/BurntSushi/toml)
### Testing
This package passes all tests in
[toml-test](https://github.com/BurntSushi/toml-test) for both the decoder
and the encoder.
### Examples
This package works similarly to how the Go standard library handles `XML`
and `JSON`. Namely, data is loaded into Go values via reflection.
For the simplest example, consider some TOML file as just a list of keys
and values:
```toml
Age = 25
Cats = [ "Cauchy", "Plato" ]
Pi = 3.14
Perfection = [ 6, 28, 496, 8128 ]
DOB = 1987-07-05T05:45:00Z
```
Which could be defined in Go as:
```go
type Config struct {
Age int
Cats []string
Pi float64
Perfection []int
DOB time.Time // requires `import time`
}
```
And then decoded with:
```go
var conf Config
if _, err := toml.Decode(tomlData, &conf); err != nil {
// handle error
}
```
You can also use struct tags if your struct field name doesn't map to a TOML
key value directly:
```toml
some_key_NAME = "wat"
```
```go
type TOML struct {
ObscureKey string `toml:"some_key_NAME"`
}
```
### Using the `encoding.TextUnmarshaler` interface
Here's an example that automatically parses duration strings into
`time.Duration` values:
```toml
[[song]]
name = "Thunder Road"
duration = "4m49s"
[[song]]
name = "Stairway to Heaven"
duration = "8m03s"
```
Which can be decoded with:
```go
type song struct {
Name string
Duration duration
}
type songs struct {
Song []song
}
var favorites songs
if _, err := toml.Decode(blob, &favorites); err != nil {
log.Fatal(err)
}
for _, s := range favorites.Song {
fmt.Printf("%s (%s)\n", s.Name, s.Duration)
}
```
And you'll also need a `duration` type that satisfies the
`encoding.TextUnmarshaler` interface:
```go
type duration struct {
time.Duration
}
func (d *duration) UnmarshalText(text []byte) error {
var err error
d.Duration, err = time.ParseDuration(string(text))
return err
}
```
### More complex usage
Here's an example of how to load the example from the official spec page:
```toml
# This is a TOML document. Boom.
title = "TOML Example"
[owner]
name = "Tom Preston-Werner"
organization = "GitHub"
bio = "GitHub Cofounder & CEO\nLikes tater tots and beer."
dob = 1979-05-27T07:32:00Z # First class dates? Why not?
[database]
server = "192.168.1.1"
ports = [ 8001, 8001, 8002 ]
connection_max = 5000
enabled = true
[servers]
# You can indent as you please. Tabs or spaces. TOML don't care.
[servers.alpha]
ip = "10.0.0.1"
dc = "eqdc10"
[servers.beta]
ip = "10.0.0.2"
dc = "eqdc10"
[clients]
data = [ ["gamma", "delta"], [1, 2] ] # just an update to make sure parsers support it
# Line breaks are OK when inside arrays
hosts = [
"alpha",
"omega"
]
```
And the corresponding Go types are:
```go
type tomlConfig struct {
Title string
Owner ownerInfo
DB database `toml:"database"`
Servers map[string]server
Clients clients
}
type ownerInfo struct {
Name string
Org string `toml:"organization"`
Bio string
DOB time.Time
}
type database struct {
Server string
Ports []int
ConnMax int `toml:"connection_max"`
Enabled bool
}
type server struct {
IP string
DC string
}
type clients struct {
Data [][]interface{}
Hosts []string
}
```
Note that a case insensitive match will be tried if an exact match can't be
found.
A working example of the above can be found in `_examples/example.{go,toml}`.

View File

@@ -0,0 +1,14 @@
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
Version 2, December 2004
Copyright (C) 2004 Sam Hocevar <sam@hocevar.net>
Everyone is permitted to copy and distribute verbatim or modified
copies of this license document, and changing it is allowed as long
as the name is changed.
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
0. You just DO WHAT THE FUCK YOU WANT TO.

View File

@@ -0,0 +1,14 @@
# Implements the TOML test suite interface
This is an implementation of the interface expected by
[toml-test](https://github.com/BurntSushi/toml-test) for my
[toml parser written in Go](https://github.com/BurntSushi/toml).
In particular, it maps TOML data on `stdin` to a JSON format on `stdout`.
Compatible with TOML version
[v0.2.0](https://github.com/mojombo/toml/blob/master/versions/toml-v0.2.0.md)
Compatible with `toml-test` version
[v0.2.0](https://github.com/BurntSushi/toml-test/tree/v0.2.0)

View File

@@ -0,0 +1,90 @@
// Command toml-test-decoder satisfies the toml-test interface for testing
// TOML decoders. Namely, it accepts TOML on stdin and outputs JSON on stdout.
package main
import (
"encoding/json"
"flag"
"fmt"
"log"
"os"
"path"
"time"
"github.com/BurntSushi/toml"
)
func init() {
log.SetFlags(0)
flag.Usage = usage
flag.Parse()
}
func usage() {
log.Printf("Usage: %s < toml-file\n", path.Base(os.Args[0]))
flag.PrintDefaults()
os.Exit(1)
}
func main() {
if flag.NArg() != 0 {
flag.Usage()
}
var tmp interface{}
if _, err := toml.DecodeReader(os.Stdin, &tmp); err != nil {
log.Fatalf("Error decoding TOML: %s", err)
}
typedTmp := translate(tmp)
if err := json.NewEncoder(os.Stdout).Encode(typedTmp); err != nil {
log.Fatalf("Error encoding JSON: %s", err)
}
}
func translate(tomlData interface{}) interface{} {
switch orig := tomlData.(type) {
case map[string]interface{}:
typed := make(map[string]interface{}, len(orig))
for k, v := range orig {
typed[k] = translate(v)
}
return typed
case []map[string]interface{}:
typed := make([]map[string]interface{}, len(orig))
for i, v := range orig {
typed[i] = translate(v).(map[string]interface{})
}
return typed
case []interface{}:
typed := make([]interface{}, len(orig))
for i, v := range orig {
typed[i] = translate(v)
}
// We don't really need to tag arrays, but let's be future proof.
// (If TOML ever supports tuples, we'll need this.)
return tag("array", typed)
case time.Time:
return tag("datetime", orig.Format("2006-01-02T15:04:05Z"))
case bool:
return tag("bool", fmt.Sprintf("%v", orig))
case int64:
return tag("integer", fmt.Sprintf("%d", orig))
case float64:
return tag("float", fmt.Sprintf("%v", orig))
case string:
return tag("string", orig)
}
panic(fmt.Sprintf("Unknown type: %T", tomlData))
}
func tag(typeName string, data interface{}) map[string]interface{} {
return map[string]interface{}{
"type": typeName,
"value": data,
}
}

View File

@@ -0,0 +1,14 @@
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
Version 2, December 2004
Copyright (C) 2004 Sam Hocevar <sam@hocevar.net>
Everyone is permitted to copy and distribute verbatim or modified
copies of this license document, and changing it is allowed as long
as the name is changed.
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
0. You just DO WHAT THE FUCK YOU WANT TO.

View File

@@ -0,0 +1,14 @@
# Implements the TOML test suite interface for TOML encoders
This is an implementation of the interface expected by
[toml-test](https://github.com/BurntSushi/toml-test) for the
[TOML encoder](https://github.com/BurntSushi/toml).
In particular, it maps JSON data on `stdin` to a TOML format on `stdout`.
Compatible with TOML version
[v0.2.0](https://github.com/mojombo/toml/blob/master/versions/toml-v0.2.0.md)
Compatible with `toml-test` version
[v0.2.0](https://github.com/BurntSushi/toml-test/tree/v0.2.0)

View File

@@ -0,0 +1,131 @@
// Command toml-test-encoder satisfies the toml-test interface for testing
// TOML encoders. Namely, it accepts JSON on stdin and outputs TOML on stdout.
package main
import (
"encoding/json"
"flag"
"log"
"os"
"path"
"strconv"
"time"
"github.com/BurntSushi/toml"
)
func init() {
log.SetFlags(0)
flag.Usage = usage
flag.Parse()
}
func usage() {
log.Printf("Usage: %s < json-file\n", path.Base(os.Args[0]))
flag.PrintDefaults()
os.Exit(1)
}
func main() {
if flag.NArg() != 0 {
flag.Usage()
}
var tmp interface{}
if err := json.NewDecoder(os.Stdin).Decode(&tmp); err != nil {
log.Fatalf("Error decoding JSON: %s", err)
}
tomlData := translate(tmp)
if err := toml.NewEncoder(os.Stdout).Encode(tomlData); err != nil {
log.Fatalf("Error encoding TOML: %s", err)
}
}
func translate(typedJson interface{}) interface{} {
switch v := typedJson.(type) {
case map[string]interface{}:
if len(v) == 2 && in("type", v) && in("value", v) {
return untag(v)
}
m := make(map[string]interface{}, len(v))
for k, v2 := range v {
m[k] = translate(v2)
}
return m
case []interface{}:
tabArray := make([]map[string]interface{}, len(v))
for i := range v {
if m, ok := translate(v[i]).(map[string]interface{}); ok {
tabArray[i] = m
} else {
log.Fatalf("JSON arrays may only contain objects. This " +
"corresponds to only tables being allowed in " +
"TOML table arrays.")
}
}
return tabArray
}
log.Fatalf("Unrecognized JSON format '%T'.", typedJson)
panic("unreachable")
}
func untag(typed map[string]interface{}) interface{} {
t := typed["type"].(string)
v := typed["value"]
switch t {
case "string":
return v.(string)
case "integer":
v := v.(string)
n, err := strconv.Atoi(v)
if err != nil {
log.Fatalf("Could not parse '%s' as integer: %s", v, err)
}
return n
case "float":
v := v.(string)
f, err := strconv.ParseFloat(v, 64)
if err != nil {
log.Fatalf("Could not parse '%s' as float64: %s", v, err)
}
return f
case "datetime":
v := v.(string)
t, err := time.Parse("2006-01-02T15:04:05Z", v)
if err != nil {
log.Fatalf("Could not parse '%s' as a datetime: %s", v, err)
}
return t
case "bool":
v := v.(string)
switch v {
case "true":
return true
case "false":
return false
}
log.Fatalf("Could not parse '%s' as a boolean.", v)
case "array":
v := v.([]interface{})
array := make([]interface{}, len(v))
for i := range v {
if m, ok := v[i].(map[string]interface{}); ok {
array[i] = untag(m)
} else {
log.Fatalf("Arrays may only contain other arrays or "+
"primitive values, but found a '%T'.", m)
}
}
return array
}
log.Fatalf("Unrecognized tag type '%s'.", t)
panic("unreachable")
}
func in(key string, m map[string]interface{}) bool {
_, ok := m[key]
return ok
}

View File

@@ -0,0 +1,14 @@
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
Version 2, December 2004
Copyright (C) 2004 Sam Hocevar <sam@hocevar.net>
Everyone is permitted to copy and distribute verbatim or modified
copies of this license document, and changing it is allowed as long
as the name is changed.
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
0. You just DO WHAT THE FUCK YOU WANT TO.

View File

@@ -0,0 +1,22 @@
# TOML Validator
If Go is installed, it's simple to try it out:
```bash
go get github.com/BurntSushi/toml/cmd/tomlv
tomlv some-toml-file.toml
```
You can see the types of every key in a TOML file with:
```bash
tomlv -types some-toml-file.toml
```
At the moment, only one error message is reported at a time. Error messages
include line numbers. No output means that the files given are valid TOML, or
there is a bug in `tomlv`.
Compatible with TOML version
[v0.1.0](https://github.com/mojombo/toml/blob/master/versions/toml-v0.1.0.md)

View File

@@ -0,0 +1,61 @@
// Command tomlv validates TOML documents and prints each key's type.
package main
import (
"flag"
"fmt"
"log"
"os"
"path"
"strings"
"text/tabwriter"
"github.com/BurntSushi/toml"
)
var (
flagTypes = false
)
func init() {
log.SetFlags(0)
flag.BoolVar(&flagTypes, "types", flagTypes,
"When set, the types of every defined key will be shown.")
flag.Usage = usage
flag.Parse()
}
func usage() {
log.Printf("Usage: %s toml-file [ toml-file ... ]\n",
path.Base(os.Args[0]))
flag.PrintDefaults()
os.Exit(1)
}
func main() {
if flag.NArg() < 1 {
flag.Usage()
}
for _, f := range flag.Args() {
var tmp interface{}
md, err := toml.DecodeFile(f, &tmp)
if err != nil {
log.Fatalf("Error in '%s': %s", f, err)
}
if flagTypes {
printTypes(md)
}
}
}
func printTypes(md toml.MetaData) {
tabw := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
for _, key := range md.Keys() {
fmt.Fprintf(tabw, "%s%s\t%s\n",
strings.Repeat(" ", len(key)-1), key, md.Type(key...))
}
tabw.Flush()
}

View File

@@ -0,0 +1,492 @@
package toml
import (
"fmt"
"io"
"io/ioutil"
"math"
"reflect"
"strings"
"time"
)
var e = fmt.Errorf
// Unmarshaler is the interface implemented by objects that can unmarshal a
// TOML description of themselves.
type Unmarshaler interface {
UnmarshalTOML(interface{}) error
}
// Unmarshal decodes the contents of `p` in TOML format into a pointer `v`.
func Unmarshal(p []byte, v interface{}) error {
_, err := Decode(string(p), v)
return err
}
// Primitive is a TOML value that hasn't been decoded into a Go value.
// When using the various `Decode*` functions, the type `Primitive` may
// be given to any value, and its decoding will be delayed.
//
// A `Primitive` value can be decoded using the `PrimitiveDecode` function.
//
// The underlying representation of a `Primitive` value is subject to change.
// Do not rely on it.
//
// N.B. Primitive values are still parsed, so using them will only avoid
// the overhead of reflection. They can be useful when you don't know the
// exact type of TOML data until run time.
type Primitive struct {
undecoded interface{}
context Key
}
// DEPRECATED!
//
// Use MetaData.PrimitiveDecode instead.
func PrimitiveDecode(primValue Primitive, v interface{}) error {
md := MetaData{decoded: make(map[string]bool)}
return md.unify(primValue.undecoded, rvalue(v))
}
// PrimitiveDecode is just like the other `Decode*` functions, except it
// decodes a TOML value that has already been parsed. Valid primitive values
// can *only* be obtained from values filled by the decoder functions,
// including this method. (i.e., `v` may contain more `Primitive`
// values.)
//
// Meta data for primitive values is included in the meta data returned by
// the `Decode*` functions with one exception: keys returned by the Undecoded
// method will only reflect keys that were decoded. Namely, any keys hidden
// behind a Primitive will be considered undecoded. Executing this method will
// update the undecoded keys in the meta data. (See the example.)
func (md *MetaData) PrimitiveDecode(primValue Primitive, v interface{}) error {
md.context = primValue.context
defer func() { md.context = nil }()
return md.unify(primValue.undecoded, rvalue(v))
}
// Decode will decode the contents of `data` in TOML format into a pointer
// `v`.
//
// TOML hashes correspond to Go structs or maps. (Dealer's choice. They can be
// used interchangeably.)
//
// TOML arrays of tables correspond to either a slice of structs or a slice
// of maps.
//
// TOML datetimes correspond to Go `time.Time` values.
//
// All other TOML types (float, string, int, bool and array) correspond
// to the obvious Go types.
//
// An exception to the above rules is if a type implements the
// encoding.TextUnmarshaler interface. In this case, any primitive TOML value
// (floats, strings, integers, booleans and datetimes) will be converted to
// a byte string and given to the value's UnmarshalText method. See the
// Unmarshaler example for a demonstration with time duration strings.
//
// Key mapping
//
// TOML keys can map to either keys in a Go map or field names in a Go
// struct. The special `toml` struct tag may be used to map TOML keys to
// struct fields that don't match the key name exactly. (See the example.)
// A case insensitive match to struct names will be tried if an exact match
// can't be found.
//
// The mapping between TOML values and Go values is loose. That is, there
// may exist TOML values that cannot be placed into your representation, and
// there may be parts of your representation that do not correspond to
// TOML values. This loose mapping can be made stricter by using the IsDefined
// and/or Undecoded methods on the MetaData returned.
//
// This decoder will not handle cyclic types. If a cyclic type is passed,
// `Decode` will not terminate.
func Decode(data string, v interface{}) (MetaData, error) {
p, err := parse(data)
if err != nil {
return MetaData{}, err
}
md := MetaData{
p.mapping, p.types, p.ordered,
make(map[string]bool, len(p.ordered)), nil,
}
return md, md.unify(p.mapping, rvalue(v))
}
// DecodeFile is just like Decode, except it will automatically read the
// contents of the file at `fpath` and decode it for you.
func DecodeFile(fpath string, v interface{}) (MetaData, error) {
bs, err := ioutil.ReadFile(fpath)
if err != nil {
return MetaData{}, err
}
return Decode(string(bs), v)
}
// DecodeReader is just like Decode, except it will consume all bytes
// from the reader and decode it for you.
func DecodeReader(r io.Reader, v interface{}) (MetaData, error) {
bs, err := ioutil.ReadAll(r)
if err != nil {
return MetaData{}, err
}
return Decode(string(bs), v)
}
// unify performs a sort of type unification based on the structure of `rv`,
// which is the client representation.
//
// Any type mismatch produces an error. Finding a type that we don't know
// how to handle produces an unsupported type error.
func (md *MetaData) unify(data interface{}, rv reflect.Value) error {
// Special case. Look for a `Primitive` value.
if rv.Type() == reflect.TypeOf((*Primitive)(nil)).Elem() {
// Save the undecoded data and the key context into the primitive
// value.
context := make(Key, len(md.context))
copy(context, md.context)
rv.Set(reflect.ValueOf(Primitive{
undecoded: data,
context: context,
}))
return nil
}
// Special case. Unmarshaler Interface support.
if rv.CanAddr() {
if v, ok := rv.Addr().Interface().(Unmarshaler); ok {
return v.UnmarshalTOML(data)
}
}
// Special case. Handle time.Time values specifically.
// TODO: Remove this code when we decide to drop support for Go 1.1.
// This isn't necessary in Go 1.2 because time.Time satisfies the encoding
// interfaces.
if rv.Type().AssignableTo(rvalue(time.Time{}).Type()) {
return md.unifyDatetime(data, rv)
}
// Special case. Look for a value satisfying the TextUnmarshaler interface.
if v, ok := rv.Interface().(TextUnmarshaler); ok {
return md.unifyText(data, v)
}
// BUG(burntsushi)
// The behavior here is incorrect whenever a Go type satisfies the
// encoding.TextUnmarshaler interface but also corresponds to a TOML
// hash or array. In particular, the unmarshaler should only be applied
// to primitive TOML values. But at this point, it will be applied to
// all kinds of values and produce an incorrect error whenever those values
// are hashes or arrays (including arrays of tables).
k := rv.Kind()
// laziness
if k >= reflect.Int && k <= reflect.Uint64 {
return md.unifyInt(data, rv)
}
switch k {
case reflect.Ptr:
elem := reflect.New(rv.Type().Elem())
err := md.unify(data, reflect.Indirect(elem))
if err != nil {
return err
}
rv.Set(elem)
return nil
case reflect.Struct:
return md.unifyStruct(data, rv)
case reflect.Map:
return md.unifyMap(data, rv)
case reflect.Array:
return md.unifyArray(data, rv)
case reflect.Slice:
return md.unifySlice(data, rv)
case reflect.String:
return md.unifyString(data, rv)
case reflect.Bool:
return md.unifyBool(data, rv)
case reflect.Interface:
// we only support empty interfaces.
if rv.NumMethod() > 0 {
return e("Unsupported type '%s'.", rv.Kind())
}
return md.unifyAnything(data, rv)
case reflect.Float32:
fallthrough
case reflect.Float64:
return md.unifyFloat64(data, rv)
}
return e("Unsupported type '%s'.", rv.Kind())
}
func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error {
tmap, ok := mapping.(map[string]interface{})
if !ok {
return mismatch(rv, "map", mapping)
}
for key, datum := range tmap {
var f *field
fields := cachedTypeFields(rv.Type())
for i := range fields {
ff := &fields[i]
if ff.name == key {
f = ff
break
}
if f == nil && strings.EqualFold(ff.name, key) {
f = ff
}
}
if f != nil {
subv := rv
for _, i := range f.index {
subv = indirect(subv.Field(i))
}
if isUnifiable(subv) {
md.decoded[md.context.add(key).String()] = true
md.context = append(md.context, key)
if err := md.unify(datum, subv); err != nil {
return e("Type mismatch for '%s.%s': %s",
rv.Type().String(), f.name, err)
}
md.context = md.context[0 : len(md.context)-1]
} else if f.name != "" {
// Bad user! No soup for you!
return e("Field '%s.%s' is unexported, and therefore cannot "+
"be loaded with reflection.", rv.Type().String(), f.name)
}
}
}
return nil
}
func (md *MetaData) unifyMap(mapping interface{}, rv reflect.Value) error {
tmap, ok := mapping.(map[string]interface{})
if !ok {
return badtype("map", mapping)
}
if rv.IsNil() {
rv.Set(reflect.MakeMap(rv.Type()))
}
for k, v := range tmap {
md.decoded[md.context.add(k).String()] = true
md.context = append(md.context, k)
rvkey := indirect(reflect.New(rv.Type().Key()))
rvval := reflect.Indirect(reflect.New(rv.Type().Elem()))
if err := md.unify(v, rvval); err != nil {
return err
}
md.context = md.context[0 : len(md.context)-1]
rvkey.SetString(k)
rv.SetMapIndex(rvkey, rvval)
}
return nil
}
func (md *MetaData) unifyArray(data interface{}, rv reflect.Value) error {
datav := reflect.ValueOf(data)
if datav.Kind() != reflect.Slice {
return badtype("slice", data)
}
sliceLen := datav.Len()
if sliceLen != rv.Len() {
return e("expected array length %d; got TOML array of length %d",
rv.Len(), sliceLen)
}
return md.unifySliceArray(datav, rv)
}
func (md *MetaData) unifySlice(data interface{}, rv reflect.Value) error {
datav := reflect.ValueOf(data)
if datav.Kind() != reflect.Slice {
return badtype("slice", data)
}
sliceLen := datav.Len()
if rv.IsNil() {
rv.Set(reflect.MakeSlice(rv.Type(), sliceLen, sliceLen))
}
return md.unifySliceArray(datav, rv)
}
func (md *MetaData) unifySliceArray(data, rv reflect.Value) error {
sliceLen := data.Len()
for i := 0; i < sliceLen; i++ {
v := data.Index(i).Interface()
sliceval := indirect(rv.Index(i))
if err := md.unify(v, sliceval); err != nil {
return err
}
}
return nil
}
func (md *MetaData) unifyDatetime(data interface{}, rv reflect.Value) error {
if _, ok := data.(time.Time); ok {
rv.Set(reflect.ValueOf(data))
return nil
}
return badtype("time.Time", data)
}
func (md *MetaData) unifyString(data interface{}, rv reflect.Value) error {
if s, ok := data.(string); ok {
rv.SetString(s)
return nil
}
return badtype("string", data)
}
func (md *MetaData) unifyFloat64(data interface{}, rv reflect.Value) error {
if num, ok := data.(float64); ok {
switch rv.Kind() {
case reflect.Float32:
fallthrough
case reflect.Float64:
rv.SetFloat(num)
default:
panic("bug")
}
return nil
}
return badtype("float", data)
}
func (md *MetaData) unifyInt(data interface{}, rv reflect.Value) error {
if num, ok := data.(int64); ok {
if rv.Kind() >= reflect.Int && rv.Kind() <= reflect.Int64 {
switch rv.Kind() {
case reflect.Int, reflect.Int64:
// No bounds checking necessary.
case reflect.Int8:
if num < math.MinInt8 || num > math.MaxInt8 {
return e("Value '%d' is out of range for int8.", num)
}
case reflect.Int16:
if num < math.MinInt16 || num > math.MaxInt16 {
return e("Value '%d' is out of range for int16.", num)
}
case reflect.Int32:
if num < math.MinInt32 || num > math.MaxInt32 {
return e("Value '%d' is out of range for int32.", num)
}
}
rv.SetInt(num)
} else if rv.Kind() >= reflect.Uint && rv.Kind() <= reflect.Uint64 {
unum := uint64(num)
switch rv.Kind() {
case reflect.Uint, reflect.Uint64:
// No bounds checking necessary.
case reflect.Uint8:
if num < 0 || unum > math.MaxUint8 {
return e("Value '%d' is out of range for uint8.", num)
}
case reflect.Uint16:
if num < 0 || unum > math.MaxUint16 {
return e("Value '%d' is out of range for uint16.", num)
}
case reflect.Uint32:
if num < 0 || unum > math.MaxUint32 {
return e("Value '%d' is out of range for uint32.", num)
}
}
rv.SetUint(unum)
} else {
panic("unreachable")
}
return nil
}
return badtype("integer", data)
}
func (md *MetaData) unifyBool(data interface{}, rv reflect.Value) error {
if b, ok := data.(bool); ok {
rv.SetBool(b)
return nil
}
return badtype("boolean", data)
}
func (md *MetaData) unifyAnything(data interface{}, rv reflect.Value) error {
rv.Set(reflect.ValueOf(data))
return nil
}
func (md *MetaData) unifyText(data interface{}, v TextUnmarshaler) error {
var s string
switch sdata := data.(type) {
case TextMarshaler:
text, err := sdata.MarshalText()
if err != nil {
return err
}
s = string(text)
case fmt.Stringer:
s = sdata.String()
case string:
s = sdata
case bool:
s = fmt.Sprintf("%v", sdata)
case int64:
s = fmt.Sprintf("%d", sdata)
case float64:
s = fmt.Sprintf("%f", sdata)
default:
return badtype("primitive (string-like)", data)
}
if err := v.UnmarshalText([]byte(s)); err != nil {
return err
}
return nil
}
// rvalue returns a reflect.Value of `v`. All pointers are resolved.
func rvalue(v interface{}) reflect.Value {
return indirect(reflect.ValueOf(v))
}
// indirect returns the value pointed to by a pointer.
// Pointers are followed until the value is not a pointer.
// New values are allocated for each nil pointer.
//
// An exception to this rule is if the value satisfies an interface of
// interest to us (like encoding.TextUnmarshaler).
func indirect(v reflect.Value) reflect.Value {
if v.Kind() != reflect.Ptr {
if v.CanAddr() {
pv := v.Addr()
if _, ok := pv.Interface().(TextUnmarshaler); ok {
return pv
}
}
return v
}
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
return indirect(reflect.Indirect(v))
}
func isUnifiable(rv reflect.Value) bool {
if rv.CanSet() {
return true
}
if _, ok := rv.Interface().(TextUnmarshaler); ok {
return true
}
return false
}
func badtype(expected string, data interface{}) error {
return e("Expected %s but found '%T'.", expected, data)
}
func mismatch(user reflect.Value, expected string, data interface{}) error {
return e("Type mismatch for %s. Expected %s but found '%T'.",
user.Type().String(), expected, data)
}

View File

@@ -0,0 +1,122 @@
package toml
import "strings"
// MetaData allows access to meta information about TOML data that may not
// be inferrable via reflection. In particular, whether a key has been defined
// and the TOML type of a key.
type MetaData struct {
mapping map[string]interface{}
types map[string]tomlType
keys []Key
decoded map[string]bool
context Key // Used only during decoding.
}
// IsDefined returns true if the key given exists in the TOML data. The key
// should be specified hierarchially. e.g.,
//
// // access the TOML key 'a.b.c'
// IsDefined("a", "b", "c")
//
// IsDefined will return false if an empty key given. Keys are case sensitive.
func (md *MetaData) IsDefined(key ...string) bool {
if len(key) == 0 {
return false
}
var hash map[string]interface{}
var ok bool
var hashOrVal interface{} = md.mapping
for _, k := range key {
if hash, ok = hashOrVal.(map[string]interface{}); !ok {
return false
}
if hashOrVal, ok = hash[k]; !ok {
return false
}
}
return true
}
// Type returns a string representation of the type of the key specified.
//
// Type will return the empty string if given an empty key or a key that
// does not exist. Keys are case sensitive.
func (md *MetaData) Type(key ...string) string {
fullkey := strings.Join(key, ".")
if typ, ok := md.types[fullkey]; ok {
return typ.typeString()
}
return ""
}
// Key is the type of any TOML key, including key groups. Use (MetaData).Keys
// to get values of this type.
type Key []string
func (k Key) String() string {
return strings.Join(k, ".")
}
func (k Key) maybeQuotedAll() string {
var ss []string
for i := range k {
ss = append(ss, k.maybeQuoted(i))
}
return strings.Join(ss, ".")
}
func (k Key) maybeQuoted(i int) string {
quote := false
for _, c := range k[i] {
if !isBareKeyChar(c) {
quote = true
break
}
}
if quote {
return "\"" + strings.Replace(k[i], "\"", "\\\"", -1) + "\""
} else {
return k[i]
}
}
func (k Key) add(piece string) Key {
newKey := make(Key, len(k)+1)
copy(newKey, k)
newKey[len(k)] = piece
return newKey
}
// Keys returns a slice of every key in the TOML data, including key groups.
// Each key is itself a slice, where the first element is the top of the
// hierarchy and the last is the most specific.
//
// The list will have the same order as the keys appeared in the TOML data.
//
// All keys returned are non-empty.
func (md *MetaData) Keys() []Key {
return md.keys
}
// Undecoded returns all keys that have not been decoded in the order in which
// they appear in the original TOML document.
//
// This includes keys that haven't been decoded because of a Primitive value.
// Once the Primitive value is decoded, the keys will be considered decoded.
//
// Also note that decoding into an empty interface will result in no decoding,
// and so no keys will be considered decoded.
//
// In this sense, the Undecoded keys correspond to keys in the TOML document
// that do not have a concrete type in your representation.
func (md *MetaData) Undecoded() []Key {
undecoded := make([]Key, 0, len(md.keys))
for _, key := range md.keys {
if !md.decoded[key.String()] {
undecoded = append(undecoded, key)
}
}
return undecoded
}

View File

@@ -0,0 +1,950 @@
package toml
import (
"fmt"
"log"
"reflect"
"testing"
"time"
)
func init() {
log.SetFlags(0)
}
func TestDecodeSimple(t *testing.T) {
var testSimple = `
age = 250
andrew = "gallant"
kait = "brady"
now = 1987-07-05T05:45:00Z
yesOrNo = true
pi = 3.14
colors = [
["red", "green", "blue"],
["cyan", "magenta", "yellow", "black"],
]
[My.Cats]
plato = "cat 1"
cauchy = "cat 2"
`
type cats struct {
Plato string
Cauchy string
}
type simple struct {
Age int
Colors [][]string
Pi float64
YesOrNo bool
Now time.Time
Andrew string
Kait string
My map[string]cats
}
var val simple
_, err := Decode(testSimple, &val)
if err != nil {
t.Fatal(err)
}
now, err := time.Parse("2006-01-02T15:04:05", "1987-07-05T05:45:00")
if err != nil {
panic(err)
}
var answer = simple{
Age: 250,
Andrew: "gallant",
Kait: "brady",
Now: now,
YesOrNo: true,
Pi: 3.14,
Colors: [][]string{
{"red", "green", "blue"},
{"cyan", "magenta", "yellow", "black"},
},
My: map[string]cats{
"Cats": cats{Plato: "cat 1", Cauchy: "cat 2"},
},
}
if !reflect.DeepEqual(val, answer) {
t.Fatalf("Expected\n-----\n%#v\n-----\nbut got\n-----\n%#v\n",
answer, val)
}
}
func TestDecodeEmbedded(t *testing.T) {
type Dog struct{ Name string }
type Age int
tests := map[string]struct {
input string
decodeInto interface{}
wantDecoded interface{}
}{
"embedded struct": {
input: `Name = "milton"`,
decodeInto: &struct{ Dog }{},
wantDecoded: &struct{ Dog }{Dog{"milton"}},
},
"embedded non-nil pointer to struct": {
input: `Name = "milton"`,
decodeInto: &struct{ *Dog }{},
wantDecoded: &struct{ *Dog }{&Dog{"milton"}},
},
"embedded nil pointer to struct": {
input: ``,
decodeInto: &struct{ *Dog }{},
wantDecoded: &struct{ *Dog }{nil},
},
"embedded int": {
input: `Age = -5`,
decodeInto: &struct{ Age }{},
wantDecoded: &struct{ Age }{-5},
},
}
for label, test := range tests {
_, err := Decode(test.input, test.decodeInto)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(test.wantDecoded, test.decodeInto) {
t.Errorf("%s: want decoded == %+v, got %+v",
label, test.wantDecoded, test.decodeInto)
}
}
}
func TestTableArrays(t *testing.T) {
var tomlTableArrays = `
[[albums]]
name = "Born to Run"
[[albums.songs]]
name = "Jungleland"
[[albums.songs]]
name = "Meeting Across the River"
[[albums]]
name = "Born in the USA"
[[albums.songs]]
name = "Glory Days"
[[albums.songs]]
name = "Dancing in the Dark"
`
type Song struct {
Name string
}
type Album struct {
Name string
Songs []Song
}
type Music struct {
Albums []Album
}
expected := Music{[]Album{
{"Born to Run", []Song{{"Jungleland"}, {"Meeting Across the River"}}},
{"Born in the USA", []Song{{"Glory Days"}, {"Dancing in the Dark"}}},
}}
var got Music
if _, err := Decode(tomlTableArrays, &got); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(expected, got) {
t.Fatalf("\n%#v\n!=\n%#v\n", expected, got)
}
}
// Case insensitive matching tests.
// A bit more comprehensive than needed given the current implementation,
// but implementations change.
// Probably still missing demonstrations of some ugly corner cases regarding
// case insensitive matching and multiple fields.
func TestCase(t *testing.T) {
var caseToml = `
tOpString = "string"
tOpInt = 1
tOpFloat = 1.1
tOpBool = true
tOpdate = 2006-01-02T15:04:05Z
tOparray = [ "array" ]
Match = "i should be in Match only"
MatcH = "i should be in MatcH only"
once = "just once"
[nEst.eD]
nEstedString = "another string"
`
type InsensitiveEd struct {
NestedString string
}
type InsensitiveNest struct {
Ed InsensitiveEd
}
type Insensitive struct {
TopString string
TopInt int
TopFloat float64
TopBool bool
TopDate time.Time
TopArray []string
Match string
MatcH string
Once string
OncE string
Nest InsensitiveNest
}
tme, err := time.Parse(time.RFC3339, time.RFC3339[:len(time.RFC3339)-5])
if err != nil {
panic(err)
}
expected := Insensitive{
TopString: "string",
TopInt: 1,
TopFloat: 1.1,
TopBool: true,
TopDate: tme,
TopArray: []string{"array"},
MatcH: "i should be in MatcH only",
Match: "i should be in Match only",
Once: "just once",
OncE: "",
Nest: InsensitiveNest{
Ed: InsensitiveEd{NestedString: "another string"},
},
}
var got Insensitive
if _, err := Decode(caseToml, &got); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(expected, got) {
t.Fatalf("\n%#v\n!=\n%#v\n", expected, got)
}
}
func TestPointers(t *testing.T) {
type Object struct {
Type string
Description string
}
type Dict struct {
NamedObject map[string]*Object
BaseObject *Object
Strptr *string
Strptrs []*string
}
s1, s2, s3 := "blah", "abc", "def"
expected := &Dict{
Strptr: &s1,
Strptrs: []*string{&s2, &s3},
NamedObject: map[string]*Object{
"foo": {"FOO", "fooooo!!!"},
"bar": {"BAR", "ba-ba-ba-ba-barrrr!!!"},
},
BaseObject: &Object{"BASE", "da base"},
}
ex1 := `
Strptr = "blah"
Strptrs = ["abc", "def"]
[NamedObject.foo]
Type = "FOO"
Description = "fooooo!!!"
[NamedObject.bar]
Type = "BAR"
Description = "ba-ba-ba-ba-barrrr!!!"
[BaseObject]
Type = "BASE"
Description = "da base"
`
dict := new(Dict)
_, err := Decode(ex1, dict)
if err != nil {
t.Errorf("Decode error: %v", err)
}
if !reflect.DeepEqual(expected, dict) {
t.Fatalf("\n%#v\n!=\n%#v\n", expected, dict)
}
}
type sphere struct {
Center [3]float64
Radius float64
}
func TestDecodeSimpleArray(t *testing.T) {
var s1 sphere
if _, err := Decode(`center = [0.0, 1.5, 0.0]`, &s1); err != nil {
t.Fatal(err)
}
}
func TestDecodeArrayWrongSize(t *testing.T) {
var s1 sphere
if _, err := Decode(`center = [0.1, 2.3]`, &s1); err == nil {
t.Fatal("Expected array type mismatch error")
}
}
func TestDecodeLargeIntoSmallInt(t *testing.T) {
type table struct {
Value int8
}
var tab table
if _, err := Decode(`value = 500`, &tab); err == nil {
t.Fatal("Expected integer out-of-bounds error.")
}
}
func TestDecodeSizedInts(t *testing.T) {
type table struct {
U8 uint8
U16 uint16
U32 uint32
U64 uint64
U uint
I8 int8
I16 int16
I32 int32
I64 int64
I int
}
answer := table{1, 1, 1, 1, 1, -1, -1, -1, -1, -1}
toml := `
u8 = 1
u16 = 1
u32 = 1
u64 = 1
u = 1
i8 = -1
i16 = -1
i32 = -1
i64 = -1
i = -1
`
var tab table
if _, err := Decode(toml, &tab); err != nil {
t.Fatal(err.Error())
}
if answer != tab {
t.Fatalf("Expected %#v but got %#v", answer, tab)
}
}
func TestUnmarshaler(t *testing.T) {
var tomlBlob = `
[dishes.hamboogie]
name = "Hamboogie with fries"
price = 10.99
[[dishes.hamboogie.ingredients]]
name = "Bread Bun"
[[dishes.hamboogie.ingredients]]
name = "Lettuce"
[[dishes.hamboogie.ingredients]]
name = "Real Beef Patty"
[[dishes.hamboogie.ingredients]]
name = "Tomato"
[dishes.eggsalad]
name = "Egg Salad with rice"
price = 3.99
[[dishes.eggsalad.ingredients]]
name = "Egg"
[[dishes.eggsalad.ingredients]]
name = "Mayo"
[[dishes.eggsalad.ingredients]]
name = "Rice"
`
m := &menu{}
if _, err := Decode(tomlBlob, m); err != nil {
log.Fatal(err)
}
if len(m.Dishes) != 2 {
t.Log("two dishes should be loaded with UnmarshalTOML()")
t.Errorf("expected %d but got %d", 2, len(m.Dishes))
}
eggSalad := m.Dishes["eggsalad"]
if _, ok := interface{}(eggSalad).(dish); !ok {
t.Errorf("expected a dish")
}
if eggSalad.Name != "Egg Salad with rice" {
t.Errorf("expected the dish to be named 'Egg Salad with rice'")
}
if len(eggSalad.Ingredients) != 3 {
t.Log("dish should be loaded with UnmarshalTOML()")
t.Errorf("expected %d but got %d", 3, len(eggSalad.Ingredients))
}
found := false
for _, i := range eggSalad.Ingredients {
if i.Name == "Rice" {
found = true
break
}
}
if !found {
t.Error("Rice was not loaded in UnmarshalTOML()")
}
// test on a value - must be passed as *
o := menu{}
if _, err := Decode(tomlBlob, &o); err != nil {
log.Fatal(err)
}
}
type menu struct {
Dishes map[string]dish
}
func (m *menu) UnmarshalTOML(p interface{}) error {
m.Dishes = make(map[string]dish)
data, _ := p.(map[string]interface{})
dishes := data["dishes"].(map[string]interface{})
for n, v := range dishes {
if d, ok := v.(map[string]interface{}); ok {
nd := dish{}
nd.UnmarshalTOML(d)
m.Dishes[n] = nd
} else {
return fmt.Errorf("not a dish")
}
}
return nil
}
type dish struct {
Name string
Price float32
Ingredients []ingredient
}
func (d *dish) UnmarshalTOML(p interface{}) error {
data, _ := p.(map[string]interface{})
d.Name, _ = data["name"].(string)
d.Price, _ = data["price"].(float32)
ingredients, _ := data["ingredients"].([]map[string]interface{})
for _, e := range ingredients {
n, _ := interface{}(e).(map[string]interface{})
name, _ := n["name"].(string)
i := ingredient{name}
d.Ingredients = append(d.Ingredients, i)
}
return nil
}
type ingredient struct {
Name string
}
func ExampleMetaData_PrimitiveDecode() {
var md MetaData
var err error
var tomlBlob = `
ranking = ["Springsteen", "J Geils"]
[bands.Springsteen]
started = 1973
albums = ["Greetings", "WIESS", "Born to Run", "Darkness"]
[bands."J Geils"]
started = 1970
albums = ["The J. Geils Band", "Full House", "Blow Your Face Out"]
`
type band struct {
Started int
Albums []string
}
type classics struct {
Ranking []string
Bands map[string]Primitive
}
// Do the initial decode. Reflection is delayed on Primitive values.
var music classics
if md, err = Decode(tomlBlob, &music); err != nil {
log.Fatal(err)
}
// MetaData still includes information on Primitive values.
fmt.Printf("Is `bands.Springsteen` defined? %v\n",
md.IsDefined("bands", "Springsteen"))
// Decode primitive data into Go values.
for _, artist := range music.Ranking {
// A band is a primitive value, so we need to decode it to get a
// real `band` value.
primValue := music.Bands[artist]
var aBand band
if err = md.PrimitiveDecode(primValue, &aBand); err != nil {
log.Fatal(err)
}
fmt.Printf("%s started in %d.\n", artist, aBand.Started)
}
// Check to see if there were any fields left undecoded.
// Note that this won't be empty before decoding the Primitive value!
fmt.Printf("Undecoded: %q\n", md.Undecoded())
// Output:
// Is `bands.Springsteen` defined? true
// Springsteen started in 1973.
// J Geils started in 1970.
// Undecoded: []
}
func ExampleDecode() {
var tomlBlob = `
# Some comments.
[alpha]
ip = "10.0.0.1"
[alpha.config]
Ports = [ 8001, 8002 ]
Location = "Toronto"
Created = 1987-07-05T05:45:00Z
[beta]
ip = "10.0.0.2"
[beta.config]
Ports = [ 9001, 9002 ]
Location = "New Jersey"
Created = 1887-01-05T05:55:00Z
`
type serverConfig struct {
Ports []int
Location string
Created time.Time
}
type server struct {
IP string `toml:"ip"`
Config serverConfig `toml:"config"`
}
type servers map[string]server
var config servers
if _, err := Decode(tomlBlob, &config); err != nil {
log.Fatal(err)
}
for _, name := range []string{"alpha", "beta"} {
s := config[name]
fmt.Printf("Server: %s (ip: %s) in %s created on %s\n",
name, s.IP, s.Config.Location,
s.Config.Created.Format("2006-01-02"))
fmt.Printf("Ports: %v\n", s.Config.Ports)
}
// Output:
// Server: alpha (ip: 10.0.0.1) in Toronto created on 1987-07-05
// Ports: [8001 8002]
// Server: beta (ip: 10.0.0.2) in New Jersey created on 1887-01-05
// Ports: [9001 9002]
}
type duration struct {
time.Duration
}
func (d *duration) UnmarshalText(text []byte) error {
var err error
d.Duration, err = time.ParseDuration(string(text))
return err
}
// Example Unmarshaler shows how to decode TOML strings into your own
// custom data type.
func Example_unmarshaler() {
blob := `
[[song]]
name = "Thunder Road"
duration = "4m49s"
[[song]]
name = "Stairway to Heaven"
duration = "8m03s"
`
type song struct {
Name string
Duration duration
}
type songs struct {
Song []song
}
var favorites songs
if _, err := Decode(blob, &favorites); err != nil {
log.Fatal(err)
}
// Code to implement the TextUnmarshaler interface for `duration`:
//
// type duration struct {
// time.Duration
// }
//
// func (d *duration) UnmarshalText(text []byte) error {
// var err error
// d.Duration, err = time.ParseDuration(string(text))
// return err
// }
for _, s := range favorites.Song {
fmt.Printf("%s (%s)\n", s.Name, s.Duration)
}
// Output:
// Thunder Road (4m49s)
// Stairway to Heaven (8m3s)
}
// Example StrictDecoding shows how to detect whether there are keys in the
// TOML document that weren't decoded into the value given. This is useful
// for returning an error to the user if they've included extraneous fields
// in their configuration.
func Example_strictDecoding() {
var blob = `
key1 = "value1"
key2 = "value2"
key3 = "value3"
`
type config struct {
Key1 string
Key3 string
}
var conf config
md, err := Decode(blob, &conf)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Undecoded keys: %q\n", md.Undecoded())
// Output:
// Undecoded keys: ["key2"]
}
// Example UnmarshalTOML shows how to implement a struct type that knows how to
// unmarshal itself. The struct must take full responsibility for mapping the
// values passed into the struct. The method may be used with interfaces in a
// struct in cases where the actual type is not known until the data is
// examined.
func Example_unmarshalTOML() {
var blob = `
[[parts]]
type = "valve"
id = "valve-1"
size = 1.2
rating = 4
[[parts]]
type = "valve"
id = "valve-2"
size = 2.1
rating = 5
[[parts]]
type = "pipe"
id = "pipe-1"
length = 2.1
diameter = 12
[[parts]]
type = "cable"
id = "cable-1"
length = 12
rating = 3.1
`
o := &order{}
err := Unmarshal([]byte(blob), o)
if err != nil {
log.Fatal(err)
}
fmt.Println(len(o.parts))
for _, part := range o.parts {
fmt.Println(part.Name())
}
// Code to implement UmarshalJSON.
// type order struct {
// // NOTE `order.parts` is a private slice of type `part` which is an
// // interface and may only be loaded from toml using the
// // UnmarshalTOML() method of the Umarshaler interface.
// parts parts
// }
// func (o *order) UnmarshalTOML(data interface{}) error {
// // NOTE the example below contains detailed type casting to show how
// // the 'data' is retrieved. In operational use, a type cast wrapper
// // may be prefered e.g.
// //
// // func AsMap(v interface{}) (map[string]interface{}, error) {
// // return v.(map[string]interface{})
// // }
// //
// // resulting in:
// // d, _ := AsMap(data)
// //
// d, _ := data.(map[string]interface{})
// parts, _ := d["parts"].([]map[string]interface{})
// for _, p := range parts {
// typ, _ := p["type"].(string)
// id, _ := p["id"].(string)
// // detect the type of part and handle each case
// switch p["type"] {
// case "valve":
// size := float32(p["size"].(float64))
// rating := int(p["rating"].(int64))
// valve := &valve{
// Type: typ,
// ID: id,
// Size: size,
// Rating: rating,
// }
// o.parts = append(o.parts, valve)
// case "pipe":
// length := float32(p["length"].(float64))
// diameter := int(p["diameter"].(int64))
// pipe := &pipe{
// Type: typ,
// ID: id,
// Length: length,
// Diameter: diameter,
// }
// o.parts = append(o.parts, pipe)
// case "cable":
// length := int(p["length"].(int64))
// rating := float32(p["rating"].(float64))
// cable := &cable{
// Type: typ,
// ID: id,
// Length: length,
// Rating: rating,
// }
// o.parts = append(o.parts, cable)
// }
// }
// return nil
// }
// type parts []part
// type part interface {
// Name() string
// }
// type valve struct {
// Type string
// ID string
// Size float32
// Rating int
// }
// func (v *valve) Name() string {
// return fmt.Sprintf("VALVE: %s", v.ID)
// }
// type pipe struct {
// Type string
// ID string
// Length float32
// Diameter int
// }
// func (p *pipe) Name() string {
// return fmt.Sprintf("PIPE: %s", p.ID)
// }
// type cable struct {
// Type string
// ID string
// Length int
// Rating float32
// }
// func (c *cable) Name() string {
// return fmt.Sprintf("CABLE: %s", c.ID)
// }
// Output:
// 4
// VALVE: valve-1
// VALVE: valve-2
// PIPE: pipe-1
// CABLE: cable-1
}
type order struct {
// NOTE `order.parts` is a private slice of type `part` which is an
// interface and may only be loaded from toml using the UnmarshalTOML()
// method of the Umarshaler interface.
parts parts
}
func (o *order) UnmarshalTOML(data interface{}) error {
// NOTE the example below contains detailed type casting to show how
// the 'data' is retrieved. In operational use, a type cast wrapper
// may be prefered e.g.
//
// func AsMap(v interface{}) (map[string]interface{}, error) {
// return v.(map[string]interface{})
// }
//
// resulting in:
// d, _ := AsMap(data)
//
d, _ := data.(map[string]interface{})
parts, _ := d["parts"].([]map[string]interface{})
for _, p := range parts {
typ, _ := p["type"].(string)
id, _ := p["id"].(string)
// detect the type of part and handle each case
switch p["type"] {
case "valve":
size := float32(p["size"].(float64))
rating := int(p["rating"].(int64))
valve := &valve{
Type: typ,
ID: id,
Size: size,
Rating: rating,
}
o.parts = append(o.parts, valve)
case "pipe":
length := float32(p["length"].(float64))
diameter := int(p["diameter"].(int64))
pipe := &pipe{
Type: typ,
ID: id,
Length: length,
Diameter: diameter,
}
o.parts = append(o.parts, pipe)
case "cable":
length := int(p["length"].(int64))
rating := float32(p["rating"].(float64))
cable := &cable{
Type: typ,
ID: id,
Length: length,
Rating: rating,
}
o.parts = append(o.parts, cable)
}
}
return nil
}
type parts []part
type part interface {
Name() string
}
type valve struct {
Type string
ID string
Size float32
Rating int
}
func (v *valve) Name() string {
return fmt.Sprintf("VALVE: %s", v.ID)
}
type pipe struct {
Type string
ID string
Length float32
Diameter int
}
func (p *pipe) Name() string {
return fmt.Sprintf("PIPE: %s", p.ID)
}
type cable struct {
Type string
ID string
Length int
Rating float32
}
func (c *cable) Name() string {
return fmt.Sprintf("CABLE: %s", c.ID)
}

View File

@@ -0,0 +1,27 @@
/*
Package toml provides facilities for decoding and encoding TOML configuration
files via reflection. There is also support for delaying decoding with
the Primitive type, and querying the set of keys in a TOML document with the
MetaData type.
The specification implemented: https://github.com/mojombo/toml
The sub-command github.com/BurntSushi/toml/cmd/tomlv can be used to verify
whether a file is a valid TOML document. It can also be used to print the
type of each key in a TOML document.
Testing
There are two important types of tests used for this package. The first is
contained inside '*_test.go' files and uses the standard Go unit testing
framework. These tests are primarily devoted to holistically testing the
decoder and encoder.
The second type of testing is used to verify the implementation's adherence
to the TOML specification. These tests have been factored into their own
project: https://github.com/BurntSushi/toml-test
The reason the tests are in a separate project is so that they can be used by
any implementation of TOML. Namely, it is language agnostic.
*/
package toml

View File

@@ -0,0 +1,551 @@
package toml
import (
"bufio"
"errors"
"fmt"
"io"
"reflect"
"sort"
"strconv"
"strings"
"time"
)
type tomlEncodeError struct{ error }
var (
errArrayMixedElementTypes = errors.New(
"can't encode array with mixed element types")
errArrayNilElement = errors.New(
"can't encode array with nil element")
errNonString = errors.New(
"can't encode a map with non-string key type")
errAnonNonStruct = errors.New(
"can't encode an anonymous field that is not a struct")
errArrayNoTable = errors.New(
"TOML array element can't contain a table")
errNoKey = errors.New(
"top-level values must be a Go map or struct")
errAnything = errors.New("") // used in testing
)
var quotedReplacer = strings.NewReplacer(
"\t", "\\t",
"\n", "\\n",
"\r", "\\r",
"\"", "\\\"",
"\\", "\\\\",
)
// Encoder controls the encoding of Go values to a TOML document to some
// io.Writer.
//
// The indentation level can be controlled with the Indent field.
type Encoder struct {
// A single indentation level. By default it is two spaces.
Indent string
// hasWritten is whether we have written any output to w yet.
hasWritten bool
w *bufio.Writer
}
// NewEncoder returns a TOML encoder that encodes Go values to the io.Writer
// given. By default, a single indentation level is 2 spaces.
func NewEncoder(w io.Writer) *Encoder {
return &Encoder{
w: bufio.NewWriter(w),
Indent: " ",
}
}
// Encode writes a TOML representation of the Go value to the underlying
// io.Writer. If the value given cannot be encoded to a valid TOML document,
// then an error is returned.
//
// The mapping between Go values and TOML values should be precisely the same
// as for the Decode* functions. Similarly, the TextMarshaler interface is
// supported by encoding the resulting bytes as strings. (If you want to write
// arbitrary binary data then you will need to use something like base64 since
// TOML does not have any binary types.)
//
// When encoding TOML hashes (i.e., Go maps or structs), keys without any
// sub-hashes are encoded first.
//
// If a Go map is encoded, then its keys are sorted alphabetically for
// deterministic output. More control over this behavior may be provided if
// there is demand for it.
//
// Encoding Go values without a corresponding TOML representation---like map
// types with non-string keys---will cause an error to be returned. Similarly
// for mixed arrays/slices, arrays/slices with nil elements, embedded
// non-struct types and nested slices containing maps or structs.
// (e.g., [][]map[string]string is not allowed but []map[string]string is OK
// and so is []map[string][]string.)
func (enc *Encoder) Encode(v interface{}) error {
rv := eindirect(reflect.ValueOf(v))
if err := enc.safeEncode(Key([]string{}), rv); err != nil {
return err
}
return enc.w.Flush()
}
func (enc *Encoder) safeEncode(key Key, rv reflect.Value) (err error) {
defer func() {
if r := recover(); r != nil {
if terr, ok := r.(tomlEncodeError); ok {
err = terr.error
return
}
panic(r)
}
}()
enc.encode(key, rv)
return nil
}
func (enc *Encoder) encode(key Key, rv reflect.Value) {
// Special case. Time needs to be in ISO8601 format.
// Special case. If we can marshal the type to text, then we used that.
// Basically, this prevents the encoder for handling these types as
// generic structs (or whatever the underlying type of a TextMarshaler is).
switch rv.Interface().(type) {
case time.Time, TextMarshaler:
enc.keyEqElement(key, rv)
return
}
k := rv.Kind()
switch k {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
reflect.Uint64,
reflect.Float32, reflect.Float64, reflect.String, reflect.Bool:
enc.keyEqElement(key, rv)
case reflect.Array, reflect.Slice:
if typeEqual(tomlArrayHash, tomlTypeOfGo(rv)) {
enc.eArrayOfTables(key, rv)
} else {
enc.keyEqElement(key, rv)
}
case reflect.Interface:
if rv.IsNil() {
return
}
enc.encode(key, rv.Elem())
case reflect.Map:
if rv.IsNil() {
return
}
enc.eTable(key, rv)
case reflect.Ptr:
if rv.IsNil() {
return
}
enc.encode(key, rv.Elem())
case reflect.Struct:
enc.eTable(key, rv)
default:
panic(e("Unsupported type for key '%s': %s", key, k))
}
}
// eElement encodes any value that can be an array element (primitives and
// arrays).
func (enc *Encoder) eElement(rv reflect.Value) {
switch v := rv.Interface().(type) {
case time.Time:
// Special case time.Time as a primitive. Has to come before
// TextMarshaler below because time.Time implements
// encoding.TextMarshaler, but we need to always use UTC.
enc.wf(v.In(time.FixedZone("UTC", 0)).Format("2006-01-02T15:04:05Z"))
return
case TextMarshaler:
// Special case. Use text marshaler if it's available for this value.
if s, err := v.MarshalText(); err != nil {
encPanic(err)
} else {
enc.writeQuoted(string(s))
}
return
}
switch rv.Kind() {
case reflect.Bool:
enc.wf(strconv.FormatBool(rv.Bool()))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Int64:
enc.wf(strconv.FormatInt(rv.Int(), 10))
case reflect.Uint, reflect.Uint8, reflect.Uint16,
reflect.Uint32, reflect.Uint64:
enc.wf(strconv.FormatUint(rv.Uint(), 10))
case reflect.Float32:
enc.wf(floatAddDecimal(strconv.FormatFloat(rv.Float(), 'f', -1, 32)))
case reflect.Float64:
enc.wf(floatAddDecimal(strconv.FormatFloat(rv.Float(), 'f', -1, 64)))
case reflect.Array, reflect.Slice:
enc.eArrayOrSliceElement(rv)
case reflect.Interface:
enc.eElement(rv.Elem())
case reflect.String:
enc.writeQuoted(rv.String())
default:
panic(e("Unexpected primitive type: %s", rv.Kind()))
}
}
// By the TOML spec, all floats must have a decimal with at least one
// number on either side.
func floatAddDecimal(fstr string) string {
if !strings.Contains(fstr, ".") {
return fstr + ".0"
}
return fstr
}
func (enc *Encoder) writeQuoted(s string) {
enc.wf("\"%s\"", quotedReplacer.Replace(s))
}
func (enc *Encoder) eArrayOrSliceElement(rv reflect.Value) {
length := rv.Len()
enc.wf("[")
for i := 0; i < length; i++ {
elem := rv.Index(i)
enc.eElement(elem)
if i != length-1 {
enc.wf(", ")
}
}
enc.wf("]")
}
func (enc *Encoder) eArrayOfTables(key Key, rv reflect.Value) {
if len(key) == 0 {
encPanic(errNoKey)
}
for i := 0; i < rv.Len(); i++ {
trv := rv.Index(i)
if isNil(trv) {
continue
}
panicIfInvalidKey(key)
enc.newline()
enc.wf("%s[[%s]]", enc.indentStr(key), key.maybeQuotedAll())
enc.newline()
enc.eMapOrStruct(key, trv)
}
}
func (enc *Encoder) eTable(key Key, rv reflect.Value) {
panicIfInvalidKey(key)
if len(key) == 1 {
// Output an extra new line between top-level tables.
// (The newline isn't written if nothing else has been written though.)
enc.newline()
}
if len(key) > 0 {
enc.wf("%s[%s]", enc.indentStr(key), key.maybeQuotedAll())
enc.newline()
}
enc.eMapOrStruct(key, rv)
}
func (enc *Encoder) eMapOrStruct(key Key, rv reflect.Value) {
switch rv := eindirect(rv); rv.Kind() {
case reflect.Map:
enc.eMap(key, rv)
case reflect.Struct:
enc.eStruct(key, rv)
default:
panic("eTable: unhandled reflect.Value Kind: " + rv.Kind().String())
}
}
func (enc *Encoder) eMap(key Key, rv reflect.Value) {
rt := rv.Type()
if rt.Key().Kind() != reflect.String {
encPanic(errNonString)
}
// Sort keys so that we have deterministic output. And write keys directly
// underneath this key first, before writing sub-structs or sub-maps.
var mapKeysDirect, mapKeysSub []string
for _, mapKey := range rv.MapKeys() {
k := mapKey.String()
if typeIsHash(tomlTypeOfGo(rv.MapIndex(mapKey))) {
mapKeysSub = append(mapKeysSub, k)
} else {
mapKeysDirect = append(mapKeysDirect, k)
}
}
var writeMapKeys = func(mapKeys []string) {
sort.Strings(mapKeys)
for _, mapKey := range mapKeys {
mrv := rv.MapIndex(reflect.ValueOf(mapKey))
if isNil(mrv) {
// Don't write anything for nil fields.
continue
}
enc.encode(key.add(mapKey), mrv)
}
}
writeMapKeys(mapKeysDirect)
writeMapKeys(mapKeysSub)
}
func (enc *Encoder) eStruct(key Key, rv reflect.Value) {
// Write keys for fields directly under this key first, because if we write
// a field that creates a new table, then all keys under it will be in that
// table (not the one we're writing here).
rt := rv.Type()
var fieldsDirect, fieldsSub [][]int
var addFields func(rt reflect.Type, rv reflect.Value, start []int)
addFields = func(rt reflect.Type, rv reflect.Value, start []int) {
for i := 0; i < rt.NumField(); i++ {
f := rt.Field(i)
// skip unexporded fields
if f.PkgPath != "" {
continue
}
frv := rv.Field(i)
if f.Anonymous {
frv := eindirect(frv)
t := frv.Type()
if t.Kind() != reflect.Struct {
encPanic(errAnonNonStruct)
}
addFields(t, frv, f.Index)
} else if typeIsHash(tomlTypeOfGo(frv)) {
fieldsSub = append(fieldsSub, append(start, f.Index...))
} else {
fieldsDirect = append(fieldsDirect, append(start, f.Index...))
}
}
}
addFields(rt, rv, nil)
var writeFields = func(fields [][]int) {
for _, fieldIndex := range fields {
sft := rt.FieldByIndex(fieldIndex)
sf := rv.FieldByIndex(fieldIndex)
if isNil(sf) {
// Don't write anything for nil fields.
continue
}
keyName := sft.Tag.Get("toml")
if keyName == "-" {
continue
}
if keyName == "" {
keyName = sft.Name
}
keyName, opts := getOptions(keyName)
if _, ok := opts["omitempty"]; ok && isEmpty(sf) {
continue
} else if _, ok := opts["omitzero"]; ok && isZero(sf) {
continue
}
enc.encode(key.add(keyName), sf)
}
}
writeFields(fieldsDirect)
writeFields(fieldsSub)
}
// tomlTypeName returns the TOML type name of the Go value's type. It is
// used to determine whether the types of array elements are mixed (which is
// forbidden). If the Go value is nil, then it is illegal for it to be an array
// element, and valueIsNil is returned as true.
// Returns the TOML type of a Go value. The type may be `nil`, which means
// no concrete TOML type could be found.
func tomlTypeOfGo(rv reflect.Value) tomlType {
if isNil(rv) || !rv.IsValid() {
return nil
}
switch rv.Kind() {
case reflect.Bool:
return tomlBool
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
reflect.Uint64:
return tomlInteger
case reflect.Float32, reflect.Float64:
return tomlFloat
case reflect.Array, reflect.Slice:
if typeEqual(tomlHash, tomlArrayType(rv)) {
return tomlArrayHash
} else {
return tomlArray
}
case reflect.Ptr, reflect.Interface:
return tomlTypeOfGo(rv.Elem())
case reflect.String:
return tomlString
case reflect.Map:
return tomlHash
case reflect.Struct:
switch rv.Interface().(type) {
case time.Time:
return tomlDatetime
case TextMarshaler:
return tomlString
default:
return tomlHash
}
default:
panic("unexpected reflect.Kind: " + rv.Kind().String())
}
}
// tomlArrayType returns the element type of a TOML array. The type returned
// may be nil if it cannot be determined (e.g., a nil slice or a zero length
// slize). This function may also panic if it finds a type that cannot be
// expressed in TOML (such as nil elements, heterogeneous arrays or directly
// nested arrays of tables).
func tomlArrayType(rv reflect.Value) tomlType {
if isNil(rv) || !rv.IsValid() || rv.Len() == 0 {
return nil
}
firstType := tomlTypeOfGo(rv.Index(0))
if firstType == nil {
encPanic(errArrayNilElement)
}
rvlen := rv.Len()
for i := 1; i < rvlen; i++ {
elem := rv.Index(i)
switch elemType := tomlTypeOfGo(elem); {
case elemType == nil:
encPanic(errArrayNilElement)
case !typeEqual(firstType, elemType):
encPanic(errArrayMixedElementTypes)
}
}
// If we have a nested array, then we must make sure that the nested
// array contains ONLY primitives.
// This checks arbitrarily nested arrays.
if typeEqual(firstType, tomlArray) || typeEqual(firstType, tomlArrayHash) {
nest := tomlArrayType(eindirect(rv.Index(0)))
if typeEqual(nest, tomlHash) || typeEqual(nest, tomlArrayHash) {
encPanic(errArrayNoTable)
}
}
return firstType
}
func getOptions(keyName string) (string, map[string]struct{}) {
opts := make(map[string]struct{})
ss := strings.Split(keyName, ",")
name := ss[0]
if len(ss) > 1 {
for _, opt := range ss {
opts[opt] = struct{}{}
}
}
return name, opts
}
func isZero(rv reflect.Value) bool {
switch rv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if rv.Int() == 0 {
return true
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if rv.Uint() == 0 {
return true
}
case reflect.Float32, reflect.Float64:
if rv.Float() == 0.0 {
return true
}
}
return false
}
func isEmpty(rv reflect.Value) bool {
switch rv.Kind() {
case reflect.String:
if len(strings.TrimSpace(rv.String())) == 0 {
return true
}
case reflect.Array, reflect.Slice, reflect.Map:
if rv.Len() == 0 {
return true
}
}
return false
}
func (enc *Encoder) newline() {
if enc.hasWritten {
enc.wf("\n")
}
}
func (enc *Encoder) keyEqElement(key Key, val reflect.Value) {
if len(key) == 0 {
encPanic(errNoKey)
}
panicIfInvalidKey(key)
enc.wf("%s%s = ", enc.indentStr(key), key.maybeQuoted(len(key)-1))
enc.eElement(val)
enc.newline()
}
func (enc *Encoder) wf(format string, v ...interface{}) {
if _, err := fmt.Fprintf(enc.w, format, v...); err != nil {
encPanic(err)
}
enc.hasWritten = true
}
func (enc *Encoder) indentStr(key Key) string {
return strings.Repeat(enc.Indent, len(key)-1)
}
func encPanic(err error) {
panic(tomlEncodeError{err})
}
func eindirect(v reflect.Value) reflect.Value {
switch v.Kind() {
case reflect.Ptr, reflect.Interface:
return eindirect(v.Elem())
default:
return v
}
}
func isNil(rv reflect.Value) bool {
switch rv.Kind() {
case reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
return rv.IsNil()
default:
return false
}
}
func panicIfInvalidKey(key Key) {
for _, k := range key {
if len(k) == 0 {
encPanic(e("Key '%s' is not a valid table name. Key names "+
"cannot be empty.", key.maybeQuotedAll()))
}
}
}
func isValidKeyName(s string) bool {
return len(s) != 0
}

View File

@@ -0,0 +1,542 @@
package toml
import (
"bytes"
"fmt"
"log"
"net"
"testing"
"time"
)
func TestEncodeRoundTrip(t *testing.T) {
type Config struct {
Age int
Cats []string
Pi float64
Perfection []int
DOB time.Time
Ipaddress net.IP
}
var inputs = Config{
13,
[]string{"one", "two", "three"},
3.145,
[]int{11, 2, 3, 4},
time.Now(),
net.ParseIP("192.168.59.254"),
}
var firstBuffer bytes.Buffer
e := NewEncoder(&firstBuffer)
err := e.Encode(inputs)
if err != nil {
t.Fatal(err)
}
var outputs Config
if _, err := Decode(firstBuffer.String(), &outputs); err != nil {
log.Printf("Could not decode:\n-----\n%s\n-----\n",
firstBuffer.String())
t.Fatal(err)
}
// could test each value individually, but I'm lazy
var secondBuffer bytes.Buffer
e2 := NewEncoder(&secondBuffer)
err = e2.Encode(outputs)
if err != nil {
t.Fatal(err)
}
if firstBuffer.String() != secondBuffer.String() {
t.Error(
firstBuffer.String(),
"\n\n is not identical to\n\n",
secondBuffer.String())
}
}
// XXX(burntsushi)
// I think these tests probably should be removed. They are good, but they
// ought to be obsolete by toml-test.
func TestEncode(t *testing.T) {
type Embedded struct {
Int int `toml:"_int"`
}
type NonStruct int
date := time.Date(2014, 5, 11, 20, 30, 40, 0, time.FixedZone("IST", 3600))
dateStr := "2014-05-11T19:30:40Z"
tests := map[string]struct {
input interface{}
wantOutput string
wantError error
}{
"bool field": {
input: struct {
BoolTrue bool
BoolFalse bool
}{true, false},
wantOutput: "BoolTrue = true\nBoolFalse = false\n",
},
"int fields": {
input: struct {
Int int
Int8 int8
Int16 int16
Int32 int32
Int64 int64
}{1, 2, 3, 4, 5},
wantOutput: "Int = 1\nInt8 = 2\nInt16 = 3\nInt32 = 4\nInt64 = 5\n",
},
"uint fields": {
input: struct {
Uint uint
Uint8 uint8
Uint16 uint16
Uint32 uint32
Uint64 uint64
}{1, 2, 3, 4, 5},
wantOutput: "Uint = 1\nUint8 = 2\nUint16 = 3\nUint32 = 4" +
"\nUint64 = 5\n",
},
"float fields": {
input: struct {
Float32 float32
Float64 float64
}{1.5, 2.5},
wantOutput: "Float32 = 1.5\nFloat64 = 2.5\n",
},
"string field": {
input: struct{ String string }{"foo"},
wantOutput: "String = \"foo\"\n",
},
"string field and unexported field": {
input: struct {
String string
unexported int
}{"foo", 0},
wantOutput: "String = \"foo\"\n",
},
"datetime field in UTC": {
input: struct{ Date time.Time }{date},
wantOutput: fmt.Sprintf("Date = %s\n", dateStr),
},
"datetime field as primitive": {
// Using a map here to fail if isStructOrMap() returns true for
// time.Time.
input: map[string]interface{}{
"Date": date,
"Int": 1,
},
wantOutput: fmt.Sprintf("Date = %s\nInt = 1\n", dateStr),
},
"array fields": {
input: struct {
IntArray0 [0]int
IntArray3 [3]int
}{[0]int{}, [3]int{1, 2, 3}},
wantOutput: "IntArray0 = []\nIntArray3 = [1, 2, 3]\n",
},
"slice fields": {
input: struct{ IntSliceNil, IntSlice0, IntSlice3 []int }{
nil, []int{}, []int{1, 2, 3},
},
wantOutput: "IntSlice0 = []\nIntSlice3 = [1, 2, 3]\n",
},
"datetime slices": {
input: struct{ DatetimeSlice []time.Time }{
[]time.Time{date, date},
},
wantOutput: fmt.Sprintf("DatetimeSlice = [%s, %s]\n",
dateStr, dateStr),
},
"nested arrays and slices": {
input: struct {
SliceOfArrays [][2]int
ArrayOfSlices [2][]int
SliceOfArraysOfSlices [][2][]int
ArrayOfSlicesOfArrays [2][][2]int
SliceOfMixedArrays [][2]interface{}
ArrayOfMixedSlices [2][]interface{}
}{
[][2]int{{1, 2}, {3, 4}},
[2][]int{{1, 2}, {3, 4}},
[][2][]int{
{
{1, 2}, {3, 4},
},
{
{5, 6}, {7, 8},
},
},
[2][][2]int{
{
{1, 2}, {3, 4},
},
{
{5, 6}, {7, 8},
},
},
[][2]interface{}{
{1, 2}, {"a", "b"},
},
[2][]interface{}{
{1, 2}, {"a", "b"},
},
},
wantOutput: `SliceOfArrays = [[1, 2], [3, 4]]
ArrayOfSlices = [[1, 2], [3, 4]]
SliceOfArraysOfSlices = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
ArrayOfSlicesOfArrays = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
SliceOfMixedArrays = [[1, 2], ["a", "b"]]
ArrayOfMixedSlices = [[1, 2], ["a", "b"]]
`,
},
"empty slice": {
input: struct{ Empty []interface{} }{[]interface{}{}},
wantOutput: "Empty = []\n",
},
"(error) slice with element type mismatch (string and integer)": {
input: struct{ Mixed []interface{} }{[]interface{}{1, "a"}},
wantError: errArrayMixedElementTypes,
},
"(error) slice with element type mismatch (integer and float)": {
input: struct{ Mixed []interface{} }{[]interface{}{1, 2.5}},
wantError: errArrayMixedElementTypes,
},
"slice with elems of differing Go types, same TOML types": {
input: struct {
MixedInts []interface{}
MixedFloats []interface{}
}{
[]interface{}{
int(1), int8(2), int16(3), int32(4), int64(5),
uint(1), uint8(2), uint16(3), uint32(4), uint64(5),
},
[]interface{}{float32(1.5), float64(2.5)},
},
wantOutput: "MixedInts = [1, 2, 3, 4, 5, 1, 2, 3, 4, 5]\n" +
"MixedFloats = [1.5, 2.5]\n",
},
"(error) slice w/ element type mismatch (one is nested array)": {
input: struct{ Mixed []interface{} }{
[]interface{}{1, []interface{}{2}},
},
wantError: errArrayMixedElementTypes,
},
"(error) slice with 1 nil element": {
input: struct{ NilElement1 []interface{} }{[]interface{}{nil}},
wantError: errArrayNilElement,
},
"(error) slice with 1 nil element (and other non-nil elements)": {
input: struct{ NilElement []interface{} }{
[]interface{}{1, nil},
},
wantError: errArrayNilElement,
},
"simple map": {
input: map[string]int{"a": 1, "b": 2},
wantOutput: "a = 1\nb = 2\n",
},
"map with interface{} value type": {
input: map[string]interface{}{"a": 1, "b": "c"},
wantOutput: "a = 1\nb = \"c\"\n",
},
"map with interface{} value type, some of which are structs": {
input: map[string]interface{}{
"a": struct{ Int int }{2},
"b": 1,
},
wantOutput: "b = 1\n\n[a]\n Int = 2\n",
},
"nested map": {
input: map[string]map[string]int{
"a": {"b": 1},
"c": {"d": 2},
},
wantOutput: "[a]\n b = 1\n\n[c]\n d = 2\n",
},
"nested struct": {
input: struct{ Struct struct{ Int int } }{
struct{ Int int }{1},
},
wantOutput: "[Struct]\n Int = 1\n",
},
"nested struct and non-struct field": {
input: struct {
Struct struct{ Int int }
Bool bool
}{struct{ Int int }{1}, true},
wantOutput: "Bool = true\n\n[Struct]\n Int = 1\n",
},
"2 nested structs": {
input: struct{ Struct1, Struct2 struct{ Int int } }{
struct{ Int int }{1}, struct{ Int int }{2},
},
wantOutput: "[Struct1]\n Int = 1\n\n[Struct2]\n Int = 2\n",
},
"deeply nested structs": {
input: struct {
Struct1, Struct2 struct{ Struct3 *struct{ Int int } }
}{
struct{ Struct3 *struct{ Int int } }{&struct{ Int int }{1}},
struct{ Struct3 *struct{ Int int } }{nil},
},
wantOutput: "[Struct1]\n [Struct1.Struct3]\n Int = 1" +
"\n\n[Struct2]\n",
},
"nested struct with nil struct elem": {
input: struct {
Struct struct{ Inner *struct{ Int int } }
}{
struct{ Inner *struct{ Int int } }{nil},
},
wantOutput: "[Struct]\n",
},
"nested struct with no fields": {
input: struct {
Struct struct{ Inner struct{} }
}{
struct{ Inner struct{} }{struct{}{}},
},
wantOutput: "[Struct]\n [Struct.Inner]\n",
},
"struct with tags": {
input: struct {
Struct struct {
Int int `toml:"_int"`
} `toml:"_struct"`
Bool bool `toml:"_bool"`
}{
struct {
Int int `toml:"_int"`
}{1}, true,
},
wantOutput: "_bool = true\n\n[_struct]\n _int = 1\n",
},
"embedded struct": {
input: struct{ Embedded }{Embedded{1}},
wantOutput: "_int = 1\n",
},
"embedded *struct": {
input: struct{ *Embedded }{&Embedded{1}},
wantOutput: "_int = 1\n",
},
"nested embedded struct": {
input: struct {
Struct struct{ Embedded } `toml:"_struct"`
}{struct{ Embedded }{Embedded{1}}},
wantOutput: "[_struct]\n _int = 1\n",
},
"nested embedded *struct": {
input: struct {
Struct struct{ *Embedded } `toml:"_struct"`
}{struct{ *Embedded }{&Embedded{1}}},
wantOutput: "[_struct]\n _int = 1\n",
},
"array of tables": {
input: struct {
Structs []*struct{ Int int } `toml:"struct"`
}{
[]*struct{ Int int }{{1}, {3}},
},
wantOutput: "[[struct]]\n Int = 1\n\n[[struct]]\n Int = 3\n",
},
"array of tables order": {
input: map[string]interface{}{
"map": map[string]interface{}{
"zero": 5,
"arr": []map[string]int{
map[string]int{
"friend": 5,
},
},
},
},
wantOutput: "[map]\n zero = 5\n\n [[map.arr]]\n friend = 5\n",
},
"(error) top-level slice": {
input: []struct{ Int int }{{1}, {2}, {3}},
wantError: errNoKey,
},
"(error) slice of slice": {
input: struct {
Slices [][]struct{ Int int }
}{
[][]struct{ Int int }{{{1}}, {{2}}, {{3}}},
},
wantError: errArrayNoTable,
},
"(error) map no string key": {
input: map[int]string{1: ""},
wantError: errNonString,
},
"(error) anonymous non-struct": {
input: struct{ NonStruct }{5},
wantError: errAnonNonStruct,
},
"(error) empty key name": {
input: map[string]int{"": 1},
wantError: errAnything,
},
"(error) empty map name": {
input: map[string]interface{}{
"": map[string]int{"v": 1},
},
wantError: errAnything,
},
}
for label, test := range tests {
encodeExpected(t, label, test.input, test.wantOutput, test.wantError)
}
}
func TestEncodeNestedTableArrays(t *testing.T) {
type song struct {
Name string `toml:"name"`
}
type album struct {
Name string `toml:"name"`
Songs []song `toml:"songs"`
}
type springsteen struct {
Albums []album `toml:"albums"`
}
value := springsteen{
[]album{
{"Born to Run",
[]song{{"Jungleland"}, {"Meeting Across the River"}}},
{"Born in the USA",
[]song{{"Glory Days"}, {"Dancing in the Dark"}}},
},
}
expected := `[[albums]]
name = "Born to Run"
[[albums.songs]]
name = "Jungleland"
[[albums.songs]]
name = "Meeting Across the River"
[[albums]]
name = "Born in the USA"
[[albums.songs]]
name = "Glory Days"
[[albums.songs]]
name = "Dancing in the Dark"
`
encodeExpected(t, "nested table arrays", value, expected, nil)
}
func TestEncodeArrayHashWithNormalHashOrder(t *testing.T) {
type Alpha struct {
V int
}
type Beta struct {
V int
}
type Conf struct {
V int
A Alpha
B []Beta
}
val := Conf{
V: 1,
A: Alpha{2},
B: []Beta{{3}},
}
expected := "V = 1\n\n[A]\n V = 2\n\n[[B]]\n V = 3\n"
encodeExpected(t, "array hash with normal hash order", val, expected, nil)
}
func TestEncodeWithOmitEmpty(t *testing.T) {
type simple struct {
User string `toml:"user"`
Pass string `toml:"password,omitempty"`
}
value := simple{"Testing", ""}
expected := fmt.Sprintf("user = %q\n", value.User)
encodeExpected(t, "simple with omitempty, is empty", value, expected, nil)
value.Pass = "some password"
expected = fmt.Sprintf("user = %q\npassword = %q\n", value.User, value.Pass)
encodeExpected(t, "simple with omitempty, not empty", value, expected, nil)
}
func TestEncodeWithOmitZero(t *testing.T) {
type simple struct {
Number int `toml:"number,omitzero"`
Real float64 `toml:"real,omitzero"`
Unsigned uint `toml:"unsigned,omitzero"`
}
value := simple{0, 0.0, uint(0)}
expected := ""
encodeExpected(t, "simple with omitzero, all zero", value, expected, nil)
value.Number = 10
value.Real = 20
value.Unsigned = 5
expected = `number = 10
real = 20.0
unsigned = 5
`
encodeExpected(t, "simple with omitzero, non-zero", value, expected, nil)
}
func encodeExpected(
t *testing.T, label string, val interface{}, wantStr string, wantErr error,
) {
var buf bytes.Buffer
enc := NewEncoder(&buf)
err := enc.Encode(val)
if err != wantErr {
if wantErr != nil {
if wantErr == errAnything && err != nil {
return
}
t.Errorf("%s: want Encode error %v, got %v", label, wantErr, err)
} else {
t.Errorf("%s: Encode failed: %s", label, err)
}
}
if err != nil {
return
}
if got := buf.String(); wantStr != got {
t.Errorf("%s: want\n-----\n%q\n-----\nbut got\n-----\n%q\n-----\n",
label, wantStr, got)
}
}
func ExampleEncoder_Encode() {
date, _ := time.Parse(time.RFC822, "14 Mar 10 18:00 UTC")
var config = map[string]interface{}{
"date": date,
"counts": []int{1, 1, 2, 3, 5, 8},
"hash": map[string]string{
"key1": "val1",
"key2": "val2",
},
}
buf := new(bytes.Buffer)
if err := NewEncoder(buf).Encode(config); err != nil {
log.Fatal(err)
}
fmt.Println(buf.String())
// Output:
// counts = [1, 1, 2, 3, 5, 8]
// date = 2010-03-14T18:00:00Z
//
// [hash]
// key1 = "val1"
// key2 = "val2"
}

View File

@@ -0,0 +1,19 @@
// +build go1.2
package toml
// In order to support Go 1.1, we define our own TextMarshaler and
// TextUnmarshaler types. For Go 1.2+, we just alias them with the
// standard library interfaces.
import (
"encoding"
)
// TextMarshaler is a synonym for encoding.TextMarshaler. It is defined here
// so that Go 1.1 can be supported.
type TextMarshaler encoding.TextMarshaler
// TextUnmarshaler is a synonym for encoding.TextUnmarshaler. It is defined
// here so that Go 1.1 can be supported.
type TextUnmarshaler encoding.TextUnmarshaler

View File

@@ -0,0 +1,18 @@
// +build !go1.2
package toml
// These interfaces were introduced in Go 1.2, so we add them manually when
// compiling for Go 1.1.
// TextMarshaler is a synonym for encoding.TextMarshaler. It is defined here
// so that Go 1.1 can be supported.
type TextMarshaler interface {
MarshalText() (text []byte, err error)
}
// TextUnmarshaler is a synonym for encoding.TextUnmarshaler. It is defined
// here so that Go 1.1 can be supported.
type TextUnmarshaler interface {
UnmarshalText(text []byte) error
}

874
Godeps/_workspace/src/github.com/BurntSushi/toml/lex.go generated vendored Normal file
View File

@@ -0,0 +1,874 @@
package toml
import (
"fmt"
"strings"
"unicode/utf8"
)
type itemType int
const (
itemError itemType = iota
itemNIL // used in the parser to indicate no type
itemEOF
itemText
itemString
itemRawString
itemMultilineString
itemRawMultilineString
itemBool
itemInteger
itemFloat
itemDatetime
itemArray // the start of an array
itemArrayEnd
itemTableStart
itemTableEnd
itemArrayTableStart
itemArrayTableEnd
itemKeyStart
itemCommentStart
)
const (
eof = 0
tableStart = '['
tableEnd = ']'
arrayTableStart = '['
arrayTableEnd = ']'
tableSep = '.'
keySep = '='
arrayStart = '['
arrayEnd = ']'
arrayValTerm = ','
commentStart = '#'
stringStart = '"'
stringEnd = '"'
rawStringStart = '\''
rawStringEnd = '\''
)
type stateFn func(lx *lexer) stateFn
type lexer struct {
input string
start int
pos int
width int
line int
state stateFn
items chan item
// A stack of state functions used to maintain context.
// The idea is to reuse parts of the state machine in various places.
// For example, values can appear at the top level or within arbitrarily
// nested arrays. The last state on the stack is used after a value has
// been lexed. Similarly for comments.
stack []stateFn
}
type item struct {
typ itemType
val string
line int
}
func (lx *lexer) nextItem() item {
for {
select {
case item := <-lx.items:
return item
default:
lx.state = lx.state(lx)
}
}
}
func lex(input string) *lexer {
lx := &lexer{
input: input + "\n",
state: lexTop,
line: 1,
items: make(chan item, 10),
stack: make([]stateFn, 0, 10),
}
return lx
}
func (lx *lexer) push(state stateFn) {
lx.stack = append(lx.stack, state)
}
func (lx *lexer) pop() stateFn {
if len(lx.stack) == 0 {
return lx.errorf("BUG in lexer: no states to pop.")
}
last := lx.stack[len(lx.stack)-1]
lx.stack = lx.stack[0 : len(lx.stack)-1]
return last
}
func (lx *lexer) current() string {
return lx.input[lx.start:lx.pos]
}
func (lx *lexer) emit(typ itemType) {
lx.items <- item{typ, lx.current(), lx.line}
lx.start = lx.pos
}
func (lx *lexer) emitTrim(typ itemType) {
lx.items <- item{typ, strings.TrimSpace(lx.current()), lx.line}
lx.start = lx.pos
}
func (lx *lexer) next() (r rune) {
if lx.pos >= len(lx.input) {
lx.width = 0
return eof
}
if lx.input[lx.pos] == '\n' {
lx.line++
}
r, lx.width = utf8.DecodeRuneInString(lx.input[lx.pos:])
lx.pos += lx.width
return r
}
// ignore skips over the pending input before this point.
func (lx *lexer) ignore() {
lx.start = lx.pos
}
// backup steps back one rune. Can be called only once per call of next.
func (lx *lexer) backup() {
lx.pos -= lx.width
if lx.pos < len(lx.input) && lx.input[lx.pos] == '\n' {
lx.line--
}
}
// accept consumes the next rune if it's equal to `valid`.
func (lx *lexer) accept(valid rune) bool {
if lx.next() == valid {
return true
}
lx.backup()
return false
}
// peek returns but does not consume the next rune in the input.
func (lx *lexer) peek() rune {
r := lx.next()
lx.backup()
return r
}
// errorf stops all lexing by emitting an error and returning `nil`.
// Note that any value that is a character is escaped if it's a special
// character (new lines, tabs, etc.).
func (lx *lexer) errorf(format string, values ...interface{}) stateFn {
lx.items <- item{
itemError,
fmt.Sprintf(format, values...),
lx.line,
}
return nil
}
// lexTop consumes elements at the top level of TOML data.
func lexTop(lx *lexer) stateFn {
r := lx.next()
if isWhitespace(r) || isNL(r) {
return lexSkip(lx, lexTop)
}
switch r {
case commentStart:
lx.push(lexTop)
return lexCommentStart
case tableStart:
return lexTableStart
case eof:
if lx.pos > lx.start {
return lx.errorf("Unexpected EOF.")
}
lx.emit(itemEOF)
return nil
}
// At this point, the only valid item can be a key, so we back up
// and let the key lexer do the rest.
lx.backup()
lx.push(lexTopEnd)
return lexKeyStart
}
// lexTopEnd is entered whenever a top-level item has been consumed. (A value
// or a table.) It must see only whitespace, and will turn back to lexTop
// upon a new line. If it sees EOF, it will quit the lexer successfully.
func lexTopEnd(lx *lexer) stateFn {
r := lx.next()
switch {
case r == commentStart:
// a comment will read to a new line for us.
lx.push(lexTop)
return lexCommentStart
case isWhitespace(r):
return lexTopEnd
case isNL(r):
lx.ignore()
return lexTop
case r == eof:
lx.ignore()
return lexTop
}
return lx.errorf("Expected a top-level item to end with a new line, "+
"comment or EOF, but got %q instead.", r)
}
// lexTable lexes the beginning of a table. Namely, it makes sure that
// it starts with a character other than '.' and ']'.
// It assumes that '[' has already been consumed.
// It also handles the case that this is an item in an array of tables.
// e.g., '[[name]]'.
func lexTableStart(lx *lexer) stateFn {
if lx.peek() == arrayTableStart {
lx.next()
lx.emit(itemArrayTableStart)
lx.push(lexArrayTableEnd)
} else {
lx.emit(itemTableStart)
lx.push(lexTableEnd)
}
return lexTableNameStart
}
func lexTableEnd(lx *lexer) stateFn {
lx.emit(itemTableEnd)
return lexTopEnd
}
func lexArrayTableEnd(lx *lexer) stateFn {
if r := lx.next(); r != arrayTableEnd {
return lx.errorf("Expected end of table array name delimiter %q, "+
"but got %q instead.", arrayTableEnd, r)
}
lx.emit(itemArrayTableEnd)
return lexTopEnd
}
func lexTableNameStart(lx *lexer) stateFn {
switch r := lx.peek(); {
case r == tableEnd || r == eof:
return lx.errorf("Unexpected end of table name. (Table names cannot " +
"be empty.)")
case r == tableSep:
return lx.errorf("Unexpected table separator. (Table names cannot " +
"be empty.)")
case r == stringStart || r == rawStringStart:
lx.ignore()
lx.push(lexTableNameEnd)
return lexValue // reuse string lexing
case isWhitespace(r):
return lexTableNameStart
default:
return lexBareTableName
}
}
// lexTableName lexes the name of a table. It assumes that at least one
// valid character for the table has already been read.
func lexBareTableName(lx *lexer) stateFn {
switch r := lx.next(); {
case isBareKeyChar(r):
return lexBareTableName
case r == tableSep || r == tableEnd:
lx.backup()
lx.emitTrim(itemText)
return lexTableNameEnd
default:
return lx.errorf("Bare keys cannot contain %q.", r)
}
}
// lexTableNameEnd reads the end of a piece of a table name, optionally
// consuming whitespace.
func lexTableNameEnd(lx *lexer) stateFn {
switch r := lx.next(); {
case isWhitespace(r):
return lexTableNameEnd
case r == tableSep:
lx.ignore()
return lexTableNameStart
case r == tableEnd:
return lx.pop()
default:
return lx.errorf("Expected '.' or ']' to end table name, but got %q "+
"instead.", r)
}
}
// lexKeyStart consumes a key name up until the first non-whitespace character.
// lexKeyStart will ignore whitespace.
func lexKeyStart(lx *lexer) stateFn {
r := lx.peek()
switch {
case r == keySep:
return lx.errorf("Unexpected key separator %q.", keySep)
case isWhitespace(r) || isNL(r):
lx.next()
return lexSkip(lx, lexKeyStart)
case r == stringStart || r == rawStringStart:
lx.ignore()
lx.emit(itemKeyStart)
lx.push(lexKeyEnd)
return lexValue // reuse string lexing
default:
lx.ignore()
lx.emit(itemKeyStart)
return lexBareKey
}
}
// lexBareKey consumes the text of a bare key. Assumes that the first character
// (which is not whitespace) has not yet been consumed.
func lexBareKey(lx *lexer) stateFn {
switch r := lx.next(); {
case isBareKeyChar(r):
return lexBareKey
case isWhitespace(r):
lx.emitTrim(itemText)
return lexKeyEnd
case r == keySep:
lx.backup()
lx.emitTrim(itemText)
return lexKeyEnd
default:
return lx.errorf("Bare keys cannot contain %q.", r)
}
}
// lexKeyEnd consumes the end of a key and trims whitespace (up to the key
// separator).
func lexKeyEnd(lx *lexer) stateFn {
switch r := lx.next(); {
case r == keySep:
return lexSkip(lx, lexValue)
case isWhitespace(r):
return lexSkip(lx, lexKeyEnd)
default:
return lx.errorf("Expected key separator %q, but got %q instead.",
keySep, r)
}
}
// lexValue starts the consumption of a value anywhere a value is expected.
// lexValue will ignore whitespace.
// After a value is lexed, the last state on the next is popped and returned.
func lexValue(lx *lexer) stateFn {
// We allow whitespace to precede a value, but NOT new lines.
// In array syntax, the array states are responsible for ignoring new
// lines.
r := lx.next()
if isWhitespace(r) {
return lexSkip(lx, lexValue)
}
switch {
case r == arrayStart:
lx.ignore()
lx.emit(itemArray)
return lexArrayValue
case r == stringStart:
if lx.accept(stringStart) {
if lx.accept(stringStart) {
lx.ignore() // Ignore """
return lexMultilineString
}
lx.backup()
}
lx.ignore() // ignore the '"'
return lexString
case r == rawStringStart:
if lx.accept(rawStringStart) {
if lx.accept(rawStringStart) {
lx.ignore() // Ignore """
return lexMultilineRawString
}
lx.backup()
}
lx.ignore() // ignore the "'"
return lexRawString
case r == 't':
return lexTrue
case r == 'f':
return lexFalse
case r == '-':
return lexNumberStart
case isDigit(r):
lx.backup() // avoid an extra state and use the same as above
return lexNumberOrDateStart
case r == '.': // special error case, be kind to users
return lx.errorf("Floats must start with a digit, not '.'.")
}
return lx.errorf("Expected value but found %q instead.", r)
}
// lexArrayValue consumes one value in an array. It assumes that '[' or ','
// have already been consumed. All whitespace and new lines are ignored.
func lexArrayValue(lx *lexer) stateFn {
r := lx.next()
switch {
case isWhitespace(r) || isNL(r):
return lexSkip(lx, lexArrayValue)
case r == commentStart:
lx.push(lexArrayValue)
return lexCommentStart
case r == arrayValTerm:
return lx.errorf("Unexpected array value terminator %q.",
arrayValTerm)
case r == arrayEnd:
return lexArrayEnd
}
lx.backup()
lx.push(lexArrayValueEnd)
return lexValue
}
// lexArrayValueEnd consumes the cruft between values of an array. Namely,
// it ignores whitespace and expects either a ',' or a ']'.
func lexArrayValueEnd(lx *lexer) stateFn {
r := lx.next()
switch {
case isWhitespace(r) || isNL(r):
return lexSkip(lx, lexArrayValueEnd)
case r == commentStart:
lx.push(lexArrayValueEnd)
return lexCommentStart
case r == arrayValTerm:
lx.ignore()
return lexArrayValue // move on to the next value
case r == arrayEnd:
return lexArrayEnd
}
return lx.errorf("Expected an array value terminator %q or an array "+
"terminator %q, but got %q instead.", arrayValTerm, arrayEnd, r)
}
// lexArrayEnd finishes the lexing of an array. It assumes that a ']' has
// just been consumed.
func lexArrayEnd(lx *lexer) stateFn {
lx.ignore()
lx.emit(itemArrayEnd)
return lx.pop()
}
// lexString consumes the inner contents of a string. It assumes that the
// beginning '"' has already been consumed and ignored.
func lexString(lx *lexer) stateFn {
r := lx.next()
switch {
case isNL(r):
return lx.errorf("Strings cannot contain new lines.")
case r == '\\':
lx.push(lexString)
return lexStringEscape
case r == stringEnd:
lx.backup()
lx.emit(itemString)
lx.next()
lx.ignore()
return lx.pop()
}
return lexString
}
// lexMultilineString consumes the inner contents of a string. It assumes that
// the beginning '"""' has already been consumed and ignored.
func lexMultilineString(lx *lexer) stateFn {
r := lx.next()
switch {
case r == '\\':
return lexMultilineStringEscape
case r == stringEnd:
if lx.accept(stringEnd) {
if lx.accept(stringEnd) {
lx.backup()
lx.backup()
lx.backup()
lx.emit(itemMultilineString)
lx.next()
lx.next()
lx.next()
lx.ignore()
return lx.pop()
}
lx.backup()
}
}
return lexMultilineString
}
// lexRawString consumes a raw string. Nothing can be escaped in such a string.
// It assumes that the beginning "'" has already been consumed and ignored.
func lexRawString(lx *lexer) stateFn {
r := lx.next()
switch {
case isNL(r):
return lx.errorf("Strings cannot contain new lines.")
case r == rawStringEnd:
lx.backup()
lx.emit(itemRawString)
lx.next()
lx.ignore()
return lx.pop()
}
return lexRawString
}
// lexMultilineRawString consumes a raw string. Nothing can be escaped in such
// a string. It assumes that the beginning "'" has already been consumed and
// ignored.
func lexMultilineRawString(lx *lexer) stateFn {
r := lx.next()
switch {
case r == rawStringEnd:
if lx.accept(rawStringEnd) {
if lx.accept(rawStringEnd) {
lx.backup()
lx.backup()
lx.backup()
lx.emit(itemRawMultilineString)
lx.next()
lx.next()
lx.next()
lx.ignore()
return lx.pop()
}
lx.backup()
}
}
return lexMultilineRawString
}
// lexMultilineStringEscape consumes an escaped character. It assumes that the
// preceding '\\' has already been consumed.
func lexMultilineStringEscape(lx *lexer) stateFn {
// Handle the special case first:
if isNL(lx.next()) {
lx.next()
return lexMultilineString
} else {
lx.backup()
lx.push(lexMultilineString)
return lexStringEscape(lx)
}
}
func lexStringEscape(lx *lexer) stateFn {
r := lx.next()
switch r {
case 'b':
fallthrough
case 't':
fallthrough
case 'n':
fallthrough
case 'f':
fallthrough
case 'r':
fallthrough
case '"':
fallthrough
case '\\':
return lx.pop()
case 'u':
return lexShortUnicodeEscape
case 'U':
return lexLongUnicodeEscape
}
return lx.errorf("Invalid escape character %q. Only the following "+
"escape characters are allowed: "+
"\\b, \\t, \\n, \\f, \\r, \\\", \\/, \\\\, "+
"\\uXXXX and \\UXXXXXXXX.", r)
}
func lexShortUnicodeEscape(lx *lexer) stateFn {
var r rune
for i := 0; i < 4; i++ {
r = lx.next()
if !isHexadecimal(r) {
return lx.errorf("Expected four hexadecimal digits after '\\u', "+
"but got '%s' instead.", lx.current())
}
}
return lx.pop()
}
func lexLongUnicodeEscape(lx *lexer) stateFn {
var r rune
for i := 0; i < 8; i++ {
r = lx.next()
if !isHexadecimal(r) {
return lx.errorf("Expected eight hexadecimal digits after '\\U', "+
"but got '%s' instead.", lx.current())
}
}
return lx.pop()
}
// lexNumberOrDateStart consumes either a (positive) integer, float or
// datetime. It assumes that NO negative sign has been consumed.
func lexNumberOrDateStart(lx *lexer) stateFn {
r := lx.next()
if !isDigit(r) {
if r == '.' {
return lx.errorf("Floats must start with a digit, not '.'.")
} else {
return lx.errorf("Expected a digit but got %q.", r)
}
}
return lexNumberOrDate
}
// lexNumberOrDate consumes either a (positive) integer, float or datetime.
func lexNumberOrDate(lx *lexer) stateFn {
r := lx.next()
switch {
case r == '-':
if lx.pos-lx.start != 5 {
return lx.errorf("All ISO8601 dates must be in full Zulu form.")
}
return lexDateAfterYear
case isDigit(r):
return lexNumberOrDate
case r == '.':
return lexFloatStart
}
lx.backup()
lx.emit(itemInteger)
return lx.pop()
}
// lexDateAfterYear consumes a full Zulu Datetime in ISO8601 format.
// It assumes that "YYYY-" has already been consumed.
func lexDateAfterYear(lx *lexer) stateFn {
formats := []rune{
// digits are '0'.
// everything else is direct equality.
'0', '0', '-', '0', '0',
'T',
'0', '0', ':', '0', '0', ':', '0', '0',
'Z',
}
for _, f := range formats {
r := lx.next()
if f == '0' {
if !isDigit(r) {
return lx.errorf("Expected digit in ISO8601 datetime, "+
"but found %q instead.", r)
}
} else if f != r {
return lx.errorf("Expected %q in ISO8601 datetime, "+
"but found %q instead.", f, r)
}
}
lx.emit(itemDatetime)
return lx.pop()
}
// lexNumberStart consumes either an integer or a float. It assumes that
// a negative sign has already been read, but that *no* digits have been
// consumed. lexNumberStart will move to the appropriate integer or float
// states.
func lexNumberStart(lx *lexer) stateFn {
// we MUST see a digit. Even floats have to start with a digit.
r := lx.next()
if !isDigit(r) {
if r == '.' {
return lx.errorf("Floats must start with a digit, not '.'.")
} else {
return lx.errorf("Expected a digit but got %q.", r)
}
}
return lexNumber
}
// lexNumber consumes an integer or a float after seeing the first digit.
func lexNumber(lx *lexer) stateFn {
r := lx.next()
switch {
case isDigit(r):
return lexNumber
case r == '.':
return lexFloatStart
}
lx.backup()
lx.emit(itemInteger)
return lx.pop()
}
// lexFloatStart starts the consumption of digits of a float after a '.'.
// Namely, at least one digit is required.
func lexFloatStart(lx *lexer) stateFn {
r := lx.next()
if !isDigit(r) {
return lx.errorf("Floats must have a digit after the '.', but got "+
"%q instead.", r)
}
return lexFloat
}
// lexFloat consumes the digits of a float after a '.'.
// Assumes that one digit has been consumed after a '.' already.
func lexFloat(lx *lexer) stateFn {
r := lx.next()
if isDigit(r) {
return lexFloat
}
lx.backup()
lx.emit(itemFloat)
return lx.pop()
}
// lexConst consumes the s[1:] in s. It assumes that s[0] has already been
// consumed.
func lexConst(lx *lexer, s string) stateFn {
for i := range s[1:] {
if r := lx.next(); r != rune(s[i+1]) {
return lx.errorf("Expected %q, but found %q instead.", s[:i+1],
s[:i]+string(r))
}
}
return nil
}
// lexTrue consumes the "rue" in "true". It assumes that 't' has already
// been consumed.
func lexTrue(lx *lexer) stateFn {
if fn := lexConst(lx, "true"); fn != nil {
return fn
}
lx.emit(itemBool)
return lx.pop()
}
// lexFalse consumes the "alse" in "false". It assumes that 'f' has already
// been consumed.
func lexFalse(lx *lexer) stateFn {
if fn := lexConst(lx, "false"); fn != nil {
return fn
}
lx.emit(itemBool)
return lx.pop()
}
// lexCommentStart begins the lexing of a comment. It will emit
// itemCommentStart and consume no characters, passing control to lexComment.
func lexCommentStart(lx *lexer) stateFn {
lx.ignore()
lx.emit(itemCommentStart)
return lexComment
}
// lexComment lexes an entire comment. It assumes that '#' has been consumed.
// It will consume *up to* the first new line character, and pass control
// back to the last state on the stack.
func lexComment(lx *lexer) stateFn {
r := lx.peek()
if isNL(r) || r == eof {
lx.emit(itemText)
return lx.pop()
}
lx.next()
return lexComment
}
// lexSkip ignores all slurped input and moves on to the next state.
func lexSkip(lx *lexer, nextState stateFn) stateFn {
return func(lx *lexer) stateFn {
lx.ignore()
return nextState
}
}
// isWhitespace returns true if `r` is a whitespace character according
// to the spec.
func isWhitespace(r rune) bool {
return r == '\t' || r == ' '
}
func isNL(r rune) bool {
return r == '\n' || r == '\r'
}
func isDigit(r rune) bool {
return r >= '0' && r <= '9'
}
func isHexadecimal(r rune) bool {
return (r >= '0' && r <= '9') ||
(r >= 'a' && r <= 'f') ||
(r >= 'A' && r <= 'F')
}
func isBareKeyChar(r rune) bool {
return (r >= 'A' && r <= 'Z') ||
(r >= 'a' && r <= 'z') ||
(r >= '0' && r <= '9') ||
r == '_' ||
r == '-'
}
func (itype itemType) String() string {
switch itype {
case itemError:
return "Error"
case itemNIL:
return "NIL"
case itemEOF:
return "EOF"
case itemText:
return "Text"
case itemString:
return "String"
case itemRawString:
return "String"
case itemMultilineString:
return "String"
case itemRawMultilineString:
return "String"
case itemBool:
return "Bool"
case itemInteger:
return "Integer"
case itemFloat:
return "Float"
case itemDatetime:
return "DateTime"
case itemTableStart:
return "TableStart"
case itemTableEnd:
return "TableEnd"
case itemKeyStart:
return "KeyStart"
case itemArray:
return "Array"
case itemArrayEnd:
return "ArrayEnd"
case itemCommentStart:
return "CommentStart"
}
panic(fmt.Sprintf("BUG: Unknown type '%d'.", int(itype)))
}
func (item item) String() string {
return fmt.Sprintf("(%s, %s)", item.typ.String(), item.val)
}

View File

@@ -0,0 +1,498 @@
package toml
import (
"fmt"
"log"
"strconv"
"strings"
"time"
"unicode"
"unicode/utf8"
)
type parser struct {
mapping map[string]interface{}
types map[string]tomlType
lx *lexer
// A list of keys in the order that they appear in the TOML data.
ordered []Key
// the full key for the current hash in scope
context Key
// the base key name for everything except hashes
currentKey string
// rough approximation of line number
approxLine int
// A map of 'key.group.names' to whether they were created implicitly.
implicits map[string]bool
}
type parseError string
func (pe parseError) Error() string {
return string(pe)
}
func parse(data string) (p *parser, err error) {
defer func() {
if r := recover(); r != nil {
var ok bool
if err, ok = r.(parseError); ok {
return
}
panic(r)
}
}()
p = &parser{
mapping: make(map[string]interface{}),
types: make(map[string]tomlType),
lx: lex(data),
ordered: make([]Key, 0),
implicits: make(map[string]bool),
}
for {
item := p.next()
if item.typ == itemEOF {
break
}
p.topLevel(item)
}
return p, nil
}
func (p *parser) panicf(format string, v ...interface{}) {
msg := fmt.Sprintf("Near line %d (last key parsed '%s'): %s",
p.approxLine, p.current(), fmt.Sprintf(format, v...))
panic(parseError(msg))
}
func (p *parser) next() item {
it := p.lx.nextItem()
if it.typ == itemError {
p.panicf("%s", it.val)
}
return it
}
func (p *parser) bug(format string, v ...interface{}) {
log.Fatalf("BUG: %s\n\n", fmt.Sprintf(format, v...))
}
func (p *parser) expect(typ itemType) item {
it := p.next()
p.assertEqual(typ, it.typ)
return it
}
func (p *parser) assertEqual(expected, got itemType) {
if expected != got {
p.bug("Expected '%s' but got '%s'.", expected, got)
}
}
func (p *parser) topLevel(item item) {
switch item.typ {
case itemCommentStart:
p.approxLine = item.line
p.expect(itemText)
case itemTableStart:
kg := p.next()
p.approxLine = kg.line
var key Key
for ; kg.typ != itemTableEnd && kg.typ != itemEOF; kg = p.next() {
key = append(key, p.keyString(kg))
}
p.assertEqual(itemTableEnd, kg.typ)
p.establishContext(key, false)
p.setType("", tomlHash)
p.ordered = append(p.ordered, key)
case itemArrayTableStart:
kg := p.next()
p.approxLine = kg.line
var key Key
for ; kg.typ != itemArrayTableEnd && kg.typ != itemEOF; kg = p.next() {
key = append(key, p.keyString(kg))
}
p.assertEqual(itemArrayTableEnd, kg.typ)
p.establishContext(key, true)
p.setType("", tomlArrayHash)
p.ordered = append(p.ordered, key)
case itemKeyStart:
kname := p.next()
p.approxLine = kname.line
p.currentKey = p.keyString(kname)
val, typ := p.value(p.next())
p.setValue(p.currentKey, val)
p.setType(p.currentKey, typ)
p.ordered = append(p.ordered, p.context.add(p.currentKey))
p.currentKey = ""
default:
p.bug("Unexpected type at top level: %s", item.typ)
}
}
// Gets a string for a key (or part of a key in a table name).
func (p *parser) keyString(it item) string {
switch it.typ {
case itemText:
return it.val
case itemString, itemMultilineString,
itemRawString, itemRawMultilineString:
s, _ := p.value(it)
return s.(string)
default:
p.bug("Unexpected key type: %s", it.typ)
panic("unreachable")
}
}
// value translates an expected value from the lexer into a Go value wrapped
// as an empty interface.
func (p *parser) value(it item) (interface{}, tomlType) {
switch it.typ {
case itemString:
return p.replaceEscapes(it.val), p.typeOfPrimitive(it)
case itemMultilineString:
trimmed := stripFirstNewline(stripEscapedWhitespace(it.val))
return p.replaceEscapes(trimmed), p.typeOfPrimitive(it)
case itemRawString:
return it.val, p.typeOfPrimitive(it)
case itemRawMultilineString:
return stripFirstNewline(it.val), p.typeOfPrimitive(it)
case itemBool:
switch it.val {
case "true":
return true, p.typeOfPrimitive(it)
case "false":
return false, p.typeOfPrimitive(it)
}
p.bug("Expected boolean value, but got '%s'.", it.val)
case itemInteger:
num, err := strconv.ParseInt(it.val, 10, 64)
if err != nil {
// See comment below for floats describing why we make a
// distinction between a bug and a user error.
if e, ok := err.(*strconv.NumError); ok &&
e.Err == strconv.ErrRange {
p.panicf("Integer '%s' is out of the range of 64-bit "+
"signed integers.", it.val)
} else {
p.bug("Expected integer value, but got '%s'.", it.val)
}
}
return num, p.typeOfPrimitive(it)
case itemFloat:
num, err := strconv.ParseFloat(it.val, 64)
if err != nil {
// Distinguish float values. Normally, it'd be a bug if the lexer
// provides an invalid float, but it's possible that the float is
// out of range of valid values (which the lexer cannot determine).
// So mark the former as a bug but the latter as a legitimate user
// error.
//
// This is also true for integers.
if e, ok := err.(*strconv.NumError); ok &&
e.Err == strconv.ErrRange {
p.panicf("Float '%s' is out of the range of 64-bit "+
"IEEE-754 floating-point numbers.", it.val)
} else {
p.bug("Expected float value, but got '%s'.", it.val)
}
}
return num, p.typeOfPrimitive(it)
case itemDatetime:
t, err := time.Parse("2006-01-02T15:04:05Z", it.val)
if err != nil {
p.bug("Expected Zulu formatted DateTime, but got '%s'.", it.val)
}
return t, p.typeOfPrimitive(it)
case itemArray:
array := make([]interface{}, 0)
types := make([]tomlType, 0)
for it = p.next(); it.typ != itemArrayEnd; it = p.next() {
if it.typ == itemCommentStart {
p.expect(itemText)
continue
}
val, typ := p.value(it)
array = append(array, val)
types = append(types, typ)
}
return array, p.typeOfArray(types)
}
p.bug("Unexpected value type: %s", it.typ)
panic("unreachable")
}
// establishContext sets the current context of the parser,
// where the context is either a hash or an array of hashes. Which one is
// set depends on the value of the `array` parameter.
//
// Establishing the context also makes sure that the key isn't a duplicate, and
// will create implicit hashes automatically.
func (p *parser) establishContext(key Key, array bool) {
var ok bool
// Always start at the top level and drill down for our context.
hashContext := p.mapping
keyContext := make(Key, 0)
// We only need implicit hashes for key[0:-1]
for _, k := range key[0 : len(key)-1] {
_, ok = hashContext[k]
keyContext = append(keyContext, k)
// No key? Make an implicit hash and move on.
if !ok {
p.addImplicit(keyContext)
hashContext[k] = make(map[string]interface{})
}
// If the hash context is actually an array of tables, then set
// the hash context to the last element in that array.
//
// Otherwise, it better be a table, since this MUST be a key group (by
// virtue of it not being the last element in a key).
switch t := hashContext[k].(type) {
case []map[string]interface{}:
hashContext = t[len(t)-1]
case map[string]interface{}:
hashContext = t
default:
p.panicf("Key '%s' was already created as a hash.", keyContext)
}
}
p.context = keyContext
if array {
// If this is the first element for this array, then allocate a new
// list of tables for it.
k := key[len(key)-1]
if _, ok := hashContext[k]; !ok {
hashContext[k] = make([]map[string]interface{}, 0, 5)
}
// Add a new table. But make sure the key hasn't already been used
// for something else.
if hash, ok := hashContext[k].([]map[string]interface{}); ok {
hashContext[k] = append(hash, make(map[string]interface{}))
} else {
p.panicf("Key '%s' was already created and cannot be used as "+
"an array.", keyContext)
}
} else {
p.setValue(key[len(key)-1], make(map[string]interface{}))
}
p.context = append(p.context, key[len(key)-1])
}
// setValue sets the given key to the given value in the current context.
// It will make sure that the key hasn't already been defined, account for
// implicit key groups.
func (p *parser) setValue(key string, value interface{}) {
var tmpHash interface{}
var ok bool
hash := p.mapping
keyContext := make(Key, 0)
for _, k := range p.context {
keyContext = append(keyContext, k)
if tmpHash, ok = hash[k]; !ok {
p.bug("Context for key '%s' has not been established.", keyContext)
}
switch t := tmpHash.(type) {
case []map[string]interface{}:
// The context is a table of hashes. Pick the most recent table
// defined as the current hash.
hash = t[len(t)-1]
case map[string]interface{}:
hash = t
default:
p.bug("Expected hash to have type 'map[string]interface{}', but "+
"it has '%T' instead.", tmpHash)
}
}
keyContext = append(keyContext, key)
if _, ok := hash[key]; ok {
// Typically, if the given key has already been set, then we have
// to raise an error since duplicate keys are disallowed. However,
// it's possible that a key was previously defined implicitly. In this
// case, it is allowed to be redefined concretely. (See the
// `tests/valid/implicit-and-explicit-after.toml` test in `toml-test`.)
//
// But we have to make sure to stop marking it as an implicit. (So that
// another redefinition provokes an error.)
//
// Note that since it has already been defined (as a hash), we don't
// want to overwrite it. So our business is done.
if p.isImplicit(keyContext) {
p.removeImplicit(keyContext)
return
}
// Otherwise, we have a concrete key trying to override a previous
// key, which is *always* wrong.
p.panicf("Key '%s' has already been defined.", keyContext)
}
hash[key] = value
}
// setType sets the type of a particular value at a given key.
// It should be called immediately AFTER setValue.
//
// Note that if `key` is empty, then the type given will be applied to the
// current context (which is either a table or an array of tables).
func (p *parser) setType(key string, typ tomlType) {
keyContext := make(Key, 0, len(p.context)+1)
for _, k := range p.context {
keyContext = append(keyContext, k)
}
if len(key) > 0 { // allow type setting for hashes
keyContext = append(keyContext, key)
}
p.types[keyContext.String()] = typ
}
// addImplicit sets the given Key as having been created implicitly.
func (p *parser) addImplicit(key Key) {
p.implicits[key.String()] = true
}
// removeImplicit stops tagging the given key as having been implicitly
// created.
func (p *parser) removeImplicit(key Key) {
p.implicits[key.String()] = false
}
// isImplicit returns true if the key group pointed to by the key was created
// implicitly.
func (p *parser) isImplicit(key Key) bool {
return p.implicits[key.String()]
}
// current returns the full key name of the current context.
func (p *parser) current() string {
if len(p.currentKey) == 0 {
return p.context.String()
}
if len(p.context) == 0 {
return p.currentKey
}
return fmt.Sprintf("%s.%s", p.context, p.currentKey)
}
func stripFirstNewline(s string) string {
if len(s) == 0 || s[0] != '\n' {
return s
}
return s[1:len(s)]
}
func stripEscapedWhitespace(s string) string {
esc := strings.Split(s, "\\\n")
if len(esc) > 1 {
for i := 1; i < len(esc); i++ {
esc[i] = strings.TrimLeftFunc(esc[i], unicode.IsSpace)
}
}
return strings.Join(esc, "")
}
func (p *parser) replaceEscapes(str string) string {
var replaced []rune
s := []byte(str)
r := 0
for r < len(s) {
if s[r] != '\\' {
c, size := utf8.DecodeRune(s[r:])
r += size
replaced = append(replaced, c)
continue
}
r += 1
if r >= len(s) {
p.bug("Escape sequence at end of string.")
return ""
}
switch s[r] {
default:
p.bug("Expected valid escape code after \\, but got %q.", s[r])
return ""
case 'b':
replaced = append(replaced, rune(0x0008))
r += 1
case 't':
replaced = append(replaced, rune(0x0009))
r += 1
case 'n':
replaced = append(replaced, rune(0x000A))
r += 1
case 'f':
replaced = append(replaced, rune(0x000C))
r += 1
case 'r':
replaced = append(replaced, rune(0x000D))
r += 1
case '"':
replaced = append(replaced, rune(0x0022))
r += 1
case '\\':
replaced = append(replaced, rune(0x005C))
r += 1
case 'u':
// At this point, we know we have a Unicode escape of the form
// `uXXXX` at [r, r+5). (Because the lexer guarantees this
// for us.)
escaped := p.asciiEscapeToUnicode(s[r+1 : r+5])
replaced = append(replaced, escaped)
r += 5
case 'U':
// At this point, we know we have a Unicode escape of the form
// `uXXXX` at [r, r+9). (Because the lexer guarantees this
// for us.)
escaped := p.asciiEscapeToUnicode(s[r+1 : r+9])
replaced = append(replaced, escaped)
r += 9
}
}
return string(replaced)
}
func (p *parser) asciiEscapeToUnicode(bs []byte) rune {
s := string(bs)
hex, err := strconv.ParseUint(strings.ToLower(s), 16, 32)
if err != nil {
p.bug("Could not parse '%s' as a hexadecimal number, but the "+
"lexer claims it's OK: %s", s, err)
}
// BUG(burntsushi)
// I honestly don't understand how this works. I can't seem
// to find a way to make this fail. I figured this would fail on invalid
// UTF-8 characters like U+DCFF, but it doesn't.
if !utf8.ValidString(string(rune(hex))) {
p.panicf("Escaped character '\\u%s' is not valid UTF-8.", s)
}
return rune(hex)
}
func isStringType(ty itemType) bool {
return ty == itemString || ty == itemMultilineString ||
ty == itemRawString || ty == itemRawMultilineString
}

View File

@@ -0,0 +1 @@
au BufWritePost *.go silent!make tags > /dev/null 2>&1

View File

@@ -0,0 +1,91 @@
package toml
// tomlType represents any Go type that corresponds to a TOML type.
// While the first draft of the TOML spec has a simplistic type system that
// probably doesn't need this level of sophistication, we seem to be militating
// toward adding real composite types.
type tomlType interface {
typeString() string
}
// typeEqual accepts any two types and returns true if they are equal.
func typeEqual(t1, t2 tomlType) bool {
if t1 == nil || t2 == nil {
return false
}
return t1.typeString() == t2.typeString()
}
func typeIsHash(t tomlType) bool {
return typeEqual(t, tomlHash) || typeEqual(t, tomlArrayHash)
}
type tomlBaseType string
func (btype tomlBaseType) typeString() string {
return string(btype)
}
func (btype tomlBaseType) String() string {
return btype.typeString()
}
var (
tomlInteger tomlBaseType = "Integer"
tomlFloat tomlBaseType = "Float"
tomlDatetime tomlBaseType = "Datetime"
tomlString tomlBaseType = "String"
tomlBool tomlBaseType = "Bool"
tomlArray tomlBaseType = "Array"
tomlHash tomlBaseType = "Hash"
tomlArrayHash tomlBaseType = "ArrayHash"
)
// typeOfPrimitive returns a tomlType of any primitive value in TOML.
// Primitive values are: Integer, Float, Datetime, String and Bool.
//
// Passing a lexer item other than the following will cause a BUG message
// to occur: itemString, itemBool, itemInteger, itemFloat, itemDatetime.
func (p *parser) typeOfPrimitive(lexItem item) tomlType {
switch lexItem.typ {
case itemInteger:
return tomlInteger
case itemFloat:
return tomlFloat
case itemDatetime:
return tomlDatetime
case itemString:
return tomlString
case itemMultilineString:
return tomlString
case itemRawString:
return tomlString
case itemRawMultilineString:
return tomlString
case itemBool:
return tomlBool
}
p.bug("Cannot infer primitive type of lex item '%s'.", lexItem)
panic("unreachable")
}
// typeOfArray returns a tomlType for an array given a list of types of its
// values.
//
// In the current spec, if an array is homogeneous, then its type is always
// "Array". If the array is not homogeneous, an error is generated.
func (p *parser) typeOfArray(types []tomlType) tomlType {
// Empty arrays are cool.
if len(types) == 0 {
return tomlArray
}
theType := types[0]
for _, t := range types[1:] {
if !typeEqual(theType, t) {
p.panicf("Array contains values of type '%s' and '%s', but "+
"arrays must be homogeneous.", theType, t)
}
}
return tomlArray
}

View File

@@ -0,0 +1,241 @@
package toml
// Struct field handling is adapted from code in encoding/json:
//
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the Go distribution.
import (
"reflect"
"sort"
"sync"
)
// A field represents a single field found in a struct.
type field struct {
name string // the name of the field (`toml` tag included)
tag bool // whether field has a `toml` tag
index []int // represents the depth of an anonymous field
typ reflect.Type // the type of the field
}
// byName sorts field by name, breaking ties with depth,
// then breaking ties with "name came from toml tag", then
// breaking ties with index sequence.
type byName []field
func (x byName) Len() int { return len(x) }
func (x byName) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
func (x byName) Less(i, j int) bool {
if x[i].name != x[j].name {
return x[i].name < x[j].name
}
if len(x[i].index) != len(x[j].index) {
return len(x[i].index) < len(x[j].index)
}
if x[i].tag != x[j].tag {
return x[i].tag
}
return byIndex(x).Less(i, j)
}
// byIndex sorts field by index sequence.
type byIndex []field
func (x byIndex) Len() int { return len(x) }
func (x byIndex) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
func (x byIndex) Less(i, j int) bool {
for k, xik := range x[i].index {
if k >= len(x[j].index) {
return false
}
if xik != x[j].index[k] {
return xik < x[j].index[k]
}
}
return len(x[i].index) < len(x[j].index)
}
// typeFields returns a list of fields that TOML should recognize for the given
// type. The algorithm is breadth-first search over the set of structs to
// include - the top struct and then any reachable anonymous structs.
func typeFields(t reflect.Type) []field {
// Anonymous fields to explore at the current level and the next.
current := []field{}
next := []field{{typ: t}}
// Count of queued names for current level and the next.
count := map[reflect.Type]int{}
nextCount := map[reflect.Type]int{}
// Types already visited at an earlier level.
visited := map[reflect.Type]bool{}
// Fields found.
var fields []field
for len(next) > 0 {
current, next = next, current[:0]
count, nextCount = nextCount, map[reflect.Type]int{}
for _, f := range current {
if visited[f.typ] {
continue
}
visited[f.typ] = true
// Scan f.typ for fields to include.
for i := 0; i < f.typ.NumField(); i++ {
sf := f.typ.Field(i)
if sf.PkgPath != "" { // unexported
continue
}
name := sf.Tag.Get("toml")
if name == "-" {
continue
}
index := make([]int, len(f.index)+1)
copy(index, f.index)
index[len(f.index)] = i
ft := sf.Type
if ft.Name() == "" && ft.Kind() == reflect.Ptr {
// Follow pointer.
ft = ft.Elem()
}
// Record found field and index sequence.
if name != "" || !sf.Anonymous || ft.Kind() != reflect.Struct {
tagged := name != ""
if name == "" {
name = sf.Name
}
fields = append(fields, field{name, tagged, index, ft})
if count[f.typ] > 1 {
// If there were multiple instances, add a second,
// so that the annihilation code will see a duplicate.
// It only cares about the distinction between 1 or 2,
// so don't bother generating any more copies.
fields = append(fields, fields[len(fields)-1])
}
continue
}
// Record new anonymous struct to explore in next round.
nextCount[ft]++
if nextCount[ft] == 1 {
f := field{name: ft.Name(), index: index, typ: ft}
next = append(next, f)
}
}
}
}
sort.Sort(byName(fields))
// Delete all fields that are hidden by the Go rules for embedded fields,
// except that fields with TOML tags are promoted.
// The fields are sorted in primary order of name, secondary order
// of field index length. Loop over names; for each name, delete
// hidden fields by choosing the one dominant field that survives.
out := fields[:0]
for advance, i := 0, 0; i < len(fields); i += advance {
// One iteration per name.
// Find the sequence of fields with the name of this first field.
fi := fields[i]
name := fi.name
for advance = 1; i+advance < len(fields); advance++ {
fj := fields[i+advance]
if fj.name != name {
break
}
}
if advance == 1 { // Only one field with this name
out = append(out, fi)
continue
}
dominant, ok := dominantField(fields[i : i+advance])
if ok {
out = append(out, dominant)
}
}
fields = out
sort.Sort(byIndex(fields))
return fields
}
// dominantField looks through the fields, all of which are known to
// have the same name, to find the single field that dominates the
// others using Go's embedding rules, modified by the presence of
// TOML tags. If there are multiple top-level fields, the boolean
// will be false: This condition is an error in Go and we skip all
// the fields.
func dominantField(fields []field) (field, bool) {
// The fields are sorted in increasing index-length order. The winner
// must therefore be one with the shortest index length. Drop all
// longer entries, which is easy: just truncate the slice.
length := len(fields[0].index)
tagged := -1 // Index of first tagged field.
for i, f := range fields {
if len(f.index) > length {
fields = fields[:i]
break
}
if f.tag {
if tagged >= 0 {
// Multiple tagged fields at the same level: conflict.
// Return no field.
return field{}, false
}
tagged = i
}
}
if tagged >= 0 {
return fields[tagged], true
}
// All remaining fields have the same length. If there's more than one,
// we have a conflict (two fields named "X" at the same level) and we
// return no field.
if len(fields) > 1 {
return field{}, false
}
return fields[0], true
}
var fieldCache struct {
sync.RWMutex
m map[reflect.Type][]field
}
// cachedTypeFields is like typeFields but uses a cache to avoid repeated work.
func cachedTypeFields(t reflect.Type) []field {
fieldCache.RLock()
f := fieldCache.m[t]
fieldCache.RUnlock()
if f != nil {
return f
}
// Compute fields without lock.
// Might duplicate effort but won't hold other computations back.
f = typeFields(t)
if f == nil {
f = []field{}
}
fieldCache.Lock()
if fieldCache.m == nil {
fieldCache.m = map[reflect.Type][]field{}
}
fieldCache.m[t] = f
fieldCache.Unlock()
return f
}

View File

@@ -0,0 +1,105 @@
// Package awserr represents API error interface accessors for the SDK.
package awserr
// An Error wraps lower level errors with code, message and an original error.
// The underlying concrete error type may also satisfy other interfaces which
// can be to used to obtain more specific information about the error.
//
// Calling Error() or String() will always include the full information about
// an error based on its underlying type.
//
// Example:
//
// output, err := s3manage.Upload(svc, input, opts)
// if err != nil {
// if awsErr, ok := err.(awserr.Error); ok {
// // Get error details
// log.Println("Error:", err.Code(), err.Message())
//
// // Prints out full error message, including original error if there was one.
// log.Println("Error:", err.Error())
//
// // Get original error
// if origErr := err.Err(); origErr != nil {
// // operate on original error.
// }
// } else {
// fmt.Println(err.Error())
// }
// }
//
type Error interface {
// Satisfy the generic error interface.
error
// Returns the short phrase depicting the classification of the error.
Code() string
// Returns the error details message.
Message() string
// Returns the original error if one was set. Nil is returned if not set.
OrigErr() error
}
// New returns an Error object described by the code, message, and origErr.
//
// If origErr satisfies the Error interface it will not be wrapped within a new
// Error object and will instead be returned.
func New(code, message string, origErr error) Error {
if e, ok := origErr.(Error); ok && e != nil {
return e
}
return newBaseError(code, message, origErr)
}
// A RequestFailure is an interface to extract request failure information from
// an Error such as the request ID of the failed request returned by a service.
// RequestFailures may not always have a requestID value if the request failed
// prior to reaching the service such as a connection error.
//
// Example:
//
// output, err := s3manage.Upload(svc, input, opts)
// if err != nil {
// if reqerr, ok := err.(RequestFailure); ok {
// log.Printf("Request failed", reqerr.Code(), reqerr.Message(), reqerr.RequestID())
// } else {
// log.Printf("Error:", err.Error()
// }
// }
//
// Combined with awserr.Error:
//
// output, err := s3manage.Upload(svc, input, opts)
// if err != nil {
// if awsErr, ok := err.(awserr.Error); ok {
// // Generic AWS Error with Code, Message, and original error (if any)
// fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
//
// if reqErr, ok := err.(awserr.RequestFailure); ok {
// // A service error occurred
// fmt.Println(reqErr.StatusCode(), reqErr.RequestID())
// }
// } else {
// fmt.Println(err.Error())
// }
// }
//
type RequestFailure interface {
Error
// The status code of the HTTP response.
StatusCode() int
// The request ID returned by the service for a request failure. This will
// be empty if no request ID is available such as the request failed due
// to a connection error.
RequestID() string
}
// NewRequestFailure returns a new request error wrapper for the given Error
// provided.
func NewRequestFailure(err Error, statusCode int, reqID string) RequestFailure {
return newRequestError(err, statusCode, reqID)
}

View File

@@ -0,0 +1,135 @@
package awserr
import "fmt"
// SprintError returns a string of the formatted error code.
//
// Both extra and origErr are optional. If they are included their lines
// will be added, but if they are not included their lines will be ignored.
func SprintError(code, message, extra string, origErr error) string {
msg := fmt.Sprintf("%s: %s", code, message)
if extra != "" {
msg = fmt.Sprintf("%s\n\t%s", msg, extra)
}
if origErr != nil {
msg = fmt.Sprintf("%s\ncaused by: %s", msg, origErr.Error())
}
return msg
}
// A baseError wraps the code and message which defines an error. It also
// can be used to wrap an original error object.
//
// Should be used as the root for errors satisfying the awserr.Error. Also
// for any error which does not fit into a specific error wrapper type.
type baseError struct {
// Classification of error
code string
// Detailed information about error
message string
// Optional original error this error is based off of. Allows building
// chained errors.
origErr error
}
// newBaseError returns an error object for the code, message, and err.
//
// code is a short no whitespace phrase depicting the classification of
// the error that is being created.
//
// message is the free flow string containing detailed information about the error.
//
// origErr is the error object which will be nested under the new error to be returned.
func newBaseError(code, message string, origErr error) *baseError {
return &baseError{
code: code,
message: message,
origErr: origErr,
}
}
// Error returns the string representation of the error.
//
// See ErrorWithExtra for formatting.
//
// Satisfies the error interface.
func (b baseError) Error() string {
return SprintError(b.code, b.message, "", b.origErr)
}
// String returns the string representation of the error.
// Alias for Error to satisfy the stringer interface.
func (b baseError) String() string {
return b.Error()
}
// Code returns the short phrase depicting the classification of the error.
func (b baseError) Code() string {
return b.code
}
// Message returns the error details message.
func (b baseError) Message() string {
return b.message
}
// OrigErr returns the original error if one was set. Nil is returned if no error
// was set.
func (b baseError) OrigErr() error {
return b.origErr
}
// So that the Error interface type can be included as an anonymous field
// in the requestError struct and not conflict with the error.Error() method.
type awsError Error
// A requestError wraps a request or service error.
//
// Composed of baseError for code, message, and original error.
type requestError struct {
awsError
statusCode int
requestID string
}
// newRequestError returns a wrapped error with additional information for request
// status code, and service requestID.
//
// Should be used to wrap all request which involve service requests. Even if
// the request failed without a service response, but had an HTTP status code
// that may be meaningful.
//
// Also wraps original errors via the baseError.
func newRequestError(err Error, statusCode int, requestID string) *requestError {
return &requestError{
awsError: err,
statusCode: statusCode,
requestID: requestID,
}
}
// Error returns the string representation of the error.
// Satisfies the error interface.
func (r requestError) Error() string {
extra := fmt.Sprintf("status code: %d, request id: [%s]",
r.statusCode, r.requestID)
return SprintError(r.Code(), r.Message(), extra, r.OrigErr())
}
// String returns the string representation of the error.
// Alias for Error to satisfy the stringer interface.
func (r requestError) String() string {
return r.Error()
}
// StatusCode returns the wrapped status code for the error
func (r requestError) StatusCode() int {
return r.statusCode
}
// RequestID returns the wrapped requestID
func (r requestError) RequestID() string {
return r.requestID
}

View File

@@ -0,0 +1,103 @@
package awsutil
import (
"io"
"reflect"
)
// Copy deeply copies a src structure to dst. Useful for copying request and
// response structures.
//
// Can copy between structs of different type, but will only copy fields which
// are assignable, and exist in both structs. Fields which are not assignable,
// or do not exist in both structs are ignored.
func Copy(dst, src interface{}) {
dstval := reflect.ValueOf(dst)
if !dstval.IsValid() {
panic("Copy dst cannot be nil")
}
rcopy(dstval, reflect.ValueOf(src), true)
}
// CopyOf returns a copy of src while also allocating the memory for dst.
// src must be a pointer type or this operation will fail.
func CopyOf(src interface{}) (dst interface{}) {
dsti := reflect.New(reflect.TypeOf(src).Elem())
dst = dsti.Interface()
rcopy(dsti, reflect.ValueOf(src), true)
return
}
// rcopy performs a recursive copy of values from the source to destination.
//
// root is used to skip certain aspects of the copy which are not valid
// for the root node of a object.
func rcopy(dst, src reflect.Value, root bool) {
if !src.IsValid() {
return
}
switch src.Kind() {
case reflect.Ptr:
if _, ok := src.Interface().(io.Reader); ok {
if dst.Kind() == reflect.Ptr && dst.Elem().CanSet() {
dst.Elem().Set(src)
} else if dst.CanSet() {
dst.Set(src)
}
} else {
e := src.Type().Elem()
if dst.CanSet() && !src.IsNil() {
dst.Set(reflect.New(e))
}
if src.Elem().IsValid() {
// Keep the current root state since the depth hasn't changed
rcopy(dst.Elem(), src.Elem(), root)
}
}
case reflect.Struct:
if !root {
dst.Set(reflect.New(src.Type()).Elem())
}
t := dst.Type()
for i := 0; i < t.NumField(); i++ {
name := t.Field(i).Name
srcval := src.FieldByName(name)
if srcval.IsValid() {
rcopy(dst.FieldByName(name), srcval, false)
}
}
case reflect.Slice:
if src.IsNil() {
break
}
s := reflect.MakeSlice(src.Type(), src.Len(), src.Cap())
dst.Set(s)
for i := 0; i < src.Len(); i++ {
rcopy(dst.Index(i), src.Index(i), false)
}
case reflect.Map:
if src.IsNil() {
break
}
s := reflect.MakeMap(src.Type())
dst.Set(s)
for _, k := range src.MapKeys() {
v := src.MapIndex(k)
v2 := reflect.New(v.Type()).Elem()
rcopy(v2, v, false)
dst.SetMapIndex(k, v2)
}
default:
// Assign the value if possible. If its not assignable, the value would
// need to be converted and the impact of that may be unexpected, or is
// not compatible with the dst type.
if src.Type().AssignableTo(dst.Type()) {
dst.Set(src)
}
}
}

View File

@@ -0,0 +1,201 @@
package awsutil_test
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"testing"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/stretchr/testify/assert"
)
func ExampleCopy() {
type Foo struct {
A int
B []*string
}
// Create the initial value
str1 := "hello"
str2 := "bye bye"
f1 := &Foo{A: 1, B: []*string{&str1, &str2}}
// Do the copy
var f2 Foo
awsutil.Copy(&f2, f1)
// Print the result
fmt.Println(awsutil.Prettify(f2))
// Output:
// {
// A: 1,
// B: ["hello","bye bye"]
// }
}
func TestCopy(t *testing.T) {
type Foo struct {
A int
B []*string
C map[string]*int
}
// Create the initial value
str1 := "hello"
str2 := "bye bye"
int1 := 1
int2 := 2
f1 := &Foo{
A: 1,
B: []*string{&str1, &str2},
C: map[string]*int{
"A": &int1,
"B": &int2,
},
}
// Do the copy
var f2 Foo
awsutil.Copy(&f2, f1)
// Values are equal
assert.Equal(t, f2.A, f1.A)
assert.Equal(t, f2.B, f1.B)
assert.Equal(t, f2.C, f1.C)
// But pointers are not!
str3 := "nothello"
int3 := 57
f2.A = 100
f2.B[0] = &str3
f2.C["B"] = &int3
assert.NotEqual(t, f2.A, f1.A)
assert.NotEqual(t, f2.B, f1.B)
assert.NotEqual(t, f2.C, f1.C)
}
func TestCopyIgnoreNilMembers(t *testing.T) {
type Foo struct {
A *string
B []string
C map[string]string
}
f := &Foo{}
assert.Nil(t, f.A)
assert.Nil(t, f.B)
assert.Nil(t, f.C)
var f2 Foo
awsutil.Copy(&f2, f)
assert.Nil(t, f2.A)
assert.Nil(t, f2.B)
assert.Nil(t, f2.C)
fcopy := awsutil.CopyOf(f)
f3 := fcopy.(*Foo)
assert.Nil(t, f3.A)
assert.Nil(t, f3.B)
assert.Nil(t, f3.C)
}
func TestCopyPrimitive(t *testing.T) {
str := "hello"
var s string
awsutil.Copy(&s, &str)
assert.Equal(t, "hello", s)
}
func TestCopyNil(t *testing.T) {
var s string
awsutil.Copy(&s, nil)
assert.Equal(t, "", s)
}
func TestCopyReader(t *testing.T) {
var buf io.Reader = bytes.NewReader([]byte("hello world"))
var r io.Reader
awsutil.Copy(&r, buf)
b, err := ioutil.ReadAll(r)
assert.NoError(t, err)
assert.Equal(t, []byte("hello world"), b)
// empty bytes because this is not a deep copy
b, err = ioutil.ReadAll(buf)
assert.NoError(t, err)
assert.Equal(t, []byte(""), b)
}
func TestCopyDifferentStructs(t *testing.T) {
type SrcFoo struct {
A int
B []*string
C map[string]*int
SrcUnique string
SameNameDiffType int
}
type DstFoo struct {
A int
B []*string
C map[string]*int
DstUnique int
SameNameDiffType string
}
// Create the initial value
str1 := "hello"
str2 := "bye bye"
int1 := 1
int2 := 2
f1 := &SrcFoo{
A: 1,
B: []*string{&str1, &str2},
C: map[string]*int{
"A": &int1,
"B": &int2,
},
SrcUnique: "unique",
SameNameDiffType: 1,
}
// Do the copy
var f2 DstFoo
awsutil.Copy(&f2, f1)
// Values are equal
assert.Equal(t, f2.A, f1.A)
assert.Equal(t, f2.B, f1.B)
assert.Equal(t, f2.C, f1.C)
assert.Equal(t, "unique", f1.SrcUnique)
assert.Equal(t, 1, f1.SameNameDiffType)
assert.Equal(t, 0, f2.DstUnique)
assert.Equal(t, "", f2.SameNameDiffType)
}
func ExampleCopyOf() {
type Foo struct {
A int
B []*string
}
// Create the initial value
str1 := "hello"
str2 := "bye bye"
f1 := &Foo{A: 1, B: []*string{&str1, &str2}}
// Do the copy
v := awsutil.CopyOf(f1)
var f2 *Foo = v.(*Foo)
// Print the result
fmt.Println(awsutil.Prettify(f2))
// Output:
// {
// A: 1,
// B: ["hello","bye bye"]
// }
}

View File

@@ -0,0 +1,187 @@
package awsutil
import (
"reflect"
"regexp"
"strconv"
"strings"
)
var indexRe = regexp.MustCompile(`(.+)\[(-?\d+)?\]$`)
// rValuesAtPath returns a slice of values found in value v. The values
// in v are explored recursively so all nested values are collected.
func rValuesAtPath(v interface{}, path string, create bool, caseSensitive bool) []reflect.Value {
pathparts := strings.Split(path, "||")
if len(pathparts) > 1 {
for _, pathpart := range pathparts {
vals := rValuesAtPath(v, pathpart, create, caseSensitive)
if vals != nil && len(vals) > 0 {
return vals
}
}
return nil
}
values := []reflect.Value{reflect.Indirect(reflect.ValueOf(v))}
components := strings.Split(path, ".")
for len(values) > 0 && len(components) > 0 {
var index *int64
var indexStar bool
c := strings.TrimSpace(components[0])
if c == "" { // no actual component, illegal syntax
return nil
} else if caseSensitive && c != "*" && strings.ToLower(c[0:1]) == c[0:1] {
// TODO normalize case for user
return nil // don't support unexported fields
}
// parse this component
if m := indexRe.FindStringSubmatch(c); m != nil {
c = m[1]
if m[2] == "" {
index = nil
indexStar = true
} else {
i, _ := strconv.ParseInt(m[2], 10, 32)
index = &i
indexStar = false
}
}
nextvals := []reflect.Value{}
for _, value := range values {
// pull component name out of struct member
if value.Kind() != reflect.Struct {
continue
}
if c == "*" { // pull all members
for i := 0; i < value.NumField(); i++ {
if f := reflect.Indirect(value.Field(i)); f.IsValid() {
nextvals = append(nextvals, f)
}
}
continue
}
value = value.FieldByNameFunc(func(name string) bool {
if c == name {
return true
} else if !caseSensitive && strings.ToLower(name) == strings.ToLower(c) {
return true
}
return false
})
if create && value.Kind() == reflect.Ptr && value.IsNil() {
value.Set(reflect.New(value.Type().Elem()))
value = value.Elem()
} else {
value = reflect.Indirect(value)
}
if value.Kind() == reflect.Slice || value.Kind() == reflect.Map {
if !create && value.IsNil() {
value = reflect.ValueOf(nil)
}
}
if value.IsValid() {
nextvals = append(nextvals, value)
}
}
values = nextvals
if indexStar || index != nil {
nextvals = []reflect.Value{}
for _, value := range values {
value := reflect.Indirect(value)
if value.Kind() != reflect.Slice {
continue
}
if indexStar { // grab all indices
for i := 0; i < value.Len(); i++ {
idx := reflect.Indirect(value.Index(i))
if idx.IsValid() {
nextvals = append(nextvals, idx)
}
}
continue
}
// pull out index
i := int(*index)
if i >= value.Len() { // check out of bounds
if create {
// TODO resize slice
} else {
continue
}
} else if i < 0 { // support negative indexing
i = value.Len() + i
}
value = reflect.Indirect(value.Index(i))
if value.Kind() == reflect.Slice || value.Kind() == reflect.Map {
if !create && value.IsNil() {
value = reflect.ValueOf(nil)
}
}
if value.IsValid() {
nextvals = append(nextvals, value)
}
}
values = nextvals
}
components = components[1:]
}
return values
}
// ValuesAtPath returns a list of objects at the lexical path inside of a structure
func ValuesAtPath(i interface{}, path string) []interface{} {
if rvals := rValuesAtPath(i, path, false, true); rvals != nil {
vals := make([]interface{}, len(rvals))
for i, rval := range rvals {
vals[i] = rval.Interface()
}
return vals
}
return nil
}
// ValuesAtAnyPath returns a list of objects at the case-insensitive lexical
// path inside of a structure
func ValuesAtAnyPath(i interface{}, path string) []interface{} {
if rvals := rValuesAtPath(i, path, false, false); rvals != nil {
vals := make([]interface{}, len(rvals))
for i, rval := range rvals {
vals[i] = rval.Interface()
}
return vals
}
return nil
}
// SetValueAtPath sets an object at the lexical path inside of a structure
func SetValueAtPath(i interface{}, path string, v interface{}) {
if rvals := rValuesAtPath(i, path, true, true); rvals != nil {
for _, rval := range rvals {
rval.Set(reflect.ValueOf(v))
}
}
}
// SetValueAtAnyPath sets an object at the case insensitive lexical path inside
// of a structure
func SetValueAtAnyPath(i interface{}, path string, v interface{}) {
if rvals := rValuesAtPath(i, path, true, false); rvals != nil {
for _, rval := range rvals {
rval.Set(reflect.ValueOf(v))
}
}
}

View File

@@ -0,0 +1,68 @@
package awsutil_test
import (
"testing"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/stretchr/testify/assert"
)
type Struct struct {
A []Struct
z []Struct
B *Struct
D *Struct
C string
}
var data = Struct{
A: []Struct{{C: "value1"}, {C: "value2"}, {C: "value3"}},
z: []Struct{{C: "value1"}, {C: "value2"}, {C: "value3"}},
B: &Struct{B: &Struct{C: "terminal"}, D: &Struct{C: "terminal2"}},
C: "initial",
}
func TestValueAtPathSuccess(t *testing.T) {
assert.Equal(t, []interface{}{"initial"}, awsutil.ValuesAtPath(data, "C"))
assert.Equal(t, []interface{}{"value1"}, awsutil.ValuesAtPath(data, "A[0].C"))
assert.Equal(t, []interface{}{"value2"}, awsutil.ValuesAtPath(data, "A[1].C"))
assert.Equal(t, []interface{}{"value3"}, awsutil.ValuesAtPath(data, "A[2].C"))
assert.Equal(t, []interface{}{"value3"}, awsutil.ValuesAtAnyPath(data, "a[2].c"))
assert.Equal(t, []interface{}{"value3"}, awsutil.ValuesAtPath(data, "A[-1].C"))
assert.Equal(t, []interface{}{"value1", "value2", "value3"}, awsutil.ValuesAtPath(data, "A[].C"))
assert.Equal(t, []interface{}{"terminal"}, awsutil.ValuesAtPath(data, "B . B . C"))
assert.Equal(t, []interface{}{"terminal", "terminal2"}, awsutil.ValuesAtPath(data, "B.*.C"))
assert.Equal(t, []interface{}{"initial"}, awsutil.ValuesAtPath(data, "A.D.X || C"))
}
func TestValueAtPathFailure(t *testing.T) {
assert.Equal(t, []interface{}(nil), awsutil.ValuesAtPath(data, "C.x"))
assert.Equal(t, []interface{}(nil), awsutil.ValuesAtPath(data, ".x"))
assert.Equal(t, []interface{}{}, awsutil.ValuesAtPath(data, "X.Y.Z"))
assert.Equal(t, []interface{}{}, awsutil.ValuesAtPath(data, "A[100].C"))
assert.Equal(t, []interface{}{}, awsutil.ValuesAtPath(data, "A[3].C"))
assert.Equal(t, []interface{}{}, awsutil.ValuesAtPath(data, "B.B.C.Z"))
assert.Equal(t, []interface{}(nil), awsutil.ValuesAtPath(data, "z[-1].C"))
assert.Equal(t, []interface{}{}, awsutil.ValuesAtPath(nil, "A.B.C"))
assert.Equal(t, []interface{}{}, awsutil.ValuesAtPath(Struct{}, "A"))
}
func TestSetValueAtPathSuccess(t *testing.T) {
var s Struct
awsutil.SetValueAtPath(&s, "C", "test1")
awsutil.SetValueAtPath(&s, "B.B.C", "test2")
awsutil.SetValueAtPath(&s, "B.D.C", "test3")
assert.Equal(t, "test1", s.C)
assert.Equal(t, "test2", s.B.B.C)
assert.Equal(t, "test3", s.B.D.C)
awsutil.SetValueAtPath(&s, "B.*.C", "test0")
assert.Equal(t, "test0", s.B.B.C)
assert.Equal(t, "test0", s.B.D.C)
var s2 Struct
awsutil.SetValueAtAnyPath(&s2, "b.b.c", "test0")
assert.Equal(t, "test0", s2.B.B.C)
awsutil.SetValueAtAnyPath(&s2, "A", []Struct{{}})
assert.Equal(t, []Struct{{}}, s2.A)
}

View File

@@ -0,0 +1,103 @@
package awsutil
import (
"bytes"
"fmt"
"io"
"reflect"
"strings"
)
// Prettify returns the string representation of a value.
func Prettify(i interface{}) string {
var buf bytes.Buffer
prettify(reflect.ValueOf(i), 0, &buf)
return buf.String()
}
// prettify will recursively walk value v to build a textual
// representation of the value.
func prettify(v reflect.Value, indent int, buf *bytes.Buffer) {
for v.Kind() == reflect.Ptr {
v = v.Elem()
}
switch v.Kind() {
case reflect.Struct:
strtype := v.Type().String()
if strtype == "time.Time" {
fmt.Fprintf(buf, "%s", v.Interface())
break
} else if strings.HasPrefix(strtype, "io.") {
buf.WriteString("<buffer>")
break
}
buf.WriteString("{\n")
names := []string{}
for i := 0; i < v.Type().NumField(); i++ {
name := v.Type().Field(i).Name
f := v.Field(i)
if name[0:1] == strings.ToLower(name[0:1]) {
continue // ignore unexported fields
}
if (f.Kind() == reflect.Ptr || f.Kind() == reflect.Slice || f.Kind() == reflect.Map) && f.IsNil() {
continue // ignore unset fields
}
names = append(names, name)
}
for i, n := range names {
val := v.FieldByName(n)
buf.WriteString(strings.Repeat(" ", indent+2))
buf.WriteString(n + ": ")
prettify(val, indent+2, buf)
if i < len(names)-1 {
buf.WriteString(",\n")
}
}
buf.WriteString("\n" + strings.Repeat(" ", indent) + "}")
case reflect.Slice:
nl, id, id2 := "", "", ""
if v.Len() > 3 {
nl, id, id2 = "\n", strings.Repeat(" ", indent), strings.Repeat(" ", indent+2)
}
buf.WriteString("[" + nl)
for i := 0; i < v.Len(); i++ {
buf.WriteString(id2)
prettify(v.Index(i), indent+2, buf)
if i < v.Len()-1 {
buf.WriteString("," + nl)
}
}
buf.WriteString(nl + id + "]")
case reflect.Map:
buf.WriteString("{\n")
for i, k := range v.MapKeys() {
buf.WriteString(strings.Repeat(" ", indent+2))
buf.WriteString(k.String() + ": ")
prettify(v.MapIndex(k), indent+2, buf)
if i < v.Len()-1 {
buf.WriteString(",\n")
}
}
buf.WriteString("\n" + strings.Repeat(" ", indent) + "}")
default:
format := "%v"
switch v.Interface().(type) {
case string:
format = "%q"
case io.ReadSeeker, io.Reader:
format = "buffer(%p)"
}
fmt.Fprintf(buf, format, v.Interface())
}
}

View File

@@ -0,0 +1,254 @@
package aws
import (
"net/http"
"os"
"time"
"github.com/aws/aws-sdk-go/aws/credentials"
)
// DefaultChainCredentials is a Credentials which will find the first available
// credentials Value from the list of Providers.
//
// This should be used in the default case. Once the type of credentials are
// known switching to the specific Credentials will be more efficient.
var DefaultChainCredentials = credentials.NewChainCredentials(
[]credentials.Provider{
&credentials.EnvProvider{},
&credentials.SharedCredentialsProvider{Filename: "", Profile: ""},
&credentials.EC2RoleProvider{ExpiryWindow: 5 * time.Minute},
})
// The default number of retries for a service. The value of -1 indicates that
// the service specific retry default will be used.
const DefaultRetries = -1
// DefaultConfig is the default all service configuration will be based off of.
// By default, all clients use this structure for initialization options unless
// a custom configuration object is passed in.
//
// You may modify this global structure to change all default configuration
// in the SDK. Note that configuration options are copied by value, so any
// modifications must happen before constructing a client.
var DefaultConfig = NewConfig().
WithCredentials(DefaultChainCredentials).
WithRegion(os.Getenv("AWS_REGION")).
WithHTTPClient(http.DefaultClient).
WithMaxRetries(DefaultRetries).
WithLogger(NewDefaultLogger()).
WithLogLevel(LogOff)
// A Config provides service configuration for service clients. By default,
// all clients will use the {DefaultConfig} structure.
type Config struct {
// The credentials object to use when signing requests. Defaults to
// {DefaultChainCredentials}.
Credentials *credentials.Credentials
// An optional endpoint URL (hostname only or fully qualified URI)
// that overrides the default generated endpoint for a client. Set this
// to `""` to use the default generated endpoint.
//
// @note You must still provide a `Region` value when specifying an
// endpoint for a client.
Endpoint *string
// The region to send requests to. This parameter is required and must
// be configured globally or on a per-client basis unless otherwise
// noted. A full list of regions is found in the "Regions and Endpoints"
// document.
//
// @see http://docs.aws.amazon.com/general/latest/gr/rande.html
// AWS Regions and Endpoints
Region *string
// Set this to `true` to disable SSL when sending requests. Defaults
// to `false`.
DisableSSL *bool
// The HTTP client to use when sending requests. Defaults to
// `http.DefaultClient`.
HTTPClient *http.Client
// An integer value representing the logging level. The default log level
// is zero (LogOff), which represents no logging. To enable logging set
// to a LogLevel Value.
LogLevel *LogLevelType
// The logger writer interface to write logging messages to. Defaults to
// standard out.
Logger Logger
// The maximum number of times that a request will be retried for failures.
// Defaults to -1, which defers the max retry setting to the service specific
// configuration.
MaxRetries *int
// Disables semantic parameter validation, which validates input for missing
// required fields and/or other semantic request input errors.
DisableParamValidation *bool
// Disables the computation of request and response checksums, e.g.,
// CRC32 checksums in Amazon DynamoDB.
DisableComputeChecksums *bool
// Set this to `true` to force the request to use path-style addressing,
// i.e., `http://s3.amazonaws.com/BUCKET/KEY`. By default, the S3 client will
// use virtual hosted bucket addressing when possible
// (`http://BUCKET.s3.amazonaws.com/KEY`).
//
// @note This configuration option is specific to the Amazon S3 service.
// @see http://docs.aws.amazon.com/AmazonS3/latest/dev/VirtualHosting.html
// Amazon S3: Virtual Hosting of Buckets
S3ForcePathStyle *bool
}
// NewConfig returns a new Config pointer that can be chained with builder methods to
// set multiple configuration values inline without using pointers.
//
// svc := s3.New(aws.NewConfig().WithRegion("us-west-2").WithMaxRetries(10))
//
func NewConfig() *Config {
return &Config{}
}
// WithCredentials sets a config Credentials value returning a Config pointer
// for chaining.
func (c *Config) WithCredentials(creds *credentials.Credentials) *Config {
c.Credentials = creds
return c
}
// WithEndpoint sets a config Endpoint value returning a Config pointer for
// chaining.
func (c *Config) WithEndpoint(endpoint string) *Config {
c.Endpoint = &endpoint
return c
}
// WithRegion sets a config Region value returning a Config pointer for
// chaining.
func (c *Config) WithRegion(region string) *Config {
c.Region = &region
return c
}
// WithDisableSSL sets a config DisableSSL value returning a Config pointer
// for chaining.
func (c *Config) WithDisableSSL(disable bool) *Config {
c.DisableSSL = &disable
return c
}
// WithHTTPClient sets a config HTTPClient value returning a Config pointer
// for chaining.
func (c *Config) WithHTTPClient(client *http.Client) *Config {
c.HTTPClient = client
return c
}
// WithMaxRetries sets a config MaxRetries value returning a Config pointer
// for chaining.
func (c *Config) WithMaxRetries(max int) *Config {
c.MaxRetries = &max
return c
}
// WithDisableParamValidation sets a config DisableParamValidation value
// returning a Config pointer for chaining.
func (c *Config) WithDisableParamValidation(disable bool) *Config {
c.DisableParamValidation = &disable
return c
}
// WithDisableComputeChecksums sets a config DisableComputeChecksums value
// returning a Config pointer for chaining.
func (c *Config) WithDisableComputeChecksums(disable bool) *Config {
c.DisableComputeChecksums = &disable
return c
}
// WithLogLevel sets a config LogLevel value returning a Config pointer for
// chaining.
func (c *Config) WithLogLevel(level LogLevelType) *Config {
c.LogLevel = &level
return c
}
// WithLogger sets a config Logger value returning a Config pointer for
// chaining.
func (c *Config) WithLogger(logger Logger) *Config {
c.Logger = logger
return c
}
// WithS3ForcePathStyle sets a config S3ForcePathStyle value returning a Config
// pointer for chaining.
func (c *Config) WithS3ForcePathStyle(force bool) *Config {
c.S3ForcePathStyle = &force
return c
}
// Merge returns a new Config with the other Config's attribute values merged into
// this Config. If the other Config's attribute is nil it will not be merged into
// the new Config to be returned.
func (c Config) Merge(other *Config) *Config {
if other == nil {
return &c
}
dst := c
if other.Credentials != nil {
dst.Credentials = other.Credentials
}
if other.Endpoint != nil {
dst.Endpoint = other.Endpoint
}
if other.Region != nil {
dst.Region = other.Region
}
if other.DisableSSL != nil {
dst.DisableSSL = other.DisableSSL
}
if other.HTTPClient != nil {
dst.HTTPClient = other.HTTPClient
}
if other.LogLevel != nil {
dst.LogLevel = other.LogLevel
}
if other.Logger != nil {
dst.Logger = other.Logger
}
if other.MaxRetries != nil {
dst.MaxRetries = other.MaxRetries
}
if other.DisableParamValidation != nil {
dst.DisableParamValidation = other.DisableParamValidation
}
if other.DisableComputeChecksums != nil {
dst.DisableComputeChecksums = other.DisableComputeChecksums
}
if other.S3ForcePathStyle != nil {
dst.S3ForcePathStyle = other.S3ForcePathStyle
}
return &dst
}
// Copy will return a shallow copy of the Config object.
func (c Config) Copy() *Config {
dst := c
return &dst
}

View File

@@ -0,0 +1,87 @@
package aws
import (
"net/http"
"reflect"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws/credentials"
)
var testCredentials = credentials.NewChainCredentials([]credentials.Provider{
&credentials.EnvProvider{},
&credentials.SharedCredentialsProvider{
Filename: "TestFilename",
Profile: "TestProfile"},
&credentials.EC2RoleProvider{ExpiryWindow: 5 * time.Minute},
})
var copyTestConfig = Config{
Credentials: testCredentials,
Endpoint: String("CopyTestEndpoint"),
Region: String("COPY_TEST_AWS_REGION"),
DisableSSL: Bool(true),
HTTPClient: http.DefaultClient,
LogLevel: LogLevel(LogDebug),
Logger: NewDefaultLogger(),
MaxRetries: Int(DefaultRetries),
DisableParamValidation: Bool(true),
DisableComputeChecksums: Bool(true),
S3ForcePathStyle: Bool(true),
}
func TestCopy(t *testing.T) {
want := copyTestConfig
got := copyTestConfig.Copy()
if !reflect.DeepEqual(*got, want) {
t.Errorf("Copy() = %+v", got)
t.Errorf(" want %+v", want)
}
}
func TestCopyReturnsNewInstance(t *testing.T) {
want := copyTestConfig
got := copyTestConfig.Copy()
if got == &want {
t.Errorf("Copy() = %p; want different instance as source %p", got, &want)
}
}
var mergeTestZeroValueConfig = Config{}
var mergeTestConfig = Config{
Credentials: testCredentials,
Endpoint: String("MergeTestEndpoint"),
Region: String("MERGE_TEST_AWS_REGION"),
DisableSSL: Bool(true),
HTTPClient: http.DefaultClient,
LogLevel: LogLevel(LogDebug),
Logger: NewDefaultLogger(),
MaxRetries: Int(10),
DisableParamValidation: Bool(true),
DisableComputeChecksums: Bool(true),
S3ForcePathStyle: Bool(true),
}
var mergeTests = []struct {
cfg *Config
in *Config
want *Config
}{
{&Config{}, nil, &Config{}},
{&Config{}, &mergeTestZeroValueConfig, &Config{}},
{&Config{}, &mergeTestConfig, &mergeTestConfig},
}
func TestMerge(t *testing.T) {
for i, tt := range mergeTests {
got := tt.cfg.Merge(tt.in)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Config %d %+v", i, tt.cfg)
t.Errorf(" Merge(%+v)", tt.in)
t.Errorf(" got %+v", got)
t.Errorf(" want %+v", tt.want)
}
}
}

View File

@@ -0,0 +1,357 @@
package aws
import "time"
// String returns a pointer to of the string value passed in.
func String(v string) *string {
return &v
}
// StringValue returns the value of the string pointer passed in or
// "" if the pointer is nil.
func StringValue(v *string) string {
if v != nil {
return *v
}
return ""
}
// StringSlice converts a slice of string values into a slice of
// string pointers
func StringSlice(src []string) []*string {
dst := make([]*string, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// StringValueSlice converts a slice of string pointers into a slice of
// string values
func StringValueSlice(src []*string) []string {
dst := make([]string, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// StringMap converts a string map of string values into a string
// map of string pointers
func StringMap(src map[string]string) map[string]*string {
dst := make(map[string]*string)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// StringValueMap converts a string map of string pointers into a string
// map of string values
func StringValueMap(src map[string]*string) map[string]string {
dst := make(map[string]string)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Bool returns a pointer to of the bool value passed in.
func Bool(v bool) *bool {
return &v
}
// BoolValue returns the value of the bool pointer passed in or
// false if the pointer is nil.
func BoolValue(v *bool) bool {
if v != nil {
return *v
}
return false
}
// BoolSlice converts a slice of bool values into a slice of
// bool pointers
func BoolSlice(src []bool) []*bool {
dst := make([]*bool, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// BoolValueSlice converts a slice of bool pointers into a slice of
// bool values
func BoolValueSlice(src []*bool) []bool {
dst := make([]bool, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// BoolMap converts a string map of bool values into a string
// map of bool pointers
func BoolMap(src map[string]bool) map[string]*bool {
dst := make(map[string]*bool)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// BoolValueMap converts a string map of bool pointers into a string
// map of bool values
func BoolValueMap(src map[string]*bool) map[string]bool {
dst := make(map[string]bool)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Int returns a pointer to of the int value passed in.
func Int(v int) *int {
return &v
}
// IntValue returns the value of the int pointer passed in or
// 0 if the pointer is nil.
func IntValue(v *int) int {
if v != nil {
return *v
}
return 0
}
// IntSlice converts a slice of int values into a slice of
// int pointers
func IntSlice(src []int) []*int {
dst := make([]*int, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// IntValueSlice converts a slice of int pointers into a slice of
// int values
func IntValueSlice(src []*int) []int {
dst := make([]int, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// IntMap converts a string map of int values into a string
// map of int pointers
func IntMap(src map[string]int) map[string]*int {
dst := make(map[string]*int)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// IntValueMap converts a string map of int pointers into a string
// map of int values
func IntValueMap(src map[string]*int) map[string]int {
dst := make(map[string]int)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Int64 returns a pointer to of the int64 value passed in.
func Int64(v int64) *int64 {
return &v
}
// Int64Value returns the value of the int64 pointer passed in or
// 0 if the pointer is nil.
func Int64Value(v *int64) int64 {
if v != nil {
return *v
}
return 0
}
// Int64Slice converts a slice of int64 values into a slice of
// int64 pointers
func Int64Slice(src []int64) []*int64 {
dst := make([]*int64, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// Int64ValueSlice converts a slice of int64 pointers into a slice of
// int64 values
func Int64ValueSlice(src []*int64) []int64 {
dst := make([]int64, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// Int64Map converts a string map of int64 values into a string
// map of int64 pointers
func Int64Map(src map[string]int64) map[string]*int64 {
dst := make(map[string]*int64)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// Int64ValueMap converts a string map of int64 pointers into a string
// map of int64 values
func Int64ValueMap(src map[string]*int64) map[string]int64 {
dst := make(map[string]int64)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Float64 returns a pointer to of the float64 value passed in.
func Float64(v float64) *float64 {
return &v
}
// Float64Value returns the value of the float64 pointer passed in or
// 0 if the pointer is nil.
func Float64Value(v *float64) float64 {
if v != nil {
return *v
}
return 0
}
// Float64Slice converts a slice of float64 values into a slice of
// float64 pointers
func Float64Slice(src []float64) []*float64 {
dst := make([]*float64, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// Float64ValueSlice converts a slice of float64 pointers into a slice of
// float64 values
func Float64ValueSlice(src []*float64) []float64 {
dst := make([]float64, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// Float64Map converts a string map of float64 values into a string
// map of float64 pointers
func Float64Map(src map[string]float64) map[string]*float64 {
dst := make(map[string]*float64)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// Float64ValueMap converts a string map of float64 pointers into a string
// map of float64 values
func Float64ValueMap(src map[string]*float64) map[string]float64 {
dst := make(map[string]float64)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}
// Time returns a pointer to of the time.Time value passed in.
func Time(v time.Time) *time.Time {
return &v
}
// TimeValue returns the value of the time.Time pointer passed in or
// time.Time{} if the pointer is nil.
func TimeValue(v *time.Time) time.Time {
if v != nil {
return *v
}
return time.Time{}
}
// TimeSlice converts a slice of time.Time values into a slice of
// time.Time pointers
func TimeSlice(src []time.Time) []*time.Time {
dst := make([]*time.Time, len(src))
for i := 0; i < len(src); i++ {
dst[i] = &(src[i])
}
return dst
}
// TimeValueSlice converts a slice of time.Time pointers into a slice of
// time.Time values
func TimeValueSlice(src []*time.Time) []time.Time {
dst := make([]time.Time, len(src))
for i := 0; i < len(src); i++ {
if src[i] != nil {
dst[i] = *(src[i])
}
}
return dst
}
// TimeMap converts a string map of time.Time values into a string
// map of time.Time pointers
func TimeMap(src map[string]time.Time) map[string]*time.Time {
dst := make(map[string]*time.Time)
for k, val := range src {
v := val
dst[k] = &v
}
return dst
}
// TimeValueMap converts a string map of time.Time pointers into a string
// map of time.Time values
func TimeValueMap(src map[string]*time.Time) map[string]time.Time {
dst := make(map[string]time.Time)
for k, val := range src {
if val != nil {
dst[k] = *val
}
}
return dst
}

View File

@@ -0,0 +1,438 @@
package aws_test
import (
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/stretchr/testify/assert"
)
var testCasesStringSlice = [][]string{
{"a", "b", "c", "d", "e"},
{"a", "b", "", "", "e"},
}
func TestStringSlice(t *testing.T) {
for idx, in := range testCasesStringSlice {
if in == nil {
continue
}
out := aws.StringSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.StringValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}
var testCasesStringValueSlice = [][]*string{
{aws.String("a"), aws.String("b"), nil, aws.String("c")},
}
func TestStringValueSlice(t *testing.T) {
for idx, in := range testCasesStringValueSlice {
if in == nil {
continue
}
out := aws.StringValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
if in[i] == nil {
assert.Empty(t, out[i], "Unexpected value at idx %d", idx)
} else {
assert.Equal(t, *(in[i]), out[i], "Unexpected value at idx %d", idx)
}
}
out2 := aws.StringSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
for i := range out2 {
if in[i] == nil {
assert.Empty(t, *(out2[i]), "Unexpected value at idx %d", idx)
} else {
assert.Equal(t, in[i], out2[i], "Unexpected value at idx %d", idx)
}
}
}
}
var testCasesStringMap = []map[string]string{
{"a": "1", "b": "2", "c": "3"},
}
func TestStringMap(t *testing.T) {
for idx, in := range testCasesStringMap {
if in == nil {
continue
}
out := aws.StringMap(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.StringValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}
var testCasesBoolSlice = [][]bool{
{true, true, false, false},
}
func TestBoolSlice(t *testing.T) {
for idx, in := range testCasesBoolSlice {
if in == nil {
continue
}
out := aws.BoolSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.BoolValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}
var testCasesBoolValueSlice = [][]*bool{}
func TestBoolValueSlice(t *testing.T) {
for idx, in := range testCasesBoolValueSlice {
if in == nil {
continue
}
out := aws.BoolValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
if in[i] == nil {
assert.Empty(t, out[i], "Unexpected value at idx %d", idx)
} else {
assert.Equal(t, *(in[i]), out[i], "Unexpected value at idx %d", idx)
}
}
out2 := aws.BoolSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
for i := range out2 {
if in[i] == nil {
assert.Empty(t, *(out2[i]), "Unexpected value at idx %d", idx)
} else {
assert.Equal(t, in[i], out2[i], "Unexpected value at idx %d", idx)
}
}
}
}
var testCasesBoolMap = []map[string]bool{
{"a": true, "b": false, "c": true},
}
func TestBoolMap(t *testing.T) {
for idx, in := range testCasesBoolMap {
if in == nil {
continue
}
out := aws.BoolMap(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.BoolValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}
var testCasesIntSlice = [][]int{
{1, 2, 3, 4},
}
func TestIntSlice(t *testing.T) {
for idx, in := range testCasesIntSlice {
if in == nil {
continue
}
out := aws.IntSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.IntValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}
var testCasesIntValueSlice = [][]*int{}
func TestIntValueSlice(t *testing.T) {
for idx, in := range testCasesIntValueSlice {
if in == nil {
continue
}
out := aws.IntValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
if in[i] == nil {
assert.Empty(t, out[i], "Unexpected value at idx %d", idx)
} else {
assert.Equal(t, *(in[i]), out[i], "Unexpected value at idx %d", idx)
}
}
out2 := aws.IntSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
for i := range out2 {
if in[i] == nil {
assert.Empty(t, *(out2[i]), "Unexpected value at idx %d", idx)
} else {
assert.Equal(t, in[i], out2[i], "Unexpected value at idx %d", idx)
}
}
}
}
var testCasesIntMap = []map[string]int{
{"a": 3, "b": 2, "c": 1},
}
func TestIntMap(t *testing.T) {
for idx, in := range testCasesIntMap {
if in == nil {
continue
}
out := aws.IntMap(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.IntValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}
var testCasesInt64Slice = [][]int64{
{1, 2, 3, 4},
}
func TestInt64Slice(t *testing.T) {
for idx, in := range testCasesInt64Slice {
if in == nil {
continue
}
out := aws.Int64Slice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.Int64ValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}
var testCasesInt64ValueSlice = [][]*int64{}
func TestInt64ValueSlice(t *testing.T) {
for idx, in := range testCasesInt64ValueSlice {
if in == nil {
continue
}
out := aws.Int64ValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
if in[i] == nil {
assert.Empty(t, out[i], "Unexpected value at idx %d", idx)
} else {
assert.Equal(t, *(in[i]), out[i], "Unexpected value at idx %d", idx)
}
}
out2 := aws.Int64Slice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
for i := range out2 {
if in[i] == nil {
assert.Empty(t, *(out2[i]), "Unexpected value at idx %d", idx)
} else {
assert.Equal(t, in[i], out2[i], "Unexpected value at idx %d", idx)
}
}
}
}
var testCasesInt64Map = []map[string]int64{
{"a": 3, "b": 2, "c": 1},
}
func TestInt64Map(t *testing.T) {
for idx, in := range testCasesInt64Map {
if in == nil {
continue
}
out := aws.Int64Map(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.Int64ValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}
var testCasesFloat64Slice = [][]float64{
{1, 2, 3, 4},
}
func TestFloat64Slice(t *testing.T) {
for idx, in := range testCasesFloat64Slice {
if in == nil {
continue
}
out := aws.Float64Slice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.Float64ValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}
var testCasesFloat64ValueSlice = [][]*float64{}
func TestFloat64ValueSlice(t *testing.T) {
for idx, in := range testCasesFloat64ValueSlice {
if in == nil {
continue
}
out := aws.Float64ValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
if in[i] == nil {
assert.Empty(t, out[i], "Unexpected value at idx %d", idx)
} else {
assert.Equal(t, *(in[i]), out[i], "Unexpected value at idx %d", idx)
}
}
out2 := aws.Float64Slice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
for i := range out2 {
if in[i] == nil {
assert.Empty(t, *(out2[i]), "Unexpected value at idx %d", idx)
} else {
assert.Equal(t, in[i], out2[i], "Unexpected value at idx %d", idx)
}
}
}
}
var testCasesFloat64Map = []map[string]float64{
{"a": 3, "b": 2, "c": 1},
}
func TestFloat64Map(t *testing.T) {
for idx, in := range testCasesFloat64Map {
if in == nil {
continue
}
out := aws.Float64Map(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.Float64ValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}
var testCasesTimeSlice = [][]time.Time{
{time.Now(), time.Now().AddDate(100, 0, 0)},
}
func TestTimeSlice(t *testing.T) {
for idx, in := range testCasesTimeSlice {
if in == nil {
continue
}
out := aws.TimeSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.TimeValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}
var testCasesTimeValueSlice = [][]*time.Time{}
func TestTimeValueSlice(t *testing.T) {
for idx, in := range testCasesTimeValueSlice {
if in == nil {
continue
}
out := aws.TimeValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
if in[i] == nil {
assert.Empty(t, out[i], "Unexpected value at idx %d", idx)
} else {
assert.Equal(t, *(in[i]), out[i], "Unexpected value at idx %d", idx)
}
}
out2 := aws.TimeSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
for i := range out2 {
if in[i] == nil {
assert.Empty(t, *(out2[i]), "Unexpected value at idx %d", idx)
} else {
assert.Equal(t, in[i], out2[i], "Unexpected value at idx %d", idx)
}
}
}
}
var testCasesTimeMap = []map[string]time.Time{
{"a": time.Now().AddDate(-100, 0, 0), "b": time.Now()},
}
func TestTimeMap(t *testing.T) {
for idx, in := range testCasesTimeMap {
if in == nil {
continue
}
out := aws.TimeMap(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.TimeValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}

View File

@@ -0,0 +1,85 @@
package credentials
import (
"github.com/aws/aws-sdk-go/aws/awserr"
)
var (
// ErrNoValidProvidersFoundInChain Is returned when there are no valid
// providers in the ChainProvider.
//
// @readonly
ErrNoValidProvidersFoundInChain = awserr.New("NoCredentialProviders", "no valid providers in chain", nil)
)
// A ChainProvider will search for a provider which returns credentials
// and cache that provider until Retrieve is called again.
//
// The ChainProvider provides a way of chaining multiple providers together
// which will pick the first available using priority order of the Providers
// in the list.
//
// If none of the Providers retrieve valid credentials Value, ChainProvider's
// Retrieve() will return the error ErrNoValidProvidersFoundInChain.
//
// If a Provider is found which returns valid credentials Value ChainProvider
// will cache that Provider for all calls to IsExpired(), until Retrieve is
// called again.
//
// Example of ChainProvider to be used with an EnvProvider and EC2RoleProvider.
// In this example EnvProvider will first check if any credentials are available
// vai the environment variables. If there are none ChainProvider will check
// the next Provider in the list, EC2RoleProvider in this case. If EC2RoleProvider
// does not return any credentials ChainProvider will return the error
// ErrNoValidProvidersFoundInChain
//
// creds := NewChainCredentials(
// []Provider{
// &EnvProvider{},
// &EC2RoleProvider{},
// })
//
// // Usage of ChainCredentials with aws.Config
// svc := ec2.New(&aws.Config{Credentials: creds})
//
type ChainProvider struct {
Providers []Provider
curr Provider
}
// NewChainCredentials returns a pointer to a new Credentials object
// wrapping a chain of providers.
func NewChainCredentials(providers []Provider) *Credentials {
return NewCredentials(&ChainProvider{
Providers: append([]Provider{}, providers...),
})
}
// Retrieve returns the credentials value or error if no provider returned
// without error.
//
// If a provider is found it will be cached and any calls to IsExpired()
// will return the expired state of the cached provider.
func (c *ChainProvider) Retrieve() (Value, error) {
for _, p := range c.Providers {
if creds, err := p.Retrieve(); err == nil {
c.curr = p
return creds, nil
}
}
c.curr = nil
// TODO better error reporting. maybe report error for each failed retrieve?
return Value{}, ErrNoValidProvidersFoundInChain
}
// IsExpired will returned the expired state of the currently cached provider
// if there is one. If there is no current provider, true will be returned.
func (c *ChainProvider) IsExpired() bool {
if c.curr != nil {
return c.curr.IsExpired()
}
return true
}

View File

@@ -0,0 +1,73 @@
package credentials
import (
"testing"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/stretchr/testify/assert"
)
func TestChainProviderGet(t *testing.T) {
p := &ChainProvider{
Providers: []Provider{
&stubProvider{err: awserr.New("FirstError", "first provider error", nil)},
&stubProvider{err: awserr.New("SecondError", "second provider error", nil)},
&stubProvider{
creds: Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "",
},
},
},
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "AKID", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "SECRET", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect session token to be empty")
}
func TestChainProviderIsExpired(t *testing.T) {
stubProvider := &stubProvider{expired: true}
p := &ChainProvider{
Providers: []Provider{
stubProvider,
},
}
assert.True(t, p.IsExpired(), "Expect expired to be true before any Retrieve")
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.False(t, p.IsExpired(), "Expect not expired after retrieve")
stubProvider.expired = true
assert.True(t, p.IsExpired(), "Expect return of expired provider")
_, err = p.Retrieve()
assert.False(t, p.IsExpired(), "Expect not expired after retrieve")
}
func TestChainProviderWithNoProvider(t *testing.T) {
p := &ChainProvider{
Providers: []Provider{},
}
assert.True(t, p.IsExpired(), "Expect expired with no providers")
_, err := p.Retrieve()
assert.Equal(t, ErrNoValidProvidersFoundInChain, err, "Expect no providers error returned")
}
func TestChainProviderWithNoValidProvider(t *testing.T) {
p := &ChainProvider{
Providers: []Provider{
&stubProvider{err: awserr.New("FirstError", "first provider error", nil)},
&stubProvider{err: awserr.New("SecondError", "second provider error", nil)},
},
}
assert.True(t, p.IsExpired(), "Expect expired with no providers")
_, err := p.Retrieve()
assert.Equal(t, ErrNoValidProvidersFoundInChain, err, "Expect no providers error returned")
}

View File

@@ -0,0 +1,220 @@
// Package credentials provides credential retrieval and management
//
// The Credentials is the primary method of getting access to and managing
// credentials Values. Using dependency injection retrieval of the credential
// values is handled by a object which satisfies the Provider interface.
//
// By default the Credentials.Get() will cache the successful result of a
// Provider's Retrieve() until Provider.IsExpired() returns true. At which
// point Credentials will call Provider's Retrieve() to get new credential Value.
//
// The Provider is responsible for determining when credentials Value have expired.
// It is also important to note that Credentials will always call Retrieve the
// first time Credentials.Get() is called.
//
// Example of using the environment variable credentials.
//
// creds := NewEnvCredentials()
//
// // Retrieve the credentials value
// credValue, err := creds.Get()
// if err != nil {
// // handle error
// }
//
// Example of forcing credentials to expire and be refreshed on the next Get().
// This may be helpful to proactively expire credentials and refresh them sooner
// than they would naturally expire on their own.
//
// creds := NewCredentials(&EC2RoleProvider{})
// creds.Expire()
// credsValue, err := creds.Get()
// // New credentials will be retrieved instead of from cache.
//
//
// Custom Provider
//
// Each Provider built into this package also provides a helper method to generate
// a Credentials pointer setup with the provider. To use a custom Provider just
// create a type which satisfies the Provider interface and pass it to the
// NewCredentials method.
//
// type MyProvider struct{}
// func (m *MyProvider) Retrieve() (Value, error) {...}
// func (m *MyProvider) IsExpired() bool {...}
//
// creds := NewCredentials(&MyProvider{})
// credValue, err := creds.Get()
//
package credentials
import (
"sync"
"time"
)
// Create an empty Credential object that can be used as dummy placeholder
// credentials for requests that do not need signed.
//
// This Credentials can be used to configure a service to not sign requests
// when making service API calls. For example, when accessing public
// s3 buckets.
//
// svc := s3.New(&aws.Config{Credentials: AnonymousCredentials})
// // Access public S3 buckets.
//
// @readonly
var AnonymousCredentials = NewStaticCredentials("", "", "")
// A Value is the AWS credentials value for individual credential fields.
type Value struct {
// AWS Access key ID
AccessKeyID string
// AWS Secret Access Key
SecretAccessKey string
// AWS Session Token
SessionToken string
}
// A Provider is the interface for any component which will provide credentials
// Value. A provider is required to manage its own Expired state, and what to
// be expired means.
//
// The Provider should not need to implement its own mutexes, because
// that will be managed by Credentials.
type Provider interface {
// Refresh returns nil if it successfully retrieved the value.
// Error is returned if the value were not obtainable, or empty.
Retrieve() (Value, error)
// IsExpired returns if the credentials are no longer valid, and need
// to be retrieved.
IsExpired() bool
}
// A Expiry provides shared expiration logic to be used by credentials
// providers to implement expiry functionality.
//
// The best method to use this struct is as an anonymous field within the
// provider's struct.
//
// Example:
// type EC2RoleProvider struct {
// Expiry
// ...
// }
type Expiry struct {
// The date/time when to expire on
expiration time.Time
// If set will be used by IsExpired to determine the current time.
// Defaults to time.Now if CurrentTime is not set. Available for testing
// to be able to mock out the current time.
CurrentTime func() time.Time
}
// SetExpiration sets the expiration IsExpired will check when called.
//
// If window is greater than 0 the expiration time will be reduced by the
// window value.
//
// Using a window is helpful to trigger credentials to expire sooner than
// the expiration time given to ensure no requests are made with expired
// tokens.
func (e *Expiry) SetExpiration(expiration time.Time, window time.Duration) {
e.expiration = expiration
if window > 0 {
e.expiration = e.expiration.Add(-window)
}
}
// IsExpired returns if the credentials are expired.
func (e *Expiry) IsExpired() bool {
if e.CurrentTime == nil {
e.CurrentTime = time.Now
}
return e.expiration.Before(e.CurrentTime())
}
// A Credentials provides synchronous safe retrieval of AWS credentials Value.
// Credentials will cache the credentials value until they expire. Once the value
// expires the next Get will attempt to retrieve valid credentials.
//
// Credentials is safe to use across multiple goroutines and will manage the
// synchronous state so the Providers do not need to implement their own
// synchronization.
//
// The first Credentials.Get() will always call Provider.Retrieve() to get the
// first instance of the credentials Value. All calls to Get() after that
// will return the cached credentials Value until IsExpired() returns true.
type Credentials struct {
creds Value
forceRefresh bool
m sync.Mutex
provider Provider
}
// NewCredentials returns a pointer to a new Credentials with the provider set.
func NewCredentials(provider Provider) *Credentials {
return &Credentials{
provider: provider,
forceRefresh: true,
}
}
// Get returns the credentials value, or error if the credentials Value failed
// to be retrieved.
//
// Will return the cached credentials Value if it has not expired. If the
// credentials Value has expired the Provider's Retrieve() will be called
// to refresh the credentials.
//
// If Credentials.Expire() was called the credentials Value will be force
// expired, and the next call to Get() will cause them to be refreshed.
func (c *Credentials) Get() (Value, error) {
c.m.Lock()
defer c.m.Unlock()
if c.isExpired() {
creds, err := c.provider.Retrieve()
if err != nil {
return Value{}, err
}
c.creds = creds
c.forceRefresh = false
}
return c.creds, nil
}
// Expire expires the credentials and forces them to be retrieved on the
// next call to Get().
//
// This will override the Provider's expired state, and force Credentials
// to call the Provider's Retrieve().
func (c *Credentials) Expire() {
c.m.Lock()
defer c.m.Unlock()
c.forceRefresh = true
}
// IsExpired returns if the credentials are no longer valid, and need
// to be retrieved.
//
// If the Credentials were forced to be expired with Expire() this will
// reflect that override.
func (c *Credentials) IsExpired() bool {
c.m.Lock()
defer c.m.Unlock()
return c.isExpired()
}
// isExpired helper method wrapping the definition of expired credentials.
func (c *Credentials) isExpired() bool {
return c.forceRefresh || c.provider.IsExpired()
}

View File

@@ -0,0 +1,62 @@
package credentials
import (
"testing"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/stretchr/testify/assert"
)
type stubProvider struct {
creds Value
expired bool
err error
}
func (s *stubProvider) Retrieve() (Value, error) {
s.expired = false
return s.creds, s.err
}
func (s *stubProvider) IsExpired() bool {
return s.expired
}
func TestCredentialsGet(t *testing.T) {
c := NewCredentials(&stubProvider{
creds: Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "",
},
expired: true,
})
creds, err := c.Get()
assert.Nil(t, err, "Expected no error")
assert.Equal(t, "AKID", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "SECRET", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect session token to be empty")
}
func TestCredentialsGetWithError(t *testing.T) {
c := NewCredentials(&stubProvider{err: awserr.New("provider error", "", nil), expired: true})
_, err := c.Get()
assert.Equal(t, "provider error", err.(awserr.Error).Code(), "Expected provider error")
}
func TestCredentialsExpire(t *testing.T) {
stub := &stubProvider{}
c := NewCredentials(stub)
stub.expired = false
assert.True(t, c.IsExpired(), "Expected to start out expired")
c.Expire()
assert.True(t, c.IsExpired(), "Expected to be expired")
c.forceRefresh = false
assert.False(t, c.IsExpired(), "Expected not to be expired")
stub.expired = true
assert.True(t, c.IsExpired(), "Expected to be expired")
}

View File

@@ -0,0 +1,162 @@
package credentials
import (
"bufio"
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
)
const metadataCredentialsEndpoint = "http://169.254.169.254/latest/meta-data/iam/security-credentials/"
// A EC2RoleProvider retrieves credentials from the EC2 service, and keeps track if
// those credentials are expired.
//
// Example how to configure the EC2RoleProvider with custom http Client, Endpoint
// or ExpiryWindow
//
// p := &credentials.EC2RoleProvider{
// // Pass in a custom timeout to be used when requesting
// // IAM EC2 Role credentials.
// Client: &http.Client{
// Timeout: 10 * time.Second,
// },
// // Use default EC2 Role metadata endpoint, Alternate endpoints can be
// // specified setting Endpoint to something else.
// Endpoint: "",
// // Do not use early expiry of credentials. If a non zero value is
// // specified the credentials will be expired early
// ExpiryWindow: 0,
// }
type EC2RoleProvider struct {
Expiry
// Endpoint must be fully quantified URL
Endpoint string
// HTTP client to use when connecting to EC2 service
Client *http.Client
// ExpiryWindow will allow the credentials to trigger refreshing prior to
// the credentials actually expiring. This is beneficial so race conditions
// with expiring credentials do not cause request to fail unexpectedly
// due to ExpiredTokenException exceptions.
//
// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
// 10 seconds before the credentials are actually expired.
//
// If ExpiryWindow is 0 or less it will be ignored.
ExpiryWindow time.Duration
}
// NewEC2RoleCredentials returns a pointer to a new Credentials object
// wrapping the EC2RoleProvider.
//
// Takes a custom http.Client which can be configured for custom handling of
// things such as timeout.
//
// Endpoint is the URL that the EC2RoleProvider will connect to when retrieving
// role and credentials.
//
// Window is the expiry window that will be subtracted from the expiry returned
// by the role credential request. This is done so that the credentials will
// expire sooner than their actual lifespan.
func NewEC2RoleCredentials(client *http.Client, endpoint string, window time.Duration) *Credentials {
return NewCredentials(&EC2RoleProvider{
Endpoint: endpoint,
Client: client,
ExpiryWindow: window,
})
}
// Retrieve retrieves credentials from the EC2 service.
// Error will be returned if the request fails, or unable to extract
// the desired credentials.
func (m *EC2RoleProvider) Retrieve() (Value, error) {
if m.Client == nil {
m.Client = http.DefaultClient
}
if m.Endpoint == "" {
m.Endpoint = metadataCredentialsEndpoint
}
credsList, err := requestCredList(m.Client, m.Endpoint)
if err != nil {
return Value{}, err
}
if len(credsList) == 0 {
return Value{}, awserr.New("EmptyEC2RoleList", "empty EC2 Role list", nil)
}
credsName := credsList[0]
roleCreds, err := requestCred(m.Client, m.Endpoint, credsName)
if err != nil {
return Value{}, err
}
m.SetExpiration(roleCreds.Expiration, m.ExpiryWindow)
return Value{
AccessKeyID: roleCreds.AccessKeyID,
SecretAccessKey: roleCreds.SecretAccessKey,
SessionToken: roleCreds.Token,
}, nil
}
// A ec2RoleCredRespBody provides the shape for deserializing credential
// request responses.
type ec2RoleCredRespBody struct {
Expiration time.Time
AccessKeyID string
SecretAccessKey string
Token string
}
// requestCredList requests a list of credentials from the EC2 service.
// If there are no credentials, or there is an error making or receiving the request
func requestCredList(client *http.Client, endpoint string) ([]string, error) {
resp, err := client.Get(endpoint)
if err != nil {
return nil, awserr.New("ListEC2Role", "failed to list EC2 Roles", err)
}
defer resp.Body.Close()
credsList := []string{}
s := bufio.NewScanner(resp.Body)
for s.Scan() {
credsList = append(credsList, s.Text())
}
if err := s.Err(); err != nil {
return nil, awserr.New("ReadEC2Role", "failed to read list of EC2 Roles", err)
}
return credsList, nil
}
// requestCred requests the credentials for a specific credentials from the EC2 service.
//
// If the credentials cannot be found, or there is an error reading the response
// and error will be returned.
func requestCred(client *http.Client, endpoint, credsName string) (*ec2RoleCredRespBody, error) {
resp, err := client.Get(endpoint + credsName)
if err != nil {
return nil, awserr.New("GetEC2RoleCredentials",
fmt.Sprintf("failed to get %s EC2 Role credentials", credsName),
err)
}
defer resp.Body.Close()
respCreds := &ec2RoleCredRespBody{}
if err := json.NewDecoder(resp.Body).Decode(respCreds); err != nil {
return nil, awserr.New("DecodeEC2RoleCredentials",
fmt.Sprintf("failed to decode %s EC2 Role credentials", credsName),
err)
}
return respCreds, nil
}

View File

@@ -0,0 +1,108 @@
package credentials
import (
"fmt"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func initTestServer(expireOn string) *httptest.Server {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.RequestURI == "/" {
fmt.Fprintln(w, "/creds")
} else {
fmt.Fprintf(w, `{
"AccessKeyId" : "accessKey",
"SecretAccessKey" : "secret",
"Token" : "token",
"Expiration" : "%s"
}`, expireOn)
}
}))
return server
}
func TestEC2RoleProvider(t *testing.T) {
server := initTestServer("2014-12-16T01:51:37Z")
defer server.Close()
p := &EC2RoleProvider{Client: http.DefaultClient, Endpoint: server.URL}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
}
func TestEC2RoleProviderIsExpired(t *testing.T) {
server := initTestServer("2014-12-16T01:51:37Z")
defer server.Close()
p := &EC2RoleProvider{Client: http.DefaultClient, Endpoint: server.URL}
p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 15, 21, 26, 0, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired before retrieve.")
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve.")
p.CurrentTime = func() time.Time {
return time.Date(3014, 12, 15, 21, 26, 0, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired.")
}
func TestEC2RoleProviderExpiryWindowIsExpired(t *testing.T) {
server := initTestServer("2014-12-16T01:51:37Z")
defer server.Close()
p := &EC2RoleProvider{Client: http.DefaultClient, Endpoint: server.URL, ExpiryWindow: time.Hour * 1}
p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 15, 0, 51, 37, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired before retrieve.")
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve.")
p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 16, 0, 55, 37, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired.")
}
func BenchmarkEC2RoleProvider(b *testing.B) {
server := initTestServer("2014-12-16T01:51:37Z")
defer server.Close()
p := &EC2RoleProvider{Client: http.DefaultClient, Endpoint: server.URL}
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
}
})
}

View File

@@ -0,0 +1,73 @@
package credentials
import (
"os"
"github.com/aws/aws-sdk-go/aws/awserr"
)
var (
// ErrAccessKeyIDNotFound is returned when the AWS Access Key ID can't be
// found in the process's environment.
//
// @readonly
ErrAccessKeyIDNotFound = awserr.New("EnvAccessKeyNotFound", "AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY not found in environment", nil)
// ErrSecretAccessKeyNotFound is returned when the AWS Secret Access Key
// can't be found in the process's environment.
//
// @readonly
ErrSecretAccessKeyNotFound = awserr.New("EnvSecretNotFound", "AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY not found in environment", nil)
)
// A EnvProvider retrieves credentials from the environment variables of the
// running process. Environment credentials never expire.
//
// Environment variables used:
//
// * Access Key ID: AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY
// * Secret Access Key: AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY
type EnvProvider struct {
retrieved bool
}
// NewEnvCredentials returns a pointer to a new Credentials object
// wrapping the environment variable provider.
func NewEnvCredentials() *Credentials {
return NewCredentials(&EnvProvider{})
}
// Retrieve retrieves the keys from the environment.
func (e *EnvProvider) Retrieve() (Value, error) {
e.retrieved = false
id := os.Getenv("AWS_ACCESS_KEY_ID")
if id == "" {
id = os.Getenv("AWS_ACCESS_KEY")
}
secret := os.Getenv("AWS_SECRET_ACCESS_KEY")
if secret == "" {
secret = os.Getenv("AWS_SECRET_KEY")
}
if id == "" {
return Value{}, ErrAccessKeyIDNotFound
}
if secret == "" {
return Value{}, ErrSecretAccessKeyNotFound
}
e.retrieved = true
return Value{
AccessKeyID: id,
SecretAccessKey: secret,
SessionToken: os.Getenv("AWS_SESSION_TOKEN"),
}, nil
}
// IsExpired returns if the credentials have been retrieved.
func (e *EnvProvider) IsExpired() bool {
return !e.retrieved
}

View File

@@ -0,0 +1,70 @@
package credentials
import (
"github.com/stretchr/testify/assert"
"os"
"testing"
)
func TestEnvProviderRetrieve(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_ACCESS_KEY_ID", "access")
os.Setenv("AWS_SECRET_ACCESS_KEY", "secret")
os.Setenv("AWS_SESSION_TOKEN", "token")
e := EnvProvider{}
creds, err := e.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "access", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
}
func TestEnvProviderIsExpired(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_ACCESS_KEY_ID", "access")
os.Setenv("AWS_SECRET_ACCESS_KEY", "secret")
os.Setenv("AWS_SESSION_TOKEN", "token")
e := EnvProvider{}
assert.True(t, e.IsExpired(), "Expect creds to be expired before retrieve.")
_, err := e.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.False(t, e.IsExpired(), "Expect creds to not be expired after retrieve.")
}
func TestEnvProviderNoAccessKeyID(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_SECRET_ACCESS_KEY", "secret")
e := EnvProvider{}
creds, err := e.Retrieve()
assert.Equal(t, ErrAccessKeyIDNotFound, err, "ErrAccessKeyIDNotFound expected, but was %#v error: %#v", creds, err)
}
func TestEnvProviderNoSecretAccessKey(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_ACCESS_KEY_ID", "access")
e := EnvProvider{}
creds, err := e.Retrieve()
assert.Equal(t, ErrSecretAccessKeyNotFound, err, "ErrSecretAccessKeyNotFound expected, but was %#v error: %#v", creds, err)
}
func TestEnvProviderAlternateNames(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_ACCESS_KEY", "access")
os.Setenv("AWS_SECRET_KEY", "secret")
e := EnvProvider{}
creds, err := e.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "access", creds.AccessKeyID, "Expected access key ID")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expected secret access key")
assert.Empty(t, creds.SessionToken, "Expected no token")
}

View File

@@ -0,0 +1,8 @@
[default]
aws_access_key_id = accessKey
aws_secret_access_key = secret
aws_session_token = token
[no_token]
aws_access_key_id = accessKey
aws_secret_access_key = secret

View File

@@ -0,0 +1,135 @@
package credentials
import (
"fmt"
"os"
"path/filepath"
"github.com/vaughan0/go-ini"
"github.com/aws/aws-sdk-go/aws/awserr"
)
var (
// ErrSharedCredentialsHomeNotFound is emitted when the user directory cannot be found.
//
// @readonly
ErrSharedCredentialsHomeNotFound = awserr.New("UserHomeNotFound", "user home directory not found.", nil)
)
// A SharedCredentialsProvider retrieves credentials from the current user's home
// directory, and keeps track if those credentials are expired.
//
// Profile ini file example: $HOME/.aws/credentials
type SharedCredentialsProvider struct {
// Path to the shared credentials file. If empty will default to current user's
// home directory.
Filename string
// AWS Profile to extract credentials from the shared credentials file. If empty
// will default to environment variable "AWS_PROFILE" or "default" if
// environment variable is also not set.
Profile string
// retrieved states if the credentials have been successfully retrieved.
retrieved bool
}
// NewSharedCredentials returns a pointer to a new Credentials object
// wrapping the Profile file provider.
func NewSharedCredentials(filename, profile string) *Credentials {
return NewCredentials(&SharedCredentialsProvider{
Filename: filename,
Profile: profile,
})
}
// Retrieve reads and extracts the shared credentials from the current
// users home directory.
func (p *SharedCredentialsProvider) Retrieve() (Value, error) {
p.retrieved = false
filename, err := p.filename()
if err != nil {
return Value{}, err
}
creds, err := loadProfile(filename, p.profile())
if err != nil {
return Value{}, err
}
p.retrieved = true
return creds, nil
}
// IsExpired returns if the shared credentials have expired.
func (p *SharedCredentialsProvider) IsExpired() bool {
return !p.retrieved
}
// loadProfiles loads from the file pointed to by shared credentials filename for profile.
// The credentials retrieved from the profile will be returned or error. Error will be
// returned if it fails to read from the file, or the data is invalid.
func loadProfile(filename, profile string) (Value, error) {
config, err := ini.LoadFile(filename)
if err != nil {
return Value{}, awserr.New("SharedCredsLoad", "failed to load shared credentials file", err)
}
iniProfile := config.Section(profile)
id, ok := iniProfile["aws_access_key_id"]
if !ok {
return Value{}, awserr.New("SharedCredsAccessKey",
fmt.Sprintf("shared credentials %s in %s did not contain aws_access_key_id", profile, filename),
nil)
}
secret, ok := iniProfile["aws_secret_access_key"]
if !ok {
return Value{}, awserr.New("SharedCredsSecret",
fmt.Sprintf("shared credentials %s in %s did not contain aws_secret_access_key", profile, filename),
nil)
}
token := iniProfile["aws_session_token"]
return Value{
AccessKeyID: id,
SecretAccessKey: secret,
SessionToken: token,
}, nil
}
// filename returns the filename to use to read AWS shared credentials.
//
// Will return an error if the user's home directory path cannot be found.
func (p *SharedCredentialsProvider) filename() (string, error) {
if p.Filename == "" {
homeDir := os.Getenv("HOME") // *nix
if homeDir == "" { // Windows
homeDir = os.Getenv("USERPROFILE")
}
if homeDir == "" {
return "", ErrSharedCredentialsHomeNotFound
}
p.Filename = filepath.Join(homeDir, ".aws", "credentials")
}
return p.Filename, nil
}
// profile returns the AWS shared credentials profile. If empty will read
// environment variable "AWS_PROFILE". If that is not set profile will
// return "default".
func (p *SharedCredentialsProvider) profile() string {
if p.Profile == "" {
p.Profile = os.Getenv("AWS_PROFILE")
}
if p.Profile == "" {
p.Profile = "default"
}
return p.Profile
}

View File

@@ -0,0 +1,77 @@
package credentials
import (
"github.com/stretchr/testify/assert"
"os"
"testing"
)
func TestSharedCredentialsProvider(t *testing.T) {
os.Clearenv()
p := SharedCredentialsProvider{Filename: "example.ini", Profile: ""}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
}
func TestSharedCredentialsProviderIsExpired(t *testing.T) {
os.Clearenv()
p := SharedCredentialsProvider{Filename: "example.ini", Profile: ""}
assert.True(t, p.IsExpired(), "Expect creds to be expired before retrieve")
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve")
}
func TestSharedCredentialsProviderWithAWS_PROFILE(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_PROFILE", "no_token")
p := SharedCredentialsProvider{Filename: "example.ini", Profile: ""}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect no token")
}
func TestSharedCredentialsProviderWithoutTokenFromProfile(t *testing.T) {
os.Clearenv()
p := SharedCredentialsProvider{Filename: "example.ini", Profile: "no_token"}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect no token")
}
func BenchmarkSharedCredentialsProvider(b *testing.B) {
os.Clearenv()
p := SharedCredentialsProvider{Filename: "example.ini", Profile: ""}
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
}
})
}

View File

@@ -0,0 +1,44 @@
package credentials
import (
"github.com/aws/aws-sdk-go/aws/awserr"
)
var (
// ErrStaticCredentialsEmpty is emitted when static credentials are empty.
//
// @readonly
ErrStaticCredentialsEmpty = awserr.New("EmptyStaticCreds", "static credentials are empty", nil)
)
// A StaticProvider is a set of credentials which are set pragmatically,
// and will never expire.
type StaticProvider struct {
Value
}
// NewStaticCredentials returns a pointer to a new Credentials object
// wrapping a static credentials value provider.
func NewStaticCredentials(id, secret, token string) *Credentials {
return NewCredentials(&StaticProvider{Value: Value{
AccessKeyID: id,
SecretAccessKey: secret,
SessionToken: token,
}})
}
// Retrieve returns the credentials or error if the credentials are invalid.
func (s *StaticProvider) Retrieve() (Value, error) {
if s.AccessKeyID == "" || s.SecretAccessKey == "" {
return Value{}, ErrStaticCredentialsEmpty
}
return s.Value, nil
}
// IsExpired returns if the credentials are expired.
//
// For StaticProvider, the credentials never expired.
func (s *StaticProvider) IsExpired() bool {
return false
}

View File

@@ -0,0 +1,34 @@
package credentials
import (
"github.com/stretchr/testify/assert"
"testing"
)
func TestStaticProviderGet(t *testing.T) {
s := StaticProvider{
Value: Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "",
},
}
creds, err := s.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "AKID", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "SECRET", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect no session token")
}
func TestStaticProviderIsExpired(t *testing.T) {
s := StaticProvider{
Value: Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "",
},
}
assert.False(t, s.IsExpired(), "Expect static credentials to never expire")
}

View File

@@ -0,0 +1,120 @@
// Package stscreds are credential Providers to retrieve STS AWS credentials.
//
// STS provides multiple ways to retrieve credentials which can be used when making
// future AWS service API operation calls.
package stscreds
import (
"fmt"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/service/sts"
"time"
)
// AssumeRoler represents the minimal subset of the STS client API used by this provider.
type AssumeRoler interface {
AssumeRole(input *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error)
}
// AssumeRoleProvider retrieves temporary credentials from the STS service, and
// keeps track of their expiration time. This provider must be used explicitly,
// as it is not included in the credentials chain.
//
// Example how to configure a service to use this provider:
//
// config := &aws.Config{
// Credentials: stscreds.NewCredentials(nil, "arn-of-the-role-to-assume", 10*time.Second),
// })
// // Use config for creating your AWS service.
//
// Example how to obtain customised credentials:
//
// provider := &stscreds.Provider{
// // Extend the duration to 1 hour.
// Duration: time.Hour,
// // Custom role name.
// RoleSessionName: "custom-session-name",
// }
// creds := credentials.NewCredentials(provider)
//
type AssumeRoleProvider struct {
credentials.Expiry
// Custom STS client. If not set the default STS client will be used.
Client AssumeRoler
// Role to be assumed.
RoleARN string
// Session name, if you wish to reuse the credentials elsewhere.
RoleSessionName string
// Expiry duration of the STS credentials. Defaults to 15 minutes if not set.
Duration time.Duration
// ExpiryWindow will allow the credentials to trigger refreshing prior to
// the credentials actually expiring. This is beneficial so race conditions
// with expiring credentials do not cause request to fail unexpectedly
// due to ExpiredTokenException exceptions.
//
// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
// 10 seconds before the credentials are actually expired.
//
// If ExpiryWindow is 0 or less it will be ignored.
ExpiryWindow time.Duration
}
// NewCredentials returns a pointer to a new Credentials object wrapping the
// AssumeRoleProvider. The credentials will expire every 15 minutes and the
// role will be named after a nanosecond timestamp of this operation.
//
// The sts and roleARN parameters are used for building the "AssumeRole" call.
// Pass nil as sts to use the default client.
//
// Window is the expiry window that will be subtracted from the expiry returned
// by the role credential request. This is done so that the credentials will
// expire sooner than their actual lifespan.
func NewCredentials(client AssumeRoler, roleARN string, window time.Duration) *credentials.Credentials {
return credentials.NewCredentials(&AssumeRoleProvider{
Client: client,
RoleARN: roleARN,
ExpiryWindow: window,
})
}
// Retrieve generates a new set of temporary credentials using STS.
func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
// Apply defaults where parameters are not set.
if p.Client == nil {
p.Client = sts.New(nil)
}
if p.RoleSessionName == "" {
// Try to work out a role name that will hopefully end up unique.
p.RoleSessionName = fmt.Sprintf("%d", time.Now().UTC().UnixNano())
}
if p.Duration == 0 {
// Expire as often as AWS permits.
p.Duration = 15 * time.Minute
}
roleOutput, err := p.Client.AssumeRole(&sts.AssumeRoleInput{
DurationSeconds: aws.Int64(int64(p.Duration / time.Second)),
RoleARN: aws.String(p.RoleARN),
RoleSessionName: aws.String(p.RoleSessionName),
})
if err != nil {
return credentials.Value{}, err
}
// We will proactively generate new credentials before they expire.
p.SetExpiration(*roleOutput.Credentials.Expiration, p.ExpiryWindow)
return credentials.Value{
AccessKeyID: *roleOutput.Credentials.AccessKeyID,
SecretAccessKey: *roleOutput.Credentials.SecretAccessKey,
SessionToken: *roleOutput.Credentials.SessionToken,
}, nil
}

View File

@@ -0,0 +1,58 @@
package stscreds
import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/stretchr/testify/assert"
"testing"
"time"
)
type stubSTS struct {
}
func (s *stubSTS) AssumeRole(input *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error) {
expiry := time.Now().Add(60 * time.Minute)
return &sts.AssumeRoleOutput{
Credentials: &sts.Credentials{
// Just reflect the role arn to the provider.
AccessKeyID: input.RoleARN,
SecretAccessKey: aws.String("assumedSecretAccessKey"),
SessionToken: aws.String("assumedSessionToken"),
Expiration: &expiry,
},
}, nil
}
func TestAssumeRoleProvider(t *testing.T) {
stub := &stubSTS{}
p := &AssumeRoleProvider{
Client: stub,
RoleARN: "roleARN",
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "roleARN", creds.AccessKeyID, "Expect access key ID to be reflected role ARN")
assert.Equal(t, "assumedSecretAccessKey", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "assumedSessionToken", creds.SessionToken, "Expect session token to match")
}
func BenchmarkAssumeRoleProvider(b *testing.B) {
stub := &stubSTS{}
p := &AssumeRoleProvider{
Client: stub,
RoleARN: "roleARN",
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
}
})
}

View File

@@ -0,0 +1,157 @@
package aws
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"regexp"
"strconv"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
)
var sleepDelay = func(delay time.Duration) {
time.Sleep(delay)
}
// Interface for matching types which also have a Len method.
type lener interface {
Len() int
}
// BuildContentLength builds the content length of a request based on the body,
// or will use the HTTPRequest.Header's "Content-Length" if defined. If unable
// to determine request body length and no "Content-Length" was specified it will panic.
func BuildContentLength(r *Request) {
if slength := r.HTTPRequest.Header.Get("Content-Length"); slength != "" {
length, _ := strconv.ParseInt(slength, 10, 64)
r.HTTPRequest.ContentLength = length
return
}
var length int64
switch body := r.Body.(type) {
case nil:
length = 0
case lener:
length = int64(body.Len())
case io.Seeker:
r.bodyStart, _ = body.Seek(0, 1)
end, _ := body.Seek(0, 2)
body.Seek(r.bodyStart, 0) // make sure to seek back to original location
length = end - r.bodyStart
default:
panic("Cannot get length of body, must provide `ContentLength`")
}
r.HTTPRequest.ContentLength = length
r.HTTPRequest.Header.Set("Content-Length", fmt.Sprintf("%d", length))
}
// UserAgentHandler is a request handler for injecting User agent into requests.
func UserAgentHandler(r *Request) {
r.HTTPRequest.Header.Set("User-Agent", SDKName+"/"+SDKVersion)
}
var reStatusCode = regexp.MustCompile(`^(\d+)`)
// SendHandler is a request handler to send service request using HTTP client.
func SendHandler(r *Request) {
var err error
r.HTTPResponse, err = r.Service.Config.HTTPClient.Do(r.HTTPRequest)
if err != nil {
// Capture the case where url.Error is returned for error processing
// response. e.g. 301 without location header comes back as string
// error and r.HTTPResponse is nil. Other url redirect errors will
// comeback in a similar method.
if e, ok := err.(*url.Error); ok {
if s := reStatusCode.FindStringSubmatch(e.Error()); s != nil {
code, _ := strconv.ParseInt(s[1], 10, 64)
r.HTTPResponse = &http.Response{
StatusCode: int(code),
Status: http.StatusText(int(code)),
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
}
return
}
}
if r.HTTPRequest == nil {
// Add a dummy request response object to ensure the HTTPResponse
// value is consistent.
r.HTTPResponse = &http.Response{
StatusCode: int(0),
Status: http.StatusText(int(0)),
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
}
}
// Catch all other request errors.
r.Error = awserr.New("RequestError", "send request failed", err)
r.Retryable = Bool(true) // network errors are retryable
}
}
// ValidateResponseHandler is a request handler to validate service response.
func ValidateResponseHandler(r *Request) {
if r.HTTPResponse.StatusCode == 0 || r.HTTPResponse.StatusCode >= 300 {
// this may be replaced by an UnmarshalError handler
r.Error = awserr.New("UnknownError", "unknown error", nil)
}
}
// AfterRetryHandler performs final checks to determine if the request should
// be retried and how long to delay.
func AfterRetryHandler(r *Request) {
// If one of the other handlers already set the retry state
// we don't want to override it based on the service's state
if r.Retryable == nil {
r.Retryable = Bool(r.Service.ShouldRetry(r))
}
if r.WillRetry() {
r.RetryDelay = r.Service.RetryRules(r)
sleepDelay(r.RetryDelay)
// when the expired token exception occurs the credentials
// need to be expired locally so that the next request to
// get credentials will trigger a credentials refresh.
if r.Error != nil {
if err, ok := r.Error.(awserr.Error); ok {
if isCodeExpiredCreds(err.Code()) {
r.Config.Credentials.Expire()
}
}
}
r.RetryCount++
r.Error = nil
}
}
var (
// ErrMissingRegion is an error that is returned if region configuration is
// not found.
//
// @readonly
ErrMissingRegion error = awserr.New("MissingRegion", "could not find region configuration", nil)
// ErrMissingEndpoint is an error that is returned if an endpoint cannot be
// resolved for a service.
//
// @readonly
ErrMissingEndpoint error = awserr.New("MissingEndpoint", "'Endpoint' configuration is required for this service", nil)
)
// ValidateEndpointHandler is a request handler to validate a request had the
// appropriate Region and Endpoint set. Will set r.Error if the endpoint or
// region is not valid.
func ValidateEndpointHandler(r *Request) {
if r.Service.SigningRegion == "" && StringValue(r.Service.Config.Region) == "" {
r.Error = ErrMissingRegion
} else if r.Service.Endpoint == "" {
r.Error = ErrMissingEndpoint
}
}

View File

@@ -0,0 +1,81 @@
package aws
import (
"net/http"
"os"
"testing"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/stretchr/testify/assert"
)
func TestValidateEndpointHandler(t *testing.T) {
os.Clearenv()
svc := NewService(NewConfig().WithRegion("us-west-2"))
svc.Handlers.Clear()
svc.Handlers.Validate.PushBack(ValidateEndpointHandler)
req := NewRequest(svc, &Operation{Name: "Operation"}, nil, nil)
err := req.Build()
assert.NoError(t, err)
}
func TestValidateEndpointHandlerErrorRegion(t *testing.T) {
os.Clearenv()
svc := NewService(nil)
svc.Handlers.Clear()
svc.Handlers.Validate.PushBack(ValidateEndpointHandler)
req := NewRequest(svc, &Operation{Name: "Operation"}, nil, nil)
err := req.Build()
assert.Error(t, err)
assert.Equal(t, ErrMissingRegion, err)
}
type mockCredsProvider struct {
expired bool
retrieveCalled bool
}
func (m *mockCredsProvider) Retrieve() (credentials.Value, error) {
m.retrieveCalled = true
return credentials.Value{}, nil
}
func (m *mockCredsProvider) IsExpired() bool {
return m.expired
}
func TestAfterRetryRefreshCreds(t *testing.T) {
os.Clearenv()
credProvider := &mockCredsProvider{}
svc := NewService(&Config{Credentials: credentials.NewCredentials(credProvider), MaxRetries: Int(1)})
svc.Handlers.Clear()
svc.Handlers.ValidateResponse.PushBack(func(r *Request) {
r.Error = awserr.New("UnknownError", "", nil)
r.HTTPResponse = &http.Response{StatusCode: 400}
})
svc.Handlers.UnmarshalError.PushBack(func(r *Request) {
r.Error = awserr.New("ExpiredTokenException", "", nil)
})
svc.Handlers.AfterRetry.PushBack(func(r *Request) {
AfterRetryHandler(r)
})
assert.True(t, svc.Config.Credentials.IsExpired(), "Expect to start out expired")
assert.False(t, credProvider.retrieveCalled)
req := NewRequest(svc, &Operation{Name: "Operation"}, nil, nil)
req.Send()
assert.True(t, svc.Config.Credentials.IsExpired())
assert.False(t, credProvider.retrieveCalled)
_, err := svc.Config.Credentials.Get()
assert.NoError(t, err)
assert.True(t, credProvider.retrieveCalled)
}

View File

@@ -0,0 +1,85 @@
package aws
// A Handlers provides a collection of request handlers for various
// stages of handling requests.
type Handlers struct {
Validate HandlerList
Build HandlerList
Sign HandlerList
Send HandlerList
ValidateResponse HandlerList
Unmarshal HandlerList
UnmarshalMeta HandlerList
UnmarshalError HandlerList
Retry HandlerList
AfterRetry HandlerList
}
// copy returns of this handler's lists.
func (h *Handlers) copy() Handlers {
return Handlers{
Validate: h.Validate.copy(),
Build: h.Build.copy(),
Sign: h.Sign.copy(),
Send: h.Send.copy(),
ValidateResponse: h.ValidateResponse.copy(),
Unmarshal: h.Unmarshal.copy(),
UnmarshalError: h.UnmarshalError.copy(),
UnmarshalMeta: h.UnmarshalMeta.copy(),
Retry: h.Retry.copy(),
AfterRetry: h.AfterRetry.copy(),
}
}
// Clear removes callback functions for all handlers
func (h *Handlers) Clear() {
h.Validate.Clear()
h.Build.Clear()
h.Send.Clear()
h.Sign.Clear()
h.Unmarshal.Clear()
h.UnmarshalMeta.Clear()
h.UnmarshalError.Clear()
h.ValidateResponse.Clear()
h.Retry.Clear()
h.AfterRetry.Clear()
}
// A HandlerList manages zero or more handlers in a list.
type HandlerList struct {
list []func(*Request)
}
// copy creates a copy of the handler list.
func (l *HandlerList) copy() HandlerList {
var n HandlerList
n.list = append([]func(*Request){}, l.list...)
return n
}
// Clear clears the handler list.
func (l *HandlerList) Clear() {
l.list = []func(*Request){}
}
// Len returns the number of handlers in the list.
func (l *HandlerList) Len() int {
return len(l.list)
}
// PushBack pushes handlers f to the back of the handler list.
func (l *HandlerList) PushBack(f ...func(*Request)) {
l.list = append(l.list, f...)
}
// PushFront pushes handlers f to the front of the handler list.
func (l *HandlerList) PushFront(f ...func(*Request)) {
l.list = append(f, l.list...)
}
// Run executes all handlers in the list with a given request object.
func (l *HandlerList) Run(r *Request) {
for _, f := range l.list {
f(r)
}
}

View File

@@ -0,0 +1,31 @@
package aws
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestHandlerList(t *testing.T) {
s := ""
r := &Request{}
l := HandlerList{}
l.PushBack(func(r *Request) {
s += "a"
r.Data = s
})
l.Run(r)
assert.Equal(t, "a", s)
assert.Equal(t, "a", r.Data)
}
func TestMultipleHandlers(t *testing.T) {
r := &Request{}
l := HandlerList{}
l.PushBack(func(r *Request) { r.Data = nil })
l.PushFront(func(r *Request) { r.Data = Bool(true) })
l.Run(r)
if r.Data != nil {
t.Error("Expected handler to execute")
}
}

View File

@@ -0,0 +1,89 @@
package aws
import (
"log"
"os"
)
// A LogLevelType defines the level logging should be performed at. Used to instruct
// the SDK which statements should be logged.
type LogLevelType uint
// LogLevel returns the pointer to a LogLevel. Should be used to workaround
// not being able to take the address of a non-composite literal.
func LogLevel(l LogLevelType) *LogLevelType {
return &l
}
// Value returns the LogLevel value or the default value LogOff if the LogLevel
// is nil. Safe to use on nil value LogLevelTypes.
func (l *LogLevelType) Value() LogLevelType {
if l != nil {
return *l
}
return LogOff
}
// Matches returns true if the v LogLevel is enabled by this LogLevel. Should be
// used with logging sub levels. Is safe to use on nil value LogLevelTypes. If
// LogLevel is nill, will default to LogOff comparison.
func (l *LogLevelType) Matches(v LogLevelType) bool {
c := l.Value()
return c&v == v
}
// AtLeast returns true if this LogLevel is at least high enough to satisfies v.
// Is safe to use on nil value LogLevelTypes. If LogLevel is nill, will default
// to LogOff comparison.
func (l *LogLevelType) AtLeast(v LogLevelType) bool {
c := l.Value()
return c >= v
}
const (
// LogOff states that no logging should be performed by the SDK. This is the
// default state of the SDK, and should be use to disable all logging.
LogOff LogLevelType = iota * 0x1000
// LogDebug state that debug output should be logged by the SDK. This should
// be used to inspect request made and responses received.
LogDebug
)
// Debug Logging Sub Levels
const (
// LogDebugWithSigning states that the SDK should log request signing and
// presigning events. This should be used to log the signing details of
// requests for debugging. Will also enable LogDebug.
LogDebugWithSigning LogLevelType = LogDebug | (1 << iota)
// LogDebugWithHTTPBody states the SDK should log HTTP request and response
// HTTP bodys in addition to the headers and path. This should be used to
// see the body content of requests and responses made while using the SDK
// Will also enable LogDebug.
LogDebugWithHTTPBody
)
// A Logger is a minimalistic interface for the SDK to log messages to. Should
// be used to provide custom logging writers for the SDK to use.
type Logger interface {
Log(...interface{})
}
// NewDefaultLogger returns a Logger which will write log messages to stdout, and
// use same formatting runes as the stdlib log.Logger
func NewDefaultLogger() Logger {
return &defaultLogger{
logger: log.New(os.Stdout, "", log.LstdFlags),
}
}
// A defaultLogger provides a minimalistic logger satisfying the Logger interface.
type defaultLogger struct {
logger *log.Logger
}
// Log logs the parameters to the stdlib logger. See log.Println.
func (l defaultLogger) Log(args ...interface{}) {
l.logger.Println(args...)
}

View File

@@ -0,0 +1,89 @@
package aws
import (
"fmt"
"reflect"
"strings"
"github.com/aws/aws-sdk-go/aws/awserr"
)
// ValidateParameters is a request handler to validate the input parameters.
// Validating parameters only has meaning if done prior to the request being sent.
func ValidateParameters(r *Request) {
if r.ParamsFilled() {
v := validator{errors: []string{}}
v.validateAny(reflect.ValueOf(r.Params), "")
if count := len(v.errors); count > 0 {
format := "%d validation errors:\n- %s"
msg := fmt.Sprintf(format, count, strings.Join(v.errors, "\n- "))
r.Error = awserr.New("InvalidParameter", msg, nil)
}
}
}
// A validator validates values. Collects validations errors which occurs.
type validator struct {
errors []string
}
// validateAny will validate any struct, slice or map type. All validations
// are also performed recursively for nested types.
func (v *validator) validateAny(value reflect.Value, path string) {
value = reflect.Indirect(value)
if !value.IsValid() {
return
}
switch value.Kind() {
case reflect.Struct:
v.validateStruct(value, path)
case reflect.Slice:
for i := 0; i < value.Len(); i++ {
v.validateAny(value.Index(i), path+fmt.Sprintf("[%d]", i))
}
case reflect.Map:
for _, n := range value.MapKeys() {
v.validateAny(value.MapIndex(n), path+fmt.Sprintf("[%q]", n.String()))
}
}
}
// validateStruct will validate the struct value's fields. If the structure has
// nested types those types will be validated also.
func (v *validator) validateStruct(value reflect.Value, path string) {
prefix := "."
if path == "" {
prefix = ""
}
for i := 0; i < value.Type().NumField(); i++ {
f := value.Type().Field(i)
if strings.ToLower(f.Name[0:1]) == f.Name[0:1] {
continue
}
fvalue := value.FieldByName(f.Name)
notset := false
if f.Tag.Get("required") != "" {
switch fvalue.Kind() {
case reflect.Ptr, reflect.Slice, reflect.Map:
if fvalue.IsNil() {
notset = true
}
default:
if !fvalue.IsValid() {
notset = true
}
}
}
if notset {
msg := "missing required parameter: " + path + prefix + f.Name
v.errors = append(v.errors, msg)
} else {
v.validateAny(fvalue, path+prefix+f.Name)
}
}
}

View File

@@ -0,0 +1,84 @@
package aws_test
import (
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/stretchr/testify/assert"
)
var service = func() *aws.Service {
s := &aws.Service{
Config: &aws.Config{},
ServiceName: "mock-service",
APIVersion: "2015-01-01",
}
return s
}()
type StructShape struct {
RequiredList []*ConditionalStructShape `required:"true"`
RequiredMap map[string]*ConditionalStructShape `required:"true"`
RequiredBool *bool `required:"true"`
OptionalStruct *ConditionalStructShape
hiddenParameter *string
metadataStructureShape
}
type metadataStructureShape struct {
SDKShapeTraits bool
}
type ConditionalStructShape struct {
Name *string `required:"true"`
SDKShapeTraits bool
}
func TestNoErrors(t *testing.T) {
input := &StructShape{
RequiredList: []*ConditionalStructShape{},
RequiredMap: map[string]*ConditionalStructShape{
"key1": {Name: aws.String("Name")},
"key2": {Name: aws.String("Name")},
},
RequiredBool: aws.Bool(true),
OptionalStruct: &ConditionalStructShape{Name: aws.String("Name")},
}
req := aws.NewRequest(service, &aws.Operation{}, input, nil)
aws.ValidateParameters(req)
assert.NoError(t, req.Error)
}
func TestMissingRequiredParameters(t *testing.T) {
input := &StructShape{}
req := aws.NewRequest(service, &aws.Operation{}, input, nil)
aws.ValidateParameters(req)
assert.Error(t, req.Error)
assert.Equal(t, "InvalidParameter", req.Error.(awserr.Error).Code())
assert.Equal(t, "3 validation errors:\n- missing required parameter: RequiredList\n- missing required parameter: RequiredMap\n- missing required parameter: RequiredBool", req.Error.(awserr.Error).Message())
}
func TestNestedMissingRequiredParameters(t *testing.T) {
input := &StructShape{
RequiredList: []*ConditionalStructShape{{}},
RequiredMap: map[string]*ConditionalStructShape{
"key1": {Name: aws.String("Name")},
"key2": {},
},
RequiredBool: aws.Bool(true),
OptionalStruct: &ConditionalStructShape{},
}
req := aws.NewRequest(service, &aws.Operation{}, input, nil)
aws.ValidateParameters(req)
assert.Error(t, req.Error)
assert.Equal(t, "InvalidParameter", req.Error.(awserr.Error).Code())
assert.Equal(t, "3 validation errors:\n- missing required parameter: RequiredList[0].Name\n- missing required parameter: RequiredMap[\"key2\"].Name\n- missing required parameter: OptionalStruct.Name", req.Error.(awserr.Error).Message())
}

View File

@@ -0,0 +1,312 @@
package aws
import (
"bytes"
"io"
"io/ioutil"
"net/http"
"net/url"
"reflect"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws/awsutil"
)
// A Request is the service request to be made.
type Request struct {
*Service
Handlers Handlers
Time time.Time
ExpireTime time.Duration
Operation *Operation
HTTPRequest *http.Request
HTTPResponse *http.Response
Body io.ReadSeeker
bodyStart int64 // offset from beginning of Body that the request body starts
Params interface{}
Error error
Data interface{}
RequestID string
RetryCount uint
Retryable *bool
RetryDelay time.Duration
built bool
}
// An Operation is the service API operation to be made.
type Operation struct {
Name string
HTTPMethod string
HTTPPath string
*Paginator
}
// Paginator keeps track of pagination configuration for an API operation.
type Paginator struct {
InputTokens []string
OutputTokens []string
LimitToken string
TruncationToken string
}
// NewRequest returns a new Request pointer for the service API
// operation and parameters.
//
// Params is any value of input parameters to be the request payload.
// Data is pointer value to an object which the request's response
// payload will be deserialized to.
func NewRequest(service *Service, operation *Operation, params interface{}, data interface{}) *Request {
method := operation.HTTPMethod
if method == "" {
method = "POST"
}
p := operation.HTTPPath
if p == "" {
p = "/"
}
httpReq, _ := http.NewRequest(method, "", nil)
httpReq.URL, _ = url.Parse(service.Endpoint + p)
r := &Request{
Service: service,
Handlers: service.Handlers.copy(),
Time: time.Now(),
ExpireTime: 0,
Operation: operation,
HTTPRequest: httpReq,
Body: nil,
Params: params,
Error: nil,
Data: data,
}
r.SetBufferBody([]byte{})
return r
}
// WillRetry returns if the request's can be retried.
func (r *Request) WillRetry() bool {
return r.Error != nil && BoolValue(r.Retryable) && r.RetryCount < r.Service.MaxRetries()
}
// ParamsFilled returns if the request's parameters have been populated
// and the parameters are valid. False is returned if no parameters are
// provided or invalid.
func (r *Request) ParamsFilled() bool {
return r.Params != nil && reflect.ValueOf(r.Params).Elem().IsValid()
}
// DataFilled returns true if the request's data for response deserialization
// target has been set and is a valid. False is returned if data is not
// set, or is invalid.
func (r *Request) DataFilled() bool {
return r.Data != nil && reflect.ValueOf(r.Data).Elem().IsValid()
}
// SetBufferBody will set the request's body bytes that will be sent to
// the service API.
func (r *Request) SetBufferBody(buf []byte) {
r.SetReaderBody(bytes.NewReader(buf))
}
// SetStringBody sets the body of the request to be backed by a string.
func (r *Request) SetStringBody(s string) {
r.SetReaderBody(strings.NewReader(s))
}
// SetReaderBody will set the request's body reader.
func (r *Request) SetReaderBody(reader io.ReadSeeker) {
r.HTTPRequest.Body = ioutil.NopCloser(reader)
r.Body = reader
}
// Presign returns the request's signed URL. Error will be returned
// if the signing fails.
func (r *Request) Presign(expireTime time.Duration) (string, error) {
r.ExpireTime = expireTime
r.Sign()
if r.Error != nil {
return "", r.Error
}
return r.HTTPRequest.URL.String(), nil
}
// Build will build the request's object so it can be signed and sent
// to the service. Build will also validate all the request's parameters.
// Anny additional build Handlers set on this request will be run
// in the order they were set.
//
// The request will only be built once. Multiple calls to build will have
// no effect.
//
// If any Validate or Build errors occur the build will stop and the error
// which occurred will be returned.
func (r *Request) Build() error {
if !r.built {
r.Error = nil
r.Handlers.Validate.Run(r)
if r.Error != nil {
return r.Error
}
r.Handlers.Build.Run(r)
r.built = true
}
return r.Error
}
// Sign will sign the request retuning error if errors are encountered.
//
// Send will build the request prior to signing. All Sign Handlers will
// be executed in the order they were set.
func (r *Request) Sign() error {
r.Build()
if r.Error != nil {
return r.Error
}
r.Handlers.Sign.Run(r)
return r.Error
}
// Send will send the request returning error if errors are encountered.
//
// Send will sign the request prior to sending. All Send Handlers will
// be executed in the order they were set.
func (r *Request) Send() error {
for {
r.Sign()
if r.Error != nil {
return r.Error
}
if BoolValue(r.Retryable) {
// Re-seek the body back to the original point in for a retry so that
// send will send the body's contents again in the upcoming request.
r.Body.Seek(r.bodyStart, 0)
}
r.Retryable = nil
r.Handlers.Send.Run(r)
if r.Error != nil {
r.Handlers.Retry.Run(r)
r.Handlers.AfterRetry.Run(r)
if r.Error != nil {
return r.Error
}
continue
}
r.Handlers.UnmarshalMeta.Run(r)
r.Handlers.ValidateResponse.Run(r)
if r.Error != nil {
r.Handlers.UnmarshalError.Run(r)
r.Handlers.Retry.Run(r)
r.Handlers.AfterRetry.Run(r)
if r.Error != nil {
return r.Error
}
continue
}
r.Handlers.Unmarshal.Run(r)
if r.Error != nil {
r.Handlers.Retry.Run(r)
r.Handlers.AfterRetry.Run(r)
if r.Error != nil {
return r.Error
}
continue
}
break
}
return nil
}
// HasNextPage returns true if this request has more pages of data available.
func (r *Request) HasNextPage() bool {
return r.nextPageTokens() != nil
}
// nextPageTokens returns the tokens to use when asking for the next page of
// data.
func (r *Request) nextPageTokens() []interface{} {
if r.Operation.Paginator == nil {
return nil
}
if r.Operation.TruncationToken != "" {
tr := awsutil.ValuesAtAnyPath(r.Data, r.Operation.TruncationToken)
if tr == nil || len(tr) == 0 {
return nil
}
switch v := tr[0].(type) {
case bool:
if v == false {
return nil
}
}
}
found := false
tokens := make([]interface{}, len(r.Operation.OutputTokens))
for i, outtok := range r.Operation.OutputTokens {
v := awsutil.ValuesAtAnyPath(r.Data, outtok)
if v != nil && len(v) > 0 {
found = true
tokens[i] = v[0]
}
}
if found {
return tokens
}
return nil
}
// NextPage returns a new Request that can be executed to return the next
// page of result data. Call .Send() on this request to execute it.
func (r *Request) NextPage() *Request {
tokens := r.nextPageTokens()
if tokens == nil {
return nil
}
data := reflect.New(reflect.TypeOf(r.Data).Elem()).Interface()
nr := NewRequest(r.Service, r.Operation, awsutil.CopyOf(r.Params), data)
for i, intok := range nr.Operation.InputTokens {
awsutil.SetValueAtAnyPath(nr.Params, intok, tokens[i])
}
return nr
}
// EachPage iterates over each page of a paginated request object. The fn
// parameter should be a function with the following sample signature:
//
// func(page *T, lastPage bool) bool {
// return true // return false to stop iterating
// }
//
// Where "T" is the structure type matching the output structure of the given
// operation. For example, a request object generated by
// DynamoDB.ListTablesRequest() would expect to see dynamodb.ListTablesOutput
// as the structure "T". The lastPage value represents whether the page is
// the last page of data or not. The return value of this function should
// return true to keep iterating or false to stop.
func (r *Request) EachPage(fn func(data interface{}, isLastPage bool) (shouldContinue bool)) error {
for page := r; page != nil; page = page.NextPage() {
page.Send()
shouldContinue := fn(page.Data, !page.HasNextPage())
if page.Error != nil || !shouldContinue {
return page.Error
}
}
return nil
}

View File

@@ -0,0 +1,305 @@
package aws_test
import (
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/internal/test/unit"
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/stretchr/testify/assert"
)
var _ = unit.Imported
// Use DynamoDB methods for simplicity
func TestPagination(t *testing.T) {
db := dynamodb.New(nil)
tokens, pages, numPages, gotToEnd := []string{}, []string{}, 0, false
reqNum := 0
resps := []*dynamodb.ListTablesOutput{
{TableNames: []*string{aws.String("Table1"), aws.String("Table2")}, LastEvaluatedTableName: aws.String("Table2")},
{TableNames: []*string{aws.String("Table3"), aws.String("Table4")}, LastEvaluatedTableName: aws.String("Table4")},
{TableNames: []*string{aws.String("Table5")}},
}
db.Handlers.Send.Clear() // mock sending
db.Handlers.Unmarshal.Clear()
db.Handlers.UnmarshalMeta.Clear()
db.Handlers.ValidateResponse.Clear()
db.Handlers.Build.PushBack(func(r *aws.Request) {
in := r.Params.(*dynamodb.ListTablesInput)
if in == nil {
tokens = append(tokens, "")
} else if in.ExclusiveStartTableName != nil {
tokens = append(tokens, *in.ExclusiveStartTableName)
}
})
db.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
r.Data = resps[reqNum]
reqNum++
})
params := &dynamodb.ListTablesInput{Limit: aws.Int64(2)}
err := db.ListTablesPages(params, func(p *dynamodb.ListTablesOutput, last bool) bool {
numPages++
for _, t := range p.TableNames {
pages = append(pages, *t)
}
if last {
if gotToEnd {
assert.Fail(t, "last=true happened twice")
}
gotToEnd = true
}
return true
})
assert.Equal(t, []string{"Table2", "Table4"}, tokens)
assert.Equal(t, []string{"Table1", "Table2", "Table3", "Table4", "Table5"}, pages)
assert.Equal(t, 3, numPages)
assert.True(t, gotToEnd)
assert.Nil(t, err)
assert.Nil(t, params.ExclusiveStartTableName)
}
// Use DynamoDB methods for simplicity
func TestPaginationEachPage(t *testing.T) {
db := dynamodb.New(nil)
tokens, pages, numPages, gotToEnd := []string{}, []string{}, 0, false
reqNum := 0
resps := []*dynamodb.ListTablesOutput{
{TableNames: []*string{aws.String("Table1"), aws.String("Table2")}, LastEvaluatedTableName: aws.String("Table2")},
{TableNames: []*string{aws.String("Table3"), aws.String("Table4")}, LastEvaluatedTableName: aws.String("Table4")},
{TableNames: []*string{aws.String("Table5")}},
}
db.Handlers.Send.Clear() // mock sending
db.Handlers.Unmarshal.Clear()
db.Handlers.UnmarshalMeta.Clear()
db.Handlers.ValidateResponse.Clear()
db.Handlers.Build.PushBack(func(r *aws.Request) {
in := r.Params.(*dynamodb.ListTablesInput)
if in == nil {
tokens = append(tokens, "")
} else if in.ExclusiveStartTableName != nil {
tokens = append(tokens, *in.ExclusiveStartTableName)
}
})
db.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
r.Data = resps[reqNum]
reqNum++
})
params := &dynamodb.ListTablesInput{Limit: aws.Int64(2)}
req, _ := db.ListTablesRequest(params)
err := req.EachPage(func(p interface{}, last bool) bool {
numPages++
for _, t := range p.(*dynamodb.ListTablesOutput).TableNames {
pages = append(pages, *t)
}
if last {
if gotToEnd {
assert.Fail(t, "last=true happened twice")
}
gotToEnd = true
}
return true
})
assert.Equal(t, []string{"Table2", "Table4"}, tokens)
assert.Equal(t, []string{"Table1", "Table2", "Table3", "Table4", "Table5"}, pages)
assert.Equal(t, 3, numPages)
assert.True(t, gotToEnd)
assert.Nil(t, err)
}
// Use DynamoDB methods for simplicity
func TestPaginationEarlyExit(t *testing.T) {
db := dynamodb.New(nil)
numPages, gotToEnd := 0, false
reqNum := 0
resps := []*dynamodb.ListTablesOutput{
{TableNames: []*string{aws.String("Table1"), aws.String("Table2")}, LastEvaluatedTableName: aws.String("Table2")},
{TableNames: []*string{aws.String("Table3"), aws.String("Table4")}, LastEvaluatedTableName: aws.String("Table4")},
{TableNames: []*string{aws.String("Table5")}},
}
db.Handlers.Send.Clear() // mock sending
db.Handlers.Unmarshal.Clear()
db.Handlers.UnmarshalMeta.Clear()
db.Handlers.ValidateResponse.Clear()
db.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
r.Data = resps[reqNum]
reqNum++
})
params := &dynamodb.ListTablesInput{Limit: aws.Int64(2)}
err := db.ListTablesPages(params, func(p *dynamodb.ListTablesOutput, last bool) bool {
numPages++
if numPages == 2 {
return false
}
if last {
if gotToEnd {
assert.Fail(t, "last=true happened twice")
}
gotToEnd = true
}
return true
})
assert.Equal(t, 2, numPages)
assert.False(t, gotToEnd)
assert.Nil(t, err)
}
func TestSkipPagination(t *testing.T) {
client := s3.New(nil)
client.Handlers.Send.Clear() // mock sending
client.Handlers.Unmarshal.Clear()
client.Handlers.UnmarshalMeta.Clear()
client.Handlers.ValidateResponse.Clear()
client.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
r.Data = &s3.HeadBucketOutput{}
})
req, _ := client.HeadBucketRequest(&s3.HeadBucketInput{Bucket: aws.String("bucket")})
numPages, gotToEnd := 0, false
req.EachPage(func(p interface{}, last bool) bool {
numPages++
if last {
gotToEnd = true
}
return true
})
assert.Equal(t, 1, numPages)
assert.True(t, gotToEnd)
}
// Use S3 for simplicity
func TestPaginationTruncation(t *testing.T) {
count := 0
client := s3.New(nil)
reqNum := &count
resps := []*s3.ListObjectsOutput{
{IsTruncated: aws.Bool(true), Contents: []*s3.Object{{Key: aws.String("Key1")}}},
{IsTruncated: aws.Bool(true), Contents: []*s3.Object{{Key: aws.String("Key2")}}},
{IsTruncated: aws.Bool(false), Contents: []*s3.Object{{Key: aws.String("Key3")}}},
{IsTruncated: aws.Bool(true), Contents: []*s3.Object{{Key: aws.String("Key4")}}},
}
client.Handlers.Send.Clear() // mock sending
client.Handlers.Unmarshal.Clear()
client.Handlers.UnmarshalMeta.Clear()
client.Handlers.ValidateResponse.Clear()
client.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
r.Data = resps[*reqNum]
*reqNum++
})
params := &s3.ListObjectsInput{Bucket: aws.String("bucket")}
results := []string{}
err := client.ListObjectsPages(params, func(p *s3.ListObjectsOutput, last bool) bool {
results = append(results, *p.Contents[0].Key)
return true
})
assert.Equal(t, []string{"Key1", "Key2", "Key3"}, results)
assert.Nil(t, err)
// Try again without truncation token at all
count = 0
resps[1].IsTruncated = nil
resps[2].IsTruncated = aws.Bool(true)
results = []string{}
err = client.ListObjectsPages(params, func(p *s3.ListObjectsOutput, last bool) bool {
results = append(results, *p.Contents[0].Key)
return true
})
assert.Equal(t, []string{"Key1", "Key2"}, results)
assert.Nil(t, err)
}
// Benchmarks
var benchResps = []*dynamodb.ListTablesOutput{
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE")}},
}
var benchDb = func() *dynamodb.DynamoDB {
db := dynamodb.New(nil)
db.Handlers.Send.Clear() // mock sending
db.Handlers.Unmarshal.Clear()
db.Handlers.UnmarshalMeta.Clear()
db.Handlers.ValidateResponse.Clear()
return db
}
func BenchmarkCodegenIterator(b *testing.B) {
reqNum := 0
db := benchDb()
db.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
r.Data = benchResps[reqNum]
reqNum++
})
input := &dynamodb.ListTablesInput{Limit: aws.Int64(2)}
iter := func(fn func(*dynamodb.ListTablesOutput, bool) bool) error {
page, _ := db.ListTablesRequest(input)
for ; page != nil; page = page.NextPage() {
page.Send()
out := page.Data.(*dynamodb.ListTablesOutput)
if result := fn(out, !page.HasNextPage()); page.Error != nil || !result {
return page.Error
}
}
return nil
}
for i := 0; i < b.N; i++ {
reqNum = 0
iter(func(p *dynamodb.ListTablesOutput, last bool) bool {
return true
})
}
}
func BenchmarkEachPageIterator(b *testing.B) {
reqNum := 0
db := benchDb()
db.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
r.Data = benchResps[reqNum]
reqNum++
})
input := &dynamodb.ListTablesInput{Limit: aws.Int64(2)}
for i := 0; i < b.N; i++ {
reqNum = 0
req, _ := db.ListTablesRequest(input)
req.EachPage(func(p interface{}, last bool) bool {
return true
})
}
}

View File

@@ -0,0 +1,225 @@
package aws
import (
"bytes"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/stretchr/testify/assert"
)
type testData struct {
Data string
}
func body(str string) io.ReadCloser {
return ioutil.NopCloser(bytes.NewReader([]byte(str)))
}
func unmarshal(req *Request) {
defer req.HTTPResponse.Body.Close()
if req.Data != nil {
json.NewDecoder(req.HTTPResponse.Body).Decode(req.Data)
}
return
}
func unmarshalError(req *Request) {
bodyBytes, err := ioutil.ReadAll(req.HTTPResponse.Body)
if err != nil {
req.Error = awserr.New("UnmarshaleError", req.HTTPResponse.Status, err)
return
}
if len(bodyBytes) == 0 {
req.Error = awserr.NewRequestFailure(
awserr.New("UnmarshaleError", req.HTTPResponse.Status, fmt.Errorf("empty body")),
req.HTTPResponse.StatusCode,
"",
)
return
}
var jsonErr jsonErrorResponse
if err := json.Unmarshal(bodyBytes, &jsonErr); err != nil {
req.Error = awserr.New("UnmarshaleError", "JSON unmarshal", err)
return
}
req.Error = awserr.NewRequestFailure(
awserr.New(jsonErr.Code, jsonErr.Message, nil),
req.HTTPResponse.StatusCode,
"",
)
}
type jsonErrorResponse struct {
Code string `json:"__type"`
Message string `json:"message"`
}
// test that retries occur for 5xx status codes
func TestRequestRecoverRetry5xx(t *testing.T) {
reqNum := 0
reqs := []http.Response{
{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
{StatusCode: 501, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
{StatusCode: 200, Body: body(`{"data":"valid"}`)},
}
s := NewService(NewConfig().WithMaxRetries(10))
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *Request) {
r.HTTPResponse = &reqs[reqNum]
reqNum++
})
out := &testData{}
r := NewRequest(s, &Operation{Name: "Operation"}, nil, out)
err := r.Send()
assert.Nil(t, err)
assert.Equal(t, 2, int(r.RetryCount))
assert.Equal(t, "valid", out.Data)
}
// test that retries occur for 4xx status codes with a response type that can be retried - see `shouldRetry`
func TestRequestRecoverRetry4xxRetryable(t *testing.T) {
reqNum := 0
reqs := []http.Response{
{StatusCode: 400, Body: body(`{"__type":"Throttling","message":"Rate exceeded."}`)},
{StatusCode: 429, Body: body(`{"__type":"ProvisionedThroughputExceededException","message":"Rate exceeded."}`)},
{StatusCode: 200, Body: body(`{"data":"valid"}`)},
}
s := NewService(NewConfig().WithMaxRetries(10))
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *Request) {
r.HTTPResponse = &reqs[reqNum]
reqNum++
})
out := &testData{}
r := NewRequest(s, &Operation{Name: "Operation"}, nil, out)
err := r.Send()
assert.Nil(t, err)
assert.Equal(t, 2, int(r.RetryCount))
assert.Equal(t, "valid", out.Data)
}
// test that retries don't occur for 4xx status codes with a response type that can't be retried
func TestRequest4xxUnretryable(t *testing.T) {
s := NewService(NewConfig().WithMaxRetries(10))
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *Request) {
r.HTTPResponse = &http.Response{StatusCode: 401, Body: body(`{"__type":"SignatureDoesNotMatch","message":"Signature does not match."}`)}
})
out := &testData{}
r := NewRequest(s, &Operation{Name: "Operation"}, nil, out)
err := r.Send()
assert.NotNil(t, err)
if e, ok := err.(awserr.RequestFailure); ok {
assert.Equal(t, 401, e.StatusCode())
} else {
assert.Fail(t, "Expected error to be a service failure")
}
assert.Equal(t, "SignatureDoesNotMatch", err.(awserr.Error).Code())
assert.Equal(t, "Signature does not match.", err.(awserr.Error).Message())
assert.Equal(t, 0, int(r.RetryCount))
}
func TestRequestExhaustRetries(t *testing.T) {
delays := []time.Duration{}
sleepDelay = func(delay time.Duration) {
delays = append(delays, delay)
}
reqNum := 0
reqs := []http.Response{
{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
}
s := NewService(NewConfig().WithMaxRetries(DefaultRetries))
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *Request) {
r.HTTPResponse = &reqs[reqNum]
reqNum++
})
r := NewRequest(s, &Operation{Name: "Operation"}, nil, nil)
err := r.Send()
assert.NotNil(t, err)
if e, ok := err.(awserr.RequestFailure); ok {
assert.Equal(t, 500, e.StatusCode())
} else {
assert.Fail(t, "Expected error to be a service failure")
}
assert.Equal(t, "UnknownError", err.(awserr.Error).Code())
assert.Equal(t, "An error occurred.", err.(awserr.Error).Message())
assert.Equal(t, 3, int(r.RetryCount))
expectDelays := []struct{ min, max time.Duration }{{30, 59}, {60, 118}, {120, 236}}
for i, v := range delays {
min := expectDelays[i].min * time.Millisecond
max := expectDelays[i].max * time.Millisecond
assert.True(t, min <= v && v <= max,
"Expect delay to be within range, i:%d, v:%s, min:%s, max:%s", i, v, min, max)
}
}
// test that the request is retried after the credentials are expired.
func TestRequestRecoverExpiredCreds(t *testing.T) {
reqNum := 0
reqs := []http.Response{
{StatusCode: 400, Body: body(`{"__type":"ExpiredTokenException","message":"expired token"}`)},
{StatusCode: 200, Body: body(`{"data":"valid"}`)},
}
s := NewService(&Config{MaxRetries: Int(10), Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "")})
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
credExpiredBeforeRetry := false
credExpiredAfterRetry := false
s.Handlers.AfterRetry.PushBack(func(r *Request) {
credExpiredAfterRetry = r.Config.Credentials.IsExpired()
})
s.Handlers.Sign.Clear()
s.Handlers.Sign.PushBack(func(r *Request) {
r.Config.Credentials.Get()
})
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *Request) {
r.HTTPResponse = &reqs[reqNum]
reqNum++
})
out := &testData{}
r := NewRequest(s, &Operation{Name: "Operation"}, nil, out)
err := r.Send()
assert.Nil(t, err)
assert.False(t, credExpiredBeforeRetry, "Expect valid creds before retry check")
assert.True(t, credExpiredAfterRetry, "Expect expired creds after retry check")
assert.False(t, s.Config.Credentials.IsExpired(), "Expect valid creds after cred expired recovery")
assert.Equal(t, 1, int(r.RetryCount))
assert.Equal(t, "valid", out.Data)
}

View File

@@ -0,0 +1,194 @@
package aws
import (
"fmt"
"math"
"math/rand"
"net/http"
"net/http/httputil"
"regexp"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/internal/endpoints"
)
// A Service implements the base service request and response handling
// used by all services.
type Service struct {
Config *Config
Handlers Handlers
ServiceName string
APIVersion string
Endpoint string
SigningName string
SigningRegion string
JSONVersion string
TargetPrefix string
RetryRules func(*Request) time.Duration
ShouldRetry func(*Request) bool
DefaultMaxRetries uint
}
var schemeRE = regexp.MustCompile("^([^:]+)://")
// NewService will return a pointer to a new Server object initialized.
func NewService(config *Config) *Service {
svc := &Service{Config: config}
svc.Initialize()
return svc
}
// Initialize initializes the service.
func (s *Service) Initialize() {
if s.Config == nil {
s.Config = &Config{}
}
if s.Config.HTTPClient == nil {
s.Config.HTTPClient = http.DefaultClient
}
if s.RetryRules == nil {
s.RetryRules = retryRules
}
if s.ShouldRetry == nil {
s.ShouldRetry = shouldRetry
}
s.DefaultMaxRetries = 3
s.Handlers.Validate.PushBack(ValidateEndpointHandler)
s.Handlers.Build.PushBack(UserAgentHandler)
s.Handlers.Sign.PushBack(BuildContentLength)
s.Handlers.Send.PushBack(SendHandler)
s.Handlers.AfterRetry.PushBack(AfterRetryHandler)
s.Handlers.ValidateResponse.PushBack(ValidateResponseHandler)
s.AddDebugHandlers()
s.buildEndpoint()
if !BoolValue(s.Config.DisableParamValidation) {
s.Handlers.Validate.PushBack(ValidateParameters)
}
}
// buildEndpoint builds the endpoint values the service will use to make requests with.
func (s *Service) buildEndpoint() {
if StringValue(s.Config.Endpoint) != "" {
s.Endpoint = *s.Config.Endpoint
} else {
s.Endpoint, s.SigningRegion =
endpoints.EndpointForRegion(s.ServiceName, StringValue(s.Config.Region))
}
if s.Endpoint != "" && !schemeRE.MatchString(s.Endpoint) {
scheme := "https"
if BoolValue(s.Config.DisableSSL) {
scheme = "http"
}
s.Endpoint = scheme + "://" + s.Endpoint
}
}
// AddDebugHandlers injects debug logging handlers into the service to log request
// debug information.
func (s *Service) AddDebugHandlers() {
if !s.Config.LogLevel.AtLeast(LogDebug) {
return
}
s.Handlers.Send.PushFront(logRequest)
s.Handlers.Send.PushBack(logResponse)
}
const logReqMsg = `DEBUG: Request %s/%s Details:
---[ REQUEST POST-SIGN ]-----------------------------
%s
-----------------------------------------------------`
func logRequest(r *Request) {
logBody := r.Config.LogLevel.Matches(LogDebugWithHTTPBody)
dumpedBody, _ := httputil.DumpRequestOut(r.HTTPRequest, logBody)
r.Config.Logger.Log(fmt.Sprintf(logReqMsg, r.ServiceName, r.Operation.Name, string(dumpedBody)))
}
const logRespMsg = `DEBUG: Response %s/%s Details:
---[ RESPONSE ]--------------------------------------
%s
-----------------------------------------------------`
func logResponse(r *Request) {
var msg = "no reponse data"
if r.HTTPResponse != nil {
logBody := r.Config.LogLevel.Matches(LogDebugWithHTTPBody)
dumpedBody, _ := httputil.DumpResponse(r.HTTPResponse, logBody)
msg = string(dumpedBody)
} else if r.Error != nil {
msg = r.Error.Error()
}
r.Config.Logger.Log(fmt.Sprintf(logRespMsg, r.ServiceName, r.Operation.Name, msg))
}
// MaxRetries returns the number of maximum returns the service will use to make
// an individual API request.
func (s *Service) MaxRetries() uint {
if IntValue(s.Config.MaxRetries) < 0 {
return s.DefaultMaxRetries
}
return uint(IntValue(s.Config.MaxRetries))
}
var seededRand = rand.New(rand.NewSource(time.Now().UnixNano()))
// retryRules returns the delay duration before retrying this request again
func retryRules(r *Request) time.Duration {
delay := int(math.Pow(2, float64(r.RetryCount))) * (seededRand.Intn(30) + 30)
return time.Duration(delay) * time.Millisecond
}
// retryableCodes is a collection of service response codes which are retry-able
// without any further action.
var retryableCodes = map[string]struct{}{
"RequestError": {},
"ProvisionedThroughputExceededException": {},
"Throttling": {},
"ThrottlingException": {},
"RequestLimitExceeded": {},
"RequestThrottled": {},
}
// credsExpiredCodes is a collection of error codes which signify the credentials
// need to be refreshed. Expired tokens require refreshing of credentials, and
// resigning before the request can be retried.
var credsExpiredCodes = map[string]struct{}{
"ExpiredToken": {},
"ExpiredTokenException": {},
"RequestExpired": {}, // EC2 Only
}
func isCodeRetryable(code string) bool {
if _, ok := retryableCodes[code]; ok {
return true
}
return isCodeExpiredCreds(code)
}
func isCodeExpiredCreds(code string) bool {
_, ok := credsExpiredCodes[code]
return ok
}
// shouldRetry returns if the request should be retried.
func shouldRetry(r *Request) bool {
if r.HTTPResponse.StatusCode >= 500 {
return true
}
if r.Error != nil {
if err, ok := r.Error.(awserr.Error); ok {
return isCodeRetryable(err.Code())
}
}
return false
}

View File

@@ -0,0 +1,55 @@
package aws
import (
"io"
)
// ReadSeekCloser wraps a io.Reader returning a ReaderSeakerCloser
func ReadSeekCloser(r io.Reader) ReaderSeekerCloser {
return ReaderSeekerCloser{r}
}
// ReaderSeekerCloser represents a reader that can also delegate io.Seeker and
// io.Closer interfaces to the underlying object if they are available.
type ReaderSeekerCloser struct {
r io.Reader
}
// Read reads from the reader up to size of p. The number of bytes read, and
// error if it occurred will be returned.
//
// If the reader is not an io.Reader zero bytes read, and nil error will be returned.
//
// Performs the same functionality as io.Reader Read
func (r ReaderSeekerCloser) Read(p []byte) (int, error) {
switch t := r.r.(type) {
case io.Reader:
return t.Read(p)
}
return 0, nil
}
// Seek sets the offset for the next Read to offset, interpreted according to
// whence: 0 means relative to the origin of the file, 1 means relative to the
// current offset, and 2 means relative to the end. Seek returns the new offset
// and an error, if any.
//
// If the ReaderSeekerCloser is not an io.Seeker nothing will be done.
func (r ReaderSeekerCloser) Seek(offset int64, whence int) (int64, error) {
switch t := r.r.(type) {
case io.Seeker:
return t.Seek(offset, whence)
}
return int64(0), nil
}
// Close closes the ReaderSeekerCloser.
//
// If the ReaderSeekerCloser is not an io.Closer nothing will be done.
func (r ReaderSeekerCloser) Close() error {
switch t := r.r.(type) {
case io.Closer:
return t.Close()
}
return nil
}

View File

@@ -0,0 +1,8 @@
// Package aws provides core functionality for making requests to AWS services.
package aws
// SDKName is the name of this AWS SDK
const SDKName = "aws-sdk-go"
// SDKVersion is the version of this SDK
const SDKVersion = "0.7.3"

View File

@@ -0,0 +1,31 @@
// Package endpoints validates regional endpoints for services.
package endpoints
//go:generate go run ../model/cli/gen-endpoints/main.go endpoints.json endpoints_map.go
//go:generate gofmt -s -w endpoints_map.go
import "strings"
// EndpointForRegion returns an endpoint and its signing region for a service and region.
// if the service and region pair are not found endpoint and signingRegion will be empty.
func EndpointForRegion(svcName, region string) (endpoint, signingRegion string) {
derivedKeys := []string{
region + "/" + svcName,
region + "/*",
"*/" + svcName,
"*/*",
}
for _, key := range derivedKeys {
if val, ok := endpointsMap.Endpoints[key]; ok {
ep := val.Endpoint
ep = strings.Replace(ep, "{region}", region, -1)
ep = strings.Replace(ep, "{service}", svcName, -1)
endpoint = ep
signingRegion = val.SigningRegion
return
}
}
return
}

View File

@@ -0,0 +1,77 @@
{
"version": 2,
"endpoints": {
"*/*": {
"endpoint": "{service}.{region}.amazonaws.com"
},
"cn-north-1/*": {
"endpoint": "{service}.{region}.amazonaws.com.cn",
"signatureVersion": "v4"
},
"us-gov-west-1/iam": {
"endpoint": "iam.us-gov.amazonaws.com"
},
"us-gov-west-1/sts": {
"endpoint": "sts.us-gov-west-1.amazonaws.com"
},
"us-gov-west-1/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"*/cloudfront": {
"endpoint": "cloudfront.amazonaws.com",
"signingRegion": "us-east-1"
},
"*/cloudsearchdomain": {
"endpoint": "",
"signingRegion": "us-east-1"
},
"*/iam": {
"endpoint": "iam.amazonaws.com",
"signingRegion": "us-east-1"
},
"*/importexport": {
"endpoint": "importexport.amazonaws.com",
"signingRegion": "us-east-1"
},
"*/route53": {
"endpoint": "route53.amazonaws.com",
"signingRegion": "us-east-1"
},
"*/sts": {
"endpoint": "sts.amazonaws.com",
"signingRegion": "us-east-1"
},
"us-east-1/sdb": {
"endpoint": "sdb.amazonaws.com",
"signingRegion": "us-east-1"
},
"us-east-1/s3": {
"endpoint": "s3.amazonaws.com"
},
"us-west-1/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"us-west-2/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"eu-west-1/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"ap-southeast-1/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"ap-southeast-2/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"ap-northeast-1/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"sa-east-1/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"eu-central-1/s3": {
"endpoint": "{service}.{region}.amazonaws.com",
"signatureVersion": "v4"
}
}
}

View File

@@ -0,0 +1,89 @@
package endpoints
// THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT.
type endpointStruct struct {
Version int
Endpoints map[string]endpointEntry
}
type endpointEntry struct {
Endpoint string
SigningRegion string
}
var endpointsMap = endpointStruct{
Version: 2,
Endpoints: map[string]endpointEntry{
"*/*": {
Endpoint: "{service}.{region}.amazonaws.com",
},
"*/cloudfront": {
Endpoint: "cloudfront.amazonaws.com",
SigningRegion: "us-east-1",
},
"*/cloudsearchdomain": {
Endpoint: "",
SigningRegion: "us-east-1",
},
"*/iam": {
Endpoint: "iam.amazonaws.com",
SigningRegion: "us-east-1",
},
"*/importexport": {
Endpoint: "importexport.amazonaws.com",
SigningRegion: "us-east-1",
},
"*/route53": {
Endpoint: "route53.amazonaws.com",
SigningRegion: "us-east-1",
},
"*/sts": {
Endpoint: "sts.amazonaws.com",
SigningRegion: "us-east-1",
},
"ap-northeast-1/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"ap-southeast-1/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"ap-southeast-2/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"cn-north-1/*": {
Endpoint: "{service}.{region}.amazonaws.com.cn",
},
"eu-central-1/s3": {
Endpoint: "{service}.{region}.amazonaws.com",
},
"eu-west-1/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"sa-east-1/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"us-east-1/s3": {
Endpoint: "s3.amazonaws.com",
},
"us-east-1/sdb": {
Endpoint: "sdb.amazonaws.com",
SigningRegion: "us-east-1",
},
"us-gov-west-1/iam": {
Endpoint: "iam.us-gov.amazonaws.com",
},
"us-gov-west-1/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"us-gov-west-1/sts": {
Endpoint: "sts.us-gov-west-1.amazonaws.com",
},
"us-west-1/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"us-west-2/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
},
}

View File

@@ -0,0 +1,28 @@
package endpoints
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestGlobalEndpoints(t *testing.T) {
region := "mock-region-1"
svcs := []string{"cloudfront", "iam", "importexport", "route53", "sts"}
for _, name := range svcs {
ep, sr := EndpointForRegion(name, region)
assert.Equal(t, name+".amazonaws.com", ep)
assert.Equal(t, "us-east-1", sr)
}
}
func TestServicesInCN(t *testing.T) {
region := "cn-north-1"
svcs := []string{"cloudfront", "iam", "importexport", "route53", "sts", "s3"}
for _, name := range svcs {
ep, _ := EndpointForRegion(name, region)
assert.Equal(t, name+"."+region+".amazonaws.com.cn", ep)
}
}

View File

@@ -0,0 +1,33 @@
// Package query provides serialisation of AWS query requests, and responses.
package query
//go:generate go run ../../fixtures/protocol/generate.go ../../fixtures/protocol/input/query.json build_test.go
import (
"net/url"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/internal/protocol/query/queryutil"
)
// Build builds a request for an AWS Query service.
func Build(r *aws.Request) {
body := url.Values{
"Action": {r.Operation.Name},
"Version": {r.Service.APIVersion},
}
if err := queryutil.Parse(body, r.Params, false); err != nil {
r.Error = awserr.New("SerializationError", "failed encoding Query request", err)
return
}
if r.ExpireTime == 0 {
r.HTTPRequest.Method = "POST"
r.HTTPRequest.Header.Set("Content-Type", "application/x-www-form-urlencoded; charset=utf-8")
r.SetBufferBody([]byte(body.Encode()))
} else { // This is a pre-signed request
r.HTTPRequest.Method = "GET"
r.HTTPRequest.URL.RawQuery = body.Encode()
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,223 @@
package queryutil
import (
"encoding/base64"
"fmt"
"net/url"
"reflect"
"sort"
"strconv"
"strings"
"time"
)
// Parse parses an object i and fills a url.Values object. The isEC2 flag
// indicates if this is the EC2 Query sub-protocol.
func Parse(body url.Values, i interface{}, isEC2 bool) error {
q := queryParser{isEC2: isEC2}
return q.parseValue(body, reflect.ValueOf(i), "", "")
}
func elemOf(value reflect.Value) reflect.Value {
for value.Kind() == reflect.Ptr {
value = value.Elem()
}
return value
}
type queryParser struct {
isEC2 bool
}
func (q *queryParser) parseValue(v url.Values, value reflect.Value, prefix string, tag reflect.StructTag) error {
value = elemOf(value)
// no need to handle zero values
if !value.IsValid() {
return nil
}
t := tag.Get("type")
if t == "" {
switch value.Kind() {
case reflect.Struct:
t = "structure"
case reflect.Slice:
t = "list"
case reflect.Map:
t = "map"
}
}
switch t {
case "structure":
return q.parseStruct(v, value, prefix)
case "list":
return q.parseList(v, value, prefix, tag)
case "map":
return q.parseMap(v, value, prefix, tag)
default:
return q.parseScalar(v, value, prefix, tag)
}
}
func (q *queryParser) parseStruct(v url.Values, value reflect.Value, prefix string) error {
if !value.IsValid() {
return nil
}
t := value.Type()
for i := 0; i < value.NumField(); i++ {
if c := t.Field(i).Name[0:1]; strings.ToLower(c) == c {
continue // ignore unexported fields
}
value := elemOf(value.Field(i))
field := t.Field(i)
var name string
if q.isEC2 {
name = field.Tag.Get("queryName")
}
if name == "" {
if field.Tag.Get("flattened") != "" && field.Tag.Get("locationNameList") != "" {
name = field.Tag.Get("locationNameList")
} else if locName := field.Tag.Get("locationName"); locName != "" {
name = locName
}
if name != "" && q.isEC2 {
name = strings.ToUpper(name[0:1]) + name[1:]
}
}
if name == "" {
name = field.Name
}
if prefix != "" {
name = prefix + "." + name
}
if err := q.parseValue(v, value, name, field.Tag); err != nil {
return err
}
}
return nil
}
func (q *queryParser) parseList(v url.Values, value reflect.Value, prefix string, tag reflect.StructTag) error {
// If it's empty, generate an empty value
if !value.IsNil() && value.Len() == 0 {
v.Set(prefix, "")
return nil
}
// check for unflattened list member
if !q.isEC2 && tag.Get("flattened") == "" {
prefix += ".member"
}
for i := 0; i < value.Len(); i++ {
slicePrefix := prefix
if slicePrefix == "" {
slicePrefix = strconv.Itoa(i + 1)
} else {
slicePrefix = slicePrefix + "." + strconv.Itoa(i+1)
}
if err := q.parseValue(v, value.Index(i), slicePrefix, ""); err != nil {
return err
}
}
return nil
}
func (q *queryParser) parseMap(v url.Values, value reflect.Value, prefix string, tag reflect.StructTag) error {
// If it's empty, generate an empty value
if !value.IsNil() && value.Len() == 0 {
v.Set(prefix, "")
return nil
}
// check for unflattened list member
if !q.isEC2 && tag.Get("flattened") == "" {
prefix += ".entry"
}
// sort keys for improved serialization consistency.
// this is not strictly necessary for protocol support.
mapKeyValues := value.MapKeys()
mapKeys := map[string]reflect.Value{}
mapKeyNames := make([]string, len(mapKeyValues))
for i, mapKey := range mapKeyValues {
name := mapKey.String()
mapKeys[name] = mapKey
mapKeyNames[i] = name
}
sort.Strings(mapKeyNames)
for i, mapKeyName := range mapKeyNames {
mapKey := mapKeys[mapKeyName]
mapValue := value.MapIndex(mapKey)
kname := tag.Get("locationNameKey")
if kname == "" {
kname = "key"
}
vname := tag.Get("locationNameValue")
if vname == "" {
vname = "value"
}
// serialize key
var keyName string
if prefix == "" {
keyName = strconv.Itoa(i+1) + "." + kname
} else {
keyName = prefix + "." + strconv.Itoa(i+1) + "." + kname
}
if err := q.parseValue(v, mapKey, keyName, ""); err != nil {
return err
}
// serialize value
var valueName string
if prefix == "" {
valueName = strconv.Itoa(i+1) + "." + vname
} else {
valueName = prefix + "." + strconv.Itoa(i+1) + "." + vname
}
if err := q.parseValue(v, mapValue, valueName, ""); err != nil {
return err
}
}
return nil
}
func (q *queryParser) parseScalar(v url.Values, r reflect.Value, name string, tag reflect.StructTag) error {
switch value := r.Interface().(type) {
case string:
v.Set(name, value)
case []byte:
if !r.IsNil() {
v.Set(name, base64.StdEncoding.EncodeToString(value))
}
case bool:
v.Set(name, strconv.FormatBool(value))
case int64:
v.Set(name, strconv.FormatInt(value, 10))
case int:
v.Set(name, strconv.Itoa(value))
case float64:
v.Set(name, strconv.FormatFloat(value, 'f', -1, 64))
case float32:
v.Set(name, strconv.FormatFloat(float64(value), 'f', -1, 32))
case time.Time:
const ISO8601UTC = "2006-01-02T15:04:05Z"
v.Set(name, value.UTC().Format(ISO8601UTC))
default:
return fmt.Errorf("unsupported value for param %s: %v (%s)", name, r.Interface(), r.Type().Name())
}
return nil
}

View File

@@ -0,0 +1,29 @@
package query
//go:generate go run ../../fixtures/protocol/generate.go ../../fixtures/protocol/output/query.json unmarshal_test.go
import (
"encoding/xml"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/internal/protocol/xml/xmlutil"
)
// Unmarshal unmarshals a response for an AWS Query service.
func Unmarshal(r *aws.Request) {
defer r.HTTPResponse.Body.Close()
if r.DataFilled() {
decoder := xml.NewDecoder(r.HTTPResponse.Body)
err := xmlutil.UnmarshalXML(r.Data, decoder, r.Operation.Name+"Result")
if err != nil {
r.Error = awserr.New("SerializationError", "failed decoding Query response", err)
return
}
}
}
// UnmarshalMeta unmarshals header response values for an AWS Query service.
func UnmarshalMeta(r *aws.Request) {
// TODO implement unmarshaling of request IDs
}

View File

@@ -0,0 +1,33 @@
package query
import (
"encoding/xml"
"io"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
)
type xmlErrorResponse struct {
XMLName xml.Name `xml:"ErrorResponse"`
Code string `xml:"Error>Code"`
Message string `xml:"Error>Message"`
RequestID string `xml:"RequestId"`
}
// UnmarshalError unmarshals an error response for an AWS Query service.
func UnmarshalError(r *aws.Request) {
defer r.HTTPResponse.Body.Close()
resp := &xmlErrorResponse{}
err := xml.NewDecoder(r.HTTPResponse.Body).Decode(resp)
if err != nil && err != io.EOF {
r.Error = awserr.New("SerializationError", "failed to decode query XML error response", err)
} else {
r.Error = awserr.NewRequestFailure(
awserr.New(resp.Code, resp.Message, nil),
r.HTTPResponse.StatusCode,
resp.RequestID,
)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,217 @@
// Package rest provides RESTful serialisation of AWS requests and responses.
package rest
import (
"bytes"
"encoding/base64"
"fmt"
"io"
"net/url"
"path"
"reflect"
"strconv"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
)
// RFC822 returns an RFC822 formatted timestamp for AWS protocols
const RFC822 = "Mon, 2 Jan 2006 15:04:05 GMT"
// Whether the byte value can be sent without escaping in AWS URLs
var noEscape [256]bool
func init() {
for i := 0; i < len(noEscape); i++ {
// AWS expects every character except these to be escaped
noEscape[i] = (i >= 'A' && i <= 'Z') ||
(i >= 'a' && i <= 'z') ||
(i >= '0' && i <= '9') ||
i == '-' ||
i == '.' ||
i == '_' ||
i == '~'
}
}
// Build builds the REST component of a service request.
func Build(r *aws.Request) {
if r.ParamsFilled() {
v := reflect.ValueOf(r.Params).Elem()
buildLocationElements(r, v)
buildBody(r, v)
}
}
func buildLocationElements(r *aws.Request, v reflect.Value) {
query := r.HTTPRequest.URL.Query()
for i := 0; i < v.NumField(); i++ {
m := v.Field(i)
if n := v.Type().Field(i).Name; n[0:1] == strings.ToLower(n[0:1]) {
continue
}
if m.IsValid() {
field := v.Type().Field(i)
name := field.Tag.Get("locationName")
if name == "" {
name = field.Name
}
if m.Kind() == reflect.Ptr {
m = m.Elem()
}
if !m.IsValid() {
continue
}
switch field.Tag.Get("location") {
case "headers": // header maps
buildHeaderMap(r, m, field.Tag.Get("locationName"))
case "header":
buildHeader(r, m, name)
case "uri":
buildURI(r, m, name)
case "querystring":
buildQueryString(r, m, name, query)
}
}
if r.Error != nil {
return
}
}
r.HTTPRequest.URL.RawQuery = query.Encode()
updatePath(r.HTTPRequest.URL, r.HTTPRequest.URL.Path)
}
func buildBody(r *aws.Request, v reflect.Value) {
if field, ok := v.Type().FieldByName("SDKShapeTraits"); ok {
if payloadName := field.Tag.Get("payload"); payloadName != "" {
pfield, _ := v.Type().FieldByName(payloadName)
if ptag := pfield.Tag.Get("type"); ptag != "" && ptag != "structure" {
payload := reflect.Indirect(v.FieldByName(payloadName))
if payload.IsValid() && payload.Interface() != nil {
switch reader := payload.Interface().(type) {
case io.ReadSeeker:
r.SetReaderBody(reader)
case []byte:
r.SetBufferBody(reader)
case string:
r.SetStringBody(reader)
default:
r.Error = awserr.New("SerializationError",
"failed to encode REST request",
fmt.Errorf("unknown payload type %s", payload.Type()))
}
}
}
}
}
}
func buildHeader(r *aws.Request, v reflect.Value, name string) {
str, err := convertType(v)
if err != nil {
r.Error = awserr.New("SerializationError", "failed to encode REST request", err)
} else if str != nil {
r.HTTPRequest.Header.Add(name, *str)
}
}
func buildHeaderMap(r *aws.Request, v reflect.Value, prefix string) {
for _, key := range v.MapKeys() {
str, err := convertType(v.MapIndex(key))
if err != nil {
r.Error = awserr.New("SerializationError", "failed to encode REST request", err)
} else if str != nil {
r.HTTPRequest.Header.Add(prefix+key.String(), *str)
}
}
}
func buildURI(r *aws.Request, v reflect.Value, name string) {
value, err := convertType(v)
if err != nil {
r.Error = awserr.New("SerializationError", "failed to encode REST request", err)
} else if value != nil {
uri := r.HTTPRequest.URL.Path
uri = strings.Replace(uri, "{"+name+"}", EscapePath(*value, true), -1)
uri = strings.Replace(uri, "{"+name+"+}", EscapePath(*value, false), -1)
r.HTTPRequest.URL.Path = uri
}
}
func buildQueryString(r *aws.Request, v reflect.Value, name string, query url.Values) {
str, err := convertType(v)
if err != nil {
r.Error = awserr.New("SerializationError", "failed to encode REST request", err)
} else if str != nil {
query.Set(name, *str)
}
}
func updatePath(url *url.URL, urlPath string) {
scheme, query := url.Scheme, url.RawQuery
hasSlash := strings.HasSuffix(urlPath, "/")
// clean up path
urlPath = path.Clean(urlPath)
if hasSlash && !strings.HasSuffix(urlPath, "/") {
urlPath += "/"
}
// get formatted URL minus scheme so we can build this into Opaque
url.Scheme, url.Path, url.RawQuery = "", "", ""
s := url.String()
url.Scheme = scheme
url.RawQuery = query
// build opaque URI
url.Opaque = s + urlPath
}
// EscapePath escapes part of a URL path in Amazon style
func EscapePath(path string, encodeSep bool) string {
var buf bytes.Buffer
for i := 0; i < len(path); i++ {
c := path[i]
if noEscape[c] || (c == '/' && !encodeSep) {
buf.WriteByte(c)
} else {
buf.WriteByte('%')
buf.WriteString(strings.ToUpper(strconv.FormatUint(uint64(c), 16)))
}
}
return buf.String()
}
func convertType(v reflect.Value) (*string, error) {
v = reflect.Indirect(v)
if !v.IsValid() {
return nil, nil
}
var str string
switch value := v.Interface().(type) {
case string:
str = value
case []byte:
str = base64.StdEncoding.EncodeToString(value)
case bool:
str = strconv.FormatBool(value)
case int64:
str = strconv.FormatInt(value, 10)
case float64:
str = strconv.FormatFloat(value, 'f', -1, 64)
case time.Time:
str = value.UTC().Format(RFC822)
default:
err := fmt.Errorf("Unsupported value for param %v (%s)", v.Interface(), v.Type())
return nil, err
}
return &str, nil
}

View File

@@ -0,0 +1,45 @@
package rest
import "reflect"
// PayloadMember returns the payload field member of i if there is one, or nil.
func PayloadMember(i interface{}) interface{} {
if i == nil {
return nil
}
v := reflect.ValueOf(i).Elem()
if !v.IsValid() {
return nil
}
if field, ok := v.Type().FieldByName("SDKShapeTraits"); ok {
if payloadName := field.Tag.Get("payload"); payloadName != "" {
field, _ := v.Type().FieldByName(payloadName)
if field.Tag.Get("type") != "structure" {
return nil
}
payload := v.FieldByName(payloadName)
if payload.IsValid() || (payload.Kind() == reflect.Ptr && !payload.IsNil()) {
return payload.Interface()
}
}
}
return nil
}
// PayloadType returns the type of a payload field member of i if there is one, or "".
func PayloadType(i interface{}) string {
v := reflect.Indirect(reflect.ValueOf(i))
if !v.IsValid() {
return ""
}
if field, ok := v.Type().FieldByName("SDKShapeTraits"); ok {
if payloadName := field.Tag.Get("payload"); payloadName != "" {
if member, ok := v.Type().FieldByName(payloadName); ok {
return member.Tag.Get("type")
}
}
}
return ""
}

View File

@@ -0,0 +1,174 @@
package rest
import (
"encoding/base64"
"fmt"
"io/ioutil"
"net/http"
"reflect"
"strconv"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
)
// Unmarshal unmarshals the REST component of a response in a REST service.
func Unmarshal(r *aws.Request) {
if r.DataFilled() {
v := reflect.Indirect(reflect.ValueOf(r.Data))
unmarshalBody(r, v)
unmarshalLocationElements(r, v)
}
}
func unmarshalBody(r *aws.Request, v reflect.Value) {
if field, ok := v.Type().FieldByName("SDKShapeTraits"); ok {
if payloadName := field.Tag.Get("payload"); payloadName != "" {
pfield, _ := v.Type().FieldByName(payloadName)
if ptag := pfield.Tag.Get("type"); ptag != "" && ptag != "structure" {
payload := v.FieldByName(payloadName)
if payload.IsValid() {
switch payload.Interface().(type) {
case []byte:
b, err := ioutil.ReadAll(r.HTTPResponse.Body)
if err != nil {
r.Error = awserr.New("SerializationError", "failed to decode REST response", err)
} else {
payload.Set(reflect.ValueOf(b))
}
case *string:
b, err := ioutil.ReadAll(r.HTTPResponse.Body)
if err != nil {
r.Error = awserr.New("SerializationError", "failed to decode REST response", err)
} else {
str := string(b)
payload.Set(reflect.ValueOf(&str))
}
default:
switch payload.Type().String() {
case "io.ReadSeeker":
payload.Set(reflect.ValueOf(aws.ReadSeekCloser(r.HTTPResponse.Body)))
case "aws.ReadSeekCloser", "io.ReadCloser":
payload.Set(reflect.ValueOf(r.HTTPResponse.Body))
default:
r.Error = awserr.New("SerializationError",
"failed to decode REST response",
fmt.Errorf("unknown payload type %s", payload.Type()))
}
}
}
}
}
}
}
func unmarshalLocationElements(r *aws.Request, v reflect.Value) {
for i := 0; i < v.NumField(); i++ {
m, field := v.Field(i), v.Type().Field(i)
if n := field.Name; n[0:1] == strings.ToLower(n[0:1]) {
continue
}
if m.IsValid() {
name := field.Tag.Get("locationName")
if name == "" {
name = field.Name
}
switch field.Tag.Get("location") {
case "statusCode":
unmarshalStatusCode(m, r.HTTPResponse.StatusCode)
case "header":
err := unmarshalHeader(m, r.HTTPResponse.Header.Get(name))
if err != nil {
r.Error = awserr.New("SerializationError", "failed to decode REST response", err)
break
}
case "headers":
prefix := field.Tag.Get("locationName")
err := unmarshalHeaderMap(m, r.HTTPResponse.Header, prefix)
if err != nil {
r.Error = awserr.New("SerializationError", "failed to decode REST response", err)
break
}
}
}
if r.Error != nil {
return
}
}
}
func unmarshalStatusCode(v reflect.Value, statusCode int) {
if !v.IsValid() {
return
}
switch v.Interface().(type) {
case *int64:
s := int64(statusCode)
v.Set(reflect.ValueOf(&s))
}
}
func unmarshalHeaderMap(r reflect.Value, headers http.Header, prefix string) error {
switch r.Interface().(type) {
case map[string]*string: // we only support string map value types
out := map[string]*string{}
for k, v := range headers {
k = http.CanonicalHeaderKey(k)
if strings.HasPrefix(strings.ToLower(k), strings.ToLower(prefix)) {
out[k[len(prefix):]] = &v[0]
}
}
r.Set(reflect.ValueOf(out))
}
return nil
}
func unmarshalHeader(v reflect.Value, header string) error {
if !v.IsValid() || (header == "" && v.Elem().Kind() != reflect.String) {
return nil
}
switch v.Interface().(type) {
case *string:
v.Set(reflect.ValueOf(&header))
case []byte:
b, err := base64.StdEncoding.DecodeString(header)
if err != nil {
return err
}
v.Set(reflect.ValueOf(&b))
case *bool:
b, err := strconv.ParseBool(header)
if err != nil {
return err
}
v.Set(reflect.ValueOf(&b))
case *int64:
i, err := strconv.ParseInt(header, 10, 64)
if err != nil {
return err
}
v.Set(reflect.ValueOf(&i))
case *float64:
f, err := strconv.ParseFloat(header, 64)
if err != nil {
return err
}
v.Set(reflect.ValueOf(&f))
case *time.Time:
t, err := time.Parse(RFC822, header)
if err != nil {
return err
}
v.Set(reflect.ValueOf(&t))
default:
err := fmt.Errorf("Unsupported value for param %v (%s)", v.Interface(), v.Type())
return err
}
return nil
}

View File

@@ -0,0 +1,287 @@
// Package xmlutil provides XML serialisation of AWS requests and responses.
package xmlutil
import (
"encoding/base64"
"encoding/xml"
"fmt"
"reflect"
"sort"
"strconv"
"strings"
"time"
)
// BuildXML will serialize params into an xml.Encoder.
// Error will be returned if the serialization of any of the params or nested values fails.
func BuildXML(params interface{}, e *xml.Encoder) error {
b := xmlBuilder{encoder: e, namespaces: map[string]string{}}
root := NewXMLElement(xml.Name{})
if err := b.buildValue(reflect.ValueOf(params), root, ""); err != nil {
return err
}
for _, c := range root.Children {
for _, v := range c {
return StructToXML(e, v, false)
}
}
return nil
}
// Returns the reflection element of a value, if it is a pointer.
func elemOf(value reflect.Value) reflect.Value {
for value.Kind() == reflect.Ptr {
value = value.Elem()
}
return value
}
// A xmlBuilder serializes values from Go code to XML
type xmlBuilder struct {
encoder *xml.Encoder
namespaces map[string]string
}
// buildValue generic XMLNode builder for any type. Will build value for their specific type
// struct, list, map, scalar.
//
// Also takes a "type" tag value to set what type a value should be converted to XMLNode as. If
// type is not provided reflect will be used to determine the value's type.
func (b *xmlBuilder) buildValue(value reflect.Value, current *XMLNode, tag reflect.StructTag) error {
value = elemOf(value)
if !value.IsValid() { // no need to handle zero values
return nil
} else if tag.Get("location") != "" { // don't handle non-body location values
return nil
}
t := tag.Get("type")
if t == "" {
switch value.Kind() {
case reflect.Struct:
t = "structure"
case reflect.Slice:
t = "list"
case reflect.Map:
t = "map"
}
}
switch t {
case "structure":
if field, ok := value.Type().FieldByName("SDKShapeTraits"); ok {
tag = tag + reflect.StructTag(" ") + field.Tag
}
return b.buildStruct(value, current, tag)
case "list":
return b.buildList(value, current, tag)
case "map":
return b.buildMap(value, current, tag)
default:
return b.buildScalar(value, current, tag)
}
}
// buildStruct adds a struct and its fields to the current XMLNode. All fields any any nested
// types are converted to XMLNodes also.
func (b *xmlBuilder) buildStruct(value reflect.Value, current *XMLNode, tag reflect.StructTag) error {
if !value.IsValid() {
return nil
}
fieldAdded := false
// unwrap payloads
if payload := tag.Get("payload"); payload != "" {
field, _ := value.Type().FieldByName(payload)
tag = field.Tag
value = elemOf(value.FieldByName(payload))
if !value.IsValid() {
return nil
}
}
child := NewXMLElement(xml.Name{Local: tag.Get("locationName")})
// there is an xmlNamespace associated with this struct
if prefix, uri := tag.Get("xmlPrefix"), tag.Get("xmlURI"); uri != "" {
ns := xml.Attr{
Name: xml.Name{Local: "xmlns"},
Value: uri,
}
if prefix != "" {
b.namespaces[prefix] = uri // register the namespace
ns.Name.Local = "xmlns:" + prefix
}
child.Attr = append(child.Attr, ns)
}
t := value.Type()
for i := 0; i < value.NumField(); i++ {
if c := t.Field(i).Name[0:1]; strings.ToLower(c) == c {
continue // ignore unexported fields
}
member := elemOf(value.Field(i))
field := t.Field(i)
mTag := field.Tag
if mTag.Get("location") != "" { // skip non-body members
continue
}
memberName := mTag.Get("locationName")
if memberName == "" {
memberName = field.Name
mTag = reflect.StructTag(string(mTag) + ` locationName:"` + memberName + `"`)
}
if err := b.buildValue(member, child, mTag); err != nil {
return err
}
fieldAdded = true
}
if fieldAdded { // only append this child if we have one ore more valid members
current.AddChild(child)
}
return nil
}
// buildList adds the value's list items to the current XMLNode as children nodes. All
// nested values in the list are converted to XMLNodes also.
func (b *xmlBuilder) buildList(value reflect.Value, current *XMLNode, tag reflect.StructTag) error {
if value.IsNil() { // don't build omitted lists
return nil
}
// check for unflattened list member
flattened := tag.Get("flattened") != ""
xname := xml.Name{Local: tag.Get("locationName")}
if flattened {
for i := 0; i < value.Len(); i++ {
child := NewXMLElement(xname)
current.AddChild(child)
if err := b.buildValue(value.Index(i), child, ""); err != nil {
return err
}
}
} else {
list := NewXMLElement(xname)
current.AddChild(list)
for i := 0; i < value.Len(); i++ {
iname := tag.Get("locationNameList")
if iname == "" {
iname = "member"
}
child := NewXMLElement(xml.Name{Local: iname})
list.AddChild(child)
if err := b.buildValue(value.Index(i), child, ""); err != nil {
return err
}
}
}
return nil
}
// buildMap adds the value's key/value pairs to the current XMLNode as children nodes. All
// nested values in the map are converted to XMLNodes also.
//
// Error will be returned if it is unable to build the map's values into XMLNodes
func (b *xmlBuilder) buildMap(value reflect.Value, current *XMLNode, tag reflect.StructTag) error {
if value.IsNil() { // don't build omitted maps
return nil
}
maproot := NewXMLElement(xml.Name{Local: tag.Get("locationName")})
current.AddChild(maproot)
current = maproot
kname, vname := "key", "value"
if n := tag.Get("locationNameKey"); n != "" {
kname = n
}
if n := tag.Get("locationNameValue"); n != "" {
vname = n
}
// sorting is not required for compliance, but it makes testing easier
keys := make([]string, value.Len())
for i, k := range value.MapKeys() {
keys[i] = k.String()
}
sort.Strings(keys)
for _, k := range keys {
v := value.MapIndex(reflect.ValueOf(k))
mapcur := current
if tag.Get("flattened") == "" { // add "entry" tag to non-flat maps
child := NewXMLElement(xml.Name{Local: "entry"})
mapcur.AddChild(child)
mapcur = child
}
kchild := NewXMLElement(xml.Name{Local: kname})
kchild.Text = k
vchild := NewXMLElement(xml.Name{Local: vname})
mapcur.AddChild(kchild)
mapcur.AddChild(vchild)
if err := b.buildValue(v, vchild, ""); err != nil {
return err
}
}
return nil
}
// buildScalar will convert the value into a string and append it as a attribute or child
// of the current XMLNode.
//
// The value will be added as an attribute if tag contains a "xmlAttribute" attribute value.
//
// Error will be returned if the value type is unsupported.
func (b *xmlBuilder) buildScalar(value reflect.Value, current *XMLNode, tag reflect.StructTag) error {
var str string
switch converted := value.Interface().(type) {
case string:
str = converted
case []byte:
if !value.IsNil() {
str = base64.StdEncoding.EncodeToString(converted)
}
case bool:
str = strconv.FormatBool(converted)
case int64:
str = strconv.FormatInt(converted, 10)
case int:
str = strconv.Itoa(converted)
case float64:
str = strconv.FormatFloat(converted, 'f', -1, 64)
case float32:
str = strconv.FormatFloat(float64(converted), 'f', -1, 32)
case time.Time:
const ISO8601UTC = "2006-01-02T15:04:05Z"
str = converted.UTC().Format(ISO8601UTC)
default:
return fmt.Errorf("unsupported value for param %s: %v (%s)",
tag.Get("locationName"), value.Interface(), value.Type().Name())
}
xname := xml.Name{Local: tag.Get("locationName")}
if tag.Get("xmlAttribute") != "" { // put into current node's attribute list
attr := xml.Attr{Name: xname, Value: str}
current.Attr = append(current.Attr, attr)
} else { // regular text node
current.AddChild(&XMLNode{Name: xname, Text: str})
}
return nil
}

View File

@@ -0,0 +1,260 @@
package xmlutil
import (
"encoding/base64"
"encoding/xml"
"fmt"
"io"
"reflect"
"strconv"
"strings"
"time"
)
// UnmarshalXML deserializes an xml.Decoder into the container v. V
// needs to match the shape of the XML expected to be decoded.
// If the shape doesn't match unmarshaling will fail.
func UnmarshalXML(v interface{}, d *xml.Decoder, wrapper string) error {
n, _ := XMLToStruct(d, nil)
if n.Children != nil {
for _, root := range n.Children {
for _, c := range root {
if wrappedChild, ok := c.Children[wrapper]; ok {
c = wrappedChild[0] // pull out wrapped element
}
err := parse(reflect.ValueOf(v), c, "")
if err != nil {
if err == io.EOF {
return nil
}
return err
}
}
}
return nil
}
return nil
}
// parse deserializes any value from the XMLNode. The type tag is used to infer the type, or reflect
// will be used to determine the type from r.
func parse(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
rtype := r.Type()
if rtype.Kind() == reflect.Ptr {
rtype = rtype.Elem() // check kind of actual element type
}
t := tag.Get("type")
if t == "" {
switch rtype.Kind() {
case reflect.Struct:
t = "structure"
case reflect.Slice:
t = "list"
case reflect.Map:
t = "map"
}
}
switch t {
case "structure":
if field, ok := rtype.FieldByName("SDKShapeTraits"); ok {
tag = field.Tag
}
return parseStruct(r, node, tag)
case "list":
return parseList(r, node, tag)
case "map":
return parseMap(r, node, tag)
default:
return parseScalar(r, node, tag)
}
}
// parseStruct deserializes a structure and its fields from an XMLNode. Any nested
// types in the structure will also be deserialized.
func parseStruct(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
t := r.Type()
if r.Kind() == reflect.Ptr {
if r.IsNil() { // create the structure if it's nil
s := reflect.New(r.Type().Elem())
r.Set(s)
r = s
}
r = r.Elem()
t = t.Elem()
}
// unwrap any payloads
if payload := tag.Get("payload"); payload != "" {
field, _ := t.FieldByName(payload)
return parseStruct(r.FieldByName(payload), node, field.Tag)
}
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
if c := field.Name[0:1]; strings.ToLower(c) == c {
continue // ignore unexported fields
}
// figure out what this field is called
name := field.Name
if field.Tag.Get("flattened") != "" && field.Tag.Get("locationNameList") != "" {
name = field.Tag.Get("locationNameList")
} else if locName := field.Tag.Get("locationName"); locName != "" {
name = locName
}
// try to find the field by name in elements
elems := node.Children[name]
if elems == nil { // try to find the field in attributes
for _, a := range node.Attr {
if name == a.Name.Local {
// turn this into a text node for de-serializing
elems = []*XMLNode{{Text: a.Value}}
}
}
}
member := r.FieldByName(field.Name)
for _, elem := range elems {
err := parse(member, elem, field.Tag)
if err != nil {
return err
}
}
}
return nil
}
// parseList deserializes a list of values from an XML node. Each list entry
// will also be deserialized.
func parseList(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
t := r.Type()
if tag.Get("flattened") == "" { // look at all item entries
mname := "member"
if name := tag.Get("locationNameList"); name != "" {
mname = name
}
if Children, ok := node.Children[mname]; ok {
if r.IsNil() {
r.Set(reflect.MakeSlice(t, len(Children), len(Children)))
}
for i, c := range Children {
err := parse(r.Index(i), c, "")
if err != nil {
return err
}
}
}
} else { // flattened list means this is a single element
if r.IsNil() {
r.Set(reflect.MakeSlice(t, 0, 0))
}
childR := reflect.Zero(t.Elem())
r.Set(reflect.Append(r, childR))
err := parse(r.Index(r.Len()-1), node, "")
if err != nil {
return err
}
}
return nil
}
// parseMap deserializes a map from an XMLNode. The direct children of the XMLNode
// will also be deserialized as map entries.
func parseMap(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
if r.IsNil() {
r.Set(reflect.MakeMap(r.Type()))
}
if tag.Get("flattened") == "" { // look at all child entries
for _, entry := range node.Children["entry"] {
parseMapEntry(r, entry, tag)
}
} else { // this element is itself an entry
parseMapEntry(r, node, tag)
}
return nil
}
// parseMapEntry deserializes a map entry from a XML node.
func parseMapEntry(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
kname, vname := "key", "value"
if n := tag.Get("locationNameKey"); n != "" {
kname = n
}
if n := tag.Get("locationNameValue"); n != "" {
vname = n
}
keys, ok := node.Children[kname]
values := node.Children[vname]
if ok {
for i, key := range keys {
keyR := reflect.ValueOf(key.Text)
value := values[i]
valueR := reflect.New(r.Type().Elem()).Elem()
parse(valueR, value, "")
r.SetMapIndex(keyR, valueR)
}
}
return nil
}
// parseScaller deserializes an XMLNode value into a concrete type based on the
// interface type of r.
//
// Error is returned if the deserialization fails due to invalid type conversion,
// or unsupported interface type.
func parseScalar(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
switch r.Interface().(type) {
case *string:
r.Set(reflect.ValueOf(&node.Text))
return nil
case []byte:
b, err := base64.StdEncoding.DecodeString(node.Text)
if err != nil {
return err
}
r.Set(reflect.ValueOf(b))
case *bool:
v, err := strconv.ParseBool(node.Text)
if err != nil {
return err
}
r.Set(reflect.ValueOf(&v))
case *int64:
v, err := strconv.ParseInt(node.Text, 10, 64)
if err != nil {
return err
}
r.Set(reflect.ValueOf(&v))
case *float64:
v, err := strconv.ParseFloat(node.Text, 64)
if err != nil {
return err
}
r.Set(reflect.ValueOf(&v))
case *time.Time:
const ISO8601UTC = "2006-01-02T15:04:05Z"
t, err := time.Parse(ISO8601UTC, node.Text)
if err != nil {
return err
}
r.Set(reflect.ValueOf(&t))
default:
return fmt.Errorf("unsupported value: %v (%s)", r.Interface(), r.Type())
}
return nil
}

View File

@@ -0,0 +1,105 @@
package xmlutil
import (
"encoding/xml"
"io"
"sort"
)
// A XMLNode contains the values to be encoded or decoded.
type XMLNode struct {
Name xml.Name `json:",omitempty"`
Children map[string][]*XMLNode `json:",omitempty"`
Text string `json:",omitempty"`
Attr []xml.Attr `json:",omitempty"`
}
// NewXMLElement returns a pointer to a new XMLNode initialized to default values.
func NewXMLElement(name xml.Name) *XMLNode {
return &XMLNode{
Name: name,
Children: map[string][]*XMLNode{},
Attr: []xml.Attr{},
}
}
// AddChild adds child to the XMLNode.
func (n *XMLNode) AddChild(child *XMLNode) {
if _, ok := n.Children[child.Name.Local]; !ok {
n.Children[child.Name.Local] = []*XMLNode{}
}
n.Children[child.Name.Local] = append(n.Children[child.Name.Local], child)
}
// XMLToStruct converts a xml.Decoder stream to XMLNode with nested values.
func XMLToStruct(d *xml.Decoder, s *xml.StartElement) (*XMLNode, error) {
out := &XMLNode{}
for {
tok, err := d.Token()
if tok == nil || err == io.EOF {
break
}
if err != nil {
return out, err
}
switch typed := tok.(type) {
case xml.CharData:
out.Text = string(typed.Copy())
case xml.StartElement:
el := typed.Copy()
out.Attr = el.Attr
if out.Children == nil {
out.Children = map[string][]*XMLNode{}
}
name := typed.Name.Local
slice := out.Children[name]
if slice == nil {
slice = []*XMLNode{}
}
node, e := XMLToStruct(d, &el)
if e != nil {
return out, e
}
node.Name = typed.Name
slice = append(slice, node)
out.Children[name] = slice
case xml.EndElement:
if s != nil && s.Name.Local == typed.Name.Local { // matching end token
return out, nil
}
}
}
return out, nil
}
// StructToXML writes an XMLNode to a xml.Encoder as tokens.
func StructToXML(e *xml.Encoder, node *XMLNode, sorted bool) error {
e.EncodeToken(xml.StartElement{Name: node.Name, Attr: node.Attr})
if node.Text != "" {
e.EncodeToken(xml.CharData([]byte(node.Text)))
} else if sorted {
sortedNames := []string{}
for k := range node.Children {
sortedNames = append(sortedNames, k)
}
sort.Strings(sortedNames)
for _, k := range sortedNames {
for _, v := range node.Children[k] {
StructToXML(e, v, sorted)
}
}
} else {
for _, c := range node.Children {
for _, v := range c {
StructToXML(e, v, sorted)
}
}
}
e.EncodeToken(xml.EndElement{Name: node.Name})
return e.Flush()
}

View File

@@ -0,0 +1,43 @@
package v4_test
import (
"net/url"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/internal/test/unit"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/stretchr/testify/assert"
)
var _ = unit.Imported
func TestPresignHandler(t *testing.T) {
svc := s3.New(nil)
req, _ := svc.PutObjectRequest(&s3.PutObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
ContentDisposition: aws.String("a+b c$d"),
ACL: aws.String("public-read"),
})
req.Time = time.Unix(0, 0)
urlstr, err := req.Presign(5 * time.Minute)
assert.NoError(t, err)
expectedDate := "19700101T000000Z"
expectedHeaders := "host;x-amz-acl"
expectedSig := "7edcb4e3a1bf12f4989018d75acbe3a7f03df24bd6f3112602d59fc551f0e4e2"
expectedCred := "AKID/19700101/mock-region/s3/aws4_request"
u, _ := url.Parse(urlstr)
urlQ := u.Query()
assert.Equal(t, expectedSig, urlQ.Get("X-Amz-Signature"))
assert.Equal(t, expectedCred, urlQ.Get("X-Amz-Credential"))
assert.Equal(t, expectedHeaders, urlQ.Get("X-Amz-SignedHeaders"))
assert.Equal(t, expectedDate, urlQ.Get("X-Amz-Date"))
assert.Equal(t, "300", urlQ.Get("X-Amz-Expires"))
assert.NotContains(t, urlstr, "+") // + encoded as %20
}

View File

@@ -0,0 +1,364 @@
// Package v4 implements signing for AWS V4 signer
package v4
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/internal/protocol/rest"
)
const (
authHeaderPrefix = "AWS4-HMAC-SHA256"
timeFormat = "20060102T150405Z"
shortTimeFormat = "20060102"
)
var ignoredHeaders = map[string]bool{
"Authorization": true,
"Content-Type": true,
"Content-Length": true,
"User-Agent": true,
}
type signer struct {
Request *http.Request
Time time.Time
ExpireTime time.Duration
ServiceName string
Region string
CredValues credentials.Value
Credentials *credentials.Credentials
Query url.Values
Body io.ReadSeeker
Debug aws.LogLevelType
Logger aws.Logger
isPresign bool
formattedTime string
formattedShortTime string
signedHeaders string
canonicalHeaders string
canonicalString string
credentialString string
stringToSign string
signature string
authorization string
}
// Sign requests with signature version 4.
//
// Will sign the requests with the service config's Credentials object
// Signing is skipped if the credentials is the credentials.AnonymousCredentials
// object.
func Sign(req *aws.Request) {
// If the request does not need to be signed ignore the signing of the
// request if the AnonymousCredentials object is used.
if req.Service.Config.Credentials == credentials.AnonymousCredentials {
return
}
region := req.Service.SigningRegion
if region == "" {
region = aws.StringValue(req.Service.Config.Region)
}
name := req.Service.SigningName
if name == "" {
name = req.Service.ServiceName
}
s := signer{
Request: req.HTTPRequest,
Time: req.Time,
ExpireTime: req.ExpireTime,
Query: req.HTTPRequest.URL.Query(),
Body: req.Body,
ServiceName: name,
Region: region,
Credentials: req.Service.Config.Credentials,
Debug: req.Service.Config.LogLevel.Value(),
Logger: req.Service.Config.Logger,
}
req.Error = s.sign()
}
func (v4 *signer) sign() error {
if v4.ExpireTime != 0 {
v4.isPresign = true
}
if v4.isRequestSigned() {
if !v4.Credentials.IsExpired() {
// If the request is already signed, and the credentials have not
// expired yet ignore the signing request.
return nil
}
// The credentials have expired for this request. The current signing
// is invalid, and needs to be request because the request will fail.
if v4.isPresign {
v4.removePresign()
// Update the request's query string to ensure the values stays in
// sync in the case retrieving the new credentials fails.
v4.Request.URL.RawQuery = v4.Query.Encode()
}
}
var err error
v4.CredValues, err = v4.Credentials.Get()
if err != nil {
return err
}
if v4.isPresign {
v4.Query.Set("X-Amz-Algorithm", authHeaderPrefix)
if v4.CredValues.SessionToken != "" {
v4.Query.Set("X-Amz-Security-Token", v4.CredValues.SessionToken)
} else {
v4.Query.Del("X-Amz-Security-Token")
}
} else if v4.CredValues.SessionToken != "" {
v4.Request.Header.Set("X-Amz-Security-Token", v4.CredValues.SessionToken)
}
v4.build()
if v4.Debug.Matches(aws.LogDebugWithSigning) {
v4.logSigningInfo()
}
return nil
}
const logSignInfoMsg = `DEBUG: Request Signiture:
---[ CANONICAL STRING ]-----------------------------
%s
---[ STRING TO SIGN ]--------------------------------
%s%s
-----------------------------------------------------`
const logSignedURLMsg = `
---[ SIGNED URL ]------------------------------------
%s`
func (v4 *signer) logSigningInfo() {
signedURLMsg := ""
if v4.isPresign {
signedURLMsg = fmt.Sprintf(logSignedURLMsg, v4.Request.URL.String())
}
msg := fmt.Sprintf(logSignInfoMsg, v4.canonicalString, v4.stringToSign, signedURLMsg)
v4.Logger.Log(msg)
}
func (v4 *signer) build() {
v4.buildTime() // no depends
v4.buildCredentialString() // no depends
if v4.isPresign {
v4.buildQuery() // no depends
}
v4.buildCanonicalHeaders() // depends on cred string
v4.buildCanonicalString() // depends on canon headers / signed headers
v4.buildStringToSign() // depends on canon string
v4.buildSignature() // depends on string to sign
if v4.isPresign {
v4.Request.URL.RawQuery += "&X-Amz-Signature=" + v4.signature
} else {
parts := []string{
authHeaderPrefix + " Credential=" + v4.CredValues.AccessKeyID + "/" + v4.credentialString,
"SignedHeaders=" + v4.signedHeaders,
"Signature=" + v4.signature,
}
v4.Request.Header.Set("Authorization", strings.Join(parts, ", "))
}
}
func (v4 *signer) buildTime() {
v4.formattedTime = v4.Time.UTC().Format(timeFormat)
v4.formattedShortTime = v4.Time.UTC().Format(shortTimeFormat)
if v4.isPresign {
duration := int64(v4.ExpireTime / time.Second)
v4.Query.Set("X-Amz-Date", v4.formattedTime)
v4.Query.Set("X-Amz-Expires", strconv.FormatInt(duration, 10))
} else {
v4.Request.Header.Set("X-Amz-Date", v4.formattedTime)
}
}
func (v4 *signer) buildCredentialString() {
v4.credentialString = strings.Join([]string{
v4.formattedShortTime,
v4.Region,
v4.ServiceName,
"aws4_request",
}, "/")
if v4.isPresign {
v4.Query.Set("X-Amz-Credential", v4.CredValues.AccessKeyID+"/"+v4.credentialString)
}
}
func (v4 *signer) buildQuery() {
for k, h := range v4.Request.Header {
if strings.HasPrefix(http.CanonicalHeaderKey(k), "X-Amz-") {
continue // never hoist x-amz-* headers, they must be signed
}
if _, ok := ignoredHeaders[http.CanonicalHeaderKey(k)]; ok {
continue // never hoist ignored headers
}
v4.Request.Header.Del(k)
v4.Query.Del(k)
for _, v := range h {
v4.Query.Add(k, v)
}
}
}
func (v4 *signer) buildCanonicalHeaders() {
var headers []string
headers = append(headers, "host")
for k := range v4.Request.Header {
if _, ok := ignoredHeaders[http.CanonicalHeaderKey(k)]; ok {
continue // ignored header
}
headers = append(headers, strings.ToLower(k))
}
sort.Strings(headers)
v4.signedHeaders = strings.Join(headers, ";")
if v4.isPresign {
v4.Query.Set("X-Amz-SignedHeaders", v4.signedHeaders)
}
headerValues := make([]string, len(headers))
for i, k := range headers {
if k == "host" {
headerValues[i] = "host:" + v4.Request.URL.Host
} else {
headerValues[i] = k + ":" +
strings.Join(v4.Request.Header[http.CanonicalHeaderKey(k)], ",")
}
}
v4.canonicalHeaders = strings.Join(headerValues, "\n")
}
func (v4 *signer) buildCanonicalString() {
v4.Request.URL.RawQuery = strings.Replace(v4.Query.Encode(), "+", "%20", -1)
uri := v4.Request.URL.Opaque
if uri != "" {
uri = "/" + strings.Join(strings.Split(uri, "/")[3:], "/")
} else {
uri = v4.Request.URL.Path
}
if uri == "" {
uri = "/"
}
if v4.ServiceName != "s3" {
uri = rest.EscapePath(uri, false)
}
v4.canonicalString = strings.Join([]string{
v4.Request.Method,
uri,
v4.Request.URL.RawQuery,
v4.canonicalHeaders + "\n",
v4.signedHeaders,
v4.bodyDigest(),
}, "\n")
}
func (v4 *signer) buildStringToSign() {
v4.stringToSign = strings.Join([]string{
authHeaderPrefix,
v4.formattedTime,
v4.credentialString,
hex.EncodeToString(makeSha256([]byte(v4.canonicalString))),
}, "\n")
}
func (v4 *signer) buildSignature() {
secret := v4.CredValues.SecretAccessKey
date := makeHmac([]byte("AWS4"+secret), []byte(v4.formattedShortTime))
region := makeHmac(date, []byte(v4.Region))
service := makeHmac(region, []byte(v4.ServiceName))
credentials := makeHmac(service, []byte("aws4_request"))
signature := makeHmac(credentials, []byte(v4.stringToSign))
v4.signature = hex.EncodeToString(signature)
}
func (v4 *signer) bodyDigest() string {
hash := v4.Request.Header.Get("X-Amz-Content-Sha256")
if hash == "" {
if v4.isPresign && v4.ServiceName == "s3" {
hash = "UNSIGNED-PAYLOAD"
} else if v4.Body == nil {
hash = hex.EncodeToString(makeSha256([]byte{}))
} else {
hash = hex.EncodeToString(makeSha256Reader(v4.Body))
}
v4.Request.Header.Add("X-Amz-Content-Sha256", hash)
}
return hash
}
// isRequestSigned returns if the request is currently signed or presigned
func (v4 *signer) isRequestSigned() bool {
if v4.isPresign && v4.Query.Get("X-Amz-Signature") != "" {
return true
}
if v4.Request.Header.Get("Authorization") != "" {
return true
}
return false
}
// unsign removes signing flags for both signed and presigned requests.
func (v4 *signer) removePresign() {
v4.Query.Del("X-Amz-Algorithm")
v4.Query.Del("X-Amz-Signature")
v4.Query.Del("X-Amz-Security-Token")
v4.Query.Del("X-Amz-Date")
v4.Query.Del("X-Amz-Expires")
v4.Query.Del("X-Amz-Credential")
v4.Query.Del("X-Amz-SignedHeaders")
}
func makeHmac(key []byte, data []byte) []byte {
hash := hmac.New(sha256.New, key)
hash.Write(data)
return hash.Sum(nil)
}
func makeSha256(data []byte) []byte {
hash := sha256.New()
hash.Write(data)
return hash.Sum(nil)
}
func makeSha256Reader(reader io.ReadSeeker) []byte {
hash := sha256.New()
start, _ := reader.Seek(0, 1)
defer reader.Seek(start, 0)
io.Copy(hash, reader)
return hash.Sum(nil)
}

View File

@@ -0,0 +1,245 @@
package v4
import (
"net/http"
"strings"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/stretchr/testify/assert"
)
func buildSigner(serviceName string, region string, signTime time.Time, expireTime time.Duration, body string) signer {
endpoint := "https://" + serviceName + "." + region + ".amazonaws.com"
reader := strings.NewReader(body)
req, _ := http.NewRequest("POST", endpoint, reader)
req.URL.Opaque = "//example.org/bucket/key-._~,!@#$%^&*()"
req.Header.Add("X-Amz-Target", "prefix.Operation")
req.Header.Add("Content-Type", "application/x-amz-json-1.0")
req.Header.Add("Content-Length", string(len(body)))
req.Header.Add("X-Amz-Meta-Other-Header", "some-value=!@#$%^&* (+)")
return signer{
Request: req,
Time: signTime,
ExpireTime: expireTime,
Query: req.URL.Query(),
Body: reader,
ServiceName: serviceName,
Region: region,
Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"),
}
}
func removeWS(text string) string {
text = strings.Replace(text, " ", "", -1)
text = strings.Replace(text, "\n", "", -1)
text = strings.Replace(text, "\t", "", -1)
return text
}
func assertEqual(t *testing.T, expected, given string) {
if removeWS(expected) != removeWS(given) {
t.Errorf("\nExpected: %s\nGiven: %s", expected, given)
}
}
func TestPresignRequest(t *testing.T) {
signer := buildSigner("dynamodb", "us-east-1", time.Unix(0, 0), 300*time.Second, "{}")
signer.sign()
expectedDate := "19700101T000000Z"
expectedHeaders := "host;x-amz-meta-other-header;x-amz-target"
expectedSig := "5eeedebf6f995145ce56daa02902d10485246d3defb34f97b973c1f40ab82d36"
expectedCred := "AKID/19700101/us-east-1/dynamodb/aws4_request"
q := signer.Request.URL.Query()
assert.Equal(t, expectedSig, q.Get("X-Amz-Signature"))
assert.Equal(t, expectedCred, q.Get("X-Amz-Credential"))
assert.Equal(t, expectedHeaders, q.Get("X-Amz-SignedHeaders"))
assert.Equal(t, expectedDate, q.Get("X-Amz-Date"))
}
func TestSignRequest(t *testing.T) {
signer := buildSigner("dynamodb", "us-east-1", time.Unix(0, 0), 0, "{}")
signer.sign()
expectedDate := "19700101T000000Z"
expectedSig := "AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/dynamodb/aws4_request, SignedHeaders=host;x-amz-date;x-amz-meta-other-header;x-amz-security-token;x-amz-target, Signature=69ada33fec48180dab153576e4dd80c4e04124f80dda3eccfed8a67c2b91ed5e"
q := signer.Request.Header
assert.Equal(t, expectedSig, q.Get("Authorization"))
assert.Equal(t, expectedDate, q.Get("X-Amz-Date"))
}
func TestSignEmptyBody(t *testing.T) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 0, "")
signer.Body = nil
signer.sign()
hash := signer.Request.Header.Get("X-Amz-Content-Sha256")
assert.Equal(t, "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", hash)
}
func TestSignBody(t *testing.T) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 0, "hello")
signer.sign()
hash := signer.Request.Header.Get("X-Amz-Content-Sha256")
assert.Equal(t, "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", hash)
}
func TestSignSeekedBody(t *testing.T) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 0, " hello")
signer.Body.Read(make([]byte, 3)) // consume first 3 bytes so body is now "hello"
signer.sign()
hash := signer.Request.Header.Get("X-Amz-Content-Sha256")
assert.Equal(t, "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", hash)
start, _ := signer.Body.Seek(0, 1)
assert.Equal(t, int64(3), start)
}
func TestPresignEmptyBodyS3(t *testing.T) {
signer := buildSigner("s3", "us-east-1", time.Now(), 5*time.Minute, "hello")
signer.sign()
hash := signer.Request.Header.Get("X-Amz-Content-Sha256")
assert.Equal(t, "UNSIGNED-PAYLOAD", hash)
}
func TestSignPrecomputedBodyChecksum(t *testing.T) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 0, "hello")
signer.Request.Header.Set("X-Amz-Content-Sha256", "PRECOMPUTED")
signer.sign()
hash := signer.Request.Header.Get("X-Amz-Content-Sha256")
assert.Equal(t, "PRECOMPUTED", hash)
}
func TestAnonymousCredentials(t *testing.T) {
r := aws.NewRequest(
aws.NewService(&aws.Config{Credentials: credentials.AnonymousCredentials}),
&aws.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
Sign(r)
urlQ := r.HTTPRequest.URL.Query()
assert.Empty(t, urlQ.Get("X-Amz-Signature"))
assert.Empty(t, urlQ.Get("X-Amz-Credential"))
assert.Empty(t, urlQ.Get("X-Amz-SignedHeaders"))
assert.Empty(t, urlQ.Get("X-Amz-Date"))
hQ := r.HTTPRequest.Header
assert.Empty(t, hQ.Get("Authorization"))
assert.Empty(t, hQ.Get("X-Amz-Date"))
}
func TestIgnoreResignRequestWithValidCreds(t *testing.T) {
r := aws.NewRequest(
aws.NewService(&aws.Config{
Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"),
Region: aws.String("us-west-2"),
}),
&aws.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
Sign(r)
sig := r.HTTPRequest.Header.Get("Authorization")
Sign(r)
assert.Equal(t, sig, r.HTTPRequest.Header.Get("Authorization"))
}
func TestIgnorePreResignRequestWithValidCreds(t *testing.T) {
r := aws.NewRequest(
aws.NewService(&aws.Config{
Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"),
Region: aws.String("us-west-2"),
}),
&aws.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
r.ExpireTime = time.Minute * 10
Sign(r)
sig := r.HTTPRequest.Header.Get("X-Amz-Signature")
Sign(r)
assert.Equal(t, sig, r.HTTPRequest.Header.Get("X-Amz-Signature"))
}
func TestResignRequestExpiredCreds(t *testing.T) {
creds := credentials.NewStaticCredentials("AKID", "SECRET", "SESSION")
r := aws.NewRequest(
aws.NewService(&aws.Config{Credentials: creds}),
&aws.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
Sign(r)
querySig := r.HTTPRequest.Header.Get("Authorization")
creds.Expire()
Sign(r)
assert.NotEqual(t, querySig, r.HTTPRequest.Header.Get("Authorization"))
}
func TestPreResignRequestExpiredCreds(t *testing.T) {
provider := &credentials.StaticProvider{credentials.Value{"AKID", "SECRET", "SESSION"}}
creds := credentials.NewCredentials(provider)
r := aws.NewRequest(
aws.NewService(&aws.Config{Credentials: creds}),
&aws.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
r.ExpireTime = time.Minute * 10
Sign(r)
querySig := r.HTTPRequest.URL.Query().Get("X-Amz-Signature")
creds.Expire()
r.Time = time.Now().Add(time.Hour * 48)
Sign(r)
assert.NotEqual(t, querySig, r.HTTPRequest.URL.Query().Get("X-Amz-Signature"))
}
func BenchmarkPresignRequest(b *testing.B) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 300*time.Second, "{}")
for i := 0; i < b.N; i++ {
signer.sign()
}
}
func BenchmarkSignRequest(b *testing.B) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 0, "{}")
for i := 0; i < b.N; i++ {
signer.sign()
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,62 @@
// THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT.
// Package cloudwatchiface provides an interface for the Amazon CloudWatch.
package cloudwatchiface
import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/cloudwatch"
)
// CloudWatchAPI is the interface type for cloudwatch.CloudWatch.
type CloudWatchAPI interface {
DeleteAlarmsRequest(*cloudwatch.DeleteAlarmsInput) (*aws.Request, *cloudwatch.DeleteAlarmsOutput)
DeleteAlarms(*cloudwatch.DeleteAlarmsInput) (*cloudwatch.DeleteAlarmsOutput, error)
DescribeAlarmHistoryRequest(*cloudwatch.DescribeAlarmHistoryInput) (*aws.Request, *cloudwatch.DescribeAlarmHistoryOutput)
DescribeAlarmHistory(*cloudwatch.DescribeAlarmHistoryInput) (*cloudwatch.DescribeAlarmHistoryOutput, error)
DescribeAlarmHistoryPages(*cloudwatch.DescribeAlarmHistoryInput, func(*cloudwatch.DescribeAlarmHistoryOutput, bool) bool) error
DescribeAlarmsRequest(*cloudwatch.DescribeAlarmsInput) (*aws.Request, *cloudwatch.DescribeAlarmsOutput)
DescribeAlarms(*cloudwatch.DescribeAlarmsInput) (*cloudwatch.DescribeAlarmsOutput, error)
DescribeAlarmsPages(*cloudwatch.DescribeAlarmsInput, func(*cloudwatch.DescribeAlarmsOutput, bool) bool) error
DescribeAlarmsForMetricRequest(*cloudwatch.DescribeAlarmsForMetricInput) (*aws.Request, *cloudwatch.DescribeAlarmsForMetricOutput)
DescribeAlarmsForMetric(*cloudwatch.DescribeAlarmsForMetricInput) (*cloudwatch.DescribeAlarmsForMetricOutput, error)
DisableAlarmActionsRequest(*cloudwatch.DisableAlarmActionsInput) (*aws.Request, *cloudwatch.DisableAlarmActionsOutput)
DisableAlarmActions(*cloudwatch.DisableAlarmActionsInput) (*cloudwatch.DisableAlarmActionsOutput, error)
EnableAlarmActionsRequest(*cloudwatch.EnableAlarmActionsInput) (*aws.Request, *cloudwatch.EnableAlarmActionsOutput)
EnableAlarmActions(*cloudwatch.EnableAlarmActionsInput) (*cloudwatch.EnableAlarmActionsOutput, error)
GetMetricStatisticsRequest(*cloudwatch.GetMetricStatisticsInput) (*aws.Request, *cloudwatch.GetMetricStatisticsOutput)
GetMetricStatistics(*cloudwatch.GetMetricStatisticsInput) (*cloudwatch.GetMetricStatisticsOutput, error)
ListMetricsRequest(*cloudwatch.ListMetricsInput) (*aws.Request, *cloudwatch.ListMetricsOutput)
ListMetrics(*cloudwatch.ListMetricsInput) (*cloudwatch.ListMetricsOutput, error)
ListMetricsPages(*cloudwatch.ListMetricsInput, func(*cloudwatch.ListMetricsOutput, bool) bool) error
PutMetricAlarmRequest(*cloudwatch.PutMetricAlarmInput) (*aws.Request, *cloudwatch.PutMetricAlarmOutput)
PutMetricAlarm(*cloudwatch.PutMetricAlarmInput) (*cloudwatch.PutMetricAlarmOutput, error)
PutMetricDataRequest(*cloudwatch.PutMetricDataInput) (*aws.Request, *cloudwatch.PutMetricDataOutput)
PutMetricData(*cloudwatch.PutMetricDataInput) (*cloudwatch.PutMetricDataOutput, error)
SetAlarmStateRequest(*cloudwatch.SetAlarmStateInput) (*aws.Request, *cloudwatch.SetAlarmStateOutput)
SetAlarmState(*cloudwatch.SetAlarmStateInput) (*cloudwatch.SetAlarmStateOutput, error)
}

View File

@@ -0,0 +1,15 @@
// THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT.
package cloudwatchiface_test
import (
"testing"
"github.com/aws/aws-sdk-go/service/cloudwatch"
"github.com/aws/aws-sdk-go/service/cloudwatch/cloudwatchiface"
"github.com/stretchr/testify/assert"
)
func TestInterface(t *testing.T) {
assert.Implements(t, (*cloudwatchiface.CloudWatchAPI)(nil), cloudwatch.New(nil))
}

View File

@@ -0,0 +1,426 @@
// THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT.
package cloudwatch_test
import (
"bytes"
"fmt"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/service/cloudwatch"
)
var _ time.Duration
var _ bytes.Buffer
func ExampleCloudWatch_DeleteAlarms() {
svc := cloudwatch.New(nil)
params := &cloudwatch.DeleteAlarmsInput{
AlarmNames: []*string{ // Required
aws.String("AlarmName"), // Required
// More values...
},
}
resp, err := svc.DeleteAlarms(params)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
// Generic AWS error with Code, Message, and original error (if any)
fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
if reqErr, ok := err.(awserr.RequestFailure); ok {
// A service error occurred
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
}
// Pretty-print the response data.
fmt.Println(awsutil.Prettify(resp))
}
func ExampleCloudWatch_DescribeAlarmHistory() {
svc := cloudwatch.New(nil)
params := &cloudwatch.DescribeAlarmHistoryInput{
AlarmName: aws.String("AlarmName"),
EndDate: aws.Time(time.Now()),
HistoryItemType: aws.String("HistoryItemType"),
MaxRecords: aws.Int64(1),
NextToken: aws.String("NextToken"),
StartDate: aws.Time(time.Now()),
}
resp, err := svc.DescribeAlarmHistory(params)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
// Generic AWS error with Code, Message, and original error (if any)
fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
if reqErr, ok := err.(awserr.RequestFailure); ok {
// A service error occurred
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
}
// Pretty-print the response data.
fmt.Println(awsutil.Prettify(resp))
}
func ExampleCloudWatch_DescribeAlarms() {
svc := cloudwatch.New(nil)
params := &cloudwatch.DescribeAlarmsInput{
ActionPrefix: aws.String("ActionPrefix"),
AlarmNamePrefix: aws.String("AlarmNamePrefix"),
AlarmNames: []*string{
aws.String("AlarmName"), // Required
// More values...
},
MaxRecords: aws.Int64(1),
NextToken: aws.String("NextToken"),
StateValue: aws.String("StateValue"),
}
resp, err := svc.DescribeAlarms(params)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
// Generic AWS error with Code, Message, and original error (if any)
fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
if reqErr, ok := err.(awserr.RequestFailure); ok {
// A service error occurred
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
}
// Pretty-print the response data.
fmt.Println(awsutil.Prettify(resp))
}
func ExampleCloudWatch_DescribeAlarmsForMetric() {
svc := cloudwatch.New(nil)
params := &cloudwatch.DescribeAlarmsForMetricInput{
MetricName: aws.String("MetricName"), // Required
Namespace: aws.String("Namespace"), // Required
Dimensions: []*cloudwatch.Dimension{
{ // Required
Name: aws.String("DimensionName"), // Required
Value: aws.String("DimensionValue"), // Required
},
// More values...
},
Period: aws.Int64(1),
Statistic: aws.String("Statistic"),
Unit: aws.String("StandardUnit"),
}
resp, err := svc.DescribeAlarmsForMetric(params)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
// Generic AWS error with Code, Message, and original error (if any)
fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
if reqErr, ok := err.(awserr.RequestFailure); ok {
// A service error occurred
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
}
// Pretty-print the response data.
fmt.Println(awsutil.Prettify(resp))
}
func ExampleCloudWatch_DisableAlarmActions() {
svc := cloudwatch.New(nil)
params := &cloudwatch.DisableAlarmActionsInput{
AlarmNames: []*string{ // Required
aws.String("AlarmName"), // Required
// More values...
},
}
resp, err := svc.DisableAlarmActions(params)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
// Generic AWS error with Code, Message, and original error (if any)
fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
if reqErr, ok := err.(awserr.RequestFailure); ok {
// A service error occurred
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
}
// Pretty-print the response data.
fmt.Println(awsutil.Prettify(resp))
}
func ExampleCloudWatch_EnableAlarmActions() {
svc := cloudwatch.New(nil)
params := &cloudwatch.EnableAlarmActionsInput{
AlarmNames: []*string{ // Required
aws.String("AlarmName"), // Required
// More values...
},
}
resp, err := svc.EnableAlarmActions(params)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
// Generic AWS error with Code, Message, and original error (if any)
fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
if reqErr, ok := err.(awserr.RequestFailure); ok {
// A service error occurred
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
}
// Pretty-print the response data.
fmt.Println(awsutil.Prettify(resp))
}
func ExampleCloudWatch_GetMetricStatistics() {
svc := cloudwatch.New(nil)
params := &cloudwatch.GetMetricStatisticsInput{
EndTime: aws.Time(time.Now()), // Required
MetricName: aws.String("MetricName"), // Required
Namespace: aws.String("Namespace"), // Required
Period: aws.Int64(1), // Required
StartTime: aws.Time(time.Now()), // Required
Statistics: []*string{ // Required
aws.String("Statistic"), // Required
// More values...
},
Dimensions: []*cloudwatch.Dimension{
{ // Required
Name: aws.String("DimensionName"), // Required
Value: aws.String("DimensionValue"), // Required
},
// More values...
},
Unit: aws.String("StandardUnit"),
}
resp, err := svc.GetMetricStatistics(params)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
// Generic AWS error with Code, Message, and original error (if any)
fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
if reqErr, ok := err.(awserr.RequestFailure); ok {
// A service error occurred
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
}
// Pretty-print the response data.
fmt.Println(awsutil.Prettify(resp))
}
func ExampleCloudWatch_ListMetrics() {
svc := cloudwatch.New(nil)
params := &cloudwatch.ListMetricsInput{
Dimensions: []*cloudwatch.DimensionFilter{
{ // Required
Name: aws.String("DimensionName"), // Required
Value: aws.String("DimensionValue"),
},
// More values...
},
MetricName: aws.String("MetricName"),
Namespace: aws.String("Namespace"),
NextToken: aws.String("NextToken"),
}
resp, err := svc.ListMetrics(params)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
// Generic AWS error with Code, Message, and original error (if any)
fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
if reqErr, ok := err.(awserr.RequestFailure); ok {
// A service error occurred
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
}
// Pretty-print the response data.
fmt.Println(awsutil.Prettify(resp))
}
func ExampleCloudWatch_PutMetricAlarm() {
svc := cloudwatch.New(nil)
params := &cloudwatch.PutMetricAlarmInput{
AlarmName: aws.String("AlarmName"), // Required
ComparisonOperator: aws.String("ComparisonOperator"), // Required
EvaluationPeriods: aws.Int64(1), // Required
MetricName: aws.String("MetricName"), // Required
Namespace: aws.String("Namespace"), // Required
Period: aws.Int64(1), // Required
Statistic: aws.String("Statistic"), // Required
Threshold: aws.Float64(1.0), // Required
ActionsEnabled: aws.Bool(true),
AlarmActions: []*string{
aws.String("ResourceName"), // Required
// More values...
},
AlarmDescription: aws.String("AlarmDescription"),
Dimensions: []*cloudwatch.Dimension{
{ // Required
Name: aws.String("DimensionName"), // Required
Value: aws.String("DimensionValue"), // Required
},
// More values...
},
InsufficientDataActions: []*string{
aws.String("ResourceName"), // Required
// More values...
},
OKActions: []*string{
aws.String("ResourceName"), // Required
// More values...
},
Unit: aws.String("StandardUnit"),
}
resp, err := svc.PutMetricAlarm(params)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
// Generic AWS error with Code, Message, and original error (if any)
fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
if reqErr, ok := err.(awserr.RequestFailure); ok {
// A service error occurred
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
}
// Pretty-print the response data.
fmt.Println(awsutil.Prettify(resp))
}
func ExampleCloudWatch_PutMetricData() {
svc := cloudwatch.New(nil)
params := &cloudwatch.PutMetricDataInput{
MetricData: []*cloudwatch.MetricDatum{ // Required
{ // Required
MetricName: aws.String("MetricName"), // Required
Dimensions: []*cloudwatch.Dimension{
{ // Required
Name: aws.String("DimensionName"), // Required
Value: aws.String("DimensionValue"), // Required
},
// More values...
},
StatisticValues: &cloudwatch.StatisticSet{
Maximum: aws.Float64(1.0), // Required
Minimum: aws.Float64(1.0), // Required
SampleCount: aws.Float64(1.0), // Required
Sum: aws.Float64(1.0), // Required
},
Timestamp: aws.Time(time.Now()),
Unit: aws.String("StandardUnit"),
Value: aws.Float64(1.0),
},
// More values...
},
Namespace: aws.String("Namespace"), // Required
}
resp, err := svc.PutMetricData(params)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
// Generic AWS error with Code, Message, and original error (if any)
fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
if reqErr, ok := err.(awserr.RequestFailure); ok {
// A service error occurred
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
}
// Pretty-print the response data.
fmt.Println(awsutil.Prettify(resp))
}
func ExampleCloudWatch_SetAlarmState() {
svc := cloudwatch.New(nil)
params := &cloudwatch.SetAlarmStateInput{
AlarmName: aws.String("AlarmName"), // Required
StateReason: aws.String("StateReason"), // Required
StateValue: aws.String("StateValue"), // Required
StateReasonData: aws.String("StateReasonData"),
}
resp, err := svc.SetAlarmState(params)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
// Generic AWS error with Code, Message, and original error (if any)
fmt.Println(awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
if reqErr, ok := err.(awserr.RequestFailure); ok {
// A service error occurred
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
}
// Pretty-print the response data.
fmt.Println(awsutil.Prettify(resp))
}

View File

@@ -0,0 +1,96 @@
// THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT.
package cloudwatch
import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/internal/protocol/query"
"github.com/aws/aws-sdk-go/internal/signer/v4"
)
// This is the Amazon CloudWatch API Reference. This guide provides detailed
// information about Amazon CloudWatch actions, data types, parameters, and
// errors. For detailed information about Amazon CloudWatch features and their
// associated API calls, go to the Amazon CloudWatch Developer Guide (http://docs.aws.amazon.com/AmazonCloudWatch/latest/DeveloperGuide).
//
// Amazon CloudWatch is a web service that enables you to publish, monitor,
// and manage various metrics, as well as configure alarm actions based on data
// from metrics. For more information about this product go to http://aws.amazon.com/cloudwatch
// (http://aws.amazon.com/cloudwatch).
//
// For information about the namespace, metric names, and dimensions that
// other Amazon Web Services products use to send metrics to Cloudwatch, go
// to Amazon CloudWatch Metrics, Namespaces, and Dimensions Reference (http://docs.aws.amazon.com/AmazonCloudWatch/latest/DeveloperGuide/CW_Support_For_AWS.html)
// in the Amazon CloudWatch Developer Guide.
//
// Use the following links to get started using the Amazon CloudWatch API Reference:
//
// Actions (http://docs.aws.amazon.com/AmazonCloudWatch/latest/APIReference/API_Operations.html):
// An alphabetical list of all Amazon CloudWatch actions. Data Types (http://docs.aws.amazon.com/AmazonCloudWatch/latest/APIReference/API_Types.html):
// An alphabetical list of all Amazon CloudWatch data types. Common Parameters
// (http://docs.aws.amazon.com/AmazonCloudWatch/latest/APIReference/CommonParameters.html):
// Parameters that all Query actions can use. Common Errors (http://docs.aws.amazon.com/AmazonCloudWatch/latest/APIReference/CommonErrors.html):
// Client and server errors that all actions can return. Regions and Endpoints
// (http://docs.aws.amazon.com/general/latest/gr/index.html?rande.html): Itemized
// regions and endpoints for all AWS products. WSDL Location (http://monitoring.amazonaws.com/doc/2010-08-01/CloudWatch.wsdl):
// http://monitoring.amazonaws.com/doc/2010-08-01/CloudWatch.wsdl In addition
// to using the Amazon CloudWatch API, you can also use the following SDKs and
// third-party libraries to access Amazon CloudWatch programmatically.
//
// AWS SDK for Java Documentation (http://aws.amazon.com/documentation/sdkforjava/)
// AWS SDK for .NET Documentation (http://aws.amazon.com/documentation/sdkfornet/)
// AWS SDK for PHP Documentation (http://aws.amazon.com/documentation/sdkforphp/)
// AWS SDK for Ruby Documentation (http://aws.amazon.com/documentation/sdkforruby/)
// Developers in the AWS developer community also provide their own libraries,
// which you can find at the following AWS developer centers:
//
// AWS Java Developer Center (http://aws.amazon.com/java/) AWS PHP Developer
// Center (http://aws.amazon.com/php/) AWS Python Developer Center (http://aws.amazon.com/python/)
// AWS Ruby Developer Center (http://aws.amazon.com/ruby/) AWS Windows and .NET
// Developer Center (http://aws.amazon.com/net/)
type CloudWatch struct {
*aws.Service
}
// Used for custom service initialization logic
var initService func(*aws.Service)
// Used for custom request initialization logic
var initRequest func(*aws.Request)
// New returns a new CloudWatch client.
func New(config *aws.Config) *CloudWatch {
service := &aws.Service{
Config: aws.DefaultConfig.Merge(config),
ServiceName: "monitoring",
APIVersion: "2010-08-01",
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(query.Build)
service.Handlers.Unmarshal.PushBack(query.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(query.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(query.UnmarshalError)
// Run custom service initialization if present
if initService != nil {
initService(service)
}
return &CloudWatch{service}
}
// newRequest creates a new request for a CloudWatch operation and runs any
// custom request initialization.
func (c *CloudWatch) newRequest(op *aws.Operation, params, data interface{}) *aws.Request {
req := aws.NewRequest(c.Service, op, params, data)
// Run custom request initialization if present
if initRequest != nil {
initRequest(req)
}
return req
}

View File

@@ -0,0 +1,136 @@
// Copyright (c) 2015 Dave Collins <dave@davec.name>
//
// Permission to use, copy, modify, and distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
// NOTE: Due to the following build constraints, this file will only be compiled
// when the code is not running on Google App Engine and "-tags disableunsafe"
// is not added to the go build command line.
// +build !appengine,!disableunsafe
package spew
import (
"reflect"
"unsafe"
)
const (
// UnsafeDisabled is a build-time constant which specifies whether or
// not access to the unsafe package is available.
UnsafeDisabled = false
// ptrSize is the size of a pointer on the current arch.
ptrSize = unsafe.Sizeof((*byte)(nil))
)
var (
// offsetPtr, offsetScalar, and offsetFlag are the offsets for the
// internal reflect.Value fields. These values are valid before golang
// commit ecccf07e7f9d which changed the format. The are also valid
// after commit 82f48826c6c7 which changed the format again to mirror
// the original format. Code in the init function updates these offsets
// as necessary.
offsetPtr = uintptr(ptrSize)
offsetScalar = uintptr(0)
offsetFlag = uintptr(ptrSize * 2)
// flagKindWidth and flagKindShift indicate various bits that the
// reflect package uses internally to track kind information.
//
// flagRO indicates whether or not the value field of a reflect.Value is
// read-only.
//
// flagIndir indicates whether the value field of a reflect.Value is
// the actual data or a pointer to the data.
//
// These values are valid before golang commit 90a7c3c86944 which
// changed their positions. Code in the init function updates these
// flags as necessary.
flagKindWidth = uintptr(5)
flagKindShift = uintptr(flagKindWidth - 1)
flagRO = uintptr(1 << 0)
flagIndir = uintptr(1 << 1)
)
func init() {
// Older versions of reflect.Value stored small integers directly in the
// ptr field (which is named val in the older versions). Versions
// between commits ecccf07e7f9d and 82f48826c6c7 added a new field named
// scalar for this purpose which unfortunately came before the flag
// field, so the offset of the flag field is different for those
// versions.
//
// This code constructs a new reflect.Value from a known small integer
// and checks if the size of the reflect.Value struct indicates it has
// the scalar field. When it does, the offsets are updated accordingly.
vv := reflect.ValueOf(0xf00)
if unsafe.Sizeof(vv) == (ptrSize * 4) {
offsetScalar = ptrSize * 2
offsetFlag = ptrSize * 3
}
// Commit 90a7c3c86944 changed the flag positions such that the low
// order bits are the kind. This code extracts the kind from the flags
// field and ensures it's the correct type. When it's not, the flag
// order has been changed to the newer format, so the flags are updated
// accordingly.
upf := unsafe.Pointer(uintptr(unsafe.Pointer(&vv)) + offsetFlag)
upfv := *(*uintptr)(upf)
flagKindMask := uintptr((1<<flagKindWidth - 1) << flagKindShift)
if (upfv&flagKindMask)>>flagKindShift != uintptr(reflect.Int) {
flagKindShift = 0
flagRO = 1 << 5
flagIndir = 1 << 6
}
}
// unsafeReflectValue converts the passed reflect.Value into a one that bypasses
// the typical safety restrictions preventing access to unaddressable and
// unexported data. It works by digging the raw pointer to the underlying
// value out of the protected value and generating a new unprotected (unsafe)
// reflect.Value to it.
//
// This allows us to check for implementations of the Stringer and error
// interfaces to be used for pretty printing ordinarily unaddressable and
// inaccessible values such as unexported struct fields.
func unsafeReflectValue(v reflect.Value) (rv reflect.Value) {
indirects := 1
vt := v.Type()
upv := unsafe.Pointer(uintptr(unsafe.Pointer(&v)) + offsetPtr)
rvf := *(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&v)) + offsetFlag))
if rvf&flagIndir != 0 {
vt = reflect.PtrTo(v.Type())
indirects++
} else if offsetScalar != 0 {
// The value is in the scalar field when it's not one of the
// reference types.
switch vt.Kind() {
case reflect.Uintptr:
case reflect.Chan:
case reflect.Func:
case reflect.Map:
case reflect.Ptr:
case reflect.UnsafePointer:
default:
upv = unsafe.Pointer(uintptr(unsafe.Pointer(&v)) +
offsetScalar)
}
}
pv := reflect.NewAt(vt, upv)
rv = pv
for i := 0; i < indirects; i++ {
rv = rv.Elem()
}
return rv
}

View File

@@ -0,0 +1,37 @@
// Copyright (c) 2015 Dave Collins <dave@davec.name>
//
// Permission to use, copy, modify, and distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
// NOTE: Due to the following build constraints, this file will only be compiled
// when either the code is running on Google App Engine or "-tags disableunsafe"
// is added to the go build command line.
// +build appengine disableunsafe
package spew
import "reflect"
const (
// UnsafeDisabled is a build-time constant which specifies whether or
// not access to the unsafe package is available.
UnsafeDisabled = true
)
// unsafeReflectValue typically converts the passed reflect.Value into a one
// that bypasses the typical safety restrictions preventing access to
// unaddressable and unexported data. However, doing this relies on access to
// the unsafe package. This is a stub version which simply returns the passed
// reflect.Value when the unsafe package is not available.
func unsafeReflectValue(v reflect.Value) reflect.Value {
return v
}

View File

@@ -0,0 +1,341 @@
/*
* Copyright (c) 2013 Dave Collins <dave@davec.name>
*
* Permission to use, copy, modify, and distribute this software for any
* purpose with or without fee is hereby granted, provided that the above
* copyright notice and this permission notice appear in all copies.
*
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
package spew
import (
"bytes"
"fmt"
"io"
"reflect"
"sort"
"strconv"
)
// Some constants in the form of bytes to avoid string overhead. This mirrors
// the technique used in the fmt package.
var (
panicBytes = []byte("(PANIC=")
plusBytes = []byte("+")
iBytes = []byte("i")
trueBytes = []byte("true")
falseBytes = []byte("false")
interfaceBytes = []byte("(interface {})")
commaNewlineBytes = []byte(",\n")
newlineBytes = []byte("\n")
openBraceBytes = []byte("{")
openBraceNewlineBytes = []byte("{\n")
closeBraceBytes = []byte("}")
asteriskBytes = []byte("*")
colonBytes = []byte(":")
colonSpaceBytes = []byte(": ")
openParenBytes = []byte("(")
closeParenBytes = []byte(")")
spaceBytes = []byte(" ")
pointerChainBytes = []byte("->")
nilAngleBytes = []byte("<nil>")
maxNewlineBytes = []byte("<max depth reached>\n")
maxShortBytes = []byte("<max>")
circularBytes = []byte("<already shown>")
circularShortBytes = []byte("<shown>")
invalidAngleBytes = []byte("<invalid>")
openBracketBytes = []byte("[")
closeBracketBytes = []byte("]")
percentBytes = []byte("%")
precisionBytes = []byte(".")
openAngleBytes = []byte("<")
closeAngleBytes = []byte(">")
openMapBytes = []byte("map[")
closeMapBytes = []byte("]")
lenEqualsBytes = []byte("len=")
capEqualsBytes = []byte("cap=")
)
// hexDigits is used to map a decimal value to a hex digit.
var hexDigits = "0123456789abcdef"
// catchPanic handles any panics that might occur during the handleMethods
// calls.
func catchPanic(w io.Writer, v reflect.Value) {
if err := recover(); err != nil {
w.Write(panicBytes)
fmt.Fprintf(w, "%v", err)
w.Write(closeParenBytes)
}
}
// handleMethods attempts to call the Error and String methods on the underlying
// type the passed reflect.Value represents and outputes the result to Writer w.
//
// It handles panics in any called methods by catching and displaying the error
// as the formatted value.
func handleMethods(cs *ConfigState, w io.Writer, v reflect.Value) (handled bool) {
// We need an interface to check if the type implements the error or
// Stringer interface. However, the reflect package won't give us an
// interface on certain things like unexported struct fields in order
// to enforce visibility rules. We use unsafe, when it's available,
// to bypass these restrictions since this package does not mutate the
// values.
if !v.CanInterface() {
if UnsafeDisabled {
return false
}
v = unsafeReflectValue(v)
}
// Choose whether or not to do error and Stringer interface lookups against
// the base type or a pointer to the base type depending on settings.
// Technically calling one of these methods with a pointer receiver can
// mutate the value, however, types which choose to satisify an error or
// Stringer interface with a pointer receiver should not be mutating their
// state inside these interface methods.
if !cs.DisablePointerMethods && !UnsafeDisabled && !v.CanAddr() {
v = unsafeReflectValue(v)
}
if v.CanAddr() {
v = v.Addr()
}
// Is it an error or Stringer?
switch iface := v.Interface().(type) {
case error:
defer catchPanic(w, v)
if cs.ContinueOnMethod {
w.Write(openParenBytes)
w.Write([]byte(iface.Error()))
w.Write(closeParenBytes)
w.Write(spaceBytes)
return false
}
w.Write([]byte(iface.Error()))
return true
case fmt.Stringer:
defer catchPanic(w, v)
if cs.ContinueOnMethod {
w.Write(openParenBytes)
w.Write([]byte(iface.String()))
w.Write(closeParenBytes)
w.Write(spaceBytes)
return false
}
w.Write([]byte(iface.String()))
return true
}
return false
}
// printBool outputs a boolean value as true or false to Writer w.
func printBool(w io.Writer, val bool) {
if val {
w.Write(trueBytes)
} else {
w.Write(falseBytes)
}
}
// printInt outputs a signed integer value to Writer w.
func printInt(w io.Writer, val int64, base int) {
w.Write([]byte(strconv.FormatInt(val, base)))
}
// printUint outputs an unsigned integer value to Writer w.
func printUint(w io.Writer, val uint64, base int) {
w.Write([]byte(strconv.FormatUint(val, base)))
}
// printFloat outputs a floating point value using the specified precision,
// which is expected to be 32 or 64bit, to Writer w.
func printFloat(w io.Writer, val float64, precision int) {
w.Write([]byte(strconv.FormatFloat(val, 'g', -1, precision)))
}
// printComplex outputs a complex value using the specified float precision
// for the real and imaginary parts to Writer w.
func printComplex(w io.Writer, c complex128, floatPrecision int) {
r := real(c)
w.Write(openParenBytes)
w.Write([]byte(strconv.FormatFloat(r, 'g', -1, floatPrecision)))
i := imag(c)
if i >= 0 {
w.Write(plusBytes)
}
w.Write([]byte(strconv.FormatFloat(i, 'g', -1, floatPrecision)))
w.Write(iBytes)
w.Write(closeParenBytes)
}
// printHexPtr outputs a uintptr formatted as hexidecimal with a leading '0x'
// prefix to Writer w.
func printHexPtr(w io.Writer, p uintptr) {
// Null pointer.
num := uint64(p)
if num == 0 {
w.Write(nilAngleBytes)
return
}
// Max uint64 is 16 bytes in hex + 2 bytes for '0x' prefix
buf := make([]byte, 18)
// It's simpler to construct the hex string right to left.
base := uint64(16)
i := len(buf) - 1
for num >= base {
buf[i] = hexDigits[num%base]
num /= base
i--
}
buf[i] = hexDigits[num]
// Add '0x' prefix.
i--
buf[i] = 'x'
i--
buf[i] = '0'
// Strip unused leading bytes.
buf = buf[i:]
w.Write(buf)
}
// valuesSorter implements sort.Interface to allow a slice of reflect.Value
// elements to be sorted.
type valuesSorter struct {
values []reflect.Value
strings []string // either nil or same len and values
cs *ConfigState
}
// newValuesSorter initializes a valuesSorter instance, which holds a set of
// surrogate keys on which the data should be sorted. It uses flags in
// ConfigState to decide if and how to populate those surrogate keys.
func newValuesSorter(values []reflect.Value, cs *ConfigState) sort.Interface {
vs := &valuesSorter{values: values, cs: cs}
if canSortSimply(vs.values[0].Kind()) {
return vs
}
if !cs.DisableMethods {
vs.strings = make([]string, len(values))
for i := range vs.values {
b := bytes.Buffer{}
if !handleMethods(cs, &b, vs.values[i]) {
vs.strings = nil
break
}
vs.strings[i] = b.String()
}
}
if vs.strings == nil && cs.SpewKeys {
vs.strings = make([]string, len(values))
for i := range vs.values {
vs.strings[i] = Sprintf("%#v", vs.values[i].Interface())
}
}
return vs
}
// canSortSimply tests whether a reflect.Kind is a primitive that can be sorted
// directly, or whether it should be considered for sorting by surrogate keys
// (if the ConfigState allows it).
func canSortSimply(kind reflect.Kind) bool {
// This switch parallels valueSortLess, except for the default case.
switch kind {
case reflect.Bool:
return true
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
return true
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
return true
case reflect.Float32, reflect.Float64:
return true
case reflect.String:
return true
case reflect.Uintptr:
return true
case reflect.Array:
return true
}
return false
}
// Len returns the number of values in the slice. It is part of the
// sort.Interface implementation.
func (s *valuesSorter) Len() int {
return len(s.values)
}
// Swap swaps the values at the passed indices. It is part of the
// sort.Interface implementation.
func (s *valuesSorter) Swap(i, j int) {
s.values[i], s.values[j] = s.values[j], s.values[i]
if s.strings != nil {
s.strings[i], s.strings[j] = s.strings[j], s.strings[i]
}
}
// valueSortLess returns whether the first value should sort before the second
// value. It is used by valueSorter.Less as part of the sort.Interface
// implementation.
func valueSortLess(a, b reflect.Value) bool {
switch a.Kind() {
case reflect.Bool:
return !a.Bool() && b.Bool()
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
return a.Int() < b.Int()
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
return a.Uint() < b.Uint()
case reflect.Float32, reflect.Float64:
return a.Float() < b.Float()
case reflect.String:
return a.String() < b.String()
case reflect.Uintptr:
return a.Uint() < b.Uint()
case reflect.Array:
// Compare the contents of both arrays.
l := a.Len()
for i := 0; i < l; i++ {
av := a.Index(i)
bv := b.Index(i)
if av.Interface() == bv.Interface() {
continue
}
return valueSortLess(av, bv)
}
}
return a.String() < b.String()
}
// Less returns whether the value at index i should sort before the
// value at index j. It is part of the sort.Interface implementation.
func (s *valuesSorter) Less(i, j int) bool {
if s.strings == nil {
return valueSortLess(s.values[i], s.values[j])
}
return s.strings[i] < s.strings[j]
}
// sortValues is a sort function that handles both native types and any type that
// can be converted to error or Stringer. Other inputs are sorted according to
// their Value.String() value to ensure display stability.
func sortValues(values []reflect.Value, cs *ConfigState) {
if len(values) == 0 {
return
}
sort.Sort(newValuesSorter(values, cs))
}

Some files were not shown because too many files have changed in this diff Show More