Upgrading server dependency. (#8807)

This commit is contained in:
Christopher Speller
2018-05-18 07:32:31 -07:00
committed by GitHub
parent d3ead7dc85
commit d5e1f7e298
178 changed files with 20244 additions and 8861 deletions

View File

@@ -1,6 +1,7 @@
package imaging
import (
"bytes"
"errors"
"image"
"image/color"
@@ -46,6 +47,26 @@ func (f Format) String() string {
}
}
var formatFromExt = map[string]Format{
".jpg": JPEG,
".jpeg": JPEG,
".png": PNG,
".tif": TIFF,
".tiff": TIFF,
".bmp": BMP,
".gif": GIF,
}
// FormatFromFilename parses image format from filename extension:
// "jpg" (or "jpeg"), "png", "gif", "tif" (or "tiff") and "bmp" are supported.
func FormatFromFilename(filename string) (Format, error) {
ext := strings.ToLower(filepath.Ext(filename))
if f, ok := formatFromExt[ext]; ok {
return f, nil
}
return -1, ErrUnsupportedFormat
}
var (
// ErrUnsupportedFormat means the given image format (or file extension) is unsupported.
ErrUnsupportedFormat = errors.New("imaging: unsupported image format")
@@ -199,22 +220,10 @@ func Encode(w io.Writer, img image.Image, format Format, opts ...EncodeOption) e
// err := imaging.Save(img, "out.jpg", imaging.JPEGQuality(80))
//
func Save(img image.Image, filename string, opts ...EncodeOption) (err error) {
formats := map[string]Format{
".jpg": JPEG,
".jpeg": JPEG,
".png": PNG,
".tif": TIFF,
".tiff": TIFF,
".bmp": BMP,
".gif": GIF,
f, err := FormatFromFilename(filename)
if err != nil {
return err
}
ext := strings.ToLower(filepath.Ext(filename))
f, ok := formats[ext]
if !ok {
return ErrUnsupportedFormat
}
file, err := fs.Create(filename)
if err != nil {
return err
@@ -236,33 +245,16 @@ func New(width, height int, fillColor color.Color) *image.NRGBA {
return &image.NRGBA{}
}
dst := image.NewNRGBA(image.Rect(0, 0, width, height))
c := color.NRGBAModel.Convert(fillColor).(color.NRGBA)
if c.R == 0 && c.G == 0 && c.B == 0 && c.A == 0 {
return dst
if (c == color.NRGBA{0, 0, 0, 0}) {
return image.NewNRGBA(image.Rect(0, 0, width, height))
}
// Fill the first row.
i := 0
for x := 0; x < width; x++ {
dst.Pix[i+0] = c.R
dst.Pix[i+1] = c.G
dst.Pix[i+2] = c.B
dst.Pix[i+3] = c.A
i += 4
return &image.NRGBA{
Pix: bytes.Repeat([]byte{c.R, c.G, c.B, c.A}, width*height),
Stride: 4 * width,
Rect: image.Rect(0, 0, width, height),
}
// Copy the first row to other rows.
size := width * 4
parallel(1, height, func(ys <-chan int) {
for y := range ys {
i = y * dst.Stride
copy(dst.Pix[i:i+size], dst.Pix[0:size])
}
})
return dst
}
// Clone returns a copy of the given image.

View File

@@ -76,6 +76,14 @@ func Commaf(v float64) string {
return buf.String()
}
// CommafWithDigits works like the Commaf but limits the resulting
// string to the given number of decimal places.
//
// e.g. CommafWithDigits(834142.32, 1) -> 834,142.3
func CommafWithDigits(f float64, decimals int) string {
return stripTrailingDigits(Commaf(f), decimals)
}
// BigComma produces a string form of the given big.Int in base 10
// with commas after every three orders of magnitude.
func BigComma(b *big.Int) string {

View File

@@ -1,6 +1,9 @@
package humanize
import "strconv"
import (
"strconv"
"strings"
)
func stripTrailingZeros(s string) string {
offset := len(s) - 1
@@ -17,7 +20,27 @@ func stripTrailingZeros(s string) string {
return s[:offset+1]
}
func stripTrailingDigits(s string, digits int) string {
if i := strings.Index(s, "."); i >= 0 {
if digits <= 0 {
return s[:i]
}
i++
if i+digits >= len(s) {
return s
}
return s[:i+digits]
}
return s
}
// Ftoa converts a float to a string with no trailing zeros.
func Ftoa(num float64) string {
return stripTrailingZeros(strconv.FormatFloat(num, 'f', 6, 64))
}
// FtoaWithDigits converts a float to a string but limits the resulting string
// to the given number of decimal places, and no trailing zeros.
func FtoaWithDigits(num float64, digits int) string {
return stripTrailingZeros(stripTrailingDigits(strconv.FormatFloat(num, 'f', 6, 64), digits))
}

View File

@@ -93,6 +93,16 @@ func SI(input float64, unit string) string {
return Ftoa(value) + " " + prefix + unit
}
// SIWithDigits works like SI but limits the resulting string to the
// given number of decimal places.
//
// e.g. SIWithDigits(1000000, 0, "B") -> 1 MB
// e.g. SIWithDigits(2.2345e-12, 2, "F") -> 2.23 pF
func SIWithDigits(input float64, decimals int, unit string) string {
value, prefix := ComputeSI(input)
return FtoaWithDigits(value, decimals) + " " + prefix + unit
}
var errInvalid = errors.New("invalid input")
// ParseSI parses an SI string back into the number and unit.

View File

@@ -17,6 +17,7 @@ type Image struct {
Type string `json:"type"`
Width uint64 `json:"width"`
Height uint64 `json:"height"`
draft bool `json:"-"`
}
// Video defines Open Graph Video type
@@ -26,6 +27,7 @@ type Video struct {
Type string `json:"type"`
Width uint64 `json:"width"`
Height uint64 `json:"height"`
draft bool `json:"-"`
}
// Audio defines Open Graph Audio Type
@@ -33,6 +35,7 @@ type Audio struct {
URL string `json:"url"`
SecureURL string `json:"secure_url"`
Type string `json:"type"`
draft bool `json:"-"`
}
// Article contain Open Graph Article structure
@@ -133,6 +136,27 @@ func (og *OpenGraph) ProcessHTML(buffer io.Reader) error {
}
}
func (og *OpenGraph) ensureHasVideo() {
if len(og.Videos) > 0 {
return
}
og.Videos = append(og.Videos, &Video{draft: true})
}
func (og *OpenGraph) ensureHasImage() {
if len(og.Images) > 0 {
return
}
og.Images = append(og.Images, &Image{draft: true})
}
func (og *OpenGraph) ensureHasAudio() {
if len(og.Audios) > 0 {
return
}
og.Audios = append(og.Audios, &Audio{draft: true})
}
// ProcessMeta processes meta attributes and adds them to Open Graph structure if they are suitable for that
func (og *OpenGraph) ProcessMeta(metaAttrs map[string]string) {
switch metaAttrs["property"] {
@@ -160,61 +184,74 @@ func (og *OpenGraph) ProcessMeta(metaAttrs map[string]string) {
og.Locale = metaAttrs["content"]
case "og:locale:alternate":
og.LocalesAlternate = append(og.LocalesAlternate, metaAttrs["content"])
case "og:audio":
if len(og.Audios)>0 && og.Audios[len(og.Audios)-1].draft {
og.Audios[len(og.Audios)-1].URL = metaAttrs["content"]
og.Audios[len(og.Audios)-1].draft = false
} else {
og.Audios = append(og.Audios, &Audio{URL: metaAttrs["content"]})
}
case "og:audio:secure_url":
og.ensureHasAudio()
og.Audios[len(og.Audios)-1].SecureURL = metaAttrs["content"]
case "og:audio:type":
og.ensureHasAudio()
og.Audios[len(og.Audios)-1].Type = metaAttrs["content"]
case "og:image":
og.Images = append(og.Images, &Image{URL: metaAttrs["content"]})
case "og:image:url":
if len(og.Images) > 0 {
if len(og.Images)>0 && og.Images[len(og.Images)-1].draft {
og.Images[len(og.Images)-1].URL = metaAttrs["content"]
og.Images[len(og.Images)-1].draft = false
} else {
og.Images = append(og.Images, &Image{URL: metaAttrs["content"]})
}
case "og:image:url":
og.ensureHasImage()
og.Images[len(og.Images)-1].URL = metaAttrs["content"]
case "og:image:secure_url":
if len(og.Images) > 0 {
og.Images[len(og.Images)-1].SecureURL = metaAttrs["content"]
}
og.ensureHasImage()
og.Images[len(og.Images)-1].SecureURL = metaAttrs["content"]
case "og:image:type":
if len(og.Images) > 0 {
og.Images[len(og.Images)-1].Type = metaAttrs["content"]
}
og.ensureHasImage()
og.Images[len(og.Images)-1].Type = metaAttrs["content"]
case "og:image:width":
if len(og.Images) > 0 {
w, err := strconv.ParseUint(metaAttrs["content"], 10, 64)
if err == nil {
og.Images[len(og.Images)-1].Width = w
}
w, err := strconv.ParseUint(metaAttrs["content"], 10, 64)
if err == nil {
og.ensureHasImage()
og.Images[len(og.Images)-1].Width = w
}
case "og:image:height":
if len(og.Images) > 0 {
h, err := strconv.ParseUint(metaAttrs["content"], 10, 64)
if err == nil {
og.Images[len(og.Images)-1].Height = h
}
h, err := strconv.ParseUint(metaAttrs["content"], 10, 64)
if err == nil {
og.ensureHasImage()
og.Images[len(og.Images)-1].Height = h
}
case "og:video":
og.Videos = append(og.Videos, &Video{URL: metaAttrs["content"]})
case "og:video:url":
if len(og.Videos) > 0 {
if len(og.Videos)>0 && og.Videos[len(og.Videos)-1].draft {
og.Videos[len(og.Videos)-1].URL = metaAttrs["content"]
og.Videos[len(og.Videos)-1].draft = false
} else {
og.Videos = append(og.Videos, &Video{URL: metaAttrs["content"]})
}
case "og:video:url":
og.ensureHasVideo()
og.Videos[len(og.Videos)-1].URL = metaAttrs["content"]
case "og:video:secure_url":
if len(og.Videos) > 0 {
og.Videos[len(og.Videos)-1].SecureURL = metaAttrs["content"]
}
og.ensureHasVideo()
og.Videos[len(og.Videos)-1].SecureURL = metaAttrs["content"]
case "og:video:type":
if len(og.Videos) > 0 {
og.Videos[len(og.Videos)-1].Type = metaAttrs["content"]
}
og.ensureHasVideo()
og.Videos[len(og.Videos)-1].Type = metaAttrs["content"]
case "og:video:width":
if len(og.Videos) > 0 {
w, err := strconv.ParseUint(metaAttrs["content"], 10, 64)
if err == nil {
og.Videos[len(og.Videos)-1].Width = w
}
w, err := strconv.ParseUint(metaAttrs["content"], 10, 64)
if err == nil {
og.ensureHasVideo()
og.Videos[len(og.Videos)-1].Width = w
}
case "og:video:height":
if len(og.Videos) > 0 {
h, err := strconv.ParseUint(metaAttrs["content"], 10, 64)
if err == nil {
og.Videos[len(og.Videos)-1].Height = h
}
h, err := strconv.ParseUint(metaAttrs["content"], 10, 64)
if err == nil {
og.ensureHasVideo()
og.Videos[len(og.Videos)-1].Height = h
}
default:
if og.isArticle {

View File

@@ -32,7 +32,7 @@ const (
// Maximum allowed depth when recursively substituing variable names.
_DEPTH_VALUES = 99
_VERSION = "1.35.0"
_VERSION = "1.36.0"
)
// Version returns current package version literal.
@@ -140,6 +140,11 @@ type LoadOptions struct {
// AllowNestedValues indicates whether to allow AWS-like nested values.
// Docs: http://docs.aws.amazon.com/cli/latest/topic/config-vars.html#nested-values
AllowNestedValues bool
// AllowPythonMultilineValues indicates whether to allow Python-like multi-line values.
// Docs: https://docs.python.org/3/library/configparser.html#supported-ini-file-structure
// Relevant quote: Values can also span multiple lines, as long as they are indented deeper
// than the first line of the value.
AllowPythonMultilineValues bool
// UnescapeValueDoubleQuotes indicates whether to unescape double quotes inside value to regular format
// when value is surrounded by double quotes, e.g. key="a \"value\"" => key=a "value"
UnescapeValueDoubleQuotes bool

View File

@@ -19,11 +19,14 @@ import (
"bytes"
"fmt"
"io"
"regexp"
"strconv"
"strings"
"unicode"
)
var pythonMultiline = regexp.MustCompile("^(\\s+)([^\n]+)")
type tokenType int
const (
@@ -194,7 +197,8 @@ func hasSurroundedQuote(in string, quote byte) bool {
}
func (p *parser) readValue(in []byte,
ignoreContinuation, ignoreInlineComment, unescapeValueDoubleQuotes, unescapeValueCommentSymbols bool) (string, error) {
parserBufferSize int,
ignoreContinuation, ignoreInlineComment, unescapeValueDoubleQuotes, unescapeValueCommentSymbols, allowPythonMultilines bool) (string, error) {
line := strings.TrimLeftFunc(string(in), unicode.IsSpace)
if len(line) == 0 {
@@ -224,11 +228,13 @@ func (p *parser) readValue(in []byte,
return line[startIdx : pos+startIdx], nil
}
lastChar := line[len(line)-1]
// Won't be able to reach here if value only contains whitespace
line = strings.TrimSpace(line)
trimmedLastChar := line[len(line)-1]
// Check continuation lines when desired
if !ignoreContinuation && line[len(line)-1] == '\\' {
if !ignoreContinuation && trimmedLastChar == '\\' {
return p.readContinuationLines(line[:len(line)-1])
}
@@ -252,7 +258,50 @@ func (p *parser) readValue(in []byte,
if strings.Contains(line, `\#`) {
line = strings.Replace(line, `\#`, "#", -1)
}
} else if allowPythonMultilines && lastChar == '\n' {
parserBufferPeekResult, _ := p.buf.Peek(parserBufferSize)
peekBuffer := bytes.NewBuffer(parserBufferPeekResult)
identSize := -1
val := line
for {
peekData, peekErr := peekBuffer.ReadBytes('\n')
if peekErr != nil {
if peekErr == io.EOF {
return val, nil
}
return "", peekErr
}
peekMatches := pythonMultiline.FindStringSubmatch(string(peekData))
if len(peekMatches) != 3 {
return val, nil
}
currentIdentSize := len(peekMatches[1])
// NOTE: Return if not a python-ini multi-line value.
if currentIdentSize < 0 {
return val, nil
}
identSize = currentIdentSize
// NOTE: Just advance the parser reader (buffer) in-sync with the peek buffer.
_, err := p.readUntil('\n')
if err != nil {
return "", err
}
val += fmt.Sprintf("\n%s", peekMatches[2])
}
// NOTE: If it was a Python multi-line value,
// return the appended value.
if identSize > 0 {
return val, nil
}
}
return line, nil
}
@@ -276,6 +325,29 @@ func (f *File) parse(reader io.Reader) (err error) {
var line []byte
var inUnparseableSection bool
// NOTE: Iterate and increase `currentPeekSize` until
// the size of the parser buffer is found.
// TODO: When Golang 1.10 is the lowest version supported,
// replace with `parserBufferSize := p.buf.Size()`.
parserBufferSize := 0
// NOTE: Peek 1kb at a time.
currentPeekSize := 1024
if f.options.AllowPythonMultilineValues {
for {
peekBytes, _ := p.buf.Peek(currentPeekSize)
peekBytesLength := len(peekBytes)
if parserBufferSize >= peekBytesLength {
break
}
currentPeekSize *= 2
parserBufferSize = peekBytesLength
}
}
for !p.isEOF {
line, err = p.readUntil('\n')
if err != nil {
@@ -352,10 +424,12 @@ func (f *File) parse(reader io.Reader) (err error) {
// Treat as boolean key when desired, and whole line is key name.
if IsErrDelimiterNotFound(err) && f.options.AllowBooleanKeys {
kname, err := p.readValue(line,
parserBufferSize,
f.options.IgnoreContinuation,
f.options.IgnoreInlineComment,
f.options.UnescapeValueDoubleQuotes,
f.options.UnescapeValueCommentSymbols)
f.options.UnescapeValueCommentSymbols,
f.options.AllowPythonMultilineValues)
if err != nil {
return err
}
@@ -379,10 +453,12 @@ func (f *File) parse(reader io.Reader) (err error) {
}
value, err := p.readValue(line[offset:],
parserBufferSize,
f.options.IgnoreContinuation,
f.options.IgnoreInlineComment,
f.options.UnescapeValueDoubleQuotes,
f.options.UnescapeValueCommentSymbols)
f.options.UnescapeValueCommentSymbols,
f.options.AllowPythonMultilineValues)
if err != nil {
return err
}

View File

@@ -1,43 +0,0 @@
# Go support for Protocol Buffers - Google's data interchange format
#
# Copyright 2010 The Go Authors. All rights reserved.
# https://github.com/golang/protobuf
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
install:
go install
test: install generate-test-pbs
go test
generate-test-pbs:
make install
make -C testdata
protoc --go_out=Mtestdata/test.proto=github.com/golang/protobuf/proto/testdata,Mgoogle/protobuf/any.proto=github.com/golang/protobuf/ptypes/any:. proto3_proto/proto3.proto
make

View File

@@ -35,22 +35,39 @@
package proto
import (
"fmt"
"log"
"reflect"
"strings"
)
// Clone returns a deep copy of a protocol buffer.
func Clone(pb Message) Message {
in := reflect.ValueOf(pb)
func Clone(src Message) Message {
in := reflect.ValueOf(src)
if in.IsNil() {
return pb
return src
}
out := reflect.New(in.Type().Elem())
// out is empty so a merge is a deep copy.
mergeStruct(out.Elem(), in.Elem())
return out.Interface().(Message)
dst := out.Interface().(Message)
Merge(dst, src)
return dst
}
// Merger is the interface representing objects that can merge messages of the same type.
type Merger interface {
// Merge merges src into this message.
// Required and optional fields that are set in src will be set to that value in dst.
// Elements of repeated fields will be appended.
//
// Merge may panic if called with a different argument type than the receiver.
Merge(src Message)
}
// generatedMerger is the custom merge method that generated protos will have.
// We must add this method since a generate Merge method will conflict with
// many existing protos that have a Merge data field already defined.
type generatedMerger interface {
XXX_Merge(src Message)
}
// Merge merges src into dst.
@@ -58,17 +75,24 @@ func Clone(pb Message) Message {
// Elements of repeated fields will be appended.
// Merge panics if src and dst are not the same type, or if dst is nil.
func Merge(dst, src Message) {
if m, ok := dst.(Merger); ok {
m.Merge(src)
return
}
in := reflect.ValueOf(src)
out := reflect.ValueOf(dst)
if out.IsNil() {
panic("proto: nil destination")
}
if in.Type() != out.Type() {
// Explicit test prior to mergeStruct so that mistyped nils will fail
panic("proto: type mismatch")
panic(fmt.Sprintf("proto.Merge(%T, %T) type mismatch", dst, src))
}
if in.IsNil() {
// Merging nil into non-nil is a quiet no-op
return // Merge from nil src is a noop
}
if m, ok := dst.(generatedMerger); ok {
m.XXX_Merge(src)
return
}
mergeStruct(out.Elem(), in.Elem())
@@ -84,7 +108,7 @@ func mergeStruct(out, in reflect.Value) {
mergeAny(out.Field(i), in.Field(i), false, sprop.Prop[i])
}
if emIn, ok := extendable(in.Addr().Interface()); ok {
if emIn, err := extendable(in.Addr().Interface()); err == nil {
emOut, _ := extendable(out.Addr().Interface())
mIn, muIn := emIn.extensionsRead()
if mIn != nil {

View File

@@ -39,8 +39,6 @@ import (
"errors"
"fmt"
"io"
"os"
"reflect"
)
// errOverflow is returned when an integer is too large to be represented.
@@ -50,10 +48,6 @@ var errOverflow = errors.New("proto: integer overflow")
// wire type is encountered. It does not get returned to user code.
var ErrInternalBadWireType = errors.New("proto: internal error: bad wiretype for oneof")
// The fundamental decoders that interpret bytes on the wire.
// Those that take integer types all return uint64 and are
// therefore of type valueDecoder.
// DecodeVarint reads a varint-encoded integer from the slice.
// It returns the integer and the number of bytes consumed, or
// zero if there is not enough.
@@ -267,9 +261,6 @@ func (p *Buffer) DecodeZigzag32() (x uint64, err error) {
return
}
// These are not ValueDecoders: they produce an array of bytes or a string.
// bytes, embedded messages
// DecodeRawBytes reads a count-delimited byte buffer from the Buffer.
// This is the format used for the bytes protocol buffer
// type and for embedded messages.
@@ -311,81 +302,29 @@ func (p *Buffer) DecodeStringBytes() (s string, err error) {
return string(buf), nil
}
// Skip the next item in the buffer. Its wire type is decoded and presented as an argument.
// If the protocol buffer has extensions, and the field matches, add it as an extension.
// Otherwise, if the XXX_unrecognized field exists, append the skipped data there.
func (o *Buffer) skipAndSave(t reflect.Type, tag, wire int, base structPointer, unrecField field) error {
oi := o.index
err := o.skip(t, tag, wire)
if err != nil {
return err
}
if !unrecField.IsValid() {
return nil
}
ptr := structPointer_Bytes(base, unrecField)
// Add the skipped field to struct field
obuf := o.buf
o.buf = *ptr
o.EncodeVarint(uint64(tag<<3 | wire))
*ptr = append(o.buf, obuf[oi:o.index]...)
o.buf = obuf
return nil
}
// Skip the next item in the buffer. Its wire type is decoded and presented as an argument.
func (o *Buffer) skip(t reflect.Type, tag, wire int) error {
var u uint64
var err error
switch wire {
case WireVarint:
_, err = o.DecodeVarint()
case WireFixed64:
_, err = o.DecodeFixed64()
case WireBytes:
_, err = o.DecodeRawBytes(false)
case WireFixed32:
_, err = o.DecodeFixed32()
case WireStartGroup:
for {
u, err = o.DecodeVarint()
if err != nil {
break
}
fwire := int(u & 0x7)
if fwire == WireEndGroup {
break
}
ftag := int(u >> 3)
err = o.skip(t, ftag, fwire)
if err != nil {
break
}
}
default:
err = fmt.Errorf("proto: can't skip unknown wire type %d for %s", wire, t)
}
return err
}
// Unmarshaler is the interface representing objects that can
// unmarshal themselves. The method should reset the receiver before
// decoding starts. The argument points to data that may be
// unmarshal themselves. The argument points to data that may be
// overwritten, so implementations should not keep references to the
// buffer.
// Unmarshal implementations should not clear the receiver.
// Any unmarshaled data should be merged into the receiver.
// Callers of Unmarshal that do not want to retain existing data
// should Reset the receiver before calling Unmarshal.
type Unmarshaler interface {
Unmarshal([]byte) error
}
// newUnmarshaler is the interface representing objects that can
// unmarshal themselves. The semantics are identical to Unmarshaler.
//
// This exists to support protoc-gen-go generated messages.
// The proto package will stop type-asserting to this interface in the future.
//
// DO NOT DEPEND ON THIS.
type newUnmarshaler interface {
XXX_Unmarshal([]byte) error
}
// Unmarshal parses the protocol buffer representation in buf and places the
// decoded result in pb. If the struct underlying pb does not match
// the data in buf, the results can be unpredictable.
@@ -395,7 +334,13 @@ type Unmarshaler interface {
// to preserve and append to existing data.
func Unmarshal(buf []byte, pb Message) error {
pb.Reset()
return UnmarshalMerge(buf, pb)
if u, ok := pb.(newUnmarshaler); ok {
return u.XXX_Unmarshal(buf)
}
if u, ok := pb.(Unmarshaler); ok {
return u.Unmarshal(buf)
}
return NewBuffer(buf).Unmarshal(pb)
}
// UnmarshalMerge parses the protocol buffer representation in buf and
@@ -405,8 +350,16 @@ func Unmarshal(buf []byte, pb Message) error {
// UnmarshalMerge merges into existing data in pb.
// Most code should use Unmarshal instead.
func UnmarshalMerge(buf []byte, pb Message) error {
// If the object can unmarshal itself, let it.
if u, ok := pb.(newUnmarshaler); ok {
return u.XXX_Unmarshal(buf)
}
if u, ok := pb.(Unmarshaler); ok {
// NOTE: The history of proto have unfortunately been inconsistent
// whether Unmarshaler should or should not implicitly clear itself.
// Some implementations do, most do not.
// Thus, calling this here may or may not do what people want.
//
// See https://github.com/golang/protobuf/issues/424
return u.Unmarshal(buf)
}
return NewBuffer(buf).Unmarshal(pb)
@@ -422,12 +375,17 @@ func (p *Buffer) DecodeMessage(pb Message) error {
}
// DecodeGroup reads a tag-delimited group from the Buffer.
// StartGroup tag is already consumed. This function consumes
// EndGroup tag.
func (p *Buffer) DecodeGroup(pb Message) error {
typ, base, err := getbase(pb)
if err != nil {
return err
b := p.buf[p.index:]
x, y := findEndGroup(b)
if x < 0 {
return io.ErrUnexpectedEOF
}
return p.unmarshalType(typ.Elem(), GetProperties(typ.Elem()), true, base)
err := Unmarshal(b[:x], pb)
p.index += y
return err
}
// Unmarshal parses the protocol buffer representation in the
@@ -438,533 +396,33 @@ func (p *Buffer) DecodeGroup(pb Message) error {
// Unlike proto.Unmarshal, this does not reset pb before starting to unmarshal.
func (p *Buffer) Unmarshal(pb Message) error {
// If the object can unmarshal itself, let it.
if u, ok := pb.(newUnmarshaler); ok {
err := u.XXX_Unmarshal(p.buf[p.index:])
p.index = len(p.buf)
return err
}
if u, ok := pb.(Unmarshaler); ok {
// NOTE: The history of proto have unfortunately been inconsistent
// whether Unmarshaler should or should not implicitly clear itself.
// Some implementations do, most do not.
// Thus, calling this here may or may not do what people want.
//
// See https://github.com/golang/protobuf/issues/424
err := u.Unmarshal(p.buf[p.index:])
p.index = len(p.buf)
return err
}
typ, base, err := getbase(pb)
if err != nil {
return err
}
err = p.unmarshalType(typ.Elem(), GetProperties(typ.Elem()), false, base)
if collectStats {
stats.Decode++
}
return err
}
// unmarshalType does the work of unmarshaling a structure.
func (o *Buffer) unmarshalType(st reflect.Type, prop *StructProperties, is_group bool, base structPointer) error {
var state errorState
required, reqFields := prop.reqCount, uint64(0)
var err error
for err == nil && o.index < len(o.buf) {
oi := o.index
var u uint64
u, err = o.DecodeVarint()
if err != nil {
break
}
wire := int(u & 0x7)
if wire == WireEndGroup {
if is_group {
if required > 0 {
// Not enough information to determine the exact field.
// (See below.)
return &RequiredNotSetError{"{Unknown}"}
}
return nil // input is satisfied
}
return fmt.Errorf("proto: %s: wiretype end group for non-group", st)
}
tag := int(u >> 3)
if tag <= 0 {
return fmt.Errorf("proto: %s: illegal tag %d (wire type %d)", st, tag, wire)
}
fieldnum, ok := prop.decoderTags.get(tag)
if !ok {
// Maybe it's an extension?
if prop.extendable {
if e, _ := extendable(structPointer_Interface(base, st)); isExtensionField(e, int32(tag)) {
if err = o.skip(st, tag, wire); err == nil {
extmap := e.extensionsWrite()
ext := extmap[int32(tag)] // may be missing
ext.enc = append(ext.enc, o.buf[oi:o.index]...)
extmap[int32(tag)] = ext
}
continue
}
}
// Maybe it's a oneof?
if prop.oneofUnmarshaler != nil {
m := structPointer_Interface(base, st).(Message)
// First return value indicates whether tag is a oneof field.
ok, err = prop.oneofUnmarshaler(m, tag, wire, o)
if err == ErrInternalBadWireType {
// Map the error to something more descriptive.
// Do the formatting here to save generated code space.
err = fmt.Errorf("bad wiretype for oneof field in %T", m)
}
if ok {
continue
}
}
err = o.skipAndSave(st, tag, wire, base, prop.unrecField)
continue
}
p := prop.Prop[fieldnum]
if p.dec == nil {
fmt.Fprintf(os.Stderr, "proto: no protobuf decoder for %s.%s\n", st, st.Field(fieldnum).Name)
continue
}
dec := p.dec
if wire != WireStartGroup && wire != p.WireType {
if wire == WireBytes && p.packedDec != nil {
// a packable field
dec = p.packedDec
} else {
err = fmt.Errorf("proto: bad wiretype for field %s.%s: got wiretype %d, want %d", st, st.Field(fieldnum).Name, wire, p.WireType)
continue
}
}
decErr := dec(o, p, base)
if decErr != nil && !state.shouldContinue(decErr, p) {
err = decErr
}
if err == nil && p.Required {
// Successfully decoded a required field.
if tag <= 64 {
// use bitmap for fields 1-64 to catch field reuse.
var mask uint64 = 1 << uint64(tag-1)
if reqFields&mask == 0 {
// new required field
reqFields |= mask
required--
}
} else {
// This is imprecise. It can be fooled by a required field
// with a tag > 64 that is encoded twice; that's very rare.
// A fully correct implementation would require allocating
// a data structure, which we would like to avoid.
required--
}
}
}
if err == nil {
if is_group {
return io.ErrUnexpectedEOF
}
if state.err != nil {
return state.err
}
if required > 0 {
// Not enough information to determine the exact field. If we use extra
// CPU, we could determine the field only if the missing required field
// has a tag <= 64 and we check reqFields.
return &RequiredNotSetError{"{Unknown}"}
}
}
return err
}
// Individual type decoders
// For each,
// u is the decoded value,
// v is a pointer to the field (pointer) in the struct
// Sizes of the pools to allocate inside the Buffer.
// The goal is modest amortization and allocation
// on at least 16-byte boundaries.
const (
boolPoolSize = 16
uint32PoolSize = 8
uint64PoolSize = 4
)
// Decode a bool.
func (o *Buffer) dec_bool(p *Properties, base structPointer) error {
u, err := p.valDec(o)
if err != nil {
return err
}
if len(o.bools) == 0 {
o.bools = make([]bool, boolPoolSize)
}
o.bools[0] = u != 0
*structPointer_Bool(base, p.field) = &o.bools[0]
o.bools = o.bools[1:]
return nil
}
func (o *Buffer) dec_proto3_bool(p *Properties, base structPointer) error {
u, err := p.valDec(o)
if err != nil {
return err
}
*structPointer_BoolVal(base, p.field) = u != 0
return nil
}
// Decode an int32.
func (o *Buffer) dec_int32(p *Properties, base structPointer) error {
u, err := p.valDec(o)
if err != nil {
return err
}
word32_Set(structPointer_Word32(base, p.field), o, uint32(u))
return nil
}
func (o *Buffer) dec_proto3_int32(p *Properties, base structPointer) error {
u, err := p.valDec(o)
if err != nil {
return err
}
word32Val_Set(structPointer_Word32Val(base, p.field), uint32(u))
return nil
}
// Decode an int64.
func (o *Buffer) dec_int64(p *Properties, base structPointer) error {
u, err := p.valDec(o)
if err != nil {
return err
}
word64_Set(structPointer_Word64(base, p.field), o, u)
return nil
}
func (o *Buffer) dec_proto3_int64(p *Properties, base structPointer) error {
u, err := p.valDec(o)
if err != nil {
return err
}
word64Val_Set(structPointer_Word64Val(base, p.field), o, u)
return nil
}
// Decode a string.
func (o *Buffer) dec_string(p *Properties, base structPointer) error {
s, err := o.DecodeStringBytes()
if err != nil {
return err
}
*structPointer_String(base, p.field) = &s
return nil
}
func (o *Buffer) dec_proto3_string(p *Properties, base structPointer) error {
s, err := o.DecodeStringBytes()
if err != nil {
return err
}
*structPointer_StringVal(base, p.field) = s
return nil
}
// Decode a slice of bytes ([]byte).
func (o *Buffer) dec_slice_byte(p *Properties, base structPointer) error {
b, err := o.DecodeRawBytes(true)
if err != nil {
return err
}
*structPointer_Bytes(base, p.field) = b
return nil
}
// Decode a slice of bools ([]bool).
func (o *Buffer) dec_slice_bool(p *Properties, base structPointer) error {
u, err := p.valDec(o)
if err != nil {
return err
}
v := structPointer_BoolSlice(base, p.field)
*v = append(*v, u != 0)
return nil
}
// Decode a slice of bools ([]bool) in packed format.
func (o *Buffer) dec_slice_packed_bool(p *Properties, base structPointer) error {
v := structPointer_BoolSlice(base, p.field)
nn, err := o.DecodeVarint()
if err != nil {
return err
}
nb := int(nn) // number of bytes of encoded bools
fin := o.index + nb
if fin < o.index {
return errOverflow
}
y := *v
for o.index < fin {
u, err := p.valDec(o)
if err != nil {
return err
}
y = append(y, u != 0)
}
*v = y
return nil
}
// Decode a slice of int32s ([]int32).
func (o *Buffer) dec_slice_int32(p *Properties, base structPointer) error {
u, err := p.valDec(o)
if err != nil {
return err
}
structPointer_Word32Slice(base, p.field).Append(uint32(u))
return nil
}
// Decode a slice of int32s ([]int32) in packed format.
func (o *Buffer) dec_slice_packed_int32(p *Properties, base structPointer) error {
v := structPointer_Word32Slice(base, p.field)
nn, err := o.DecodeVarint()
if err != nil {
return err
}
nb := int(nn) // number of bytes of encoded int32s
fin := o.index + nb
if fin < o.index {
return errOverflow
}
for o.index < fin {
u, err := p.valDec(o)
if err != nil {
return err
}
v.Append(uint32(u))
}
return nil
}
// Decode a slice of int64s ([]int64).
func (o *Buffer) dec_slice_int64(p *Properties, base structPointer) error {
u, err := p.valDec(o)
if err != nil {
return err
}
structPointer_Word64Slice(base, p.field).Append(u)
return nil
}
// Decode a slice of int64s ([]int64) in packed format.
func (o *Buffer) dec_slice_packed_int64(p *Properties, base structPointer) error {
v := structPointer_Word64Slice(base, p.field)
nn, err := o.DecodeVarint()
if err != nil {
return err
}
nb := int(nn) // number of bytes of encoded int64s
fin := o.index + nb
if fin < o.index {
return errOverflow
}
for o.index < fin {
u, err := p.valDec(o)
if err != nil {
return err
}
v.Append(u)
}
return nil
}
// Decode a slice of strings ([]string).
func (o *Buffer) dec_slice_string(p *Properties, base structPointer) error {
s, err := o.DecodeStringBytes()
if err != nil {
return err
}
v := structPointer_StringSlice(base, p.field)
*v = append(*v, s)
return nil
}
// Decode a slice of slice of bytes ([][]byte).
func (o *Buffer) dec_slice_slice_byte(p *Properties, base structPointer) error {
b, err := o.DecodeRawBytes(true)
if err != nil {
return err
}
v := structPointer_BytesSlice(base, p.field)
*v = append(*v, b)
return nil
}
// Decode a map field.
func (o *Buffer) dec_new_map(p *Properties, base structPointer) error {
raw, err := o.DecodeRawBytes(false)
if err != nil {
return err
}
oi := o.index // index at the end of this map entry
o.index -= len(raw) // move buffer back to start of map entry
mptr := structPointer_NewAt(base, p.field, p.mtype) // *map[K]V
if mptr.Elem().IsNil() {
mptr.Elem().Set(reflect.MakeMap(mptr.Type().Elem()))
}
v := mptr.Elem() // map[K]V
// Prepare addressable doubly-indirect placeholders for the key and value types.
// See enc_new_map for why.
keyptr := reflect.New(reflect.PtrTo(p.mtype.Key())).Elem() // addressable *K
keybase := toStructPointer(keyptr.Addr()) // **K
var valbase structPointer
var valptr reflect.Value
switch p.mtype.Elem().Kind() {
case reflect.Slice:
// []byte
var dummy []byte
valptr = reflect.ValueOf(&dummy) // *[]byte
valbase = toStructPointer(valptr) // *[]byte
case reflect.Ptr:
// message; valptr is **Msg; need to allocate the intermediate pointer
valptr = reflect.New(reflect.PtrTo(p.mtype.Elem())).Elem() // addressable *V
valptr.Set(reflect.New(valptr.Type().Elem()))
valbase = toStructPointer(valptr)
default:
// everything else
valptr = reflect.New(reflect.PtrTo(p.mtype.Elem())).Elem() // addressable *V
valbase = toStructPointer(valptr.Addr()) // **V
}
// Decode.
// This parses a restricted wire format, namely the encoding of a message
// with two fields. See enc_new_map for the format.
for o.index < oi {
// tagcode for key and value properties are always a single byte
// because they have tags 1 and 2.
tagcode := o.buf[o.index]
o.index++
switch tagcode {
case p.mkeyprop.tagcode[0]:
if err := p.mkeyprop.dec(o, p.mkeyprop, keybase); err != nil {
return err
}
case p.mvalprop.tagcode[0]:
if err := p.mvalprop.dec(o, p.mvalprop, valbase); err != nil {
return err
}
default:
// TODO: Should we silently skip this instead?
return fmt.Errorf("proto: bad map data tag %d", raw[0])
}
}
keyelem, valelem := keyptr.Elem(), valptr.Elem()
if !keyelem.IsValid() {
keyelem = reflect.Zero(p.mtype.Key())
}
if !valelem.IsValid() {
valelem = reflect.Zero(p.mtype.Elem())
}
v.SetMapIndex(keyelem, valelem)
return nil
}
// Decode a group.
func (o *Buffer) dec_struct_group(p *Properties, base structPointer) error {
bas := structPointer_GetStructPointer(base, p.field)
if structPointer_IsNil(bas) {
// allocate new nested message
bas = toStructPointer(reflect.New(p.stype))
structPointer_SetStructPointer(base, p.field, bas)
}
return o.unmarshalType(p.stype, p.sprop, true, bas)
}
// Decode an embedded message.
func (o *Buffer) dec_struct_message(p *Properties, base structPointer) (err error) {
raw, e := o.DecodeRawBytes(false)
if e != nil {
return e
}
bas := structPointer_GetStructPointer(base, p.field)
if structPointer_IsNil(bas) {
// allocate new nested message
bas = toStructPointer(reflect.New(p.stype))
structPointer_SetStructPointer(base, p.field, bas)
}
// If the object can unmarshal itself, let it.
if p.isUnmarshaler {
iv := structPointer_Interface(bas, p.stype)
return iv.(Unmarshaler).Unmarshal(raw)
}
obuf := o.buf
oi := o.index
o.buf = raw
o.index = 0
err = o.unmarshalType(p.stype, p.sprop, false, bas)
o.buf = obuf
o.index = oi
return err
}
// Decode a slice of embedded messages.
func (o *Buffer) dec_slice_struct_message(p *Properties, base structPointer) error {
return o.dec_slice_struct(p, false, base)
}
// Decode a slice of embedded groups.
func (o *Buffer) dec_slice_struct_group(p *Properties, base structPointer) error {
return o.dec_slice_struct(p, true, base)
}
// Decode a slice of structs ([]*struct).
func (o *Buffer) dec_slice_struct(p *Properties, is_group bool, base structPointer) error {
v := reflect.New(p.stype)
bas := toStructPointer(v)
structPointer_StructPointerSlice(base, p.field).Append(bas)
if is_group {
err := o.unmarshalType(p.stype, p.sprop, is_group, bas)
return err
}
raw, err := o.DecodeRawBytes(false)
if err != nil {
return err
}
// If the object can unmarshal itself, let it.
if p.isUnmarshaler {
iv := v.Interface()
return iv.(Unmarshaler).Unmarshal(raw)
}
obuf := o.buf
oi := o.index
o.buf = raw
o.index = 0
err = o.unmarshalType(p.stype, p.sprop, is_group, bas)
o.buf = obuf
o.index = oi
// Slow workaround for messages that aren't Unmarshalers.
// This includes some hand-coded .pb.go files and
// bootstrap protos.
// TODO: fix all of those and then add Unmarshal to
// the Message interface. Then:
// The cast above and code below can be deleted.
// The old unmarshaler can be deleted.
// Clients can call Unmarshal directly (can already do that, actually).
var info InternalMessageInfo
err := info.Unmarshal(pb, p.buf[p.index:])
p.index = len(p.buf)
return err
}

View File

@@ -35,8 +35,14 @@ import (
"fmt"
"reflect"
"strings"
"sync"
"sync/atomic"
)
type generatedDiscarder interface {
XXX_DiscardUnknown()
}
// DiscardUnknown recursively discards all unknown fields from this message
// and all embedded messages.
//
@@ -49,9 +55,202 @@ import (
// For proto2 messages, the unknown fields of message extensions are only
// discarded from messages that have been accessed via GetExtension.
func DiscardUnknown(m Message) {
if m, ok := m.(generatedDiscarder); ok {
m.XXX_DiscardUnknown()
return
}
// TODO: Dynamically populate a InternalMessageInfo for legacy messages,
// but the master branch has no implementation for InternalMessageInfo,
// so it would be more work to replicate that approach.
discardLegacy(m)
}
// DiscardUnknown recursively discards all unknown fields.
func (a *InternalMessageInfo) DiscardUnknown(m Message) {
di := atomicLoadDiscardInfo(&a.discard)
if di == nil {
di = getDiscardInfo(reflect.TypeOf(m).Elem())
atomicStoreDiscardInfo(&a.discard, di)
}
di.discard(toPointer(&m))
}
type discardInfo struct {
typ reflect.Type
initialized int32 // 0: only typ is valid, 1: everything is valid
lock sync.Mutex
fields []discardFieldInfo
unrecognized field
}
type discardFieldInfo struct {
field field // Offset of field, guaranteed to be valid
discard func(src pointer)
}
var (
discardInfoMap = map[reflect.Type]*discardInfo{}
discardInfoLock sync.Mutex
)
func getDiscardInfo(t reflect.Type) *discardInfo {
discardInfoLock.Lock()
defer discardInfoLock.Unlock()
di := discardInfoMap[t]
if di == nil {
di = &discardInfo{typ: t}
discardInfoMap[t] = di
}
return di
}
func (di *discardInfo) discard(src pointer) {
if src.isNil() {
return // Nothing to do.
}
if atomic.LoadInt32(&di.initialized) == 0 {
di.computeDiscardInfo()
}
for _, fi := range di.fields {
sfp := src.offset(fi.field)
fi.discard(sfp)
}
// For proto2 messages, only discard unknown fields in message extensions
// that have been accessed via GetExtension.
if em, err := extendable(src.asPointerTo(di.typ).Interface()); err == nil {
// Ignore lock since DiscardUnknown is not concurrency safe.
emm, _ := em.extensionsRead()
for _, mx := range emm {
if m, ok := mx.value.(Message); ok {
DiscardUnknown(m)
}
}
}
if di.unrecognized.IsValid() {
*src.offset(di.unrecognized).toBytes() = nil
}
}
func (di *discardInfo) computeDiscardInfo() {
di.lock.Lock()
defer di.lock.Unlock()
if di.initialized != 0 {
return
}
t := di.typ
n := t.NumField()
for i := 0; i < n; i++ {
f := t.Field(i)
if strings.HasPrefix(f.Name, "XXX_") {
continue
}
dfi := discardFieldInfo{field: toField(&f)}
tf := f.Type
// Unwrap tf to get its most basic type.
var isPointer, isSlice bool
if tf.Kind() == reflect.Slice && tf.Elem().Kind() != reflect.Uint8 {
isSlice = true
tf = tf.Elem()
}
if tf.Kind() == reflect.Ptr {
isPointer = true
tf = tf.Elem()
}
if isPointer && isSlice && tf.Kind() != reflect.Struct {
panic(fmt.Sprintf("%v.%s cannot be a slice of pointers to primitive types", t, f.Name))
}
switch tf.Kind() {
case reflect.Struct:
switch {
case !isPointer:
panic(fmt.Sprintf("%v.%s cannot be a direct struct value", t, f.Name))
case isSlice: // E.g., []*pb.T
di := getDiscardInfo(tf)
dfi.discard = func(src pointer) {
sps := src.getPointerSlice()
for _, sp := range sps {
if !sp.isNil() {
di.discard(sp)
}
}
}
default: // E.g., *pb.T
di := getDiscardInfo(tf)
dfi.discard = func(src pointer) {
sp := src.getPointer()
if !sp.isNil() {
di.discard(sp)
}
}
}
case reflect.Map:
switch {
case isPointer || isSlice:
panic(fmt.Sprintf("%v.%s cannot be a pointer to a map or a slice of map values", t, f.Name))
default: // E.g., map[K]V
if tf.Elem().Kind() == reflect.Ptr { // Proto struct (e.g., *T)
dfi.discard = func(src pointer) {
sm := src.asPointerTo(tf).Elem()
if sm.Len() == 0 {
return
}
for _, key := range sm.MapKeys() {
val := sm.MapIndex(key)
DiscardUnknown(val.Interface().(Message))
}
}
} else {
dfi.discard = func(pointer) {} // Noop
}
}
case reflect.Interface:
// Must be oneof field.
switch {
case isPointer || isSlice:
panic(fmt.Sprintf("%v.%s cannot be a pointer to a interface or a slice of interface values", t, f.Name))
default: // E.g., interface{}
// TODO: Make this faster?
dfi.discard = func(src pointer) {
su := src.asPointerTo(tf).Elem()
if !su.IsNil() {
sv := su.Elem().Elem().Field(0)
if sv.Kind() == reflect.Ptr && sv.IsNil() {
return
}
switch sv.Type().Kind() {
case reflect.Ptr: // Proto struct (e.g., *T)
DiscardUnknown(sv.Interface().(Message))
}
}
}
}
default:
continue
}
di.fields = append(di.fields, dfi)
}
di.unrecognized = invalidField
if f, ok := t.FieldByName("XXX_unrecognized"); ok {
if f.Type != reflect.TypeOf([]byte{}) {
panic("expected XXX_unrecognized to be of type []byte")
}
di.unrecognized = toField(&f)
}
atomic.StoreInt32(&di.initialized, 1)
}
func discardLegacy(m Message) {
v := reflect.ValueOf(m)
if v.Kind() != reflect.Ptr || v.IsNil() {
@@ -139,7 +338,7 @@ func discardLegacy(m Message) {
// For proto2 messages, only discard unknown fields in message extensions
// that have been accessed via GetExtension.
if em, ok := extendable(m); ok {
if em, err := extendable(m); err == nil {
// Ignore lock since discardLegacy is not concurrency safe.
emm, _ := em.extensionsRead()
for _, mx := range emm {

File diff suppressed because it is too large Load Diff

View File

@@ -109,15 +109,6 @@ func equalStruct(v1, v2 reflect.Value) bool {
// set/unset mismatch
return false
}
b1, ok := f1.Interface().(raw)
if ok {
b2 := f2.Interface().(raw)
// RawMessage
if !bytes.Equal(b1.Bytes(), b2.Bytes()) {
return false
}
continue
}
f1, f2 = f1.Elem(), f2.Elem()
}
if !equalAny(f1, f2, sprop.Prop[i]) {
@@ -146,11 +137,7 @@ func equalStruct(v1, v2 reflect.Value) bool {
u1 := uf.Bytes()
u2 := v2.FieldByName("XXX_unrecognized").Bytes()
if !bytes.Equal(u1, u2) {
return false
}
return true
return bytes.Equal(u1, u2)
}
// v1 and v2 are known to have the same type.
@@ -261,6 +248,15 @@ func equalExtMap(base reflect.Type, em1, em2 map[int32]Extension) bool {
m1, m2 := e1.value, e2.value
if m1 == nil && m2 == nil {
// Both have only encoded form.
if bytes.Equal(e1.enc, e2.enc) {
continue
}
// The bytes are different, but the extensions might still be
// equal. We need to decode them to compare.
}
if m1 != nil && m2 != nil {
// Both are unencoded.
if !equalAny(reflect.ValueOf(m1), reflect.ValueOf(m2), nil) {
@@ -276,8 +272,12 @@ func equalExtMap(base reflect.Type, em1, em2 map[int32]Extension) bool {
desc = m[extNum]
}
if desc == nil {
// If both have only encoded form and the bytes are the same,
// it is handled above. We get here when the bytes are different.
// We don't know how to decode it, so just compare them as byte
// slices.
log.Printf("proto: don't know how to compare extension %d of %v", extNum, base)
continue
return false
}
var err error
if m1 == nil {

View File

@@ -38,6 +38,7 @@ package proto
import (
"errors"
"fmt"
"io"
"reflect"
"strconv"
"sync"
@@ -91,14 +92,29 @@ func (n notLocker) Unlock() {}
// extendable returns the extendableProto interface for the given generated proto message.
// If the proto message has the old extension format, it returns a wrapper that implements
// the extendableProto interface.
func extendable(p interface{}) (extendableProto, bool) {
if ep, ok := p.(extendableProto); ok {
return ep, ok
func extendable(p interface{}) (extendableProto, error) {
switch p := p.(type) {
case extendableProto:
if isNilPtr(p) {
return nil, fmt.Errorf("proto: nil %T is not extendable", p)
}
return p, nil
case extendableProtoV1:
if isNilPtr(p) {
return nil, fmt.Errorf("proto: nil %T is not extendable", p)
}
return extensionAdapter{p}, nil
}
if ep, ok := p.(extendableProtoV1); ok {
return extensionAdapter{ep}, ok
}
return nil, false
// Don't allocate a specific error containing %T:
// this is the hot path for Clone and MarshalText.
return nil, errNotExtendable
}
var errNotExtendable = errors.New("proto: not an extendable proto.Message")
func isNilPtr(x interface{}) bool {
v := reflect.ValueOf(x)
return v.Kind() == reflect.Ptr && v.IsNil()
}
// XXX_InternalExtensions is an internal representation of proto extensions.
@@ -143,9 +159,6 @@ func (e *XXX_InternalExtensions) extensionsRead() (map[int32]Extension, sync.Loc
return e.p.extensionMap, &e.p.mu
}
var extendableProtoType = reflect.TypeOf((*extendableProto)(nil)).Elem()
var extendableProtoV1Type = reflect.TypeOf((*extendableProtoV1)(nil)).Elem()
// ExtensionDesc represents an extension specification.
// Used in generated code from the protocol compiler.
type ExtensionDesc struct {
@@ -179,8 +192,8 @@ type Extension struct {
// SetRawExtension is for testing only.
func SetRawExtension(base Message, id int32, b []byte) {
epb, ok := extendable(base)
if !ok {
epb, err := extendable(base)
if err != nil {
return
}
extmap := epb.extensionsWrite()
@@ -205,7 +218,7 @@ func checkExtensionTypes(pb extendableProto, extension *ExtensionDesc) error {
pbi = ea.extendableProtoV1
}
if a, b := reflect.TypeOf(pbi), reflect.TypeOf(extension.ExtendedType); a != b {
return errors.New("proto: bad extended type; " + b.String() + " does not extend " + a.String())
return fmt.Errorf("proto: bad extended type; %v does not extend %v", b, a)
}
// Check the range.
if !isExtensionField(pb, extension.Field) {
@@ -250,85 +263,11 @@ func extensionProperties(ed *ExtensionDesc) *Properties {
return prop
}
// encode encodes any unmarshaled (unencoded) extensions in e.
func encodeExtensions(e *XXX_InternalExtensions) error {
m, mu := e.extensionsRead()
if m == nil {
return nil // fast path
}
mu.Lock()
defer mu.Unlock()
return encodeExtensionsMap(m)
}
// encode encodes any unmarshaled (unencoded) extensions in e.
func encodeExtensionsMap(m map[int32]Extension) error {
for k, e := range m {
if e.value == nil || e.desc == nil {
// Extension is only in its encoded form.
continue
}
// We don't skip extensions that have an encoded form set,
// because the extension value may have been mutated after
// the last time this function was called.
et := reflect.TypeOf(e.desc.ExtensionType)
props := extensionProperties(e.desc)
p := NewBuffer(nil)
// If e.value has type T, the encoder expects a *struct{ X T }.
// Pass a *T with a zero field and hope it all works out.
x := reflect.New(et)
x.Elem().Set(reflect.ValueOf(e.value))
if err := props.enc(p, props, toStructPointer(x)); err != nil {
return err
}
e.enc = p.buf
m[k] = e
}
return nil
}
func extensionsSize(e *XXX_InternalExtensions) (n int) {
m, mu := e.extensionsRead()
if m == nil {
return 0
}
mu.Lock()
defer mu.Unlock()
return extensionsMapSize(m)
}
func extensionsMapSize(m map[int32]Extension) (n int) {
for _, e := range m {
if e.value == nil || e.desc == nil {
// Extension is only in its encoded form.
n += len(e.enc)
continue
}
// We don't skip extensions that have an encoded form set,
// because the extension value may have been mutated after
// the last time this function was called.
et := reflect.TypeOf(e.desc.ExtensionType)
props := extensionProperties(e.desc)
// If e.value has type T, the encoder expects a *struct{ X T }.
// Pass a *T with a zero field and hope it all works out.
x := reflect.New(et)
x.Elem().Set(reflect.ValueOf(e.value))
n += props.size(props, toStructPointer(x))
}
return
}
// HasExtension returns whether the given extension is present in pb.
func HasExtension(pb Message, extension *ExtensionDesc) bool {
// TODO: Check types, field numbers, etc.?
epb, ok := extendable(pb)
if !ok {
epb, err := extendable(pb)
if err != nil {
return false
}
extmap, mu := epb.extensionsRead()
@@ -336,15 +275,15 @@ func HasExtension(pb Message, extension *ExtensionDesc) bool {
return false
}
mu.Lock()
_, ok = extmap[extension.Field]
_, ok := extmap[extension.Field]
mu.Unlock()
return ok
}
// ClearExtension removes the given extension from pb.
func ClearExtension(pb Message, extension *ExtensionDesc) {
epb, ok := extendable(pb)
if !ok {
epb, err := extendable(pb)
if err != nil {
return
}
// TODO: Check types, field numbers, etc.?
@@ -352,16 +291,26 @@ func ClearExtension(pb Message, extension *ExtensionDesc) {
delete(extmap, extension.Field)
}
// GetExtension parses and returns the given extension of pb.
// If the extension is not present and has no default value it returns ErrMissingExtension.
// GetExtension retrieves a proto2 extended field from pb.
//
// If the descriptor is type complete (i.e., ExtensionDesc.ExtensionType is non-nil),
// then GetExtension parses the encoded field and returns a Go value of the specified type.
// If the field is not present, then the default value is returned (if one is specified),
// otherwise ErrMissingExtension is reported.
//
// If the descriptor is not type complete (i.e., ExtensionDesc.ExtensionType is nil),
// then GetExtension returns the raw encoded bytes of the field extension.
func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
epb, ok := extendable(pb)
if !ok {
return nil, errors.New("proto: not an extendable proto")
epb, err := extendable(pb)
if err != nil {
return nil, err
}
if err := checkExtensionTypes(epb, extension); err != nil {
return nil, err
if extension.ExtendedType != nil {
// can only check type if this is a complete descriptor
if err := checkExtensionTypes(epb, extension); err != nil {
return nil, err
}
}
emap, mu := epb.extensionsRead()
@@ -388,6 +337,11 @@ func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
return e.value, nil
}
if extension.ExtensionType == nil {
// incomplete descriptor
return e.enc, nil
}
v, err := decodeExtension(e.enc, extension)
if err != nil {
return nil, err
@@ -405,6 +359,11 @@ func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
// defaultExtensionValue returns the default value for extension.
// If no default for an extension is defined ErrMissingExtension is returned.
func defaultExtensionValue(extension *ExtensionDesc) (interface{}, error) {
if extension.ExtensionType == nil {
// incomplete descriptor, so no default
return nil, ErrMissingExtension
}
t := reflect.TypeOf(extension.ExtensionType)
props := extensionProperties(extension)
@@ -439,31 +398,28 @@ func defaultExtensionValue(extension *ExtensionDesc) (interface{}, error) {
// decodeExtension decodes an extension encoded in b.
func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) {
o := NewBuffer(b)
t := reflect.TypeOf(extension.ExtensionType)
props := extensionProperties(extension)
unmarshal := typeUnmarshaler(t, extension.Tag)
// t is a pointer to a struct, pointer to basic type or a slice.
// Allocate a "field" to store the pointer/slice itself; the
// pointer/slice will be stored here. We pass
// the address of this field to props.dec.
// This passes a zero field and a *t and lets props.dec
// interpret it as a *struct{ x t }.
// Allocate space to store the pointer/slice.
value := reflect.New(t).Elem()
var err error
for {
// Discard wire type and field number varint. It isn't needed.
if _, err := o.DecodeVarint(); err != nil {
x, n := decodeVarint(b)
if n == 0 {
return nil, io.ErrUnexpectedEOF
}
b = b[n:]
wire := int(x) & 7
b, err = unmarshal(b, valToPointer(value.Addr()), wire)
if err != nil {
return nil, err
}
if err := props.dec(o, props, toStructPointer(value.Addr())); err != nil {
return nil, err
}
if o.index >= len(o.buf) {
if len(b) == 0 {
break
}
}
@@ -473,9 +429,9 @@ func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) {
// GetExtensions returns a slice of the extensions present in pb that are also listed in es.
// The returned slice has the same length as es; missing extensions will appear as nil elements.
func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) {
epb, ok := extendable(pb)
if !ok {
return nil, errors.New("proto: not an extendable proto")
epb, err := extendable(pb)
if err != nil {
return nil, err
}
extensions = make([]interface{}, len(es))
for i, e := range es {
@@ -494,9 +450,9 @@ func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, e
// For non-registered extensions, ExtensionDescs returns an incomplete descriptor containing
// just the Field field, which defines the extension's field number.
func ExtensionDescs(pb Message) ([]*ExtensionDesc, error) {
epb, ok := extendable(pb)
if !ok {
return nil, fmt.Errorf("proto: %T is not an extendable proto.Message", pb)
epb, err := extendable(pb)
if err != nil {
return nil, err
}
registeredExtensions := RegisteredExtensions(pb)
@@ -523,9 +479,9 @@ func ExtensionDescs(pb Message) ([]*ExtensionDesc, error) {
// SetExtension sets the specified extension of pb to the specified value.
func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error {
epb, ok := extendable(pb)
if !ok {
return errors.New("proto: not an extendable proto")
epb, err := extendable(pb)
if err != nil {
return err
}
if err := checkExtensionTypes(epb, extension); err != nil {
return err
@@ -550,8 +506,8 @@ func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error
// ClearAllExtensions clears all extensions from pb.
func ClearAllExtensions(pb Message) {
epb, ok := extendable(pb)
if !ok {
epb, err := extendable(pb)
if err != nil {
return
}
m := epb.extensionsWrite()

View File

@@ -265,6 +265,7 @@ package proto
import (
"encoding/json"
"errors"
"fmt"
"log"
"reflect"
@@ -273,6 +274,8 @@ import (
"sync"
)
var errInvalidUTF8 = errors.New("proto: invalid UTF-8 string")
// Message is implemented by generated protocol buffer messages.
type Message interface {
Reset()
@@ -309,16 +312,7 @@ type Buffer struct {
buf []byte // encode/decode byte stream
index int // read point
// pools of basic types to amortize allocation.
bools []bool
uint32s []uint32
uint64s []uint64
// extra pools, only used with pointer_reflect.go
int32s []int32
int64s []int64
float32s []float32
float64s []float64
deterministic bool
}
// NewBuffer allocates a new Buffer and initializes its internal data to
@@ -343,6 +337,30 @@ func (p *Buffer) SetBuf(s []byte) {
// Bytes returns the contents of the Buffer.
func (p *Buffer) Bytes() []byte { return p.buf }
// SetDeterministic sets whether to use deterministic serialization.
//
// Deterministic serialization guarantees that for a given binary, equal
// messages will always be serialized to the same bytes. This implies:
//
// - Repeated serialization of a message will return the same bytes.
// - Different processes of the same binary (which may be executing on
// different machines) will serialize equal messages to the same bytes.
//
// Note that the deterministic serialization is NOT canonical across
// languages. It is not guaranteed to remain stable over time. It is unstable
// across different builds with schema changes due to unknown fields.
// Users who need canonical serialization (e.g., persistent storage in a
// canonical form, fingerprinting, etc.) should define their own
// canonicalization specification and implement their own serializer rather
// than relying on this API.
//
// If deterministic serialization is requested, map entries will be sorted
// by keys in lexographical order. This is an implementation detail and
// subject to change.
func (p *Buffer) SetDeterministic(deterministic bool) {
p.deterministic = deterministic
}
/*
* Helper routines for simplifying the creation of optional fields of basic type.
*/
@@ -831,22 +849,12 @@ func fieldDefault(ft reflect.Type, prop *Properties) (sf *scalarField, nestedMes
return sf, false, nil
}
// mapKeys returns a sort.Interface to be used for sorting the map keys.
// Map fields may have key types of non-float scalars, strings and enums.
// The easiest way to sort them in some deterministic order is to use fmt.
// If this turns out to be inefficient we can always consider other options,
// such as doing a Schwartzian transform.
func mapKeys(vs []reflect.Value) sort.Interface {
s := mapKeySorter{
vs: vs,
// default Less function: textual comparison
less: func(a, b reflect.Value) bool {
return fmt.Sprint(a.Interface()) < fmt.Sprint(b.Interface())
},
}
s := mapKeySorter{vs: vs}
// Type specialization per https://developers.google.com/protocol-buffers/docs/proto#maps;
// numeric keys are sorted numerically.
// Type specialization per https://developers.google.com/protocol-buffers/docs/proto#maps.
if len(vs) == 0 {
return s
}
@@ -855,6 +863,12 @@ func mapKeys(vs []reflect.Value) sort.Interface {
s.less = func(a, b reflect.Value) bool { return a.Int() < b.Int() }
case reflect.Uint32, reflect.Uint64:
s.less = func(a, b reflect.Value) bool { return a.Uint() < b.Uint() }
case reflect.Bool:
s.less = func(a, b reflect.Value) bool { return !a.Bool() && b.Bool() } // false < true
case reflect.String:
s.less = func(a, b reflect.Value) bool { return a.String() < b.String() }
default:
panic(fmt.Sprintf("unsupported map key type: %v", vs[0].Kind()))
}
return s
@@ -895,3 +909,13 @@ const ProtoPackageIsVersion2 = true
// ProtoPackageIsVersion1 is referenced from generated protocol buffer files
// to assert that that code is compatible with this version of the proto package.
const ProtoPackageIsVersion1 = true
// InternalMessageInfo is a type used internally by generated .pb.go files.
// This type is not intended to be used by non-generated code.
// This type is not subject to any compatibility guarantee.
type InternalMessageInfo struct {
marshal *marshalInfo
unmarshal *unmarshalInfo
merge *mergeInfo
discard *discardInfo
}

View File

@@ -42,6 +42,7 @@ import (
"fmt"
"reflect"
"sort"
"sync"
)
// errNoMessageTypeID occurs when a protocol buffer does not have a message type ID.
@@ -94,10 +95,7 @@ func (ms *messageSet) find(pb Message) *_MessageSet_Item {
}
func (ms *messageSet) Has(pb Message) bool {
if ms.find(pb) != nil {
return true
}
return false
return ms.find(pb) != nil
}
func (ms *messageSet) Unmarshal(pb Message) error {
@@ -150,46 +148,42 @@ func skipVarint(buf []byte) []byte {
// MarshalMessageSet encodes the extension map represented by m in the message set wire format.
// It is called by generated Marshal methods on protocol buffer messages with the message_set_wire_format option.
func MarshalMessageSet(exts interface{}) ([]byte, error) {
var m map[int32]Extension
return marshalMessageSet(exts, false)
}
// marshaMessageSet implements above function, with the opt to turn on / off deterministic during Marshal.
func marshalMessageSet(exts interface{}, deterministic bool) ([]byte, error) {
switch exts := exts.(type) {
case *XXX_InternalExtensions:
if err := encodeExtensions(exts); err != nil {
return nil, err
}
m, _ = exts.extensionsRead()
var u marshalInfo
siz := u.sizeMessageSet(exts)
b := make([]byte, 0, siz)
return u.appendMessageSet(b, exts, deterministic)
case map[int32]Extension:
if err := encodeExtensionsMap(exts); err != nil {
return nil, err
// This is an old-style extension map.
// Wrap it in a new-style XXX_InternalExtensions.
ie := XXX_InternalExtensions{
p: &struct {
mu sync.Mutex
extensionMap map[int32]Extension
}{
extensionMap: exts,
},
}
m = exts
var u marshalInfo
siz := u.sizeMessageSet(&ie)
b := make([]byte, 0, siz)
return u.appendMessageSet(b, &ie, deterministic)
default:
return nil, errors.New("proto: not an extension map")
}
// Sort extension IDs to provide a deterministic encoding.
// See also enc_map in encode.go.
ids := make([]int, 0, len(m))
for id := range m {
ids = append(ids, int(id))
}
sort.Ints(ids)
ms := &messageSet{Item: make([]*_MessageSet_Item, 0, len(m))}
for _, id := range ids {
e := m[int32(id)]
// Remove the wire type and field number varint, as well as the length varint.
msg := skipVarint(skipVarint(e.enc))
ms.Item = append(ms.Item, &_MessageSet_Item{
TypeId: Int32(int32(id)),
Message: msg,
})
}
return Marshal(ms)
}
// UnmarshalMessageSet decodes the extension map encoded in buf in the message set wire format.
// It is called by generated Unmarshal methods on protocol buffer messages with the message_set_wire_format option.
// It is called by Unmarshal methods on protocol buffer messages with the message_set_wire_format option.
func UnmarshalMessageSet(buf []byte, exts interface{}) error {
var m map[int32]Extension
switch exts := exts.(type) {
@@ -235,7 +229,15 @@ func MarshalMessageSetJSON(exts interface{}) ([]byte, error) {
var m map[int32]Extension
switch exts := exts.(type) {
case *XXX_InternalExtensions:
m, _ = exts.extensionsRead()
var mu sync.Locker
m, mu = exts.extensionsRead()
if m != nil {
// Keep the extensions map locked until we're done marshaling to prevent
// races between marshaling and unmarshaling the lazily-{en,de}coded
// values.
mu.Lock()
defer mu.Unlock()
}
case map[int32]Extension:
m = exts
default:
@@ -253,15 +255,16 @@ func MarshalMessageSetJSON(exts interface{}) ([]byte, error) {
for i, id := range ids {
ext := m[id]
if i > 0 {
b.WriteByte(',')
}
msd, ok := messageSetMap[id]
if !ok {
// Unknown type; we can't render it, so skip it.
continue
}
if i > 0 && b.Len() > 1 {
b.WriteByte(',')
}
fmt.Fprintf(&b, `"[%s]":`, msd.name)
x := ext.value

View File

@@ -29,7 +29,7 @@
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// +build appengine js
// +build purego appengine js
// This file contains an implementation of proto field accesses using package reflect.
// It is slower than the code in pointer_unsafe.go but it avoids package unsafe and can
@@ -38,32 +38,13 @@
package proto
import (
"math"
"reflect"
"sync"
)
// A structPointer is a pointer to a struct.
type structPointer struct {
v reflect.Value
}
const unsafeAllowed = false
// toStructPointer returns a structPointer equivalent to the given reflect value.
// The reflect value must itself be a pointer to a struct.
func toStructPointer(v reflect.Value) structPointer {
return structPointer{v}
}
// IsNil reports whether p is nil.
func structPointer_IsNil(p structPointer) bool {
return p.v.IsNil()
}
// Interface returns the struct pointer as an interface value.
func structPointer_Interface(p structPointer, _ reflect.Type) interface{} {
return p.v.Interface()
}
// A field identifies a field in a struct, accessible from a structPointer.
// A field identifies a field in a struct, accessible from a pointer.
// In this implementation, a field is identified by the sequence of field indices
// passed to reflect's FieldByIndex.
type field []int
@@ -76,409 +57,301 @@ func toField(f *reflect.StructField) field {
// invalidField is an invalid field identifier.
var invalidField = field(nil)
// zeroField is a noop when calling pointer.offset.
var zeroField = field([]int{})
// IsValid reports whether the field identifier is valid.
func (f field) IsValid() bool { return f != nil }
// field returns the given field in the struct as a reflect value.
func structPointer_field(p structPointer, f field) reflect.Value {
// Special case: an extension map entry with a value of type T
// passes a *T to the struct-handling code with a zero field,
// expecting that it will be treated as equivalent to *struct{ X T },
// which has the same memory layout. We have to handle that case
// specially, because reflect will panic if we call FieldByIndex on a
// non-struct.
if f == nil {
return p.v.Elem()
}
return p.v.Elem().FieldByIndex(f)
}
// ifield returns the given field in the struct as an interface value.
func structPointer_ifield(p structPointer, f field) interface{} {
return structPointer_field(p, f).Addr().Interface()
}
// Bytes returns the address of a []byte field in the struct.
func structPointer_Bytes(p structPointer, f field) *[]byte {
return structPointer_ifield(p, f).(*[]byte)
}
// BytesSlice returns the address of a [][]byte field in the struct.
func structPointer_BytesSlice(p structPointer, f field) *[][]byte {
return structPointer_ifield(p, f).(*[][]byte)
}
// Bool returns the address of a *bool field in the struct.
func structPointer_Bool(p structPointer, f field) **bool {
return structPointer_ifield(p, f).(**bool)
}
// BoolVal returns the address of a bool field in the struct.
func structPointer_BoolVal(p structPointer, f field) *bool {
return structPointer_ifield(p, f).(*bool)
}
// BoolSlice returns the address of a []bool field in the struct.
func structPointer_BoolSlice(p structPointer, f field) *[]bool {
return structPointer_ifield(p, f).(*[]bool)
}
// String returns the address of a *string field in the struct.
func structPointer_String(p structPointer, f field) **string {
return structPointer_ifield(p, f).(**string)
}
// StringVal returns the address of a string field in the struct.
func structPointer_StringVal(p structPointer, f field) *string {
return structPointer_ifield(p, f).(*string)
}
// StringSlice returns the address of a []string field in the struct.
func structPointer_StringSlice(p structPointer, f field) *[]string {
return structPointer_ifield(p, f).(*[]string)
}
// Extensions returns the address of an extension map field in the struct.
func structPointer_Extensions(p structPointer, f field) *XXX_InternalExtensions {
return structPointer_ifield(p, f).(*XXX_InternalExtensions)
}
// ExtMap returns the address of an extension map field in the struct.
func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension {
return structPointer_ifield(p, f).(*map[int32]Extension)
}
// NewAt returns the reflect.Value for a pointer to a field in the struct.
func structPointer_NewAt(p structPointer, f field, typ reflect.Type) reflect.Value {
return structPointer_field(p, f).Addr()
}
// SetStructPointer writes a *struct field in the struct.
func structPointer_SetStructPointer(p structPointer, f field, q structPointer) {
structPointer_field(p, f).Set(q.v)
}
// GetStructPointer reads a *struct field in the struct.
func structPointer_GetStructPointer(p structPointer, f field) structPointer {
return structPointer{structPointer_field(p, f)}
}
// StructPointerSlice the address of a []*struct field in the struct.
func structPointer_StructPointerSlice(p structPointer, f field) structPointerSlice {
return structPointerSlice{structPointer_field(p, f)}
}
// A structPointerSlice represents the address of a slice of pointers to structs
// (themselves messages or groups). That is, v.Type() is *[]*struct{...}.
type structPointerSlice struct {
// The pointer type is for the table-driven decoder.
// The implementation here uses a reflect.Value of pointer type to
// create a generic pointer. In pointer_unsafe.go we use unsafe
// instead of reflect to implement the same (but faster) interface.
type pointer struct {
v reflect.Value
}
func (p structPointerSlice) Len() int { return p.v.Len() }
func (p structPointerSlice) Index(i int) structPointer { return structPointer{p.v.Index(i)} }
func (p structPointerSlice) Append(q structPointer) {
p.v.Set(reflect.Append(p.v, q.v))
// toPointer converts an interface of pointer type to a pointer
// that points to the same target.
func toPointer(i *Message) pointer {
return pointer{v: reflect.ValueOf(*i)}
}
var (
int32Type = reflect.TypeOf(int32(0))
uint32Type = reflect.TypeOf(uint32(0))
float32Type = reflect.TypeOf(float32(0))
int64Type = reflect.TypeOf(int64(0))
uint64Type = reflect.TypeOf(uint64(0))
float64Type = reflect.TypeOf(float64(0))
)
// A word32 represents a field of type *int32, *uint32, *float32, or *enum.
// That is, v.Type() is *int32, *uint32, *float32, or *enum and v is assignable.
type word32 struct {
v reflect.Value
// toAddrPointer converts an interface to a pointer that points to
// the interface data.
func toAddrPointer(i *interface{}, isptr bool) pointer {
v := reflect.ValueOf(*i)
u := reflect.New(v.Type())
u.Elem().Set(v)
return pointer{v: u}
}
// IsNil reports whether p is nil.
func word32_IsNil(p word32) bool {
// valToPointer converts v to a pointer. v must be of pointer type.
func valToPointer(v reflect.Value) pointer {
return pointer{v: v}
}
// offset converts from a pointer to a structure to a pointer to
// one of its fields.
func (p pointer) offset(f field) pointer {
return pointer{v: p.v.Elem().FieldByIndex(f).Addr()}
}
func (p pointer) isNil() bool {
return p.v.IsNil()
}
// Set sets p to point at a newly allocated word with bits set to x.
func word32_Set(p word32, o *Buffer, x uint32) {
t := p.v.Type().Elem()
switch t {
case int32Type:
if len(o.int32s) == 0 {
o.int32s = make([]int32, uint32PoolSize)
}
o.int32s[0] = int32(x)
p.v.Set(reflect.ValueOf(&o.int32s[0]))
o.int32s = o.int32s[1:]
return
case uint32Type:
if len(o.uint32s) == 0 {
o.uint32s = make([]uint32, uint32PoolSize)
}
o.uint32s[0] = x
p.v.Set(reflect.ValueOf(&o.uint32s[0]))
o.uint32s = o.uint32s[1:]
return
case float32Type:
if len(o.float32s) == 0 {
o.float32s = make([]float32, uint32PoolSize)
}
o.float32s[0] = math.Float32frombits(x)
p.v.Set(reflect.ValueOf(&o.float32s[0]))
o.float32s = o.float32s[1:]
return
}
// must be enum
p.v.Set(reflect.New(t))
p.v.Elem().SetInt(int64(int32(x)))
}
// Get gets the bits pointed at by p, as a uint32.
func word32_Get(p word32) uint32 {
elem := p.v.Elem()
switch elem.Kind() {
case reflect.Int32:
return uint32(elem.Int())
case reflect.Uint32:
return uint32(elem.Uint())
case reflect.Float32:
return math.Float32bits(float32(elem.Float()))
}
panic("unreachable")
}
// Word32 returns a reference to a *int32, *uint32, *float32, or *enum field in the struct.
func structPointer_Word32(p structPointer, f field) word32 {
return word32{structPointer_field(p, f)}
}
// A word32Val represents a field of type int32, uint32, float32, or enum.
// That is, v.Type() is int32, uint32, float32, or enum and v is assignable.
type word32Val struct {
v reflect.Value
}
// Set sets *p to x.
func word32Val_Set(p word32Val, x uint32) {
switch p.v.Type() {
case int32Type:
p.v.SetInt(int64(x))
return
case uint32Type:
p.v.SetUint(uint64(x))
return
case float32Type:
p.v.SetFloat(float64(math.Float32frombits(x)))
return
}
// must be enum
p.v.SetInt(int64(int32(x)))
}
// Get gets the bits pointed at by p, as a uint32.
func word32Val_Get(p word32Val) uint32 {
elem := p.v
switch elem.Kind() {
case reflect.Int32:
return uint32(elem.Int())
case reflect.Uint32:
return uint32(elem.Uint())
case reflect.Float32:
return math.Float32bits(float32(elem.Float()))
}
panic("unreachable")
}
// Word32Val returns a reference to a int32, uint32, float32, or enum field in the struct.
func structPointer_Word32Val(p structPointer, f field) word32Val {
return word32Val{structPointer_field(p, f)}
}
// A word32Slice is a slice of 32-bit values.
// That is, v.Type() is []int32, []uint32, []float32, or []enum.
type word32Slice struct {
v reflect.Value
}
func (p word32Slice) Append(x uint32) {
n, m := p.v.Len(), p.v.Cap()
// grow updates the slice s in place to make it one element longer.
// s must be addressable.
// Returns the (addressable) new element.
func grow(s reflect.Value) reflect.Value {
n, m := s.Len(), s.Cap()
if n < m {
p.v.SetLen(n + 1)
s.SetLen(n + 1)
} else {
t := p.v.Type().Elem()
p.v.Set(reflect.Append(p.v, reflect.Zero(t)))
s.Set(reflect.Append(s, reflect.Zero(s.Type().Elem())))
}
elem := p.v.Index(n)
switch elem.Kind() {
case reflect.Int32:
elem.SetInt(int64(int32(x)))
case reflect.Uint32:
elem.SetUint(uint64(x))
case reflect.Float32:
elem.SetFloat(float64(math.Float32frombits(x)))
return s.Index(n)
}
func (p pointer) toInt64() *int64 {
return p.v.Interface().(*int64)
}
func (p pointer) toInt64Ptr() **int64 {
return p.v.Interface().(**int64)
}
func (p pointer) toInt64Slice() *[]int64 {
return p.v.Interface().(*[]int64)
}
var int32ptr = reflect.TypeOf((*int32)(nil))
func (p pointer) toInt32() *int32 {
return p.v.Convert(int32ptr).Interface().(*int32)
}
// The toInt32Ptr/Slice methods don't work because of enums.
// Instead, we must use set/get methods for the int32ptr/slice case.
/*
func (p pointer) toInt32Ptr() **int32 {
return p.v.Interface().(**int32)
}
func (p pointer) toInt32Slice() *[]int32 {
return p.v.Interface().(*[]int32)
}
*/
func (p pointer) getInt32Ptr() *int32 {
if p.v.Type().Elem().Elem() == reflect.TypeOf(int32(0)) {
// raw int32 type
return p.v.Elem().Interface().(*int32)
}
// an enum
return p.v.Elem().Convert(int32PtrType).Interface().(*int32)
}
func (p pointer) setInt32Ptr(v int32) {
// Allocate value in a *int32. Possibly convert that to a *enum.
// Then assign it to a **int32 or **enum.
// Note: we can convert *int32 to *enum, but we can't convert
// **int32 to **enum!
p.v.Elem().Set(reflect.ValueOf(&v).Convert(p.v.Type().Elem()))
}
func (p word32Slice) Len() int {
return p.v.Len()
}
func (p word32Slice) Index(i int) uint32 {
elem := p.v.Index(i)
switch elem.Kind() {
case reflect.Int32:
return uint32(elem.Int())
case reflect.Uint32:
return uint32(elem.Uint())
case reflect.Float32:
return math.Float32bits(float32(elem.Float()))
// getInt32Slice copies []int32 from p as a new slice.
// This behavior differs from the implementation in pointer_unsafe.go.
func (p pointer) getInt32Slice() []int32 {
if p.v.Type().Elem().Elem() == reflect.TypeOf(int32(0)) {
// raw int32 type
return p.v.Elem().Interface().([]int32)
}
panic("unreachable")
// an enum
// Allocate a []int32, then assign []enum's values into it.
// Note: we can't convert []enum to []int32.
slice := p.v.Elem()
s := make([]int32, slice.Len())
for i := 0; i < slice.Len(); i++ {
s[i] = int32(slice.Index(i).Int())
}
return s
}
// Word32Slice returns a reference to a []int32, []uint32, []float32, or []enum field in the struct.
func structPointer_Word32Slice(p structPointer, f field) word32Slice {
return word32Slice{structPointer_field(p, f)}
}
// word64 is like word32 but for 64-bit values.
type word64 struct {
v reflect.Value
}
func word64_Set(p word64, o *Buffer, x uint64) {
t := p.v.Type().Elem()
switch t {
case int64Type:
if len(o.int64s) == 0 {
o.int64s = make([]int64, uint64PoolSize)
}
o.int64s[0] = int64(x)
p.v.Set(reflect.ValueOf(&o.int64s[0]))
o.int64s = o.int64s[1:]
return
case uint64Type:
if len(o.uint64s) == 0 {
o.uint64s = make([]uint64, uint64PoolSize)
}
o.uint64s[0] = x
p.v.Set(reflect.ValueOf(&o.uint64s[0]))
o.uint64s = o.uint64s[1:]
return
case float64Type:
if len(o.float64s) == 0 {
o.float64s = make([]float64, uint64PoolSize)
}
o.float64s[0] = math.Float64frombits(x)
p.v.Set(reflect.ValueOf(&o.float64s[0]))
o.float64s = o.float64s[1:]
// setInt32Slice copies []int32 into p as a new slice.
// This behavior differs from the implementation in pointer_unsafe.go.
func (p pointer) setInt32Slice(v []int32) {
if p.v.Type().Elem().Elem() == reflect.TypeOf(int32(0)) {
// raw int32 type
p.v.Elem().Set(reflect.ValueOf(v))
return
}
panic("unreachable")
}
func word64_IsNil(p word64) bool {
return p.v.IsNil()
}
func word64_Get(p word64) uint64 {
elem := p.v.Elem()
switch elem.Kind() {
case reflect.Int64:
return uint64(elem.Int())
case reflect.Uint64:
return elem.Uint()
case reflect.Float64:
return math.Float64bits(elem.Float())
// an enum
// Allocate a []enum, then assign []int32's values into it.
// Note: we can't convert []enum to []int32.
slice := reflect.MakeSlice(p.v.Type().Elem(), len(v), cap(v))
for i, x := range v {
slice.Index(i).SetInt(int64(x))
}
panic("unreachable")
p.v.Elem().Set(slice)
}
func (p pointer) appendInt32Slice(v int32) {
grow(p.v.Elem()).SetInt(int64(v))
}
func structPointer_Word64(p structPointer, f field) word64 {
return word64{structPointer_field(p, f)}
func (p pointer) toUint64() *uint64 {
return p.v.Interface().(*uint64)
}
func (p pointer) toUint64Ptr() **uint64 {
return p.v.Interface().(**uint64)
}
func (p pointer) toUint64Slice() *[]uint64 {
return p.v.Interface().(*[]uint64)
}
func (p pointer) toUint32() *uint32 {
return p.v.Interface().(*uint32)
}
func (p pointer) toUint32Ptr() **uint32 {
return p.v.Interface().(**uint32)
}
func (p pointer) toUint32Slice() *[]uint32 {
return p.v.Interface().(*[]uint32)
}
func (p pointer) toBool() *bool {
return p.v.Interface().(*bool)
}
func (p pointer) toBoolPtr() **bool {
return p.v.Interface().(**bool)
}
func (p pointer) toBoolSlice() *[]bool {
return p.v.Interface().(*[]bool)
}
func (p pointer) toFloat64() *float64 {
return p.v.Interface().(*float64)
}
func (p pointer) toFloat64Ptr() **float64 {
return p.v.Interface().(**float64)
}
func (p pointer) toFloat64Slice() *[]float64 {
return p.v.Interface().(*[]float64)
}
func (p pointer) toFloat32() *float32 {
return p.v.Interface().(*float32)
}
func (p pointer) toFloat32Ptr() **float32 {
return p.v.Interface().(**float32)
}
func (p pointer) toFloat32Slice() *[]float32 {
return p.v.Interface().(*[]float32)
}
func (p pointer) toString() *string {
return p.v.Interface().(*string)
}
func (p pointer) toStringPtr() **string {
return p.v.Interface().(**string)
}
func (p pointer) toStringSlice() *[]string {
return p.v.Interface().(*[]string)
}
func (p pointer) toBytes() *[]byte {
return p.v.Interface().(*[]byte)
}
func (p pointer) toBytesSlice() *[][]byte {
return p.v.Interface().(*[][]byte)
}
func (p pointer) toExtensions() *XXX_InternalExtensions {
return p.v.Interface().(*XXX_InternalExtensions)
}
func (p pointer) toOldExtensions() *map[int32]Extension {
return p.v.Interface().(*map[int32]Extension)
}
func (p pointer) getPointer() pointer {
return pointer{v: p.v.Elem()}
}
func (p pointer) setPointer(q pointer) {
p.v.Elem().Set(q.v)
}
func (p pointer) appendPointer(q pointer) {
grow(p.v.Elem()).Set(q.v)
}
// word64Val is like word32Val but for 64-bit values.
type word64Val struct {
v reflect.Value
// getPointerSlice copies []*T from p as a new []pointer.
// This behavior differs from the implementation in pointer_unsafe.go.
func (p pointer) getPointerSlice() []pointer {
if p.v.IsNil() {
return nil
}
n := p.v.Elem().Len()
s := make([]pointer, n)
for i := 0; i < n; i++ {
s[i] = pointer{v: p.v.Elem().Index(i)}
}
return s
}
func word64Val_Set(p word64Val, o *Buffer, x uint64) {
switch p.v.Type() {
case int64Type:
p.v.SetInt(int64(x))
return
case uint64Type:
p.v.SetUint(x)
return
case float64Type:
p.v.SetFloat(math.Float64frombits(x))
// setPointerSlice copies []pointer into p as a new []*T.
// This behavior differs from the implementation in pointer_unsafe.go.
func (p pointer) setPointerSlice(v []pointer) {
if v == nil {
p.v.Elem().Set(reflect.New(p.v.Elem().Type()).Elem())
return
}
panic("unreachable")
}
func word64Val_Get(p word64Val) uint64 {
elem := p.v
switch elem.Kind() {
case reflect.Int64:
return uint64(elem.Int())
case reflect.Uint64:
return elem.Uint()
case reflect.Float64:
return math.Float64bits(elem.Float())
s := reflect.MakeSlice(p.v.Elem().Type(), 0, len(v))
for _, p := range v {
s = reflect.Append(s, p.v)
}
panic("unreachable")
p.v.Elem().Set(s)
}
func structPointer_Word64Val(p structPointer, f field) word64Val {
return word64Val{structPointer_field(p, f)}
}
type word64Slice struct {
v reflect.Value
}
func (p word64Slice) Append(x uint64) {
n, m := p.v.Len(), p.v.Cap()
if n < m {
p.v.SetLen(n + 1)
} else {
t := p.v.Type().Elem()
p.v.Set(reflect.Append(p.v, reflect.Zero(t)))
}
elem := p.v.Index(n)
switch elem.Kind() {
case reflect.Int64:
elem.SetInt(int64(int64(x)))
case reflect.Uint64:
elem.SetUint(uint64(x))
case reflect.Float64:
elem.SetFloat(float64(math.Float64frombits(x)))
// getInterfacePointer returns a pointer that points to the
// interface data of the interface pointed by p.
func (p pointer) getInterfacePointer() pointer {
if p.v.Elem().IsNil() {
return pointer{v: p.v.Elem()}
}
return pointer{v: p.v.Elem().Elem().Elem().Field(0).Addr()} // *interface -> interface -> *struct -> struct
}
func (p word64Slice) Len() int {
return p.v.Len()
func (p pointer) asPointerTo(t reflect.Type) reflect.Value {
// TODO: check that p.v.Type().Elem() == t?
return p.v
}
func (p word64Slice) Index(i int) uint64 {
elem := p.v.Index(i)
switch elem.Kind() {
case reflect.Int64:
return uint64(elem.Int())
case reflect.Uint64:
return uint64(elem.Uint())
case reflect.Float64:
return math.Float64bits(float64(elem.Float()))
}
panic("unreachable")
func atomicLoadUnmarshalInfo(p **unmarshalInfo) *unmarshalInfo {
atomicLock.Lock()
defer atomicLock.Unlock()
return *p
}
func atomicStoreUnmarshalInfo(p **unmarshalInfo, v *unmarshalInfo) {
atomicLock.Lock()
defer atomicLock.Unlock()
*p = v
}
func atomicLoadMarshalInfo(p **marshalInfo) *marshalInfo {
atomicLock.Lock()
defer atomicLock.Unlock()
return *p
}
func atomicStoreMarshalInfo(p **marshalInfo, v *marshalInfo) {
atomicLock.Lock()
defer atomicLock.Unlock()
*p = v
}
func atomicLoadMergeInfo(p **mergeInfo) *mergeInfo {
atomicLock.Lock()
defer atomicLock.Unlock()
return *p
}
func atomicStoreMergeInfo(p **mergeInfo, v *mergeInfo) {
atomicLock.Lock()
defer atomicLock.Unlock()
*p = v
}
func atomicLoadDiscardInfo(p **discardInfo) *discardInfo {
atomicLock.Lock()
defer atomicLock.Unlock()
return *p
}
func atomicStoreDiscardInfo(p **discardInfo, v *discardInfo) {
atomicLock.Lock()
defer atomicLock.Unlock()
*p = v
}
func structPointer_Word64Slice(p structPointer, f field) word64Slice {
return word64Slice{structPointer_field(p, f)}
}
var atomicLock sync.Mutex

View File

@@ -29,7 +29,7 @@
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// +build !appengine,!js
// +build !purego,!appengine,!js
// This file contains the implementation of the proto field accesses using package unsafe.
@@ -37,38 +37,13 @@ package proto
import (
"reflect"
"sync/atomic"
"unsafe"
)
// NOTE: These type_Foo functions would more idiomatically be methods,
// but Go does not allow methods on pointer types, and we must preserve
// some pointer type for the garbage collector. We use these
// funcs with clunky names as our poor approximation to methods.
//
// An alternative would be
// type structPointer struct { p unsafe.Pointer }
// but that does not registerize as well.
const unsafeAllowed = true
// A structPointer is a pointer to a struct.
type structPointer unsafe.Pointer
// toStructPointer returns a structPointer equivalent to the given reflect value.
func toStructPointer(v reflect.Value) structPointer {
return structPointer(unsafe.Pointer(v.Pointer()))
}
// IsNil reports whether p is nil.
func structPointer_IsNil(p structPointer) bool {
return p == nil
}
// Interface returns the struct pointer, assumed to have element type t,
// as an interface value.
func structPointer_Interface(p structPointer, t reflect.Type) interface{} {
return reflect.NewAt(t, unsafe.Pointer(p)).Interface()
}
// A field identifies a field in a struct, accessible from a structPointer.
// A field identifies a field in a struct, accessible from a pointer.
// In this implementation, a field is identified by its byte offset from the start of the struct.
type field uintptr
@@ -80,191 +55,254 @@ func toField(f *reflect.StructField) field {
// invalidField is an invalid field identifier.
const invalidField = ^field(0)
// zeroField is a noop when calling pointer.offset.
const zeroField = field(0)
// IsValid reports whether the field identifier is valid.
func (f field) IsValid() bool {
return f != ^field(0)
return f != invalidField
}
// Bytes returns the address of a []byte field in the struct.
func structPointer_Bytes(p structPointer, f field) *[]byte {
return (*[]byte)(unsafe.Pointer(uintptr(p) + uintptr(f)))
// The pointer type below is for the new table-driven encoder/decoder.
// The implementation here uses unsafe.Pointer to create a generic pointer.
// In pointer_reflect.go we use reflect instead of unsafe to implement
// the same (but slower) interface.
type pointer struct {
p unsafe.Pointer
}
// BytesSlice returns the address of a [][]byte field in the struct.
func structPointer_BytesSlice(p structPointer, f field) *[][]byte {
return (*[][]byte)(unsafe.Pointer(uintptr(p) + uintptr(f)))
// size of pointer
var ptrSize = unsafe.Sizeof(uintptr(0))
// toPointer converts an interface of pointer type to a pointer
// that points to the same target.
func toPointer(i *Message) pointer {
// Super-tricky - read pointer out of data word of interface value.
// Saves ~25ns over the equivalent:
// return valToPointer(reflect.ValueOf(*i))
return pointer{p: (*[2]unsafe.Pointer)(unsafe.Pointer(i))[1]}
}
// Bool returns the address of a *bool field in the struct.
func structPointer_Bool(p structPointer, f field) **bool {
return (**bool)(unsafe.Pointer(uintptr(p) + uintptr(f)))
}
// BoolVal returns the address of a bool field in the struct.
func structPointer_BoolVal(p structPointer, f field) *bool {
return (*bool)(unsafe.Pointer(uintptr(p) + uintptr(f)))
}
// BoolSlice returns the address of a []bool field in the struct.
func structPointer_BoolSlice(p structPointer, f field) *[]bool {
return (*[]bool)(unsafe.Pointer(uintptr(p) + uintptr(f)))
}
// String returns the address of a *string field in the struct.
func structPointer_String(p structPointer, f field) **string {
return (**string)(unsafe.Pointer(uintptr(p) + uintptr(f)))
}
// StringVal returns the address of a string field in the struct.
func structPointer_StringVal(p structPointer, f field) *string {
return (*string)(unsafe.Pointer(uintptr(p) + uintptr(f)))
}
// StringSlice returns the address of a []string field in the struct.
func structPointer_StringSlice(p structPointer, f field) *[]string {
return (*[]string)(unsafe.Pointer(uintptr(p) + uintptr(f)))
}
// ExtMap returns the address of an extension map field in the struct.
func structPointer_Extensions(p structPointer, f field) *XXX_InternalExtensions {
return (*XXX_InternalExtensions)(unsafe.Pointer(uintptr(p) + uintptr(f)))
}
func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension {
return (*map[int32]Extension)(unsafe.Pointer(uintptr(p) + uintptr(f)))
}
// NewAt returns the reflect.Value for a pointer to a field in the struct.
func structPointer_NewAt(p structPointer, f field, typ reflect.Type) reflect.Value {
return reflect.NewAt(typ, unsafe.Pointer(uintptr(p)+uintptr(f)))
}
// SetStructPointer writes a *struct field in the struct.
func structPointer_SetStructPointer(p structPointer, f field, q structPointer) {
*(*structPointer)(unsafe.Pointer(uintptr(p) + uintptr(f))) = q
}
// GetStructPointer reads a *struct field in the struct.
func structPointer_GetStructPointer(p structPointer, f field) structPointer {
return *(*structPointer)(unsafe.Pointer(uintptr(p) + uintptr(f)))
}
// StructPointerSlice the address of a []*struct field in the struct.
func structPointer_StructPointerSlice(p structPointer, f field) *structPointerSlice {
return (*structPointerSlice)(unsafe.Pointer(uintptr(p) + uintptr(f)))
}
// A structPointerSlice represents a slice of pointers to structs (themselves submessages or groups).
type structPointerSlice []structPointer
func (v *structPointerSlice) Len() int { return len(*v) }
func (v *structPointerSlice) Index(i int) structPointer { return (*v)[i] }
func (v *structPointerSlice) Append(p structPointer) { *v = append(*v, p) }
// A word32 is the address of a "pointer to 32-bit value" field.
type word32 **uint32
// IsNil reports whether *v is nil.
func word32_IsNil(p word32) bool {
return *p == nil
}
// Set sets *v to point at a newly allocated word set to x.
func word32_Set(p word32, o *Buffer, x uint32) {
if len(o.uint32s) == 0 {
o.uint32s = make([]uint32, uint32PoolSize)
// toAddrPointer converts an interface to a pointer that points to
// the interface data.
func toAddrPointer(i *interface{}, isptr bool) pointer {
// Super-tricky - read or get the address of data word of interface value.
if isptr {
// The interface is of pointer type, thus it is a direct interface.
// The data word is the pointer data itself. We take its address.
return pointer{p: unsafe.Pointer(uintptr(unsafe.Pointer(i)) + ptrSize)}
}
o.uint32s[0] = x
*p = &o.uint32s[0]
o.uint32s = o.uint32s[1:]
// The interface is not of pointer type. The data word is the pointer
// to the data.
return pointer{p: (*[2]unsafe.Pointer)(unsafe.Pointer(i))[1]}
}
// Get gets the value pointed at by *v.
func word32_Get(p word32) uint32 {
return **p
// valToPointer converts v to a pointer. v must be of pointer type.
func valToPointer(v reflect.Value) pointer {
return pointer{p: unsafe.Pointer(v.Pointer())}
}
// Word32 returns the address of a *int32, *uint32, *float32, or *enum field in the struct.
func structPointer_Word32(p structPointer, f field) word32 {
return word32((**uint32)(unsafe.Pointer(uintptr(p) + uintptr(f))))
// offset converts from a pointer to a structure to a pointer to
// one of its fields.
func (p pointer) offset(f field) pointer {
// For safety, we should panic if !f.IsValid, however calling panic causes
// this to no longer be inlineable, which is a serious performance cost.
/*
if !f.IsValid() {
panic("invalid field")
}
*/
return pointer{p: unsafe.Pointer(uintptr(p.p) + uintptr(f))}
}
// A word32Val is the address of a 32-bit value field.
type word32Val *uint32
// Set sets *p to x.
func word32Val_Set(p word32Val, x uint32) {
*p = x
func (p pointer) isNil() bool {
return p.p == nil
}
// Get gets the value pointed at by p.
func word32Val_Get(p word32Val) uint32 {
return *p
func (p pointer) toInt64() *int64 {
return (*int64)(p.p)
}
func (p pointer) toInt64Ptr() **int64 {
return (**int64)(p.p)
}
func (p pointer) toInt64Slice() *[]int64 {
return (*[]int64)(p.p)
}
func (p pointer) toInt32() *int32 {
return (*int32)(p.p)
}
// Word32Val returns the address of a *int32, *uint32, *float32, or *enum field in the struct.
func structPointer_Word32Val(p structPointer, f field) word32Val {
return word32Val((*uint32)(unsafe.Pointer(uintptr(p) + uintptr(f))))
}
// A word32Slice is a slice of 32-bit values.
type word32Slice []uint32
func (v *word32Slice) Append(x uint32) { *v = append(*v, x) }
func (v *word32Slice) Len() int { return len(*v) }
func (v *word32Slice) Index(i int) uint32 { return (*v)[i] }
// Word32Slice returns the address of a []int32, []uint32, []float32, or []enum field in the struct.
func structPointer_Word32Slice(p structPointer, f field) *word32Slice {
return (*word32Slice)(unsafe.Pointer(uintptr(p) + uintptr(f)))
}
// word64 is like word32 but for 64-bit values.
type word64 **uint64
func word64_Set(p word64, o *Buffer, x uint64) {
if len(o.uint64s) == 0 {
o.uint64s = make([]uint64, uint64PoolSize)
// See pointer_reflect.go for why toInt32Ptr/Slice doesn't exist.
/*
func (p pointer) toInt32Ptr() **int32 {
return (**int32)(p.p)
}
o.uint64s[0] = x
*p = &o.uint64s[0]
o.uint64s = o.uint64s[1:]
func (p pointer) toInt32Slice() *[]int32 {
return (*[]int32)(p.p)
}
*/
func (p pointer) getInt32Ptr() *int32 {
return *(**int32)(p.p)
}
func (p pointer) setInt32Ptr(v int32) {
*(**int32)(p.p) = &v
}
func word64_IsNil(p word64) bool {
return *p == nil
// getInt32Slice loads a []int32 from p.
// The value returned is aliased with the original slice.
// This behavior differs from the implementation in pointer_reflect.go.
func (p pointer) getInt32Slice() []int32 {
return *(*[]int32)(p.p)
}
func word64_Get(p word64) uint64 {
return **p
// setInt32Slice stores a []int32 to p.
// The value set is aliased with the input slice.
// This behavior differs from the implementation in pointer_reflect.go.
func (p pointer) setInt32Slice(v []int32) {
*(*[]int32)(p.p) = v
}
func structPointer_Word64(p structPointer, f field) word64 {
return word64((**uint64)(unsafe.Pointer(uintptr(p) + uintptr(f))))
// TODO: Can we get rid of appendInt32Slice and use setInt32Slice instead?
func (p pointer) appendInt32Slice(v int32) {
s := (*[]int32)(p.p)
*s = append(*s, v)
}
// word64Val is like word32Val but for 64-bit values.
type word64Val *uint64
func word64Val_Set(p word64Val, o *Buffer, x uint64) {
*p = x
func (p pointer) toUint64() *uint64 {
return (*uint64)(p.p)
}
func (p pointer) toUint64Ptr() **uint64 {
return (**uint64)(p.p)
}
func (p pointer) toUint64Slice() *[]uint64 {
return (*[]uint64)(p.p)
}
func (p pointer) toUint32() *uint32 {
return (*uint32)(p.p)
}
func (p pointer) toUint32Ptr() **uint32 {
return (**uint32)(p.p)
}
func (p pointer) toUint32Slice() *[]uint32 {
return (*[]uint32)(p.p)
}
func (p pointer) toBool() *bool {
return (*bool)(p.p)
}
func (p pointer) toBoolPtr() **bool {
return (**bool)(p.p)
}
func (p pointer) toBoolSlice() *[]bool {
return (*[]bool)(p.p)
}
func (p pointer) toFloat64() *float64 {
return (*float64)(p.p)
}
func (p pointer) toFloat64Ptr() **float64 {
return (**float64)(p.p)
}
func (p pointer) toFloat64Slice() *[]float64 {
return (*[]float64)(p.p)
}
func (p pointer) toFloat32() *float32 {
return (*float32)(p.p)
}
func (p pointer) toFloat32Ptr() **float32 {
return (**float32)(p.p)
}
func (p pointer) toFloat32Slice() *[]float32 {
return (*[]float32)(p.p)
}
func (p pointer) toString() *string {
return (*string)(p.p)
}
func (p pointer) toStringPtr() **string {
return (**string)(p.p)
}
func (p pointer) toStringSlice() *[]string {
return (*[]string)(p.p)
}
func (p pointer) toBytes() *[]byte {
return (*[]byte)(p.p)
}
func (p pointer) toBytesSlice() *[][]byte {
return (*[][]byte)(p.p)
}
func (p pointer) toExtensions() *XXX_InternalExtensions {
return (*XXX_InternalExtensions)(p.p)
}
func (p pointer) toOldExtensions() *map[int32]Extension {
return (*map[int32]Extension)(p.p)
}
func word64Val_Get(p word64Val) uint64 {
return *p
// getPointerSlice loads []*T from p as a []pointer.
// The value returned is aliased with the original slice.
// This behavior differs from the implementation in pointer_reflect.go.
func (p pointer) getPointerSlice() []pointer {
// Super-tricky - p should point to a []*T where T is a
// message type. We load it as []pointer.
return *(*[]pointer)(p.p)
}
func structPointer_Word64Val(p structPointer, f field) word64Val {
return word64Val((*uint64)(unsafe.Pointer(uintptr(p) + uintptr(f))))
// setPointerSlice stores []pointer into p as a []*T.
// The value set is aliased with the input slice.
// This behavior differs from the implementation in pointer_reflect.go.
func (p pointer) setPointerSlice(v []pointer) {
// Super-tricky - p should point to a []*T where T is a
// message type. We store it as []pointer.
*(*[]pointer)(p.p) = v
}
// word64Slice is like word32Slice but for 64-bit values.
type word64Slice []uint64
func (v *word64Slice) Append(x uint64) { *v = append(*v, x) }
func (v *word64Slice) Len() int { return len(*v) }
func (v *word64Slice) Index(i int) uint64 { return (*v)[i] }
func structPointer_Word64Slice(p structPointer, f field) *word64Slice {
return (*word64Slice)(unsafe.Pointer(uintptr(p) + uintptr(f)))
// getPointer loads the pointer at p and returns it.
func (p pointer) getPointer() pointer {
return pointer{p: *(*unsafe.Pointer)(p.p)}
}
// setPointer stores the pointer q at p.
func (p pointer) setPointer(q pointer) {
*(*unsafe.Pointer)(p.p) = q.p
}
// append q to the slice pointed to by p.
func (p pointer) appendPointer(q pointer) {
s := (*[]unsafe.Pointer)(p.p)
*s = append(*s, q.p)
}
// getInterfacePointer returns a pointer that points to the
// interface data of the interface pointed by p.
func (p pointer) getInterfacePointer() pointer {
// Super-tricky - read pointer out of data word of interface value.
return pointer{p: (*(*[2]unsafe.Pointer)(p.p))[1]}
}
// asPointerTo returns a reflect.Value that is a pointer to an
// object of type t stored at p.
func (p pointer) asPointerTo(t reflect.Type) reflect.Value {
return reflect.NewAt(t, p.p)
}
func atomicLoadUnmarshalInfo(p **unmarshalInfo) *unmarshalInfo {
return (*unmarshalInfo)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(p))))
}
func atomicStoreUnmarshalInfo(p **unmarshalInfo, v *unmarshalInfo) {
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(p)), unsafe.Pointer(v))
}
func atomicLoadMarshalInfo(p **marshalInfo) *marshalInfo {
return (*marshalInfo)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(p))))
}
func atomicStoreMarshalInfo(p **marshalInfo, v *marshalInfo) {
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(p)), unsafe.Pointer(v))
}
func atomicLoadMergeInfo(p **mergeInfo) *mergeInfo {
return (*mergeInfo)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(p))))
}
func atomicStoreMergeInfo(p **mergeInfo, v *mergeInfo) {
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(p)), unsafe.Pointer(v))
}
func atomicLoadDiscardInfo(p **discardInfo) *discardInfo {
return (*discardInfo)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(p))))
}
func atomicStoreDiscardInfo(p **discardInfo, v *discardInfo) {
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(p)), unsafe.Pointer(v))
}

View File

@@ -58,42 +58,6 @@ const (
WireFixed32 = 5
)
const startSize = 10 // initial slice/string sizes
// Encoders are defined in encode.go
// An encoder outputs the full representation of a field, including its
// tag and encoder type.
type encoder func(p *Buffer, prop *Properties, base structPointer) error
// A valueEncoder encodes a single integer in a particular encoding.
type valueEncoder func(o *Buffer, x uint64) error
// Sizers are defined in encode.go
// A sizer returns the encoded size of a field, including its tag and encoder
// type.
type sizer func(prop *Properties, base structPointer) int
// A valueSizer returns the encoded size of a single integer in a particular
// encoding.
type valueSizer func(x uint64) int
// Decoders are defined in decode.go
// A decoder creates a value from its wire representation.
// Unrecognized subelements are saved in unrec.
type decoder func(p *Buffer, prop *Properties, base structPointer) error
// A valueDecoder decodes a single integer in a particular encoding.
type valueDecoder func(o *Buffer) (x uint64, err error)
// A oneofMarshaler does the marshaling for all oneof fields in a message.
type oneofMarshaler func(Message, *Buffer) error
// A oneofUnmarshaler does the unmarshaling for a oneof field in a message.
type oneofUnmarshaler func(Message, int, int, *Buffer) (bool, error)
// A oneofSizer does the sizing for all oneof fields in a message.
type oneofSizer func(Message) int
// tagMap is an optimization over map[int]int for typical protocol buffer
// use-cases. Encoded protocol buffers are often in tag order with small tag
// numbers.
@@ -140,13 +104,6 @@ type StructProperties struct {
decoderTags tagMap // map from proto tag to struct field number
decoderOrigNames map[string]int // map from original name to struct field number
order []int // list of struct field numbers in tag order
unrecField field // field id of the XXX_unrecognized []byte field
extendable bool // is this an extendable proto
oneofMarshaler oneofMarshaler
oneofUnmarshaler oneofUnmarshaler
oneofSizer oneofSizer
stype reflect.Type
// OneofTypes contains information about the oneof fields in this message.
// It is keyed by the original name of a field.
@@ -187,36 +144,19 @@ type Properties struct {
Default string // default value
HasDefault bool // whether an explicit default was provided
def_uint64 uint64
enc encoder
valEnc valueEncoder // set for bool and numeric types only
field field
tagcode []byte // encoding of EncodeVarint((Tag<<3)|WireType)
tagbuf [8]byte
stype reflect.Type // set for struct types only
sprop *StructProperties // set for struct types only
isMarshaler bool
isUnmarshaler bool
stype reflect.Type // set for struct types only
sprop *StructProperties // set for struct types only
mtype reflect.Type // set for map types only
mkeyprop *Properties // set for map types only
mvalprop *Properties // set for map types only
size sizer
valSize valueSizer // set for bool and numeric types only
dec decoder
valDec valueDecoder // set for bool and numeric types only
// If this is a packable field, this will be the decoder for the packed version of the field.
packedDec decoder
}
// String formats the properties in the protobuf struct field tag style.
func (p *Properties) String() string {
s := p.Wire
s = ","
s += ","
s += strconv.Itoa(p.Tag)
if p.Required {
s += ",req"
@@ -262,29 +202,14 @@ func (p *Properties) Parse(s string) {
switch p.Wire {
case "varint":
p.WireType = WireVarint
p.valEnc = (*Buffer).EncodeVarint
p.valDec = (*Buffer).DecodeVarint
p.valSize = sizeVarint
case "fixed32":
p.WireType = WireFixed32
p.valEnc = (*Buffer).EncodeFixed32
p.valDec = (*Buffer).DecodeFixed32
p.valSize = sizeFixed32
case "fixed64":
p.WireType = WireFixed64
p.valEnc = (*Buffer).EncodeFixed64
p.valDec = (*Buffer).DecodeFixed64
p.valSize = sizeFixed64
case "zigzag32":
p.WireType = WireVarint
p.valEnc = (*Buffer).EncodeZigzag32
p.valDec = (*Buffer).DecodeZigzag32
p.valSize = sizeZigzag32
case "zigzag64":
p.WireType = WireVarint
p.valEnc = (*Buffer).EncodeZigzag64
p.valDec = (*Buffer).DecodeZigzag64
p.valSize = sizeZigzag64
case "bytes", "group":
p.WireType = WireBytes
// no numeric converter for non-numeric types
@@ -299,6 +224,7 @@ func (p *Properties) Parse(s string) {
return
}
outer:
for i := 2; i < len(fields); i++ {
f := fields[i]
switch {
@@ -326,229 +252,28 @@ func (p *Properties) Parse(s string) {
if i+1 < len(fields) {
// Commas aren't escaped, and def is always last.
p.Default += "," + strings.Join(fields[i+1:], ",")
break
break outer
}
}
}
}
func logNoSliceEnc(t1, t2 reflect.Type) {
fmt.Fprintf(os.Stderr, "proto: no slice oenc for %T = []%T\n", t1, t2)
}
var protoMessageType = reflect.TypeOf((*Message)(nil)).Elem()
// Initialize the fields for encoding and decoding.
func (p *Properties) setEncAndDec(typ reflect.Type, f *reflect.StructField, lockGetProp bool) {
p.enc = nil
p.dec = nil
p.size = nil
// setFieldProps initializes the field properties for submessages and maps.
func (p *Properties) setFieldProps(typ reflect.Type, f *reflect.StructField, lockGetProp bool) {
switch t1 := typ; t1.Kind() {
default:
fmt.Fprintf(os.Stderr, "proto: no coders for %v\n", t1)
// proto3 scalar types
case reflect.Bool:
p.enc = (*Buffer).enc_proto3_bool
p.dec = (*Buffer).dec_proto3_bool
p.size = size_proto3_bool
case reflect.Int32:
p.enc = (*Buffer).enc_proto3_int32
p.dec = (*Buffer).dec_proto3_int32
p.size = size_proto3_int32
case reflect.Uint32:
p.enc = (*Buffer).enc_proto3_uint32
p.dec = (*Buffer).dec_proto3_int32 // can reuse
p.size = size_proto3_uint32
case reflect.Int64, reflect.Uint64:
p.enc = (*Buffer).enc_proto3_int64
p.dec = (*Buffer).dec_proto3_int64
p.size = size_proto3_int64
case reflect.Float32:
p.enc = (*Buffer).enc_proto3_uint32 // can just treat them as bits
p.dec = (*Buffer).dec_proto3_int32
p.size = size_proto3_uint32
case reflect.Float64:
p.enc = (*Buffer).enc_proto3_int64 // can just treat them as bits
p.dec = (*Buffer).dec_proto3_int64
p.size = size_proto3_int64
case reflect.String:
p.enc = (*Buffer).enc_proto3_string
p.dec = (*Buffer).dec_proto3_string
p.size = size_proto3_string
case reflect.Ptr:
switch t2 := t1.Elem(); t2.Kind() {
default:
fmt.Fprintf(os.Stderr, "proto: no encoder function for %v -> %v\n", t1, t2)
break
case reflect.Bool:
p.enc = (*Buffer).enc_bool
p.dec = (*Buffer).dec_bool
p.size = size_bool
case reflect.Int32:
p.enc = (*Buffer).enc_int32
p.dec = (*Buffer).dec_int32
p.size = size_int32
case reflect.Uint32:
p.enc = (*Buffer).enc_uint32
p.dec = (*Buffer).dec_int32 // can reuse
p.size = size_uint32
case reflect.Int64, reflect.Uint64:
p.enc = (*Buffer).enc_int64
p.dec = (*Buffer).dec_int64
p.size = size_int64
case reflect.Float32:
p.enc = (*Buffer).enc_uint32 // can just treat them as bits
p.dec = (*Buffer).dec_int32
p.size = size_uint32
case reflect.Float64:
p.enc = (*Buffer).enc_int64 // can just treat them as bits
p.dec = (*Buffer).dec_int64
p.size = size_int64
case reflect.String:
p.enc = (*Buffer).enc_string
p.dec = (*Buffer).dec_string
p.size = size_string
case reflect.Struct:
if t1.Elem().Kind() == reflect.Struct {
p.stype = t1.Elem()
p.isMarshaler = isMarshaler(t1)
p.isUnmarshaler = isUnmarshaler(t1)
if p.Wire == "bytes" {
p.enc = (*Buffer).enc_struct_message
p.dec = (*Buffer).dec_struct_message
p.size = size_struct_message
} else {
p.enc = (*Buffer).enc_struct_group
p.dec = (*Buffer).dec_struct_group
p.size = size_struct_group
}
}
case reflect.Slice:
switch t2 := t1.Elem(); t2.Kind() {
default:
logNoSliceEnc(t1, t2)
break
case reflect.Bool:
if p.Packed {
p.enc = (*Buffer).enc_slice_packed_bool
p.size = size_slice_packed_bool
} else {
p.enc = (*Buffer).enc_slice_bool
p.size = size_slice_bool
}
p.dec = (*Buffer).dec_slice_bool
p.packedDec = (*Buffer).dec_slice_packed_bool
case reflect.Int32:
if p.Packed {
p.enc = (*Buffer).enc_slice_packed_int32
p.size = size_slice_packed_int32
} else {
p.enc = (*Buffer).enc_slice_int32
p.size = size_slice_int32
}
p.dec = (*Buffer).dec_slice_int32
p.packedDec = (*Buffer).dec_slice_packed_int32
case reflect.Uint32:
if p.Packed {
p.enc = (*Buffer).enc_slice_packed_uint32
p.size = size_slice_packed_uint32
} else {
p.enc = (*Buffer).enc_slice_uint32
p.size = size_slice_uint32
}
p.dec = (*Buffer).dec_slice_int32
p.packedDec = (*Buffer).dec_slice_packed_int32
case reflect.Int64, reflect.Uint64:
if p.Packed {
p.enc = (*Buffer).enc_slice_packed_int64
p.size = size_slice_packed_int64
} else {
p.enc = (*Buffer).enc_slice_int64
p.size = size_slice_int64
}
p.dec = (*Buffer).dec_slice_int64
p.packedDec = (*Buffer).dec_slice_packed_int64
case reflect.Uint8:
p.dec = (*Buffer).dec_slice_byte
if p.proto3 {
p.enc = (*Buffer).enc_proto3_slice_byte
p.size = size_proto3_slice_byte
} else {
p.enc = (*Buffer).enc_slice_byte
p.size = size_slice_byte
}
case reflect.Float32, reflect.Float64:
switch t2.Bits() {
case 32:
// can just treat them as bits
if p.Packed {
p.enc = (*Buffer).enc_slice_packed_uint32
p.size = size_slice_packed_uint32
} else {
p.enc = (*Buffer).enc_slice_uint32
p.size = size_slice_uint32
}
p.dec = (*Buffer).dec_slice_int32
p.packedDec = (*Buffer).dec_slice_packed_int32
case 64:
// can just treat them as bits
if p.Packed {
p.enc = (*Buffer).enc_slice_packed_int64
p.size = size_slice_packed_int64
} else {
p.enc = (*Buffer).enc_slice_int64
p.size = size_slice_int64
}
p.dec = (*Buffer).dec_slice_int64
p.packedDec = (*Buffer).dec_slice_packed_int64
default:
logNoSliceEnc(t1, t2)
break
}
case reflect.String:
p.enc = (*Buffer).enc_slice_string
p.dec = (*Buffer).dec_slice_string
p.size = size_slice_string
case reflect.Ptr:
switch t3 := t2.Elem(); t3.Kind() {
default:
fmt.Fprintf(os.Stderr, "proto: no ptr oenc for %T -> %T -> %T\n", t1, t2, t3)
break
case reflect.Struct:
p.stype = t2.Elem()
p.isMarshaler = isMarshaler(t2)
p.isUnmarshaler = isUnmarshaler(t2)
if p.Wire == "bytes" {
p.enc = (*Buffer).enc_slice_struct_message
p.dec = (*Buffer).dec_slice_struct_message
p.size = size_slice_struct_message
} else {
p.enc = (*Buffer).enc_slice_struct_group
p.dec = (*Buffer).dec_slice_struct_group
p.size = size_slice_struct_group
}
}
case reflect.Slice:
switch t2.Elem().Kind() {
default:
fmt.Fprintf(os.Stderr, "proto: no slice elem oenc for %T -> %T -> %T\n", t1, t2, t2.Elem())
break
case reflect.Uint8:
p.enc = (*Buffer).enc_slice_slice_byte
p.dec = (*Buffer).dec_slice_slice_byte
p.size = size_slice_slice_byte
}
if t2 := t1.Elem(); t2.Kind() == reflect.Ptr && t2.Elem().Kind() == reflect.Struct {
p.stype = t2.Elem()
}
case reflect.Map:
p.enc = (*Buffer).enc_new_map
p.dec = (*Buffer).dec_new_map
p.size = size_new_map
p.mtype = t1
p.mkeyprop = &Properties{}
p.mkeyprop.init(reflect.PtrTo(p.mtype.Key()), "Key", f.Tag.Get("protobuf_key"), nil, lockGetProp)
@@ -562,20 +287,6 @@ func (p *Properties) setEncAndDec(typ reflect.Type, f *reflect.StructField, lock
p.mvalprop.init(vtype, "Value", f.Tag.Get("protobuf_val"), nil, lockGetProp)
}
// precalculate tag code
wire := p.WireType
if p.Packed {
wire = WireBytes
}
x := uint32(p.Tag)<<3 | uint32(wire)
i := 0
for i = 0; x > 127; i++ {
p.tagbuf[i] = 0x80 | uint8(x&0x7F)
x >>= 7
}
p.tagbuf[i] = uint8(x)
p.tagcode = p.tagbuf[0 : i+1]
if p.stype != nil {
if lockGetProp {
p.sprop = GetProperties(p.stype)
@@ -586,32 +297,9 @@ func (p *Properties) setEncAndDec(typ reflect.Type, f *reflect.StructField, lock
}
var (
marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem()
unmarshalerType = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem()
)
// isMarshaler reports whether type t implements Marshaler.
func isMarshaler(t reflect.Type) bool {
// We're checking for (likely) pointer-receiver methods
// so if t is not a pointer, something is very wrong.
// The calls above only invoke isMarshaler on pointer types.
if t.Kind() != reflect.Ptr {
panic("proto: misuse of isMarshaler")
}
return t.Implements(marshalerType)
}
// isUnmarshaler reports whether type t implements Unmarshaler.
func isUnmarshaler(t reflect.Type) bool {
// We're checking for (likely) pointer-receiver methods
// so if t is not a pointer, something is very wrong.
// The calls above only invoke isUnmarshaler on pointer types.
if t.Kind() != reflect.Ptr {
panic("proto: misuse of isUnmarshaler")
}
return t.Implements(unmarshalerType)
}
// Init populates the properties from a protocol buffer struct tag.
func (p *Properties) Init(typ reflect.Type, name, tag string, f *reflect.StructField) {
p.init(typ, name, tag, f, true)
@@ -621,14 +309,11 @@ func (p *Properties) init(typ reflect.Type, name, tag string, f *reflect.StructF
// "bytes,49,opt,def=hello!"
p.Name = name
p.OrigName = name
if f != nil {
p.field = toField(f)
}
if tag == "" {
return
}
p.Parse(tag)
p.setEncAndDec(typ, f, lockGetProp)
p.setFieldProps(typ, f, lockGetProp)
}
var (
@@ -678,9 +363,6 @@ func getPropertiesLocked(t reflect.Type) *StructProperties {
propertiesMap[t] = prop
// build properties
prop.extendable = reflect.PtrTo(t).Implements(extendableProtoType) ||
reflect.PtrTo(t).Implements(extendableProtoV1Type)
prop.unrecField = invalidField
prop.Prop = make([]*Properties, t.NumField())
prop.order = make([]int, t.NumField())
@@ -690,17 +372,6 @@ func getPropertiesLocked(t reflect.Type) *StructProperties {
name := f.Name
p.init(f.Type, name, f.Tag.Get("protobuf"), &f, false)
if f.Name == "XXX_InternalExtensions" { // special case
p.enc = (*Buffer).enc_exts
p.dec = nil // not needed
p.size = size_exts
} else if f.Name == "XXX_extensions" { // special case
p.enc = (*Buffer).enc_map
p.dec = nil // not needed
p.size = size_map
} else if f.Name == "XXX_unrecognized" { // special case
prop.unrecField = toField(&f)
}
oneof := f.Tag.Get("protobuf_oneof") // special case
if oneof != "" {
// Oneof fields don't use the traditional protobuf tag.
@@ -715,9 +386,6 @@ func getPropertiesLocked(t reflect.Type) *StructProperties {
}
print("\n")
}
if p.enc == nil && !strings.HasPrefix(f.Name, "XXX_") && oneof == "" {
fmt.Fprintln(os.Stderr, "proto: no encoder for", f.Name, f.Type.String(), "[GetProperties]")
}
}
// Re-order prop.order.
@@ -728,8 +396,7 @@ func getPropertiesLocked(t reflect.Type) *StructProperties {
}
if om, ok := reflect.Zero(reflect.PtrTo(t)).Interface().(oneofMessage); ok {
var oots []interface{}
prop.oneofMarshaler, prop.oneofUnmarshaler, prop.oneofSizer, oots = om.XXX_OneofFuncs()
prop.stype = t
_, _, _, oots = om.XXX_OneofFuncs()
// Interpret oneof metadata.
prop.OneofTypes = make(map[string]*OneofProperties)
@@ -779,30 +446,6 @@ func getPropertiesLocked(t reflect.Type) *StructProperties {
return prop
}
// Return the Properties object for the x[0]'th field of the structure.
func propByIndex(t reflect.Type, x []int) *Properties {
if len(x) != 1 {
fmt.Fprintf(os.Stderr, "proto: field index dimension %d (not 1) for type %s\n", len(x), t)
return nil
}
prop := GetProperties(t)
return prop.Prop[x[0]]
}
// Get the address and type of a pointer to a struct from an interface.
func getbase(pb Message) (t reflect.Type, b structPointer, err error) {
if pb == nil {
err = ErrNil
return
}
// get the reflect type of the pointer to the struct.
t = reflect.TypeOf(pb)
// get the address of the struct.
value := reflect.ValueOf(pb)
b = toStructPointer(value)
return
}
// A global registry of enum types.
// The generated code will register the generated maps by calling RegisterEnum.
@@ -826,20 +469,42 @@ func EnumValueMap(enumType string) map[string]int32 {
// A registry of all linked message types.
// The string is a fully-qualified proto name ("pkg.Message").
var (
protoTypes = make(map[string]reflect.Type)
revProtoTypes = make(map[reflect.Type]string)
protoTypedNils = make(map[string]Message) // a map from proto names to typed nil pointers
protoMapTypes = make(map[string]reflect.Type) // a map from proto names to map types
revProtoTypes = make(map[reflect.Type]string)
)
// RegisterType is called from generated code and maps from the fully qualified
// proto name to the type (pointer to struct) of the protocol buffer.
func RegisterType(x Message, name string) {
if _, ok := protoTypes[name]; ok {
if _, ok := protoTypedNils[name]; ok {
// TODO: Some day, make this a panic.
log.Printf("proto: duplicate proto type registered: %s", name)
return
}
t := reflect.TypeOf(x)
protoTypes[name] = t
if v := reflect.ValueOf(x); v.Kind() == reflect.Ptr && v.Pointer() == 0 {
// Generated code always calls RegisterType with nil x.
// This check is just for extra safety.
protoTypedNils[name] = x
} else {
protoTypedNils[name] = reflect.Zero(t).Interface().(Message)
}
revProtoTypes[t] = name
}
// RegisterMapType is called from generated code and maps from the fully qualified
// proto name to the native map type of the proto map definition.
func RegisterMapType(x interface{}, name string) {
if reflect.TypeOf(x).Kind() != reflect.Map {
panic(fmt.Sprintf("RegisterMapType(%T, %q); want map", x, name))
}
if _, ok := protoMapTypes[name]; ok {
log.Printf("proto: duplicate proto type registered: %s", name)
return
}
t := reflect.TypeOf(x)
protoMapTypes[name] = t
revProtoTypes[t] = name
}
@@ -855,7 +520,14 @@ func MessageName(x Message) string {
}
// MessageType returns the message type (pointer to struct) for a named message.
func MessageType(name string) reflect.Type { return protoTypes[name] }
// The type is not guaranteed to implement proto.Message if the name refers to a
// map entry.
func MessageType(name string) reflect.Type {
if t, ok := protoTypedNils[name]; ok {
return reflect.TypeOf(t)
}
return protoMapTypes[name]
}
// A registry of all linked proto files.
var (

2681
vendor/github.com/golang/protobuf/proto/table_marshal.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

654
vendor/github.com/golang/protobuf/proto/table_merge.go generated vendored Normal file
View File

@@ -0,0 +1,654 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2016 The Go Authors. All rights reserved.
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package proto
import (
"fmt"
"reflect"
"strings"
"sync"
"sync/atomic"
)
// Merge merges the src message into dst.
// This assumes that dst and src of the same type and are non-nil.
func (a *InternalMessageInfo) Merge(dst, src Message) {
mi := atomicLoadMergeInfo(&a.merge)
if mi == nil {
mi = getMergeInfo(reflect.TypeOf(dst).Elem())
atomicStoreMergeInfo(&a.merge, mi)
}
mi.merge(toPointer(&dst), toPointer(&src))
}
type mergeInfo struct {
typ reflect.Type
initialized int32 // 0: only typ is valid, 1: everything is valid
lock sync.Mutex
fields []mergeFieldInfo
unrecognized field // Offset of XXX_unrecognized
}
type mergeFieldInfo struct {
field field // Offset of field, guaranteed to be valid
// isPointer reports whether the value in the field is a pointer.
// This is true for the following situations:
// * Pointer to struct
// * Pointer to basic type (proto2 only)
// * Slice (first value in slice header is a pointer)
// * String (first value in string header is a pointer)
isPointer bool
// basicWidth reports the width of the field assuming that it is directly
// embedded in the struct (as is the case for basic types in proto3).
// The possible values are:
// 0: invalid
// 1: bool
// 4: int32, uint32, float32
// 8: int64, uint64, float64
basicWidth int
// Where dst and src are pointers to the types being merged.
merge func(dst, src pointer)
}
var (
mergeInfoMap = map[reflect.Type]*mergeInfo{}
mergeInfoLock sync.Mutex
)
func getMergeInfo(t reflect.Type) *mergeInfo {
mergeInfoLock.Lock()
defer mergeInfoLock.Unlock()
mi := mergeInfoMap[t]
if mi == nil {
mi = &mergeInfo{typ: t}
mergeInfoMap[t] = mi
}
return mi
}
// merge merges src into dst assuming they are both of type *mi.typ.
func (mi *mergeInfo) merge(dst, src pointer) {
if dst.isNil() {
panic("proto: nil destination")
}
if src.isNil() {
return // Nothing to do.
}
if atomic.LoadInt32(&mi.initialized) == 0 {
mi.computeMergeInfo()
}
for _, fi := range mi.fields {
sfp := src.offset(fi.field)
// As an optimization, we can avoid the merge function call cost
// if we know for sure that the source will have no effect
// by checking if it is the zero value.
if unsafeAllowed {
if fi.isPointer && sfp.getPointer().isNil() { // Could be slice or string
continue
}
if fi.basicWidth > 0 {
switch {
case fi.basicWidth == 1 && !*sfp.toBool():
continue
case fi.basicWidth == 4 && *sfp.toUint32() == 0:
continue
case fi.basicWidth == 8 && *sfp.toUint64() == 0:
continue
}
}
}
dfp := dst.offset(fi.field)
fi.merge(dfp, sfp)
}
// TODO: Make this faster?
out := dst.asPointerTo(mi.typ).Elem()
in := src.asPointerTo(mi.typ).Elem()
if emIn, err := extendable(in.Addr().Interface()); err == nil {
emOut, _ := extendable(out.Addr().Interface())
mIn, muIn := emIn.extensionsRead()
if mIn != nil {
mOut := emOut.extensionsWrite()
muIn.Lock()
mergeExtension(mOut, mIn)
muIn.Unlock()
}
}
if mi.unrecognized.IsValid() {
if b := *src.offset(mi.unrecognized).toBytes(); len(b) > 0 {
*dst.offset(mi.unrecognized).toBytes() = append([]byte(nil), b...)
}
}
}
func (mi *mergeInfo) computeMergeInfo() {
mi.lock.Lock()
defer mi.lock.Unlock()
if mi.initialized != 0 {
return
}
t := mi.typ
n := t.NumField()
props := GetProperties(t)
for i := 0; i < n; i++ {
f := t.Field(i)
if strings.HasPrefix(f.Name, "XXX_") {
continue
}
mfi := mergeFieldInfo{field: toField(&f)}
tf := f.Type
// As an optimization, we can avoid the merge function call cost
// if we know for sure that the source will have no effect
// by checking if it is the zero value.
if unsafeAllowed {
switch tf.Kind() {
case reflect.Ptr, reflect.Slice, reflect.String:
// As a special case, we assume slices and strings are pointers
// since we know that the first field in the SliceSlice or
// StringHeader is a data pointer.
mfi.isPointer = true
case reflect.Bool:
mfi.basicWidth = 1
case reflect.Int32, reflect.Uint32, reflect.Float32:
mfi.basicWidth = 4
case reflect.Int64, reflect.Uint64, reflect.Float64:
mfi.basicWidth = 8
}
}
// Unwrap tf to get at its most basic type.
var isPointer, isSlice bool
if tf.Kind() == reflect.Slice && tf.Elem().Kind() != reflect.Uint8 {
isSlice = true
tf = tf.Elem()
}
if tf.Kind() == reflect.Ptr {
isPointer = true
tf = tf.Elem()
}
if isPointer && isSlice && tf.Kind() != reflect.Struct {
panic("both pointer and slice for basic type in " + tf.Name())
}
switch tf.Kind() {
case reflect.Int32:
switch {
case isSlice: // E.g., []int32
mfi.merge = func(dst, src pointer) {
// NOTE: toInt32Slice is not defined (see pointer_reflect.go).
/*
sfsp := src.toInt32Slice()
if *sfsp != nil {
dfsp := dst.toInt32Slice()
*dfsp = append(*dfsp, *sfsp...)
if *dfsp == nil {
*dfsp = []int64{}
}
}
*/
sfs := src.getInt32Slice()
if sfs != nil {
dfs := dst.getInt32Slice()
dfs = append(dfs, sfs...)
if dfs == nil {
dfs = []int32{}
}
dst.setInt32Slice(dfs)
}
}
case isPointer: // E.g., *int32
mfi.merge = func(dst, src pointer) {
// NOTE: toInt32Ptr is not defined (see pointer_reflect.go).
/*
sfpp := src.toInt32Ptr()
if *sfpp != nil {
dfpp := dst.toInt32Ptr()
if *dfpp == nil {
*dfpp = Int32(**sfpp)
} else {
**dfpp = **sfpp
}
}
*/
sfp := src.getInt32Ptr()
if sfp != nil {
dfp := dst.getInt32Ptr()
if dfp == nil {
dst.setInt32Ptr(*sfp)
} else {
*dfp = *sfp
}
}
}
default: // E.g., int32
mfi.merge = func(dst, src pointer) {
if v := *src.toInt32(); v != 0 {
*dst.toInt32() = v
}
}
}
case reflect.Int64:
switch {
case isSlice: // E.g., []int64
mfi.merge = func(dst, src pointer) {
sfsp := src.toInt64Slice()
if *sfsp != nil {
dfsp := dst.toInt64Slice()
*dfsp = append(*dfsp, *sfsp...)
if *dfsp == nil {
*dfsp = []int64{}
}
}
}
case isPointer: // E.g., *int64
mfi.merge = func(dst, src pointer) {
sfpp := src.toInt64Ptr()
if *sfpp != nil {
dfpp := dst.toInt64Ptr()
if *dfpp == nil {
*dfpp = Int64(**sfpp)
} else {
**dfpp = **sfpp
}
}
}
default: // E.g., int64
mfi.merge = func(dst, src pointer) {
if v := *src.toInt64(); v != 0 {
*dst.toInt64() = v
}
}
}
case reflect.Uint32:
switch {
case isSlice: // E.g., []uint32
mfi.merge = func(dst, src pointer) {
sfsp := src.toUint32Slice()
if *sfsp != nil {
dfsp := dst.toUint32Slice()
*dfsp = append(*dfsp, *sfsp...)
if *dfsp == nil {
*dfsp = []uint32{}
}
}
}
case isPointer: // E.g., *uint32
mfi.merge = func(dst, src pointer) {
sfpp := src.toUint32Ptr()
if *sfpp != nil {
dfpp := dst.toUint32Ptr()
if *dfpp == nil {
*dfpp = Uint32(**sfpp)
} else {
**dfpp = **sfpp
}
}
}
default: // E.g., uint32
mfi.merge = func(dst, src pointer) {
if v := *src.toUint32(); v != 0 {
*dst.toUint32() = v
}
}
}
case reflect.Uint64:
switch {
case isSlice: // E.g., []uint64
mfi.merge = func(dst, src pointer) {
sfsp := src.toUint64Slice()
if *sfsp != nil {
dfsp := dst.toUint64Slice()
*dfsp = append(*dfsp, *sfsp...)
if *dfsp == nil {
*dfsp = []uint64{}
}
}
}
case isPointer: // E.g., *uint64
mfi.merge = func(dst, src pointer) {
sfpp := src.toUint64Ptr()
if *sfpp != nil {
dfpp := dst.toUint64Ptr()
if *dfpp == nil {
*dfpp = Uint64(**sfpp)
} else {
**dfpp = **sfpp
}
}
}
default: // E.g., uint64
mfi.merge = func(dst, src pointer) {
if v := *src.toUint64(); v != 0 {
*dst.toUint64() = v
}
}
}
case reflect.Float32:
switch {
case isSlice: // E.g., []float32
mfi.merge = func(dst, src pointer) {
sfsp := src.toFloat32Slice()
if *sfsp != nil {
dfsp := dst.toFloat32Slice()
*dfsp = append(*dfsp, *sfsp...)
if *dfsp == nil {
*dfsp = []float32{}
}
}
}
case isPointer: // E.g., *float32
mfi.merge = func(dst, src pointer) {
sfpp := src.toFloat32Ptr()
if *sfpp != nil {
dfpp := dst.toFloat32Ptr()
if *dfpp == nil {
*dfpp = Float32(**sfpp)
} else {
**dfpp = **sfpp
}
}
}
default: // E.g., float32
mfi.merge = func(dst, src pointer) {
if v := *src.toFloat32(); v != 0 {
*dst.toFloat32() = v
}
}
}
case reflect.Float64:
switch {
case isSlice: // E.g., []float64
mfi.merge = func(dst, src pointer) {
sfsp := src.toFloat64Slice()
if *sfsp != nil {
dfsp := dst.toFloat64Slice()
*dfsp = append(*dfsp, *sfsp...)
if *dfsp == nil {
*dfsp = []float64{}
}
}
}
case isPointer: // E.g., *float64
mfi.merge = func(dst, src pointer) {
sfpp := src.toFloat64Ptr()
if *sfpp != nil {
dfpp := dst.toFloat64Ptr()
if *dfpp == nil {
*dfpp = Float64(**sfpp)
} else {
**dfpp = **sfpp
}
}
}
default: // E.g., float64
mfi.merge = func(dst, src pointer) {
if v := *src.toFloat64(); v != 0 {
*dst.toFloat64() = v
}
}
}
case reflect.Bool:
switch {
case isSlice: // E.g., []bool
mfi.merge = func(dst, src pointer) {
sfsp := src.toBoolSlice()
if *sfsp != nil {
dfsp := dst.toBoolSlice()
*dfsp = append(*dfsp, *sfsp...)
if *dfsp == nil {
*dfsp = []bool{}
}
}
}
case isPointer: // E.g., *bool
mfi.merge = func(dst, src pointer) {
sfpp := src.toBoolPtr()
if *sfpp != nil {
dfpp := dst.toBoolPtr()
if *dfpp == nil {
*dfpp = Bool(**sfpp)
} else {
**dfpp = **sfpp
}
}
}
default: // E.g., bool
mfi.merge = func(dst, src pointer) {
if v := *src.toBool(); v {
*dst.toBool() = v
}
}
}
case reflect.String:
switch {
case isSlice: // E.g., []string
mfi.merge = func(dst, src pointer) {
sfsp := src.toStringSlice()
if *sfsp != nil {
dfsp := dst.toStringSlice()
*dfsp = append(*dfsp, *sfsp...)
if *dfsp == nil {
*dfsp = []string{}
}
}
}
case isPointer: // E.g., *string
mfi.merge = func(dst, src pointer) {
sfpp := src.toStringPtr()
if *sfpp != nil {
dfpp := dst.toStringPtr()
if *dfpp == nil {
*dfpp = String(**sfpp)
} else {
**dfpp = **sfpp
}
}
}
default: // E.g., string
mfi.merge = func(dst, src pointer) {
if v := *src.toString(); v != "" {
*dst.toString() = v
}
}
}
case reflect.Slice:
isProto3 := props.Prop[i].proto3
switch {
case isPointer:
panic("bad pointer in byte slice case in " + tf.Name())
case tf.Elem().Kind() != reflect.Uint8:
panic("bad element kind in byte slice case in " + tf.Name())
case isSlice: // E.g., [][]byte
mfi.merge = func(dst, src pointer) {
sbsp := src.toBytesSlice()
if *sbsp != nil {
dbsp := dst.toBytesSlice()
for _, sb := range *sbsp {
if sb == nil {
*dbsp = append(*dbsp, nil)
} else {
*dbsp = append(*dbsp, append([]byte{}, sb...))
}
}
if *dbsp == nil {
*dbsp = [][]byte{}
}
}
}
default: // E.g., []byte
mfi.merge = func(dst, src pointer) {
sbp := src.toBytes()
if *sbp != nil {
dbp := dst.toBytes()
if !isProto3 || len(*sbp) > 0 {
*dbp = append([]byte{}, *sbp...)
}
}
}
}
case reflect.Struct:
switch {
case !isPointer:
panic(fmt.Sprintf("message field %s without pointer", tf))
case isSlice: // E.g., []*pb.T
mi := getMergeInfo(tf)
mfi.merge = func(dst, src pointer) {
sps := src.getPointerSlice()
if sps != nil {
dps := dst.getPointerSlice()
for _, sp := range sps {
var dp pointer
if !sp.isNil() {
dp = valToPointer(reflect.New(tf))
mi.merge(dp, sp)
}
dps = append(dps, dp)
}
if dps == nil {
dps = []pointer{}
}
dst.setPointerSlice(dps)
}
}
default: // E.g., *pb.T
mi := getMergeInfo(tf)
mfi.merge = func(dst, src pointer) {
sp := src.getPointer()
if !sp.isNil() {
dp := dst.getPointer()
if dp.isNil() {
dp = valToPointer(reflect.New(tf))
dst.setPointer(dp)
}
mi.merge(dp, sp)
}
}
}
case reflect.Map:
switch {
case isPointer || isSlice:
panic("bad pointer or slice in map case in " + tf.Name())
default: // E.g., map[K]V
mfi.merge = func(dst, src pointer) {
sm := src.asPointerTo(tf).Elem()
if sm.Len() == 0 {
return
}
dm := dst.asPointerTo(tf).Elem()
if dm.IsNil() {
dm.Set(reflect.MakeMap(tf))
}
switch tf.Elem().Kind() {
case reflect.Ptr: // Proto struct (e.g., *T)
for _, key := range sm.MapKeys() {
val := sm.MapIndex(key)
val = reflect.ValueOf(Clone(val.Interface().(Message)))
dm.SetMapIndex(key, val)
}
case reflect.Slice: // E.g. Bytes type (e.g., []byte)
for _, key := range sm.MapKeys() {
val := sm.MapIndex(key)
val = reflect.ValueOf(append([]byte{}, val.Bytes()...))
dm.SetMapIndex(key, val)
}
default: // Basic type (e.g., string)
for _, key := range sm.MapKeys() {
val := sm.MapIndex(key)
dm.SetMapIndex(key, val)
}
}
}
}
case reflect.Interface:
// Must be oneof field.
switch {
case isPointer || isSlice:
panic("bad pointer or slice in interface case in " + tf.Name())
default: // E.g., interface{}
// TODO: Make this faster?
mfi.merge = func(dst, src pointer) {
su := src.asPointerTo(tf).Elem()
if !su.IsNil() {
du := dst.asPointerTo(tf).Elem()
typ := su.Elem().Type()
if du.IsNil() || du.Elem().Type() != typ {
du.Set(reflect.New(typ.Elem())) // Initialize interface if empty
}
sv := su.Elem().Elem().Field(0)
if sv.Kind() == reflect.Ptr && sv.IsNil() {
return
}
dv := du.Elem().Elem().Field(0)
if dv.Kind() == reflect.Ptr && dv.IsNil() {
dv.Set(reflect.New(sv.Type().Elem())) // Initialize proto message if empty
}
switch sv.Type().Kind() {
case reflect.Ptr: // Proto struct (e.g., *T)
Merge(dv.Interface().(Message), sv.Interface().(Message))
case reflect.Slice: // E.g. Bytes type (e.g., []byte)
dv.Set(reflect.ValueOf(append([]byte{}, sv.Bytes()...)))
default: // Basic type (e.g., string)
dv.Set(sv)
}
}
}
}
default:
panic(fmt.Sprintf("merger not found for type:%s", tf))
}
mi.fields = append(mi.fields, mfi)
}
mi.unrecognized = invalidField
if f, ok := t.FieldByName("XXX_unrecognized"); ok {
if f.Type != reflect.TypeOf([]byte{}) {
panic("expected XXX_unrecognized to be of type []byte")
}
mi.unrecognized = toField(&f)
}
atomic.StoreInt32(&mi.initialized, 1)
}

File diff suppressed because it is too large Load Diff

View File

@@ -50,7 +50,6 @@ import (
var (
newline = []byte("\n")
spaces = []byte(" ")
gtNewline = []byte(">\n")
endBraceNewline = []byte("}\n")
backslashN = []byte{'\\', 'n'}
backslashR = []byte{'\\', 'r'}
@@ -170,11 +169,6 @@ func writeName(w *textWriter, props *Properties) error {
return nil
}
// raw is the interface satisfied by RawMessage.
type raw interface {
Bytes() []byte
}
func requiresQuotes(u string) bool {
// When type URL contains any characters except [0-9A-Za-z./\-]*, it must be quoted.
for _, ch := range u {
@@ -269,6 +263,10 @@ func (tm *TextMarshaler) writeStruct(w *textWriter, sv reflect.Value) error {
props := sprops.Prop[i]
name := st.Field(i).Name
if name == "XXX_NoUnkeyedLiteral" {
continue
}
if strings.HasPrefix(name, "XXX_") {
// There are two XXX_ fields:
// XXX_unrecognized []byte
@@ -436,12 +434,6 @@ func (tm *TextMarshaler) writeStruct(w *textWriter, sv reflect.Value) error {
return err
}
}
if b, ok := fv.Interface().(raw); ok {
if err := writeRaw(w, b.Bytes()); err != nil {
return err
}
continue
}
// Enums have a String method, so writeAny will work fine.
if err := tm.writeAny(w, fv, props); err != nil {
@@ -455,7 +447,7 @@ func (tm *TextMarshaler) writeStruct(w *textWriter, sv reflect.Value) error {
// Extensions (the XXX_extensions field).
pv := sv.Addr()
if _, ok := extendable(pv.Interface()); ok {
if _, err := extendable(pv.Interface()); err == nil {
if err := tm.writeExtensions(w, pv); err != nil {
return err
}
@@ -464,27 +456,6 @@ func (tm *TextMarshaler) writeStruct(w *textWriter, sv reflect.Value) error {
return nil
}
// writeRaw writes an uninterpreted raw message.
func writeRaw(w *textWriter, b []byte) error {
if err := w.WriteByte('<'); err != nil {
return err
}
if !w.compact {
if err := w.WriteByte('\n'); err != nil {
return err
}
}
w.indent()
if err := writeUnknownStruct(w, b); err != nil {
return err
}
w.unindent()
if err := w.WriteByte('>'); err != nil {
return err
}
return nil
}
// writeAny writes an arbitrary field.
func (tm *TextMarshaler) writeAny(w *textWriter, v reflect.Value, props *Properties) error {
v = reflect.Indirect(v)
@@ -535,6 +506,19 @@ func (tm *TextMarshaler) writeAny(w *textWriter, v reflect.Value, props *Propert
}
}
w.indent()
if v.CanAddr() {
// Calling v.Interface on a struct causes the reflect package to
// copy the entire struct. This is racy with the new Marshaler
// since we atomically update the XXX_sizecache.
//
// Thus, we retrieve a pointer to the struct if possible to avoid
// a race since v.Interface on the pointer doesn't copy the struct.
//
// If v is not addressable, then we are not worried about a race
// since it implies that the binary Marshaler cannot possibly be
// mutating this value.
v = v.Addr()
}
if etm, ok := v.Interface().(encoding.TextMarshaler); ok {
text, err := etm.MarshalText()
if err != nil {
@@ -543,8 +527,13 @@ func (tm *TextMarshaler) writeAny(w *textWriter, v reflect.Value, props *Propert
if _, err = w.Write(text); err != nil {
return err
}
} else if err := tm.writeStruct(w, v); err != nil {
return err
} else {
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
if err := tm.writeStruct(w, v); err != nil {
return err
}
}
w.unindent()
if err := w.WriteByte(ket); err != nil {

View File

@@ -206,7 +206,6 @@ func (p *textParser) advance() {
var (
errBadUTF8 = errors.New("proto: bad UTF-8")
errBadHex = errors.New("proto: bad hexadecimal")
)
func unquoteC(s string, quote rune) (string, error) {
@@ -277,60 +276,47 @@ func unescape(s string) (ch string, tail string, err error) {
return "?", s, nil // trigraph workaround
case '\'', '"', '\\':
return string(r), s, nil
case '0', '1', '2', '3', '4', '5', '6', '7', 'x', 'X':
case '0', '1', '2', '3', '4', '5', '6', '7':
if len(s) < 2 {
return "", "", fmt.Errorf(`\%c requires 2 following digits`, r)
}
base := 8
ss := s[:2]
ss := string(r) + s[:2]
s = s[2:]
if r == 'x' || r == 'X' {
base = 16
} else {
ss = string(r) + ss
}
i, err := strconv.ParseUint(ss, base, 8)
i, err := strconv.ParseUint(ss, 8, 8)
if err != nil {
return "", "", err
return "", "", fmt.Errorf(`\%s contains non-octal digits`, ss)
}
return string([]byte{byte(i)}), s, nil
case 'u', 'U':
n := 4
if r == 'U' {
case 'x', 'X', 'u', 'U':
var n int
switch r {
case 'x', 'X':
n = 2
case 'u':
n = 4
case 'U':
n = 8
}
if len(s) < n {
return "", "", fmt.Errorf(`\%c requires %d digits`, r, n)
}
bs := make([]byte, n/2)
for i := 0; i < n; i += 2 {
a, ok1 := unhex(s[i])
b, ok2 := unhex(s[i+1])
if !ok1 || !ok2 {
return "", "", errBadHex
}
bs[i/2] = a<<4 | b
return "", "", fmt.Errorf(`\%c requires %d following digits`, r, n)
}
ss := s[:n]
s = s[n:]
return string(bs), s, nil
i, err := strconv.ParseUint(ss, 16, 64)
if err != nil {
return "", "", fmt.Errorf(`\%c%s contains non-hexadecimal digits`, r, ss)
}
if r == 'x' || r == 'X' {
return string([]byte{byte(i)}), s, nil
}
if i > utf8.MaxRune {
return "", "", fmt.Errorf(`\%c%s is not a valid Unicode code point`, r, ss)
}
return string(i), s, nil
}
return "", "", fmt.Errorf(`unknown escape \%c`, r)
}
// Adapted from src/pkg/strconv/quote.go.
func unhex(b byte) (v byte, ok bool) {
switch {
case '0' <= b && b <= '9':
return b - '0', true
case 'a' <= b && b <= 'f':
return b - 'a' + 10, true
case 'A' <= b && b <= 'F':
return b - 'A' + 10, true
}
return 0, false
}
// Back off the parser by one token. Can only be done between calls to next().
// It makes the next advance() a no-op.
func (p *textParser) back() { p.backed = true }
@@ -728,6 +714,9 @@ func (p *textParser) consumeExtName() (string, error) {
if tok.err != nil {
return "", p.errorf("unrecognized type_url or extension name: %s", tok.err)
}
if p.done && tok.value != "]" {
return "", p.errorf("unclosed type_url or extension name")
}
}
return strings.Join(parts, ""), nil
}
@@ -865,7 +854,7 @@ func (p *textParser) readAny(v reflect.Value, props *Properties) error {
return p.readStruct(fv, terminator)
case reflect.Uint32:
if x, err := strconv.ParseUint(tok.value, 0, 32); err == nil {
fv.SetUint(x)
fv.SetUint(uint64(x))
return nil
}
case reflect.Uint64:
@@ -883,13 +872,9 @@ func (p *textParser) readAny(v reflect.Value, props *Properties) error {
// UnmarshalText returns *RequiredNotSetError.
func UnmarshalText(s string, pb Message) error {
if um, ok := pb.(encoding.TextUnmarshaler); ok {
err := um.UnmarshalText([]byte(s))
return err
return um.UnmarshalText([]byte(s))
}
pb.Reset()
v := reflect.ValueOf(pb)
if pe := newTextParser(s).readStruct(v.Elem(), ""); pe != nil {
return pe
}
return nil
return newTextParser(s).readStruct(v.Elem(), "")
}

View File

@@ -7,13 +7,13 @@ matrix:
- go: 1.4
- go: 1.5
- go: 1.6
- go: 1.7
- go: tip
allow_failures:
- go: tip
install:
- go get golang.org/x/tools/cmd/vet
script:
- go get -t -v ./...
- diff -u <(echo -n) <(gofmt -d .)
- go tool vet .
- go vet $(go list ./... | grep -v /vendor/)
- go test -v -race ./...

View File

@@ -4,4 +4,7 @@ context
gorilla/context is a general purpose registry for global request variables.
> Note: gorilla/context, having been born well before `context.Context` existed, does not play well
> with the shallow copying of the request that [`http.Request.WithContext`](https://golang.org/pkg/net/http/#Request.WithContext) (added to net/http Go 1.7 onwards) performs. You should either use *just* gorilla/context, or moving forward, the new `http.Request.Context()`.
Read the full documentation here: http://www.gorillatoolkit.org/pkg/context

View File

@@ -5,6 +5,12 @@
/*
Package context stores values shared during a request lifetime.
Note: gorilla/context, having been born well before `context.Context` existed,
does not play well > with the shallow copying of the request that
[`http.Request.WithContext`](https://golang.org/pkg/net/http/#Request.WithContext)
(added to net/http Go 1.7 onwards) performs. You should either use *just*
gorilla/context, or moving forward, the new `http.Request.Context()`.
For example, a router can set variables extracted from the URL and later
application handlers can access those values, or it can be used to store
sessions values to be saved at the end of a request. There are several

View File

@@ -3,11 +3,12 @@ sudo: false
matrix:
include:
- go: 1.5
- go: 1.6
- go: 1.7
- go: 1.8
- go: 1.9
- go: 1.5.x
- go: 1.6.x
- go: 1.7.x
- go: 1.8.x
- go: 1.9.x
- go: 1.10.x
- go: tip
allow_failures:
- go: tip

View File

@@ -1,5 +1,5 @@
gorilla/mux
===
# gorilla/mux
[![GoDoc](https://godoc.org/github.com/gorilla/mux?status.svg)](https://godoc.org/github.com/gorilla/mux)
[![Build Status](https://travis-ci.org/gorilla/mux.svg?branch=master)](https://travis-ci.org/gorilla/mux)
[![Sourcegraph](https://sourcegraph.com/github.com/gorilla/mux/-/badge.svg)](https://sourcegraph.com/github.com/gorilla/mux?badge)
@@ -29,6 +29,7 @@ The name mux stands for "HTTP request multiplexer". Like the standard `http.Serv
* [Walking Routes](#walking-routes)
* [Graceful Shutdown](#graceful-shutdown)
* [Middleware](#middleware)
* [Testing Handlers](#testing-handlers)
* [Full Example](#full-example)
---
@@ -178,70 +179,13 @@ s.HandleFunc("/{key}/", ProductHandler)
// "/products/{key}/details"
s.HandleFunc("/{key}/details", ProductDetailsHandler)
```
### Listing Routes
Routes on a mux can be listed using the Router.Walk method—useful for generating documentation:
```go
package main
import (
"fmt"
"net/http"
"strings"
"github.com/gorilla/mux"
)
func handler(w http.ResponseWriter, r *http.Request) {
return
}
func main() {
r := mux.NewRouter()
r.HandleFunc("/", handler)
r.HandleFunc("/products", handler).Methods("POST")
r.HandleFunc("/articles", handler).Methods("GET")
r.HandleFunc("/articles/{id}", handler).Methods("GET", "PUT")
r.HandleFunc("/authors", handler).Queries("surname", "{surname}")
r.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error {
t, err := route.GetPathTemplate()
if err != nil {
return err
}
qt, err := route.GetQueriesTemplates()
if err != nil {
return err
}
// p will contain regular expression is compatible with regular expression in Perl, Python, and other languages.
// for instance the regular expression for path '/articles/{id}' will be '^/articles/(?P<v0>[^/]+)$'
p, err := route.GetPathRegexp()
if err != nil {
return err
}
// qr will contain a list of regular expressions with the same semantics as GetPathRegexp,
// just applied to the Queries pairs instead, e.g., 'Queries("surname", "{surname}") will return
// {"^surname=(?P<v0>.*)$}. Where each combined query pair will have an entry in the list.
qr, err := route.GetQueriesRegexp()
if err != nil {
return err
}
m, err := route.GetMethods()
if err != nil {
return err
}
fmt.Println(strings.Join(m, ","), strings.Join(qt, ","), strings.Join(qr, ","), t, p)
return nil
})
http.Handle("/", r)
}
```
### Static Files
Note that the path provided to `PathPrefix()` represents a "wildcard": calling
`PathPrefix("/static/").Handler(...)` means that the handler will be passed any
request that matches "/static/*". This makes it easy to serve static files with mux:
request that matches "/static/\*". This makes it easy to serve static files with mux:
```go
func main() {
@@ -348,41 +292,58 @@ The `Walk` function on `mux.Router` can be used to visit all of the routes that
the following prints all of the registered routes:
```go
r := mux.NewRouter()
r.HandleFunc("/", handler)
r.HandleFunc("/products", handler).Methods("POST")
r.HandleFunc("/articles", handler).Methods("GET")
r.HandleFunc("/articles/{id}", handler).Methods("GET", "PUT")
r.HandleFunc("/authors", handler).Queries("surname", "{surname}")
r.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error {
t, err := route.GetPathTemplate()
if err != nil {
return err
}
qt, err := route.GetQueriesTemplates()
if err != nil {
return err
}
// p will contain a regular expression that is compatible with regular expressions in Perl, Python, and other languages.
// For example, the regular expression for path '/articles/{id}' will be '^/articles/(?P<v0>[^/]+)$'.
p, err := route.GetPathRegexp()
if err != nil {
return err
}
// qr will contain a list of regular expressions with the same semantics as GetPathRegexp,
// just applied to the Queries pairs instead, e.g., 'Queries("surname", "{surname}") will return
// {"^surname=(?P<v0>.*)$}. Where each combined query pair will have an entry in the list.
qr, err := route.GetQueriesRegexp()
if err != nil {
return err
}
m, err := route.GetMethods()
if err != nil {
return err
}
fmt.Println(strings.Join(m, ","), strings.Join(qt, ","), strings.Join(qr, ","), t, p)
return nil
})
package main
import (
"fmt"
"net/http"
"strings"
"github.com/gorilla/mux"
)
func handler(w http.ResponseWriter, r *http.Request) {
return
}
func main() {
r := mux.NewRouter()
r.HandleFunc("/", handler)
r.HandleFunc("/products", handler).Methods("POST")
r.HandleFunc("/articles", handler).Methods("GET")
r.HandleFunc("/articles/{id}", handler).Methods("GET", "PUT")
r.HandleFunc("/authors", handler).Queries("surname", "{surname}")
err := r.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error {
pathTemplate, err := route.GetPathTemplate()
if err == nil {
fmt.Println("ROUTE:", pathTemplate)
}
pathRegexp, err := route.GetPathRegexp()
if err == nil {
fmt.Println("Path regexp:", pathRegexp)
}
queriesTemplates, err := route.GetQueriesTemplates()
if err == nil {
fmt.Println("Queries templates:", strings.Join(queriesTemplates, ","))
}
queriesRegexps, err := route.GetQueriesRegexp()
if err == nil {
fmt.Println("Queries regexps:", strings.Join(queriesRegexps, ","))
}
methods, err := route.GetMethods()
if err == nil {
fmt.Println("Methods:", strings.Join(methods, ","))
}
fmt.Println()
return nil
})
if err != nil {
fmt.Println(err)
}
http.Handle("/", r)
}
```
### Graceful Shutdown
@@ -399,6 +360,7 @@ import (
"net/http"
"os"
"os/signal"
"time"
"github.com/gorilla/mux"
)
@@ -410,7 +372,7 @@ func main() {
r := mux.NewRouter()
// Add your routes as needed
srv := &http.Server{
Addr: "0.0.0.0:8080",
// Good practice to set timeouts to avoid Slowloris attacks.
@@ -426,7 +388,7 @@ func main() {
log.Println(err)
}
}()
c := make(chan os.Signal, 1)
// We'll accept graceful shutdowns when quit via SIGINT (Ctrl+C)
// SIGKILL, SIGQUIT or SIGTERM (Ctrl+/) will not be caught.
@@ -436,7 +398,8 @@ func main() {
<-c
// Create a deadline to wait for.
ctx, cancel := context.WithTimeout(ctx, wait)
ctx, cancel := context.WithTimeout(context.Background(), wait)
defer cancel()
// Doesn't block if no connections, but will otherwise wait
// until the timeout deadline.
srv.Shutdown(ctx)
@@ -464,7 +427,7 @@ Typically, the returned handler is a closure which does something with the http.
A very basic middleware which logs the URI of the request being handled could be written as:
```go
func simpleMw(next http.Handler) http.Handler {
func loggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Do stuff here
log.Println(r.RequestURI)
@@ -474,12 +437,12 @@ func simpleMw(next http.Handler) http.Handler {
}
```
Middlewares can be added to a router using `Router.AddMiddlewareFunc()`:
Middlewares can be added to a router using `Router.Use()`:
```go
r := mux.NewRouter()
r.HandleFunc("/", handler)
r.AddMiddleware(simpleMw)
r.Use(loggingMiddleware)
```
A more complex authentication middleware, which maps session token to users, could be written as:
@@ -502,7 +465,7 @@ func (amw *authenticationMiddleware) Populate() {
func (amw *authenticationMiddleware) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get("X-Session-Token")
if user, found := amw.tokenUsers[token]; found {
// We found the token in our map
log.Printf("Authenticated user %s\n", user)
@@ -510,7 +473,7 @@ func (amw *authenticationMiddleware) Middleware(next http.Handler) http.Handler
next.ServeHTTP(w, r)
} else {
// Write an error and stop the handler chain
http.Error(w, "Forbidden", 403)
http.Error(w, "Forbidden", http.StatusForbidden)
}
})
}
@@ -523,10 +486,136 @@ r.HandleFunc("/", handler)
amw := authenticationMiddleware{}
amw.Populate()
r.AddMiddlewareFunc(amw.Middleware)
r.Use(amw.Middleware)
```
Note: The handler chain will be stopped if your middleware doesn't call `next.ServeHTTP()` with the corresponding parameters. This can be used to abort a request if the middleware writer wants to. Middlewares *should* write to `ResponseWriter` if they *are* going to terminate the request, and they *should not* write to `ResponseWriter` if they *are not* going to terminate it.
Note: The handler chain will be stopped if your middleware doesn't call `next.ServeHTTP()` with the corresponding parameters. This can be used to abort a request if the middleware writer wants to. Middlewares _should_ write to `ResponseWriter` if they _are_ going to terminate the request, and they _should not_ write to `ResponseWriter` if they _are not_ going to terminate it.
### Testing Handlers
Testing handlers in a Go web application is straightforward, and _mux_ doesn't complicate this any further. Given two files: `endpoints.go` and `endpoints_test.go`, here's how we'd test an application using _mux_.
First, our simple HTTP handler:
```go
// endpoints.go
package main
func HealthCheckHandler(w http.ResponseWriter, r *http.Request) {
// A very simple health check.
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json")
// In the future we could report back on the status of our DB, or our cache
// (e.g. Redis) by performing a simple PING, and include them in the response.
io.WriteString(w, `{"alive": true}`)
}
func main() {
r := mux.NewRouter()
r.HandleFunc("/health", HealthCheckHandler)
log.Fatal(http.ListenAndServe("localhost:8080", r))
}
```
Our test code:
```go
// endpoints_test.go
package main
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestHealthCheckHandler(t *testing.T) {
// Create a request to pass to our handler. We don't have any query parameters for now, so we'll
// pass 'nil' as the third parameter.
req, err := http.NewRequest("GET", "/health", nil)
if err != nil {
t.Fatal(err)
}
// We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response.
rr := httptest.NewRecorder()
handler := http.HandlerFunc(HealthCheckHandler)
// Our handlers satisfy http.Handler, so we can call their ServeHTTP method
// directly and pass in our Request and ResponseRecorder.
handler.ServeHTTP(rr, req)
// Check the status code is what we expect.
if status := rr.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}
// Check the response body is what we expect.
expected := `{"alive": true}`
if rr.Body.String() != expected {
t.Errorf("handler returned unexpected body: got %v want %v",
rr.Body.String(), expected)
}
}
```
In the case that our routes have [variables](#examples), we can pass those in the request. We could write
[table-driven tests](https://dave.cheney.net/2013/06/09/writing-table-driven-tests-in-go) to test multiple
possible route variables as needed.
```go
// endpoints.go
func main() {
r := mux.NewRouter()
// A route with a route variable:
r.HandleFunc("/metrics/{type}", MetricsHandler)
log.Fatal(http.ListenAndServe("localhost:8080", r))
}
```
Our test file, with a table-driven test of `routeVariables`:
```go
// endpoints_test.go
func TestMetricsHandler(t *testing.T) {
tt := []struct{
routeVariable string
shouldPass bool
}{
{"goroutines", true},
{"heap", true},
{"counters", true},
{"queries", true},
{"adhadaeqm3k", false},
}
for _, tc := range tt {
path := fmt.Sprintf("/metrics/%s", tc.routeVariable)
req, err := http.NewRequest("GET", path, nil)
if err != nil {
t.Fatal(err)
}
rr := httptest.NewRecorder()
// Need to create a router that we can pass the request through so that the vars will be added to the context
router := mux.NewRouter()
router.HandleFunc("/metrics/{type}", MetricsHandler)
router.ServeHTTP(rr, req)
// In this case, our MetricsHandler returns a non-200 response
// for a route variable it doesn't know about.
if rr.Code == http.StatusOK && !tc.shouldPass {
t.Errorf("handler should have failed on routeVariable %s: got %v want %v",
tc.routeVariable, rr.Code, http.StatusOK)
}
}
}
```
## Full Example

View File

@@ -239,8 +239,7 @@ as well:
"category", "technology",
"id", "42")
Since **vX.Y.Z**, mux supports the addition of middlewares to a [Router](https://godoc.org/github.com/gorilla/mux#Router), which are executed if a
match is found (including subrouters). Middlewares are defined using the de facto standard type:
Mux supports the addition of middlewares to a Router, which are executed in the order they are added if a match is found, including its subrouters. Middlewares are (typically) small pieces of code which take one request, do something with it, and pass it down to another middleware or the final handler. Some common use cases for middleware are request logging, header manipulation, or ResponseWriter hijacking.
type MiddlewareFunc func(http.Handler) http.Handler
@@ -261,7 +260,7 @@ Middlewares can be added to a router using `Router.Use()`:
r := mux.NewRouter()
r.HandleFunc("/", handler)
r.AddMiddleware(simpleMw)
r.Use(simpleMw)
A more complex authentication middleware, which maps session token to users, could be written as:
@@ -288,7 +287,7 @@ A more complex authentication middleware, which maps session token to users, cou
log.Printf("Authenticated user %s\n", user)
next.ServeHTTP(w, r)
} else {
http.Error(w, "Forbidden", 403)
http.Error(w, "Forbidden", http.StatusForbidden)
}
})
}

View File

@@ -1,6 +1,9 @@
package mux
import "net/http"
import (
"net/http"
"strings"
)
// MiddlewareFunc is a function which receives an http.Handler and returns another http.Handler.
// Typically, the returned handler is a closure which does something with the http.ResponseWriter and http.Request passed
@@ -12,17 +15,58 @@ type middleware interface {
Middleware(handler http.Handler) http.Handler
}
// MiddlewareFunc also implements the middleware interface.
// Middleware allows MiddlewareFunc to implement the middleware interface.
func (mw MiddlewareFunc) Middleware(handler http.Handler) http.Handler {
return mw(handler)
}
// Use appends a MiddlewareFunc to the chain. Middleware can be used to intercept or otherwise modify requests and/or responses, and are executed in the order that they are applied to the Router.
func (r *Router) Use(mwf MiddlewareFunc) {
r.middlewares = append(r.middlewares, mwf)
func (r *Router) Use(mwf ...MiddlewareFunc) {
for _, fn := range mwf {
r.middlewares = append(r.middlewares, fn)
}
}
// useInterface appends a middleware to the chain. Middleware can be used to intercept or otherwise modify requests and/or responses, and are executed in the order that they are applied to the Router.
func (r *Router) useInterface(mw middleware) {
r.middlewares = append(r.middlewares, mw)
}
// CORSMethodMiddleware sets the Access-Control-Allow-Methods response header
// on a request, by matching routes based only on paths. It also handles
// OPTIONS requests, by settings Access-Control-Allow-Methods, and then
// returning without calling the next http handler.
func CORSMethodMiddleware(r *Router) MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
var allMethods []string
err := r.Walk(func(route *Route, _ *Router, _ []*Route) error {
for _, m := range route.matchers {
if _, ok := m.(*routeRegexp); ok {
if m.Match(req, &RouteMatch{}) {
methods, err := route.GetMethods()
if err != nil {
return err
}
allMethods = append(allMethods, methods...)
}
break
}
}
return nil
})
if err == nil {
w.Header().Set("Access-Control-Allow-Methods", strings.Join(append(allMethods, "OPTIONS"), ","))
if req.Method == "OPTIONS" {
return
}
}
next.ServeHTTP(w, req)
})
}
}

View File

@@ -13,8 +13,11 @@ import (
)
var (
// ErrMethodMismatch is returned when the method in the request does not match
// the method defined against the route.
ErrMethodMismatch = errors.New("method is not allowed")
ErrNotFound = errors.New("no matching route was found")
// ErrNotFound is returned when no route match is found.
ErrNotFound = errors.New("no matching route was found")
)
// NewRouter returns a new router instance.
@@ -95,9 +98,9 @@ func (r *Router) Match(req *http.Request, match *RouteMatch) bool {
if r.MethodNotAllowedHandler != nil {
match.Handler = r.MethodNotAllowedHandler
return true
} else {
return false
}
return false
}
// Closest match for a router (includes sub-routers)

View File

@@ -43,6 +43,8 @@ type Route struct {
buildVarsFunc BuildVarsFunc
}
// SkipClean reports whether path cleaning is enabled for this route via
// Router.SkipClean.
func (r *Route) SkipClean() bool {
return r.skipClean
}
@@ -622,7 +624,7 @@ func (r *Route) GetPathRegexp() (string, error) {
// route queries.
// This is useful for building simple REST API documentation and for instrumentation
// against third-party services.
// An empty list will be returned if the route does not have queries.
// An error will be returned if the route does not have queries.
func (r *Route) GetQueriesRegexp() ([]string, error) {
if r.err != nil {
return nil, r.err
@@ -641,7 +643,7 @@ func (r *Route) GetQueriesRegexp() ([]string, error) {
// query matching.
// This is useful for building simple REST API documentation and for instrumentation
// against third-party services.
// An empty list will be returned if the route does not define queries.
// An error will be returned if the route does not define queries.
func (r *Route) GetQueriesTemplates() ([]string, error) {
if r.err != nil {
return nil, r.err
@@ -659,7 +661,7 @@ func (r *Route) GetQueriesTemplates() ([]string, error) {
// GetMethods returns the methods the route matches against
// This is useful for building simple REST API documentation and for instrumentation
// against third-party services.
// An empty list will be returned if route does not have methods.
// An error will be returned if route does not have methods.
func (r *Route) GetMethods() ([]string, error) {
if r.err != nil {
return nil, r.err
@@ -669,7 +671,7 @@ func (r *Route) GetMethods() ([]string, error) {
return []string(methods), nil
}
}
return nil, nil
return nil, errors.New("mux: route doesn't have methods")
}
// GetHostTemplate returns the template used to build the

View File

@@ -7,7 +7,8 @@ package mux
import "net/http"
// SetURLVars sets the URL variables for the given request, to be accessed via
// mux.Vars for testing route behaviour.
// mux.Vars for testing route behaviour. Arguments are not modified, a shallow
// copy is returned.
//
// This API should only be used for testing purposes; it provides a way to
// inject variables into the request context. Alternatively, URL variables

View File

@@ -201,7 +201,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
}
if d.EnableCompression {
req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover")
req.Header["Sec-WebSocket-Extensions"] = []string{"permessage-deflate; server_no_context_takeover; client_no_context_takeover"}
}
var deadline time.Time

View File

@@ -14,7 +14,7 @@ import (
"strings"
)
type netDialerFunc func(netowrk, addr string) (net.Conn, error)
type netDialerFunc func(network, addr string) (net.Conn, error)
func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
return fn(network, addr)

View File

@@ -103,7 +103,7 @@ func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header
//
// The responseHeader is included in the response to the client's upgrade
// request. Use the responseHeader to specify cookies (Set-Cookie) and the
// application negotiated subprotocol (Sec-Websocket-Protocol).
// application negotiated subprotocol (Sec-WebSocket-Protocol).
//
// If the upgrade fails, then Upgrade replies to the client with an HTTP error
// response.
@@ -127,7 +127,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
}
if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok {
return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-Websocket-Extensions' headers are unsupported")
return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-WebSocket-Extensions' headers are unsupported")
}
checkOrigin := u.CheckOrigin
@@ -140,7 +140,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
challengeKey := r.Header.Get("Sec-Websocket-Key")
if challengeKey == "" {
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: `Sec-Websocket-Key' header is missing or blank")
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: `Sec-WebSocket-Key' header is missing or blank")
}
subprotocol := u.selectSubprotocol(r, responseHeader)
@@ -190,12 +190,12 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
p = append(p, computeAcceptKey(challengeKey)...)
p = append(p, "\r\n"...)
if c.subprotocol != "" {
p = append(p, "Sec-Websocket-Protocol: "...)
p = append(p, "Sec-WebSocket-Protocol: "...)
p = append(p, c.subprotocol...)
p = append(p, "\r\n"...)
}
if compress {
p = append(p, "Sec-Websocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...)
p = append(p, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...)
}
for k, vs := range responseHeader {
if k == "Sec-Websocket-Protocol" {

View File

@@ -33,7 +33,7 @@ The above seems very easy to read, unless your duration looks like this:
package main
import (
"fmt"
"fmt"
"github.com/hako/durafmt"
)
@@ -47,6 +47,28 @@ func main() {
}
```
### durafmt.ParseStringShort()
Version of `durafmt.ParseString()` that only returns the first part of the duration string.
```go
package main
import (
"fmt"
"github.com/hako/durafmt"
)
func main() {
duration, err := durafmt.ParseStringShort("354h22m3.24s")
if err != nil {
fmt.Println(err)
}
fmt.Println(duration) // 2 weeks
// duration.String() // String short representation. "2 weeks"
}
```
### durafmt.Parse()
```go
@@ -65,6 +87,26 @@ func main() {
}
```
### durafmt.ParseShort()
Version of `durafmt.Parse()` that only returns the first part of the duration string.
```go
package main
import (
"fmt"
"time"
"github.com/hako/durafmt"
)
func main() {
timeduration := (354 * time.Hour) + (22 * time.Minute) + (3 * time.Second)
duration := durafmt.ParseShort(timeduration).String()
fmt.Println(duration) // 2 weeks
}
```
# Contributing
Contributions are welcome! Fork this repo and add your changes and submit a PR.

View File

@@ -16,12 +16,19 @@ var (
type Durafmt struct {
duration time.Duration
input string // Used as reference.
short bool
}
// Parse creates a new *Durafmt struct, returns error if input is invalid.
func Parse(dinput time.Duration) *Durafmt {
input := dinput.String()
return &Durafmt{dinput, input}
return &Durafmt{dinput, input, false}
}
// ParseShort creates a new *Durafmt struct, short form, returns error if input is invalid.
func ParseShort(dinput time.Duration) *Durafmt {
input := dinput.String()
return &Durafmt{dinput, input, true}
}
// ParseString creates a new *Durafmt struct from a string.
@@ -34,7 +41,20 @@ func ParseString(input string) (*Durafmt, error) {
if err != nil {
return nil, err
}
return &Durafmt{duration, input}, nil
return &Durafmt{duration, input, false}, nil
}
// ParseStringShort creates a new *Durafmt struct from a string, short form
// returns an error if input is invalid.
func ParseStringShort(input string) (*Durafmt, error) {
if input == "0" || input == "-0" {
return nil, errors.New("durafmt: missing unit in duration " + input)
}
duration, err := time.ParseDuration(input)
if err != nil {
return nil, err
}
return &Durafmt{duration, input, true}, nil
}
// String parses d *Durafmt into a human readable duration.
@@ -116,5 +136,12 @@ func (d *Durafmt) String() string {
}
// trim any remaining spaces.
duration = strings.TrimSpace(duration)
// if more than 2 spaces present return the first 2 strings
// if short version is requested
if d.short {
duration = strings.Join(strings.Split(duration, " ")[:2], " ")
}
return duration
}

View File

@@ -6,4 +6,5 @@ go:
- 1.7.x
- 1.8.x
- 1.9.x
- "1.10.x"
- tip

View File

@@ -1,5 +1,13 @@
## Changelog
### [1.8](https://github.com/magiconair/properties/tree/v1.8) - 15 May 2018
* [PR #26](https://github.com/magiconair/properties/pull/26): Disable expansion during loading
This adds the option to disable property expansion during loading.
Thanks to [@kmala](https://github.com/kmala) for the patch.
### [1.7.6](https://github.com/magiconair/properties/tree/v1.7.6) - 14 Feb 2018
* [PR #29](https://github.com/magiconair/properties/pull/29): Reworked expansion logic to handle more complex cases.

View File

@@ -1,5 +1,6 @@
[![](https://img.shields.io/github/tag/magiconair/properties.svg?style=flat-square&label=release)](https://github.com/magiconair/properties/releases)
[![Build Status](https://img.shields.io/travis/magiconair/properties.svg?branch=master&style=flat-square)](https://travis-ci.org/magiconair/properties)
[![Travis CI Status](https://img.shields.io/travis/magiconair/properties.svg?branch=master&style=flat-square&label=travis)](https://travis-ci.org/magiconair/properties)
[![Codeship CI Status](https://img.shields.io/codeship/16aaf660-f615-0135-b8f0-7e33b70920c0/master.svg?label=codeship&style=flat-square)](https://app.codeship.com/projects/274177")
[![License](https://img.shields.io/badge/License-BSD%202--Clause-orange.svg?style=flat-square)](https://raw.githubusercontent.com/magiconair/properties/master/LICENSE)
[![GoDoc](http://img.shields.io/badge/godoc-reference-5272B4.svg?style=flat-square)](http://godoc.org/github.com/magiconair/properties)

View File

@@ -1,4 +1,4 @@
// Copyright 2017 Frank Schroeder. All rights reserved.
// Copyright 2018 Frank Schroeder. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

View File

@@ -1,4 +1,4 @@
// Copyright 2017 Frank Schroeder. All rights reserved.
// Copyright 2018 Frank Schroeder. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
@@ -73,7 +73,7 @@
// # refers to the users' home dir
// home = ${HOME}
//
// # local key takes precendence over env var: u = foo
// # local key takes precedence over env var: u = foo
// USER = foo
// u = ${USER}
//
@@ -102,7 +102,7 @@
// v = p.GetString("key", "def")
// v = p.GetDuration("key", 999)
//
// As an alterantive properties may be applied with the standard
// As an alternative properties may be applied with the standard
// library's flag implementation at any time.
//
// # Standard configuration

View File

@@ -1,4 +1,4 @@
// Copyright 2017 Frank Schroeder. All rights reserved.
// Copyright 2018 Frank Schroeder. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

View File

@@ -1,4 +1,4 @@
// Copyright 2017 Frank Schroeder. All rights reserved.
// Copyright 2018 Frank Schroeder. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//

View File

@@ -1,4 +1,4 @@
// Copyright 2017 Frank Schroeder. All rights reserved.
// Copyright 2018 Frank Schroeder. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
@@ -16,21 +16,157 @@ import (
type Encoding uint
const (
// utf8Default is a private placeholder for the zero value of Encoding to
// ensure that it has the correct meaning. UTF8 is the default encoding but
// was assigned a non-zero value which cannot be changed without breaking
// existing code. Clients should continue to use the public constants.
utf8Default Encoding = iota
// UTF8 interprets the input data as UTF-8.
UTF8 Encoding = 1 << iota
UTF8
// ISO_8859_1 interprets the input data as ISO-8859-1.
ISO_8859_1
)
type Loader struct {
// Encoding determines how the data from files and byte buffers
// is interpreted. For URLs the Content-Type header is used
// to determine the encoding of the data.
Encoding Encoding
// DisableExpansion configures the property expansion of the
// returned property object. When set to true, the property values
// will not be expanded and the Property object will not be checked
// for invalid expansion expressions.
DisableExpansion bool
// IgnoreMissing configures whether missing files or URLs which return
// 404 are reported as errors. When set to true, missing files and 404
// status codes are not reported as errors.
IgnoreMissing bool
}
// Load reads a buffer into a Properties struct.
func (l *Loader) LoadBytes(buf []byte) (*Properties, error) {
return l.loadBytes(buf, l.Encoding)
}
// LoadAll reads the content of multiple URLs or files in the given order into
// a Properties struct. If IgnoreMissing is true then a 404 status code or
// missing file will not be reported as error. Encoding sets the encoding for
// files. For the URLs see LoadURL for the Content-Type header and the
// encoding.
func (l *Loader) LoadAll(names []string) (*Properties, error) {
all := NewProperties()
for _, name := range names {
n, err := expandName(name)
if err != nil {
return nil, err
}
var p *Properties
switch {
case strings.HasPrefix(n, "http://"):
p, err = l.LoadURL(n)
case strings.HasPrefix(n, "https://"):
p, err = l.LoadURL(n)
default:
p, err = l.LoadFile(n)
}
if err != nil {
return nil, err
}
all.Merge(p)
}
all.DisableExpansion = l.DisableExpansion
if all.DisableExpansion {
return all, nil
}
return all, all.check()
}
// LoadFile reads a file into a Properties struct.
// If IgnoreMissing is true then a missing file will not be
// reported as error.
func (l *Loader) LoadFile(filename string) (*Properties, error) {
data, err := ioutil.ReadFile(filename)
if err != nil {
if l.IgnoreMissing && os.IsNotExist(err) {
LogPrintf("properties: %s not found. skipping", filename)
return NewProperties(), nil
}
return nil, err
}
return l.loadBytes(data, l.Encoding)
}
// LoadURL reads the content of the URL into a Properties struct.
//
// The encoding is determined via the Content-Type header which
// should be set to 'text/plain'. If the 'charset' parameter is
// missing, 'iso-8859-1' or 'latin1' the encoding is set to
// ISO-8859-1. If the 'charset' parameter is set to 'utf-8' the
// encoding is set to UTF-8. A missing content type header is
// interpreted as 'text/plain; charset=utf-8'.
func (l *Loader) LoadURL(url string) (*Properties, error) {
resp, err := http.Get(url)
if err != nil {
return nil, fmt.Errorf("properties: error fetching %q. %s", url, err)
}
if resp.StatusCode == 404 && l.IgnoreMissing {
LogPrintf("properties: %s returned %d. skipping", url, resp.StatusCode)
return NewProperties(), nil
}
if resp.StatusCode != 200 {
return nil, fmt.Errorf("properties: %s returned %d", url, resp.StatusCode)
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("properties: %s error reading response. %s", url, err)
}
defer resp.Body.Close()
ct := resp.Header.Get("Content-Type")
var enc Encoding
switch strings.ToLower(ct) {
case "text/plain", "text/plain; charset=iso-8859-1", "text/plain; charset=latin1":
enc = ISO_8859_1
case "", "text/plain; charset=utf-8":
enc = UTF8
default:
return nil, fmt.Errorf("properties: invalid content type %s", ct)
}
return l.loadBytes(body, enc)
}
func (l *Loader) loadBytes(buf []byte, enc Encoding) (*Properties, error) {
p, err := parse(convert(buf, enc))
if err != nil {
return nil, err
}
p.DisableExpansion = l.DisableExpansion
if p.DisableExpansion {
return p, nil
}
return p, p.check()
}
// Load reads a buffer into a Properties struct.
func Load(buf []byte, enc Encoding) (*Properties, error) {
return loadBuf(buf, enc)
l := &Loader{Encoding: enc}
return l.LoadBytes(buf)
}
// LoadString reads an UTF8 string into a properties struct.
func LoadString(s string) (*Properties, error) {
return loadBuf([]byte(s), UTF8)
l := &Loader{Encoding: UTF8}
return l.LoadBytes([]byte(s))
}
// LoadMap creates a new Properties struct from a string map.
@@ -44,34 +180,32 @@ func LoadMap(m map[string]string) *Properties {
// LoadFile reads a file into a Properties struct.
func LoadFile(filename string, enc Encoding) (*Properties, error) {
return loadAll([]string{filename}, enc, false)
l := &Loader{Encoding: enc}
return l.LoadAll([]string{filename})
}
// LoadFiles reads multiple files in the given order into
// a Properties struct. If 'ignoreMissing' is true then
// non-existent files will not be reported as error.
func LoadFiles(filenames []string, enc Encoding, ignoreMissing bool) (*Properties, error) {
return loadAll(filenames, enc, ignoreMissing)
l := &Loader{Encoding: enc, IgnoreMissing: ignoreMissing}
return l.LoadAll(filenames)
}
// LoadURL reads the content of the URL into a Properties struct.
//
// The encoding is determined via the Content-Type header which
// should be set to 'text/plain'. If the 'charset' parameter is
// missing, 'iso-8859-1' or 'latin1' the encoding is set to
// ISO-8859-1. If the 'charset' parameter is set to 'utf-8' the
// encoding is set to UTF-8. A missing content type header is
// interpreted as 'text/plain; charset=utf-8'.
// See Loader#LoadURL for details.
func LoadURL(url string) (*Properties, error) {
return loadAll([]string{url}, UTF8, false)
l := &Loader{Encoding: UTF8}
return l.LoadAll([]string{url})
}
// LoadURLs reads the content of multiple URLs in the given order into a
// Properties struct. If 'ignoreMissing' is true then a 404 status code will
// not be reported as error. See LoadURL for the Content-Type header
// Properties struct. If IgnoreMissing is true then a 404 status code will
// not be reported as error. See Loader#LoadURL for the Content-Type header
// and the encoding.
func LoadURLs(urls []string, ignoreMissing bool) (*Properties, error) {
return loadAll(urls, UTF8, ignoreMissing)
l := &Loader{Encoding: UTF8, IgnoreMissing: ignoreMissing}
return l.LoadAll(urls)
}
// LoadAll reads the content of multiple URLs or files in the given order into a
@@ -79,7 +213,8 @@ func LoadURLs(urls []string, ignoreMissing bool) (*Properties, error) {
// not be reported as error. Encoding sets the encoding for files. For the URLs please see
// LoadURL for the Content-Type header and the encoding.
func LoadAll(names []string, enc Encoding, ignoreMissing bool) (*Properties, error) {
return loadAll(names, enc, ignoreMissing)
l := &Loader{Encoding: enc, IgnoreMissing: ignoreMissing}
return l.LoadAll(names)
}
// MustLoadString reads an UTF8 string into a Properties struct and
@@ -122,90 +257,6 @@ func MustLoadAll(names []string, enc Encoding, ignoreMissing bool) *Properties {
return must(LoadAll(names, enc, ignoreMissing))
}
func loadBuf(buf []byte, enc Encoding) (*Properties, error) {
p, err := parse(convert(buf, enc))
if err != nil {
return nil, err
}
return p, p.check()
}
func loadAll(names []string, enc Encoding, ignoreMissing bool) (*Properties, error) {
result := NewProperties()
for _, name := range names {
n, err := expandName(name)
if err != nil {
return nil, err
}
var p *Properties
if strings.HasPrefix(n, "http://") || strings.HasPrefix(n, "https://") {
p, err = loadURL(n, ignoreMissing)
} else {
p, err = loadFile(n, enc, ignoreMissing)
}
if err != nil {
return nil, err
}
result.Merge(p)
}
return result, result.check()
}
func loadFile(filename string, enc Encoding, ignoreMissing bool) (*Properties, error) {
data, err := ioutil.ReadFile(filename)
if err != nil {
if ignoreMissing && os.IsNotExist(err) {
LogPrintf("properties: %s not found. skipping", filename)
return NewProperties(), nil
}
return nil, err
}
p, err := parse(convert(data, enc))
if err != nil {
return nil, err
}
return p, nil
}
func loadURL(url string, ignoreMissing bool) (*Properties, error) {
resp, err := http.Get(url)
if err != nil {
return nil, fmt.Errorf("properties: error fetching %q. %s", url, err)
}
if resp.StatusCode == 404 && ignoreMissing {
LogPrintf("properties: %s returned %d. skipping", url, resp.StatusCode)
return NewProperties(), nil
}
if resp.StatusCode != 200 {
return nil, fmt.Errorf("properties: %s returned %d", url, resp.StatusCode)
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("properties: %s error reading response. %s", url, err)
}
if err = resp.Body.Close(); err != nil {
return nil, fmt.Errorf("properties: %s error reading response. %s", url, err)
}
ct := resp.Header.Get("Content-Type")
var enc Encoding
switch strings.ToLower(ct) {
case "text/plain", "text/plain; charset=iso-8859-1", "text/plain; charset=latin1":
enc = ISO_8859_1
case "", "text/plain; charset=utf-8":
enc = UTF8
default:
return nil, fmt.Errorf("properties: invalid content type %s", ct)
}
p, err := parse(convert(body, enc))
if err != nil {
return nil, err
}
return p, nil
}
func must(p *Properties, err error) *Properties {
if err != nil {
ErrorHandler(err)
@@ -226,7 +277,7 @@ func expandName(name string) (string, error) {
// first 256 unicode code points cover ISO-8859-1.
func convert(buf []byte, enc Encoding) string {
switch enc {
case UTF8:
case utf8Default, UTF8:
return string(buf)
case ISO_8859_1:
runes := make([]rune, len(buf))

View File

@@ -1,4 +1,4 @@
// Copyright 2017 Frank Schroeder. All rights reserved.
// Copyright 2018 Frank Schroeder. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

View File

@@ -1,4 +1,4 @@
// Copyright 2017 Frank Schroeder. All rights reserved.
// Copyright 2018 Frank Schroeder. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
@@ -83,6 +83,17 @@ func NewProperties() *Properties {
}
}
// Load reads a buffer into the given Properties struct.
func (p *Properties) Load(buf []byte, enc Encoding) error {
l := &Loader{Encoding: enc, DisableExpansion: p.DisableExpansion}
newProperties, err := l.LoadBytes(buf)
if err != nil {
return err
}
p.Merge(newProperties)
return nil
}
// Get returns the expanded value for the given key if exists.
// Otherwise, ok is false.
func (p *Properties) Get(key string) (value string, ok bool) {

View File

@@ -1,4 +1,4 @@
// Copyright 2017 Frank Schroeder. All rights reserved.
// Copyright 2018 Frank Schroeder. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

View File

@@ -7,8 +7,12 @@ import (
"context"
"crypto/tls"
"encoding/binary"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
"strings"
"time"
)
@@ -16,6 +20,8 @@ import (
const dnsTimeout time.Duration = 2 * time.Second
const tcpIdleTimeout time.Duration = 8 * time.Second
const dohMimeType = "application/dns-udpwireformat"
// A Conn represents a connection to a DNS server.
type Conn struct {
net.Conn // a net.Conn holding the connection
@@ -37,6 +43,7 @@ type Client struct {
DialTimeout time.Duration // net.DialTimeout, defaults to 2 seconds, or net.Dialer.Timeout if expiring earlier - overridden by Timeout when that value is non-zero
ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero
WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero
HTTPClient *http.Client // The http.Client to use for DNS-over-HTTPS
TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2)
SingleInflight bool // if true suppress multiple outstanding queries for the same Qname, Qtype and Qclass
group singleflight
@@ -134,6 +141,11 @@ func (c *Client) Dial(address string) (conn *Conn, err error) {
// attribute appropriately
func (c *Client) Exchange(m *Msg, address string) (r *Msg, rtt time.Duration, err error) {
if !c.SingleInflight {
if c.Net == "https" {
// TODO(tmthrgd): pipe timeouts into exchangeDOH
return c.exchangeDOH(context.TODO(), m, address)
}
return c.exchange(m, address)
}
@@ -146,6 +158,11 @@ func (c *Client) Exchange(m *Msg, address string) (r *Msg, rtt time.Duration, er
cl = cl1
}
r, rtt, err, shared := c.group.Do(m.Question[0].Name+t+cl, func() (*Msg, time.Duration, error) {
if c.Net == "https" {
// TODO(tmthrgd): pipe timeouts into exchangeDOH
return c.exchangeDOH(context.TODO(), m, address)
}
return c.exchange(m, address)
})
if r != nil && shared {
@@ -191,6 +208,77 @@ func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro
return r, rtt, err
}
func (c *Client) exchangeDOH(ctx context.Context, m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
p, err := m.Pack()
if err != nil {
return nil, 0, err
}
// TODO(tmthrgd): Allow the path to be customised?
u := &url.URL{
Scheme: "https",
Host: a,
Path: "/.well-known/dns-query",
}
if u.Port() == "443" {
u.Host = u.Hostname()
}
req, err := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(p))
if err != nil {
return nil, 0, err
}
req.Header.Set("Content-Type", dohMimeType)
req.Header.Set("Accept", dohMimeType)
t := time.Now()
hc := http.DefaultClient
if c.HTTPClient != nil {
hc = c.HTTPClient
}
if ctx != context.Background() && ctx != context.TODO() {
req = req.WithContext(ctx)
}
resp, err := hc.Do(req)
if err != nil {
return nil, 0, err
}
defer closeHTTPBody(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, 0, fmt.Errorf("dns: server returned HTTP %d error: %q", resp.StatusCode, resp.Status)
}
if ct := resp.Header.Get("Content-Type"); ct != dohMimeType {
return nil, 0, fmt.Errorf("dns: unexpected Content-Type %q; expected %q", ct, dohMimeType)
}
p, err = ioutil.ReadAll(resp.Body)
if err != nil {
return nil, 0, err
}
rtt = time.Since(t)
r = new(Msg)
if err := r.Unpack(p); err != nil {
return r, 0, err
}
// TODO: TSIG? Is it even supported over DoH?
return r, rtt, nil
}
func closeHTTPBody(r io.ReadCloser) error {
io.Copy(ioutil.Discard, io.LimitReader(r, 8<<20))
return r.Close()
}
// ReadMsg reads a message from the connection co.
// If the received message contains a TSIG record the transaction signature
// is verified. This method always tries to return the message, however if an
@@ -490,6 +578,10 @@ func DialTimeoutWithTLS(network, address string, tlsConfig *tls.Config, timeout
// context, if present. If there is both a context deadline and a configured
// timeout on the client, the earliest of the two takes effect.
func (c *Client) ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
if !c.SingleInflight && c.Net == "https" {
return c.exchangeDOH(ctx, m, a)
}
var timeout time.Duration
if deadline, ok := ctx.Deadline(); !ok {
timeout = 0
@@ -498,6 +590,7 @@ func (c *Client) ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg,
}
// not passing the context to the underlying calls, as the API does not support
// context. For timeouts you should set up Client.Dialer and call Client.Exchange.
// TODO(tmthrgd): this is a race condition
c.Dialer = &net.Dialer{Timeout: timeout}
return c.Exchange(m, a)
}

View File

@@ -101,7 +101,8 @@ Names:
// compressionLenHelperType - all types that have domain-name/cdomain-name can be used for compressing names
fmt.Fprint(b, "func compressionLenHelperType(c map[string]int, r RR) {\n")
fmt.Fprint(b, "func compressionLenHelperType(c map[string]int, r RR, initLen int) int {\n")
fmt.Fprint(b, "currentLen := initLen\n")
fmt.Fprint(b, "switch x := r.(type) {\n")
for _, name := range domainTypes {
o := scope.Lookup(name)
@@ -109,7 +110,10 @@ Names:
fmt.Fprintf(b, "case *%s:\n", name)
for i := 1; i < st.NumFields(); i++ {
out := func(s string) { fmt.Fprintf(b, "compressionLenHelper(c, x.%s)\n", st.Field(i).Name()) }
out := func(s string) {
fmt.Fprintf(b, "currentLen -= len(x.%s) + 1\n", st.Field(i).Name())
fmt.Fprintf(b, "currentLen += compressionLenHelper(c, x.%s, currentLen)\n", st.Field(i).Name())
}
if _, ok := st.Field(i).Type().(*types.Slice); ok {
switch st.Tag(i) {
@@ -118,8 +122,12 @@ Names:
case `dns:"cdomain-name"`:
// For HIP we need to slice over the elements in this slice.
fmt.Fprintf(b, `for i := range x.%s {
compressionLenHelper(c, x.%s[i])
}
currentLen -= len(x.%s[i]) + 1
}
`, st.Field(i).Name(), st.Field(i).Name())
fmt.Fprintf(b, `for i := range x.%s {
currentLen += compressionLenHelper(c, x.%s[i], currentLen)
}
`, st.Field(i).Name(), st.Field(i).Name())
}
continue
@@ -133,11 +141,11 @@ Names:
}
}
}
fmt.Fprintln(b, "}\n}\n\n")
fmt.Fprintln(b, "}\nreturn currentLen - initLen\n}\n\n")
// compressionLenSearchType - search cdomain-tags types for compressible names.
fmt.Fprint(b, "func compressionLenSearchType(c map[string]int, r RR) (int, bool) {\n")
fmt.Fprint(b, "func compressionLenSearchType(c map[string]int, r RR) (int, bool, int) {\n")
fmt.Fprint(b, "switch x := r.(type) {\n")
for _, name := range cdomainTypes {
o := scope.Lookup(name)
@@ -147,7 +155,7 @@ Names:
j := 1
for i := 1; i < st.NumFields(); i++ {
out := func(s string, j int) {
fmt.Fprintf(b, "k%d, ok%d := compressionLenSearch(c, x.%s)\n", j, j, st.Field(i).Name())
fmt.Fprintf(b, "k%d, ok%d, sz%d := compressionLenSearch(c, x.%s)\n", j, j, j, st.Field(i).Name())
}
// There are no slice types with names that can be compressed.
@@ -160,13 +168,15 @@ Names:
}
k := "k1"
ok := "ok1"
sz := "sz1"
for i := 2; i < j; i++ {
k += fmt.Sprintf(" + k%d", i)
ok += fmt.Sprintf(" && ok%d", i)
sz += fmt.Sprintf(" + sz%d", i)
}
fmt.Fprintf(b, "return %s, %s\n", k, ok)
fmt.Fprintf(b, "return %s, %s, %s\n", k, ok, sz)
}
fmt.Fprintln(b, "}\nreturn 0, false\n}\n\n")
fmt.Fprintln(b, "}\nreturn 0, false, 0\n}\n\n")
// gofmt
res, err := format.Source(b.Bytes())

10
vendor/github.com/miekg/dns/dns.go generated vendored
View File

@@ -55,16 +55,6 @@ func (h *RR_Header) Header() *RR_Header { return h }
// Just to implement the RR interface.
func (h *RR_Header) copy() RR { return nil }
func (h *RR_Header) copyHeader() *RR_Header {
r := new(RR_Header)
r.Name = h.Name
r.Rrtype = h.Rrtype
r.Class = h.Class
r.Ttl = h.Ttl
r.Rdlength = h.Rdlength
return r
}
func (h *RR_Header) String() string {
var s string

View File

@@ -73,6 +73,7 @@ var StringToAlgorithm = reverseInt8(AlgorithmToString)
// AlgorithmToHash is a map of algorithm crypto hash IDs to crypto.Hash's.
var AlgorithmToHash = map[uint8]crypto.Hash{
RSAMD5: crypto.MD5, // Deprecated in RFC 6725
DSA: crypto.SHA1,
RSASHA1: crypto.SHA1,
RSASHA1NSEC3SHA1: crypto.SHA1,
RSASHA256: crypto.SHA256,
@@ -239,7 +240,7 @@ func (k *DNSKEY) ToDS(h uint8) *DS {
// ToCDNSKEY converts a DNSKEY record to a CDNSKEY record.
func (k *DNSKEY) ToCDNSKEY() *CDNSKEY {
c := &CDNSKEY{DNSKEY: *k}
c.Hdr = *k.Hdr.copyHeader()
c.Hdr = k.Hdr
c.Hdr.Rrtype = TypeCDNSKEY
return c
}
@@ -247,7 +248,7 @@ func (k *DNSKEY) ToCDNSKEY() *CDNSKEY {
// ToCDS converts a DS record to a CDS record.
func (d *DS) ToCDS() *CDS {
c := &CDS{DS: *d}
c.Hdr = *d.Hdr.copyHeader()
c.Hdr = d.Hdr
c.Hdr.Rrtype = TypeCDS
return c
}

4
vendor/github.com/miekg/dns/doc.go generated vendored
View File

@@ -73,11 +73,11 @@ and port to use for the connection:
Port: 12345,
Zone: "",
}
d := net.Dialer{
c.Dialer := &net.Dialer{
Timeout: 200 * time.Millisecond,
LocalAddr: &laddr,
}
in, rtt, err := c.ExchangeWithDialer(&d, m1, "8.8.8.8:53")
in, rtt, err := c.Exchange(m1, "8.8.8.8:53")
If these "advanced" features are not needed, a simple UDP query can be sent,
with:

127
vendor/github.com/miekg/dns/msg.go generated vendored
View File

@@ -691,18 +691,20 @@ func (dns *Msg) Pack() (msg []byte, err error) {
return dns.PackBuffer(nil)
}
// PackBuffer packs a Msg, using the given buffer buf. If buf is too small
// a new buffer is allocated.
// PackBuffer packs a Msg, using the given buffer buf. If buf is too small a new buffer is allocated.
func (dns *Msg) PackBuffer(buf []byte) (msg []byte, err error) {
// We use a similar function in tsig.go's stripTsig.
var (
dh Header
compression map[string]int
)
var compression map[string]int
if dns.Compress {
compression = make(map[string]int) // Compression pointer mappings
compression = make(map[string]int) // Compression pointer mappings.
}
return dns.packBufferWithCompressionMap(buf, compression)
}
// packBufferWithCompressionMap packs a Msg, using the given buffer buf.
func (dns *Msg) packBufferWithCompressionMap(buf []byte, compression map[string]int) (msg []byte, err error) {
// We use a similar function in tsig.go's stripTsig.
var dh Header
if dns.Rcode < 0 || dns.Rcode > 0xFFF {
return nil, ErrRcode
@@ -714,12 +716,11 @@ func (dns *Msg) PackBuffer(buf []byte) (msg []byte, err error) {
return nil, ErrExtendedRcode
}
opt.SetExtendedRcode(uint8(dns.Rcode >> 4))
dns.Rcode &= 0xF
}
// Convert convenient Msg into wire-like Header.
dh.Id = dns.Id
dh.Bits = uint16(dns.Opcode)<<11 | uint16(dns.Rcode)
dh.Bits = uint16(dns.Opcode)<<11 | uint16(dns.Rcode&0xF)
if dns.Response {
dh.Bits |= _QR
}
@@ -922,23 +923,27 @@ func (dns *Msg) String() string {
// than packing it, measuring the size and discarding the buffer.
func (dns *Msg) Len() int { return compressedLen(dns, dns.Compress) }
func compressedLenWithCompressionMap(dns *Msg, compression map[string]int) int {
l := 12 // Message header is always 12 bytes
for _, r := range dns.Question {
compressionLenHelper(compression, r.Name, l)
l += r.len()
}
l += compressionLenSlice(l, compression, dns.Answer)
l += compressionLenSlice(l, compression, dns.Ns)
l += compressionLenSlice(l, compression, dns.Extra)
return l
}
// compressedLen returns the message length when in compressed wire format
// when compress is true, otherwise the uncompressed length is returned.
func compressedLen(dns *Msg, compress bool) int {
// We always return one more than needed.
l := 12 // Message header is always 12 bytes
if compress {
compression := map[string]int{}
for _, r := range dns.Question {
l += r.len()
compressionLenHelper(compression, r.Name)
}
l += compressionLenSlice(l, compression, dns.Answer)
l += compressionLenSlice(l, compression, dns.Ns)
l += compressionLenSlice(l, compression, dns.Extra)
return l
return compressedLenWithCompressionMap(dns, compression)
}
l := 12 // Message header is always 12 bytes
for _, r := range dns.Question {
l += r.len()
@@ -962,70 +967,94 @@ func compressedLen(dns *Msg, compress bool) int {
return l
}
func compressionLenSlice(len int, c map[string]int, rs []RR) int {
var l int
func compressionLenSlice(lenp int, c map[string]int, rs []RR) int {
initLen := lenp
for _, r := range rs {
if r == nil {
continue
}
// track this length, and the global length in len, while taking compression into account for both.
// TmpLen is to track len of record at 14bits boudaries
tmpLen := lenp
x := r.len()
l += x
len += x
k, ok := compressionLenSearch(c, r.Header().Name)
// track this length, and the global length in len, while taking compression into account for both.
k, ok, _ := compressionLenSearch(c, r.Header().Name)
if ok {
l += 1 - k
len += 1 - k
// Size of x is reduced by k, but we add 1 since k includes the '.' and label descriptor take 2 bytes
// so, basically x:= x - k - 1 + 2
x += 1 - k
}
if len < maxCompressionOffset {
compressionLenHelper(c, r.Header().Name)
}
k, ok = compressionLenSearchType(c, r)
tmpLen += compressionLenHelper(c, r.Header().Name, tmpLen)
k, ok, _ = compressionLenSearchType(c, r)
if ok {
l += 1 - k
len += 1 - k
x += 1 - k
}
lenp += x
tmpLen = lenp
tmpLen += compressionLenHelperType(c, r, tmpLen)
if len < maxCompressionOffset {
compressionLenHelperType(c, r)
}
}
return l
return lenp - initLen
}
// Put the parts of the name in the compression map.
func compressionLenHelper(c map[string]int, s string) {
// Put the parts of the name in the compression map, return the size in bytes added in payload
func compressionLenHelper(c map[string]int, s string, currentLen int) int {
if currentLen > maxCompressionOffset {
// We won't be able to add any label that could be re-used later anyway
return 0
}
if _, ok := c[s]; ok {
return 0
}
initLen := currentLen
pref := ""
prev := s
lbs := Split(s)
for j := len(lbs) - 1; j >= 0; j-- {
for j := 0; j < len(lbs); j++ {
pref = s[lbs[j]:]
currentLen += len(prev) - len(pref)
prev = pref
if _, ok := c[pref]; !ok {
c[pref] = len(pref)
// If first byte label is within the first 14bits, it might be re-used later
if currentLen < maxCompressionOffset {
c[pref] = currentLen
}
} else {
added := currentLen - initLen
if j > 0 {
// We added a new PTR
added += 2
}
return added
}
}
return currentLen - initLen
}
// Look for each part in the compression map and returns its length,
// keep on searching so we get the longest match.
func compressionLenSearch(c map[string]int, s string) (int, bool) {
// Will return the size of compression found, whether a match has been
// found and the size of record if added in payload
func compressionLenSearch(c map[string]int, s string) (int, bool, int) {
off := 0
end := false
if s == "" { // don't bork on bogus data
return 0, false
return 0, false, 0
}
fullSize := 0
for {
if _, ok := c[s[off:]]; ok {
return len(s[off:]), true
return len(s[off:]), true, fullSize + off
}
if end {
break
}
// Each label descriptor takes 2 bytes, add it
fullSize += 2
off, end = NextLabel(s, off)
}
return 0, false
return 0, false, fullSize + len(s)
}
// Copy returns a new RR which is a deep-copy of r.

View File

@@ -56,8 +56,7 @@ func (r *PrivateRR) len() int { return r.Hdr.len() + r.Data.Len() }
func (r *PrivateRR) copy() RR {
// make new RR like this:
rr := mkPrivateRR(r.Hdr.Rrtype)
newh := r.Hdr.copyHeader()
rr.Hdr = *newh
rr.Hdr = r.Hdr
err := r.Data.Copy(rr.Data)
if err != nil {

View File

@@ -1255,8 +1255,10 @@ func setNSEC3(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
if len(l.token) == 0 || l.err {
return nil, &ParseError{f, "bad NSEC3 Salt", l}, ""
}
rr.SaltLength = uint8(len(l.token)) / 2
rr.Salt = l.token
if l.token != "-" {
rr.SaltLength = uint8(len(l.token)) / 2
rr.Salt = l.token
}
<-c
l = <-c
@@ -1321,8 +1323,10 @@ func setNSEC3PARAM(h RR_Header, c chan lex, o, f string) (RR, *ParseError, strin
rr.Iterations = uint16(i)
<-c
l = <-c
rr.SaltLength = uint8(len(l.token))
rr.Salt = l.token
if l.token != "-" {
rr.SaltLength = uint8(len(l.token))
rr.Salt = l.token
}
return rr, nil, ""
}

199
vendor/github.com/miekg/dns/server.go generated vendored
View File

@@ -9,12 +9,19 @@ import (
"io"
"net"
"sync"
"sync/atomic"
"time"
)
// Maximum number of TCP queries before we close the socket.
// Default maximum number of TCP queries before we close the socket.
const maxTCPQueries = 128
// Interval for stop worker if no load
const idleWorkerTimeout = 10 * time.Second
// Maximum number of workers
const maxWorkersCount = 10000
// Handler is implemented by any value that implements ServeDNS.
type Handler interface {
ServeDNS(w ResponseWriter, r *Msg)
@@ -43,6 +50,7 @@ type ResponseWriter interface {
}
type response struct {
msg []byte
hijacked bool // connection has been hijacked by handler
tsigStatus error
tsigTimersOnly bool
@@ -51,7 +59,6 @@ type response struct {
udp *net.UDPConn // i/o connection if UDP was used
tcp net.Conn // i/o connection if TCP was used
udpSession *SessionUDP // oob data to get egress interface right
remoteAddr net.Addr // address of the client
writer Writer // writer to output the raw DNS bits
}
@@ -296,12 +303,63 @@ type Server struct {
DecorateReader DecorateReader
// DecorateWriter is optional, allows customization of the process that writes raw DNS messages.
DecorateWriter DecorateWriter
// Maximum number of TCP queries before we close the socket. Default is maxTCPQueries (unlimited if -1).
MaxTCPQueries int
// UDP packet or TCP connection queue
queue chan *response
// Workers count
workersCount int32
// Shutdown handling
lock sync.RWMutex
started bool
}
func (srv *Server) worker(w *response) {
srv.serve(w)
for {
count := atomic.LoadInt32(&srv.workersCount)
if count > maxWorkersCount {
return
}
if atomic.CompareAndSwapInt32(&srv.workersCount, count, count+1) {
break
}
}
defer atomic.AddInt32(&srv.workersCount, -1)
inUse := false
timeout := time.NewTimer(idleWorkerTimeout)
defer timeout.Stop()
LOOP:
for {
select {
case w, ok := <-srv.queue:
if !ok {
break LOOP
}
inUse = true
srv.serve(w)
case <-timeout.C:
if !inUse {
break LOOP
}
inUse = false
timeout.Reset(idleWorkerTimeout)
}
}
}
func (srv *Server) spawnWorker(w *response) {
select {
case srv.queue <- w:
default:
go srv.worker(w)
}
}
// ListenAndServe starts a nameserver on the configured address in *Server.
func (srv *Server) ListenAndServe() error {
srv.lock.Lock()
@@ -309,6 +367,7 @@ func (srv *Server) ListenAndServe() error {
if srv.started {
return &Error{err: "server already started"}
}
addr := srv.Addr
if addr == "" {
addr = ":domain"
@@ -316,6 +375,8 @@ func (srv *Server) ListenAndServe() error {
if srv.UDPSize == 0 {
srv.UDPSize = MinMsgSize
}
srv.queue = make(chan *response)
defer close(srv.queue)
switch srv.Net {
case "tcp", "tcp4", "tcp6":
a, err := net.ResolveTCPAddr(srv.Net, addr)
@@ -380,8 +441,11 @@ func (srv *Server) ActivateAndServe() error {
if srv.started {
return &Error{err: "server already started"}
}
pConn := srv.PacketConn
l := srv.Listener
srv.queue = make(chan *response)
defer close(srv.queue)
if pConn != nil {
if srv.UDPSize == 0 {
srv.UDPSize = MinMsgSize
@@ -439,7 +503,6 @@ func (srv *Server) getReadTimeout() time.Duration {
}
// serveTCP starts a TCP listener for the server.
// Each request is handled in a separate goroutine.
func (srv *Server) serveTCP(l net.Listener) error {
defer l.Close()
@@ -447,17 +510,6 @@ func (srv *Server) serveTCP(l net.Listener) error {
srv.NotifyStartedFunc()
}
reader := Reader(&defaultReader{srv})
if srv.DecorateReader != nil {
reader = srv.DecorateReader(reader)
}
handler := srv.Handler
if handler == nil {
handler = DefaultServeMux
}
rtimeout := srv.getReadTimeout()
// deadline is not used here
for {
rw, err := l.Accept()
srv.lock.RLock()
@@ -472,19 +524,11 @@ func (srv *Server) serveTCP(l net.Listener) error {
}
return err
}
go func() {
m, err := reader.ReadTCP(rw, rtimeout)
if err != nil {
rw.Close()
return
}
srv.serve(rw.RemoteAddr(), handler, m, nil, nil, rw)
}()
srv.spawnWorker(&response{tsigSecret: srv.TsigSecret, tcp: rw})
}
}
// serveUDP starts a UDP listener for the server.
// Each request is handled in a separate goroutine.
func (srv *Server) serveUDP(l *net.UDPConn) error {
defer l.Close()
@@ -497,10 +541,6 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
reader = srv.DecorateReader(reader)
}
handler := srv.Handler
if handler == nil {
handler = DefaultServeMux
}
rtimeout := srv.getReadTimeout()
// deadline is not used here
for {
@@ -520,80 +560,98 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
if len(m) < headerSize {
continue
}
go srv.serve(s.RemoteAddr(), handler, m, l, s, nil)
srv.spawnWorker(&response{msg: m, tsigSecret: srv.TsigSecret, udp: l, udpSession: s})
}
}
// Serve a new connection.
func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *SessionUDP, t net.Conn) {
w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s}
func (srv *Server) serve(w *response) {
if srv.DecorateWriter != nil {
w.writer = srv.DecorateWriter(w)
} else {
w.writer = w
}
q := 0 // counter for the amount of TCP queries we get
if w.udp != nil {
// serve UDP
srv.serveDNS(w)
return
}
reader := Reader(&defaultReader{srv})
if srv.DecorateReader != nil {
reader = srv.DecorateReader(reader)
}
Redo:
defer func() {
if !w.hijacked {
w.Close()
}
}()
idleTimeout := tcpIdleTimeout
if srv.IdleTimeout != nil {
idleTimeout = srv.IdleTimeout()
}
timeout := srv.getReadTimeout()
limit := srv.MaxTCPQueries
if limit == 0 {
limit = maxTCPQueries
}
for q := 0; q < limit || limit == -1; q++ {
var err error
w.msg, err = reader.ReadTCP(w.tcp, timeout)
if err != nil {
// TODO(tmthrgd): handle error
break
}
srv.serveDNS(w)
if w.tcp == nil {
break // Close() was called
}
if w.hijacked {
break // client will call Close() themselves
}
// The first read uses the read timeout, the rest use the
// idle timeout.
timeout = idleTimeout
}
}
func (srv *Server) serveDNS(w *response) {
req := new(Msg)
err := req.Unpack(m)
err := req.Unpack(w.msg)
if err != nil { // Send a FormatError back
x := new(Msg)
x.SetRcodeFormatError(req)
w.WriteMsg(x)
goto Exit
return
}
if !srv.Unsafe && req.Response {
goto Exit
return
}
w.tsigStatus = nil
if w.tsigSecret != nil {
if t := req.IsTsig(); t != nil {
secret := t.Hdr.Name
if _, ok := w.tsigSecret[secret]; !ok {
w.tsigStatus = ErrKeyAlg
if secret, ok := w.tsigSecret[t.Hdr.Name]; ok {
w.tsigStatus = TsigVerify(w.msg, secret, "", false)
} else {
w.tsigStatus = ErrSecret
}
w.tsigStatus = TsigVerify(m, w.tsigSecret[secret], "", false)
w.tsigTimersOnly = false
w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC
}
}
h.ServeDNS(w, req) // Writes back to the client
Exit:
if w.tcp == nil {
return
}
// TODO(miek): make this number configurable?
if q > maxTCPQueries { // close socket after this many queries
w.Close()
return
handler := srv.Handler
if handler == nil {
handler = DefaultServeMux
}
if w.hijacked {
return // client calls Close()
}
if u != nil { // UDP, "close" and return
w.Close()
return
}
idleTimeout := tcpIdleTimeout
if srv.IdleTimeout != nil {
idleTimeout = srv.IdleTimeout()
}
m, err = reader.ReadTCP(w.tcp, idleTimeout)
if err == nil {
q++
goto Redo
}
w.Close()
return
handler.ServeDNS(w, req) // Writes back to the client
}
func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
@@ -696,7 +754,12 @@ func (w *response) LocalAddr() net.Addr {
}
// RemoteAddr implements the ResponseWriter.RemoteAddr method.
func (w *response) RemoteAddr() net.Addr { return w.remoteAddr }
func (w *response) RemoteAddr() net.Addr {
if w.tcp != nil {
return w.tcp.RemoteAddr()
}
return w.udpSession.RemoteAddr()
}
// TsigStatus implements the ResponseWriter.TsigStatus method.
func (w *response) TsigStatus() error { return w.tsigStatus }

View File

@@ -226,7 +226,7 @@ func main() {
continue
}
fmt.Fprintf(b, "func (rr *%s) copy() RR {\n", name)
fields := []string{"*rr.Hdr.copyHeader()"}
fields := []string{"rr.Hdr"}
for i := 1; i < st.NumFields(); i++ {
f := st.Field(i).Name()
if sl, ok := st.Field(i).Type().(*types.Slice); ok {

33
vendor/github.com/miekg/dns/udp.go generated vendored
View File

@@ -9,6 +9,22 @@ import (
"golang.org/x/net/ipv6"
)
// This is the required size of the OOB buffer to pass to ReadMsgUDP.
var udpOOBSize = func() int {
// We can't know whether we'll get an IPv4 control message or an
// IPv6 control message ahead of time. To get around this, we size
// the buffer equal to the largest of the two.
oob4 := ipv4.NewControlMessage(ipv4.FlagDst | ipv4.FlagInterface)
oob6 := ipv6.NewControlMessage(ipv6.FlagDst | ipv6.FlagInterface)
if len(oob4) > len(oob6) {
return len(oob4)
}
return len(oob6)
}()
// SessionUDP holds the remote address and the associated
// out-of-band data.
type SessionUDP struct {
@@ -22,7 +38,7 @@ func (s *SessionUDP) RemoteAddr() net.Addr { return s.raddr }
// ReadFromSessionUDP acts just like net.UDPConn.ReadFrom(), but returns a session object instead of a
// net.UDPAddr.
func ReadFromSessionUDP(conn *net.UDPConn, b []byte) (int, *SessionUDP, error) {
oob := make([]byte, 40)
oob := make([]byte, udpOOBSize)
n, oobn, _, raddr, err := conn.ReadMsgUDP(b, oob)
if err != nil {
return n, nil, err
@@ -53,18 +69,15 @@ func parseDstFromOOB(oob []byte) net.IP {
// Start with IPv6 and then fallback to IPv4
// TODO(fastest963): Figure out a way to prefer one or the other. Looking at
// the lvl of the header for a 0 or 41 isn't cross-platform.
var dst net.IP
cm6 := new(ipv6.ControlMessage)
if cm6.Parse(oob) == nil {
dst = cm6.Dst
if cm6.Parse(oob) == nil && cm6.Dst != nil {
return cm6.Dst
}
if dst == nil {
cm4 := new(ipv4.ControlMessage)
if cm4.Parse(oob) == nil {
dst = cm4.Dst
}
cm4 := new(ipv4.ControlMessage)
if cm4.Parse(oob) == nil && cm4.Dst != nil {
return cm4.Dst
}
return dst
return nil
}
// correctSource takes oob data and returns new oob data with the Src equal to the Dst

View File

@@ -3,7 +3,7 @@ package dns
import "fmt"
// Version is current version of this library.
var Version = V{1, 0, 5}
var Version = V{1, 0, 7}
// V holds the version of this library.
type V struct {

View File

@@ -2,117 +2,154 @@
package dns
func compressionLenHelperType(c map[string]int, r RR) {
func compressionLenHelperType(c map[string]int, r RR, initLen int) int {
currentLen := initLen
switch x := r.(type) {
case *AFSDB:
compressionLenHelper(c, x.Hostname)
currentLen -= len(x.Hostname) + 1
currentLen += compressionLenHelper(c, x.Hostname, currentLen)
case *CNAME:
compressionLenHelper(c, x.Target)
currentLen -= len(x.Target) + 1
currentLen += compressionLenHelper(c, x.Target, currentLen)
case *DNAME:
compressionLenHelper(c, x.Target)
currentLen -= len(x.Target) + 1
currentLen += compressionLenHelper(c, x.Target, currentLen)
case *HIP:
for i := range x.RendezvousServers {
compressionLenHelper(c, x.RendezvousServers[i])
currentLen -= len(x.RendezvousServers[i]) + 1
}
for i := range x.RendezvousServers {
currentLen += compressionLenHelper(c, x.RendezvousServers[i], currentLen)
}
case *KX:
compressionLenHelper(c, x.Exchanger)
currentLen -= len(x.Exchanger) + 1
currentLen += compressionLenHelper(c, x.Exchanger, currentLen)
case *LP:
compressionLenHelper(c, x.Fqdn)
currentLen -= len(x.Fqdn) + 1
currentLen += compressionLenHelper(c, x.Fqdn, currentLen)
case *MB:
compressionLenHelper(c, x.Mb)
currentLen -= len(x.Mb) + 1
currentLen += compressionLenHelper(c, x.Mb, currentLen)
case *MD:
compressionLenHelper(c, x.Md)
currentLen -= len(x.Md) + 1
currentLen += compressionLenHelper(c, x.Md, currentLen)
case *MF:
compressionLenHelper(c, x.Mf)
currentLen -= len(x.Mf) + 1
currentLen += compressionLenHelper(c, x.Mf, currentLen)
case *MG:
compressionLenHelper(c, x.Mg)
currentLen -= len(x.Mg) + 1
currentLen += compressionLenHelper(c, x.Mg, currentLen)
case *MINFO:
compressionLenHelper(c, x.Rmail)
compressionLenHelper(c, x.Email)
currentLen -= len(x.Rmail) + 1
currentLen += compressionLenHelper(c, x.Rmail, currentLen)
currentLen -= len(x.Email) + 1
currentLen += compressionLenHelper(c, x.Email, currentLen)
case *MR:
compressionLenHelper(c, x.Mr)
currentLen -= len(x.Mr) + 1
currentLen += compressionLenHelper(c, x.Mr, currentLen)
case *MX:
compressionLenHelper(c, x.Mx)
currentLen -= len(x.Mx) + 1
currentLen += compressionLenHelper(c, x.Mx, currentLen)
case *NAPTR:
compressionLenHelper(c, x.Replacement)
currentLen -= len(x.Replacement) + 1
currentLen += compressionLenHelper(c, x.Replacement, currentLen)
case *NS:
compressionLenHelper(c, x.Ns)
currentLen -= len(x.Ns) + 1
currentLen += compressionLenHelper(c, x.Ns, currentLen)
case *NSAPPTR:
compressionLenHelper(c, x.Ptr)
currentLen -= len(x.Ptr) + 1
currentLen += compressionLenHelper(c, x.Ptr, currentLen)
case *NSEC:
compressionLenHelper(c, x.NextDomain)
currentLen -= len(x.NextDomain) + 1
currentLen += compressionLenHelper(c, x.NextDomain, currentLen)
case *PTR:
compressionLenHelper(c, x.Ptr)
currentLen -= len(x.Ptr) + 1
currentLen += compressionLenHelper(c, x.Ptr, currentLen)
case *PX:
compressionLenHelper(c, x.Map822)
compressionLenHelper(c, x.Mapx400)
currentLen -= len(x.Map822) + 1
currentLen += compressionLenHelper(c, x.Map822, currentLen)
currentLen -= len(x.Mapx400) + 1
currentLen += compressionLenHelper(c, x.Mapx400, currentLen)
case *RP:
compressionLenHelper(c, x.Mbox)
compressionLenHelper(c, x.Txt)
currentLen -= len(x.Mbox) + 1
currentLen += compressionLenHelper(c, x.Mbox, currentLen)
currentLen -= len(x.Txt) + 1
currentLen += compressionLenHelper(c, x.Txt, currentLen)
case *RRSIG:
compressionLenHelper(c, x.SignerName)
currentLen -= len(x.SignerName) + 1
currentLen += compressionLenHelper(c, x.SignerName, currentLen)
case *RT:
compressionLenHelper(c, x.Host)
currentLen -= len(x.Host) + 1
currentLen += compressionLenHelper(c, x.Host, currentLen)
case *SIG:
compressionLenHelper(c, x.SignerName)
currentLen -= len(x.SignerName) + 1
currentLen += compressionLenHelper(c, x.SignerName, currentLen)
case *SOA:
compressionLenHelper(c, x.Ns)
compressionLenHelper(c, x.Mbox)
currentLen -= len(x.Ns) + 1
currentLen += compressionLenHelper(c, x.Ns, currentLen)
currentLen -= len(x.Mbox) + 1
currentLen += compressionLenHelper(c, x.Mbox, currentLen)
case *SRV:
compressionLenHelper(c, x.Target)
currentLen -= len(x.Target) + 1
currentLen += compressionLenHelper(c, x.Target, currentLen)
case *TALINK:
compressionLenHelper(c, x.PreviousName)
compressionLenHelper(c, x.NextName)
currentLen -= len(x.PreviousName) + 1
currentLen += compressionLenHelper(c, x.PreviousName, currentLen)
currentLen -= len(x.NextName) + 1
currentLen += compressionLenHelper(c, x.NextName, currentLen)
case *TKEY:
compressionLenHelper(c, x.Algorithm)
currentLen -= len(x.Algorithm) + 1
currentLen += compressionLenHelper(c, x.Algorithm, currentLen)
case *TSIG:
compressionLenHelper(c, x.Algorithm)
currentLen -= len(x.Algorithm) + 1
currentLen += compressionLenHelper(c, x.Algorithm, currentLen)
}
return currentLen - initLen
}
func compressionLenSearchType(c map[string]int, r RR) (int, bool) {
func compressionLenSearchType(c map[string]int, r RR) (int, bool, int) {
switch x := r.(type) {
case *AFSDB:
k1, ok1 := compressionLenSearch(c, x.Hostname)
return k1, ok1
k1, ok1, sz1 := compressionLenSearch(c, x.Hostname)
return k1, ok1, sz1
case *CNAME:
k1, ok1 := compressionLenSearch(c, x.Target)
return k1, ok1
k1, ok1, sz1 := compressionLenSearch(c, x.Target)
return k1, ok1, sz1
case *MB:
k1, ok1 := compressionLenSearch(c, x.Mb)
return k1, ok1
k1, ok1, sz1 := compressionLenSearch(c, x.Mb)
return k1, ok1, sz1
case *MD:
k1, ok1 := compressionLenSearch(c, x.Md)
return k1, ok1
k1, ok1, sz1 := compressionLenSearch(c, x.Md)
return k1, ok1, sz1
case *MF:
k1, ok1 := compressionLenSearch(c, x.Mf)
return k1, ok1
k1, ok1, sz1 := compressionLenSearch(c, x.Mf)
return k1, ok1, sz1
case *MG:
k1, ok1 := compressionLenSearch(c, x.Mg)
return k1, ok1
k1, ok1, sz1 := compressionLenSearch(c, x.Mg)
return k1, ok1, sz1
case *MINFO:
k1, ok1 := compressionLenSearch(c, x.Rmail)
k2, ok2 := compressionLenSearch(c, x.Email)
return k1 + k2, ok1 && ok2
k1, ok1, sz1 := compressionLenSearch(c, x.Rmail)
k2, ok2, sz2 := compressionLenSearch(c, x.Email)
return k1 + k2, ok1 && ok2, sz1 + sz2
case *MR:
k1, ok1 := compressionLenSearch(c, x.Mr)
return k1, ok1
k1, ok1, sz1 := compressionLenSearch(c, x.Mr)
return k1, ok1, sz1
case *MX:
k1, ok1 := compressionLenSearch(c, x.Mx)
return k1, ok1
k1, ok1, sz1 := compressionLenSearch(c, x.Mx)
return k1, ok1, sz1
case *NS:
k1, ok1 := compressionLenSearch(c, x.Ns)
return k1, ok1
k1, ok1, sz1 := compressionLenSearch(c, x.Ns)
return k1, ok1, sz1
case *PTR:
k1, ok1 := compressionLenSearch(c, x.Ptr)
return k1, ok1
k1, ok1, sz1 := compressionLenSearch(c, x.Ptr)
return k1, ok1, sz1
case *RT:
k1, ok1 := compressionLenSearch(c, x.Host)
return k1, ok1
k1, ok1, sz1 := compressionLenSearch(c, x.Host)
return k1, ok1, sz1
case *SOA:
k1, ok1 := compressionLenSearch(c, x.Ns)
k2, ok2 := compressionLenSearch(c, x.Mbox)
return k1 + k2, ok1 && ok2
k1, ok1, sz1 := compressionLenSearch(c, x.Ns)
k2, ok2, sz2 := compressionLenSearch(c, x.Mbox)
return k1 + k2, ok1 && ok2, sz1 + sz2
}
return 0, false
return 0, false, 0
}

130
vendor/github.com/miekg/dns/ztypes.go generated vendored
View File

@@ -649,215 +649,215 @@ func (rr *X25) len() int {
// copy() functions
func (rr *A) copy() RR {
return &A{*rr.Hdr.copyHeader(), copyIP(rr.A)}
return &A{rr.Hdr, copyIP(rr.A)}
}
func (rr *AAAA) copy() RR {
return &AAAA{*rr.Hdr.copyHeader(), copyIP(rr.AAAA)}
return &AAAA{rr.Hdr, copyIP(rr.AAAA)}
}
func (rr *AFSDB) copy() RR {
return &AFSDB{*rr.Hdr.copyHeader(), rr.Subtype, rr.Hostname}
return &AFSDB{rr.Hdr, rr.Subtype, rr.Hostname}
}
func (rr *ANY) copy() RR {
return &ANY{*rr.Hdr.copyHeader()}
return &ANY{rr.Hdr}
}
func (rr *AVC) copy() RR {
Txt := make([]string, len(rr.Txt))
copy(Txt, rr.Txt)
return &AVC{*rr.Hdr.copyHeader(), Txt}
return &AVC{rr.Hdr, Txt}
}
func (rr *CAA) copy() RR {
return &CAA{*rr.Hdr.copyHeader(), rr.Flag, rr.Tag, rr.Value}
return &CAA{rr.Hdr, rr.Flag, rr.Tag, rr.Value}
}
func (rr *CERT) copy() RR {
return &CERT{*rr.Hdr.copyHeader(), rr.Type, rr.KeyTag, rr.Algorithm, rr.Certificate}
return &CERT{rr.Hdr, rr.Type, rr.KeyTag, rr.Algorithm, rr.Certificate}
}
func (rr *CNAME) copy() RR {
return &CNAME{*rr.Hdr.copyHeader(), rr.Target}
return &CNAME{rr.Hdr, rr.Target}
}
func (rr *CSYNC) copy() RR {
TypeBitMap := make([]uint16, len(rr.TypeBitMap))
copy(TypeBitMap, rr.TypeBitMap)
return &CSYNC{*rr.Hdr.copyHeader(), rr.Serial, rr.Flags, TypeBitMap}
return &CSYNC{rr.Hdr, rr.Serial, rr.Flags, TypeBitMap}
}
func (rr *DHCID) copy() RR {
return &DHCID{*rr.Hdr.copyHeader(), rr.Digest}
return &DHCID{rr.Hdr, rr.Digest}
}
func (rr *DNAME) copy() RR {
return &DNAME{*rr.Hdr.copyHeader(), rr.Target}
return &DNAME{rr.Hdr, rr.Target}
}
func (rr *DNSKEY) copy() RR {
return &DNSKEY{*rr.Hdr.copyHeader(), rr.Flags, rr.Protocol, rr.Algorithm, rr.PublicKey}
return &DNSKEY{rr.Hdr, rr.Flags, rr.Protocol, rr.Algorithm, rr.PublicKey}
}
func (rr *DS) copy() RR {
return &DS{*rr.Hdr.copyHeader(), rr.KeyTag, rr.Algorithm, rr.DigestType, rr.Digest}
return &DS{rr.Hdr, rr.KeyTag, rr.Algorithm, rr.DigestType, rr.Digest}
}
func (rr *EID) copy() RR {
return &EID{*rr.Hdr.copyHeader(), rr.Endpoint}
return &EID{rr.Hdr, rr.Endpoint}
}
func (rr *EUI48) copy() RR {
return &EUI48{*rr.Hdr.copyHeader(), rr.Address}
return &EUI48{rr.Hdr, rr.Address}
}
func (rr *EUI64) copy() RR {
return &EUI64{*rr.Hdr.copyHeader(), rr.Address}
return &EUI64{rr.Hdr, rr.Address}
}
func (rr *GID) copy() RR {
return &GID{*rr.Hdr.copyHeader(), rr.Gid}
return &GID{rr.Hdr, rr.Gid}
}
func (rr *GPOS) copy() RR {
return &GPOS{*rr.Hdr.copyHeader(), rr.Longitude, rr.Latitude, rr.Altitude}
return &GPOS{rr.Hdr, rr.Longitude, rr.Latitude, rr.Altitude}
}
func (rr *HINFO) copy() RR {
return &HINFO{*rr.Hdr.copyHeader(), rr.Cpu, rr.Os}
return &HINFO{rr.Hdr, rr.Cpu, rr.Os}
}
func (rr *HIP) copy() RR {
RendezvousServers := make([]string, len(rr.RendezvousServers))
copy(RendezvousServers, rr.RendezvousServers)
return &HIP{*rr.Hdr.copyHeader(), rr.HitLength, rr.PublicKeyAlgorithm, rr.PublicKeyLength, rr.Hit, rr.PublicKey, RendezvousServers}
return &HIP{rr.Hdr, rr.HitLength, rr.PublicKeyAlgorithm, rr.PublicKeyLength, rr.Hit, rr.PublicKey, RendezvousServers}
}
func (rr *KX) copy() RR {
return &KX{*rr.Hdr.copyHeader(), rr.Preference, rr.Exchanger}
return &KX{rr.Hdr, rr.Preference, rr.Exchanger}
}
func (rr *L32) copy() RR {
return &L32{*rr.Hdr.copyHeader(), rr.Preference, copyIP(rr.Locator32)}
return &L32{rr.Hdr, rr.Preference, copyIP(rr.Locator32)}
}
func (rr *L64) copy() RR {
return &L64{*rr.Hdr.copyHeader(), rr.Preference, rr.Locator64}
return &L64{rr.Hdr, rr.Preference, rr.Locator64}
}
func (rr *LOC) copy() RR {
return &LOC{*rr.Hdr.copyHeader(), rr.Version, rr.Size, rr.HorizPre, rr.VertPre, rr.Latitude, rr.Longitude, rr.Altitude}
return &LOC{rr.Hdr, rr.Version, rr.Size, rr.HorizPre, rr.VertPre, rr.Latitude, rr.Longitude, rr.Altitude}
}
func (rr *LP) copy() RR {
return &LP{*rr.Hdr.copyHeader(), rr.Preference, rr.Fqdn}
return &LP{rr.Hdr, rr.Preference, rr.Fqdn}
}
func (rr *MB) copy() RR {
return &MB{*rr.Hdr.copyHeader(), rr.Mb}
return &MB{rr.Hdr, rr.Mb}
}
func (rr *MD) copy() RR {
return &MD{*rr.Hdr.copyHeader(), rr.Md}
return &MD{rr.Hdr, rr.Md}
}
func (rr *MF) copy() RR {
return &MF{*rr.Hdr.copyHeader(), rr.Mf}
return &MF{rr.Hdr, rr.Mf}
}
func (rr *MG) copy() RR {
return &MG{*rr.Hdr.copyHeader(), rr.Mg}
return &MG{rr.Hdr, rr.Mg}
}
func (rr *MINFO) copy() RR {
return &MINFO{*rr.Hdr.copyHeader(), rr.Rmail, rr.Email}
return &MINFO{rr.Hdr, rr.Rmail, rr.Email}
}
func (rr *MR) copy() RR {
return &MR{*rr.Hdr.copyHeader(), rr.Mr}
return &MR{rr.Hdr, rr.Mr}
}
func (rr *MX) copy() RR {
return &MX{*rr.Hdr.copyHeader(), rr.Preference, rr.Mx}
return &MX{rr.Hdr, rr.Preference, rr.Mx}
}
func (rr *NAPTR) copy() RR {
return &NAPTR{*rr.Hdr.copyHeader(), rr.Order, rr.Preference, rr.Flags, rr.Service, rr.Regexp, rr.Replacement}
return &NAPTR{rr.Hdr, rr.Order, rr.Preference, rr.Flags, rr.Service, rr.Regexp, rr.Replacement}
}
func (rr *NID) copy() RR {
return &NID{*rr.Hdr.copyHeader(), rr.Preference, rr.NodeID}
return &NID{rr.Hdr, rr.Preference, rr.NodeID}
}
func (rr *NIMLOC) copy() RR {
return &NIMLOC{*rr.Hdr.copyHeader(), rr.Locator}
return &NIMLOC{rr.Hdr, rr.Locator}
}
func (rr *NINFO) copy() RR {
ZSData := make([]string, len(rr.ZSData))
copy(ZSData, rr.ZSData)
return &NINFO{*rr.Hdr.copyHeader(), ZSData}
return &NINFO{rr.Hdr, ZSData}
}
func (rr *NS) copy() RR {
return &NS{*rr.Hdr.copyHeader(), rr.Ns}
return &NS{rr.Hdr, rr.Ns}
}
func (rr *NSAPPTR) copy() RR {
return &NSAPPTR{*rr.Hdr.copyHeader(), rr.Ptr}
return &NSAPPTR{rr.Hdr, rr.Ptr}
}
func (rr *NSEC) copy() RR {
TypeBitMap := make([]uint16, len(rr.TypeBitMap))
copy(TypeBitMap, rr.TypeBitMap)
return &NSEC{*rr.Hdr.copyHeader(), rr.NextDomain, TypeBitMap}
return &NSEC{rr.Hdr, rr.NextDomain, TypeBitMap}
}
func (rr *NSEC3) copy() RR {
TypeBitMap := make([]uint16, len(rr.TypeBitMap))
copy(TypeBitMap, rr.TypeBitMap)
return &NSEC3{*rr.Hdr.copyHeader(), rr.Hash, rr.Flags, rr.Iterations, rr.SaltLength, rr.Salt, rr.HashLength, rr.NextDomain, TypeBitMap}
return &NSEC3{rr.Hdr, rr.Hash, rr.Flags, rr.Iterations, rr.SaltLength, rr.Salt, rr.HashLength, rr.NextDomain, TypeBitMap}
}
func (rr *NSEC3PARAM) copy() RR {
return &NSEC3PARAM{*rr.Hdr.copyHeader(), rr.Hash, rr.Flags, rr.Iterations, rr.SaltLength, rr.Salt}
return &NSEC3PARAM{rr.Hdr, rr.Hash, rr.Flags, rr.Iterations, rr.SaltLength, rr.Salt}
}
func (rr *OPENPGPKEY) copy() RR {
return &OPENPGPKEY{*rr.Hdr.copyHeader(), rr.PublicKey}
return &OPENPGPKEY{rr.Hdr, rr.PublicKey}
}
func (rr *OPT) copy() RR {
Option := make([]EDNS0, len(rr.Option))
copy(Option, rr.Option)
return &OPT{*rr.Hdr.copyHeader(), Option}
return &OPT{rr.Hdr, Option}
}
func (rr *PTR) copy() RR {
return &PTR{*rr.Hdr.copyHeader(), rr.Ptr}
return &PTR{rr.Hdr, rr.Ptr}
}
func (rr *PX) copy() RR {
return &PX{*rr.Hdr.copyHeader(), rr.Preference, rr.Map822, rr.Mapx400}
return &PX{rr.Hdr, rr.Preference, rr.Map822, rr.Mapx400}
}
func (rr *RFC3597) copy() RR {
return &RFC3597{*rr.Hdr.copyHeader(), rr.Rdata}
return &RFC3597{rr.Hdr, rr.Rdata}
}
func (rr *RKEY) copy() RR {
return &RKEY{*rr.Hdr.copyHeader(), rr.Flags, rr.Protocol, rr.Algorithm, rr.PublicKey}
return &RKEY{rr.Hdr, rr.Flags, rr.Protocol, rr.Algorithm, rr.PublicKey}
}
func (rr *RP) copy() RR {
return &RP{*rr.Hdr.copyHeader(), rr.Mbox, rr.Txt}
return &RP{rr.Hdr, rr.Mbox, rr.Txt}
}
func (rr *RRSIG) copy() RR {
return &RRSIG{*rr.Hdr.copyHeader(), rr.TypeCovered, rr.Algorithm, rr.Labels, rr.OrigTtl, rr.Expiration, rr.Inception, rr.KeyTag, rr.SignerName, rr.Signature}
return &RRSIG{rr.Hdr, rr.TypeCovered, rr.Algorithm, rr.Labels, rr.OrigTtl, rr.Expiration, rr.Inception, rr.KeyTag, rr.SignerName, rr.Signature}
}
func (rr *RT) copy() RR {
return &RT{*rr.Hdr.copyHeader(), rr.Preference, rr.Host}
return &RT{rr.Hdr, rr.Preference, rr.Host}
}
func (rr *SMIMEA) copy() RR {
return &SMIMEA{*rr.Hdr.copyHeader(), rr.Usage, rr.Selector, rr.MatchingType, rr.Certificate}
return &SMIMEA{rr.Hdr, rr.Usage, rr.Selector, rr.MatchingType, rr.Certificate}
}
func (rr *SOA) copy() RR {
return &SOA{*rr.Hdr.copyHeader(), rr.Ns, rr.Mbox, rr.Serial, rr.Refresh, rr.Retry, rr.Expire, rr.Minttl}
return &SOA{rr.Hdr, rr.Ns, rr.Mbox, rr.Serial, rr.Refresh, rr.Retry, rr.Expire, rr.Minttl}
}
func (rr *SPF) copy() RR {
Txt := make([]string, len(rr.Txt))
copy(Txt, rr.Txt)
return &SPF{*rr.Hdr.copyHeader(), Txt}
return &SPF{rr.Hdr, Txt}
}
func (rr *SRV) copy() RR {
return &SRV{*rr.Hdr.copyHeader(), rr.Priority, rr.Weight, rr.Port, rr.Target}
return &SRV{rr.Hdr, rr.Priority, rr.Weight, rr.Port, rr.Target}
}
func (rr *SSHFP) copy() RR {
return &SSHFP{*rr.Hdr.copyHeader(), rr.Algorithm, rr.Type, rr.FingerPrint}
return &SSHFP{rr.Hdr, rr.Algorithm, rr.Type, rr.FingerPrint}
}
func (rr *TA) copy() RR {
return &TA{*rr.Hdr.copyHeader(), rr.KeyTag, rr.Algorithm, rr.DigestType, rr.Digest}
return &TA{rr.Hdr, rr.KeyTag, rr.Algorithm, rr.DigestType, rr.Digest}
}
func (rr *TALINK) copy() RR {
return &TALINK{*rr.Hdr.copyHeader(), rr.PreviousName, rr.NextName}
return &TALINK{rr.Hdr, rr.PreviousName, rr.NextName}
}
func (rr *TKEY) copy() RR {
return &TKEY{*rr.Hdr.copyHeader(), rr.Algorithm, rr.Inception, rr.Expiration, rr.Mode, rr.Error, rr.KeySize, rr.Key, rr.OtherLen, rr.OtherData}
return &TKEY{rr.Hdr, rr.Algorithm, rr.Inception, rr.Expiration, rr.Mode, rr.Error, rr.KeySize, rr.Key, rr.OtherLen, rr.OtherData}
}
func (rr *TLSA) copy() RR {
return &TLSA{*rr.Hdr.copyHeader(), rr.Usage, rr.Selector, rr.MatchingType, rr.Certificate}
return &TLSA{rr.Hdr, rr.Usage, rr.Selector, rr.MatchingType, rr.Certificate}
}
func (rr *TSIG) copy() RR {
return &TSIG{*rr.Hdr.copyHeader(), rr.Algorithm, rr.TimeSigned, rr.Fudge, rr.MACSize, rr.MAC, rr.OrigId, rr.Error, rr.OtherLen, rr.OtherData}
return &TSIG{rr.Hdr, rr.Algorithm, rr.TimeSigned, rr.Fudge, rr.MACSize, rr.MAC, rr.OrigId, rr.Error, rr.OtherLen, rr.OtherData}
}
func (rr *TXT) copy() RR {
Txt := make([]string, len(rr.Txt))
copy(Txt, rr.Txt)
return &TXT{*rr.Hdr.copyHeader(), Txt}
return &TXT{rr.Hdr, Txt}
}
func (rr *UID) copy() RR {
return &UID{*rr.Hdr.copyHeader(), rr.Uid}
return &UID{rr.Hdr, rr.Uid}
}
func (rr *UINFO) copy() RR {
return &UINFO{*rr.Hdr.copyHeader(), rr.Uinfo}
return &UINFO{rr.Hdr, rr.Uinfo}
}
func (rr *URI) copy() RR {
return &URI{*rr.Hdr.copyHeader(), rr.Priority, rr.Weight, rr.Target}
return &URI{rr.Hdr, rr.Priority, rr.Weight, rr.Target}
}
func (rr *X25) copy() RR {
return &X25{*rr.Hdr.copyHeader(), rr.PSDNAddress}
return &X25{rr.Hdr, rr.PSDNAddress}
}

View File

@@ -1,6 +1,6 @@
/*
* Minio Go Library for Amazon S3 Compatible Cloud Storage
* Copyright 2017 Minio, Inc.
* Copyright 2017, 2018 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -20,6 +20,8 @@ package minio
import (
"context"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"strconv"
@@ -343,11 +345,12 @@ func (c Client) uploadPartCopy(ctx context.Context, bucket, object, uploadID str
return p, nil
}
// ComposeObject - creates an object using server-side copying of
// ComposeObjectWithProgress - creates an object using server-side copying of
// existing objects. It takes a list of source objects (with optional
// offsets) and concatenates them into a new object using only
// server-side copying operations.
func (c Client) ComposeObject(dst DestinationInfo, srcs []SourceInfo) error {
// server-side copying operations. Optionally takes progress reader hook
// for applications to look at current progress.
func (c Client) ComposeObjectWithProgress(dst DestinationInfo, srcs []SourceInfo, progress io.Reader) error {
if len(srcs) < 1 || len(srcs) > maxPartsCount {
return ErrInvalidArgument("There must be as least one and up to 10000 source objects.")
}
@@ -421,7 +424,7 @@ func (c Client) ComposeObject(dst DestinationInfo, srcs []SourceInfo) error {
// involved, it is being copied wholly and at most 5GiB in
// size, emptyfiles are also supported).
if (totalParts == 1 && srcs[0].start == -1 && totalSize <= maxPartSize) || (totalSize == 0) {
return c.CopyObject(dst, srcs[0])
return c.CopyObjectWithProgress(dst, srcs[0], progress)
}
// Now, handle multipart-copy cases.
@@ -476,6 +479,9 @@ func (c Client) ComposeObject(dst DestinationInfo, srcs []SourceInfo) error {
if err != nil {
return err
}
if progress != nil {
io.CopyN(ioutil.Discard, progress, start+end-1)
}
objParts = append(objParts, complPart)
partIndex++
}
@@ -490,10 +496,20 @@ func (c Client) ComposeObject(dst DestinationInfo, srcs []SourceInfo) error {
return nil
}
// partsRequired is ceiling(size / copyPartSize)
// ComposeObject - creates an object using server-side copying of
// existing objects. It takes a list of source objects (with optional
// offsets) and concatenates them into a new object using only
// server-side copying operations.
func (c Client) ComposeObject(dst DestinationInfo, srcs []SourceInfo) error {
return c.ComposeObjectWithProgress(dst, srcs, nil)
}
// partsRequired is maximum parts possible with
// max part size of ceiling(maxMultipartPutObjectSize / (maxPartsCount - 1))
func partsRequired(size int64) int64 {
r := size / copyPartSize
if size%copyPartSize > 0 {
maxPartSize := maxMultipartPutObjectSize / (maxPartsCount - 1)
r := size / int64(maxPartSize)
if size%int64(maxPartSize) > 0 {
r++
}
return r

View File

@@ -99,13 +99,17 @@ func (c Client) MakeBucket(bucketName string, location string) (err error) {
}
// SetBucketPolicy set the access permissions on an existing bucket.
//
func (c Client) SetBucketPolicy(bucketName, policy string) error {
// Input validation.
if err := s3utils.CheckValidBucketName(bucketName); err != nil {
return err
}
// If policy is empty then delete the bucket policy.
if policy == "" {
return c.removeBucketPolicy(bucketName)
}
// Save the updated policies.
return c.putBucketPolicy(bucketName, policy)
}

View File

@@ -1,6 +1,6 @@
/*
* Minio Go Library for Amazon S3 Compatible Cloud Storage
* Copyright 2017 Minio, Inc.
* Copyright 2017, 2018 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -19,6 +19,8 @@ package minio
import (
"context"
"io"
"io/ioutil"
"net/http"
"github.com/minio/minio-go/pkg/encrypt"
@@ -26,13 +28,31 @@ import (
// CopyObject - copy a source object into a new object
func (c Client) CopyObject(dst DestinationInfo, src SourceInfo) error {
return c.CopyObjectWithProgress(dst, src, nil)
}
// CopyObjectWithProgress - copy a source object into a new object, optionally takes
// progress bar input to notify current progress.
func (c Client) CopyObjectWithProgress(dst DestinationInfo, src SourceInfo, progress io.Reader) error {
header := make(http.Header)
for k, v := range src.Headers {
header[k] = v
}
var err error
var size int64
// If progress bar is specified, size should be requested as well initiate a StatObject request.
if progress != nil {
size, _, _, err = src.getProps(c)
if err != nil {
return err
}
}
if src.encryption != nil {
encrypt.SSECopy(src.encryption).Marshal(header)
}
if dst.encryption != nil {
dst.encryption.Marshal(header)
}
@@ -53,5 +73,11 @@ func (c Client) CopyObject(dst DestinationInfo, src SourceInfo) error {
if resp.StatusCode != http.StatusOK {
return httpRespToErrorResponse(resp, dst.bucket, dst.object)
}
// Update the progress properly after successful copy.
if progress != nil {
io.CopyN(ioutil.Discard, progress, size)
}
return nil
}

View File

@@ -33,16 +33,17 @@ import (
// PutObjectOptions represents options specified by user for PutObject call
type PutObjectOptions struct {
UserMetadata map[string]string
Progress io.Reader
ContentType string
ContentEncoding string
ContentDisposition string
ContentLanguage string
CacheControl string
ServerSideEncryption encrypt.ServerSide
NumThreads uint
StorageClass string
UserMetadata map[string]string
Progress io.Reader
ContentType string
ContentEncoding string
ContentDisposition string
ContentLanguage string
CacheControl string
ServerSideEncryption encrypt.ServerSide
NumThreads uint
StorageClass string
WebsiteRedirectLocation string
}
// getNumThreads - gets the number of threads to be used in the multipart
@@ -84,6 +85,9 @@ func (opts PutObjectOptions) Header() (header http.Header) {
if opts.StorageClass != "" {
header[amzStorageClass] = []string{opts.StorageClass}
}
if opts.WebsiteRedirectLocation != "" {
header[amzWebsiteRedirectLocation] = []string{opts.WebsiteRedirectLocation}
}
for k, v := range opts.UserMetadata {
if !isAmzHeader(k) && !isStandardHeader(k) && !isStorageClassHeader(k) {
header["X-Amz-Meta-"+k] = []string{v}

View File

@@ -66,6 +66,9 @@ var defaultFilterKeys = []string{
"x-amz-bucket-region",
"x-amz-request-id",
"x-amz-id-2",
"Content-Security-Policy",
"X-Xss-Protection",
// Add new headers to be ignored.
}

View File

@@ -99,7 +99,7 @@ type Options struct {
// Global constants.
const (
libraryName = "minio-go"
libraryVersion = "6.0.0"
libraryVersion = "6.0.1"
)
// User Agent should always following the below style.

View File

@@ -27,10 +27,6 @@ const absMinPartSize = 1024 * 1024 * 5
// putObject behaves internally as multipart.
const minPartSize = 1024 * 1024 * 64
// copyPartSize - default (and maximum) part size to copy in a
// copy-object request (5GiB)
const copyPartSize = 1024 * 1024 * 1024 * 5
// maxPartsCount - maximum number of parts for a single multipart session.
const maxPartsCount = 10000
@@ -61,3 +57,6 @@ const (
// Storage class header constant.
const amzStorageClass = "X-Amz-Storage-Class"
// Website redirect location header constant
const amzWebsiteRedirectLocation = "X-Amz-Website-Redirect-Location"

View File

@@ -82,6 +82,8 @@ func (c Core) PutObject(bucket, object string, data io.Reader, size int64, md5Ba
opts.ContentType = v
} else if strings.ToLower(k) == "cache-control" {
opts.CacheControl = v
} else if strings.ToLower(k) == strings.ToLower(amzWebsiteRedirectLocation) {
opts.WebsiteRedirectLocation = v
} else {
m[k] = metadata[k]
}

View File

@@ -3474,15 +3474,13 @@ func testFunctional() {
function = "SetBucketPolicy(bucketName, readOnlyPolicy)"
functionAll += ", " + function
readOnlyPolicy := `{"Version":"2012-10-17","Statement":[{"Action":["s3:ListBucket"],"Effect":"Allow","Principal":{"AWS":["*"]},"Resource":["arn:aws:s3:::` + bucketName + `"],"Sid":""}]}`
readOnlyPolicy := `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Principal":{"AWS":["*"]},"Action":["s3:ListBucket"],"Resource":["arn:aws:s3:::` + bucketName + `"]}]}`
args = map[string]interface{}{
"bucketName": bucketName,
"bucketPolicy": readOnlyPolicy,
}
err = c.SetBucketPolicy(bucketName, readOnlyPolicy)
if err != nil {
logError(testName, function, args, startTime, "", "SetBucketPolicy failed", err)
return
@@ -3494,14 +3492,12 @@ func testFunctional() {
"bucketName": bucketName,
}
readOnlyPolicyRet, err := c.GetBucketPolicy(bucketName)
if err != nil {
logError(testName, function, args, startTime, "", "GetBucketPolicy failed", err)
return
}
if strings.Compare(readOnlyPolicyRet, readOnlyPolicy) != 0 {
logError(testName, function, args, startTime, "", "policy should be set to readonly", err)
if readOnlyPolicyRet == "" {
logError(testName, function, args, startTime, "", "policy should be set", err)
return
}
@@ -3509,8 +3505,7 @@ func testFunctional() {
function = "SetBucketPolicy(bucketName, writeOnlyPolicy)"
functionAll += ", " + function
writeOnlyPolicy := `{"Version":"2012-10-17","Statement":[{"Action":["s3:ListBucketMultipartUploads"],"Effect":"Allow","Principal":{"AWS":["*"]},"Resource":["arn:aws:s3:::` + bucketName + `"],"Sid":""}]}`
writeOnlyPolicy := `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Principal":{"AWS":["*"]},"Action":["s3:ListBucketMultipartUploads"],"Resource":["arn:aws:s3:::` + bucketName + `"]}]}`
args = map[string]interface{}{
"bucketName": bucketName,
"bucketPolicy": writeOnlyPolicy,
@@ -3534,8 +3529,8 @@ func testFunctional() {
return
}
if strings.Compare(writeOnlyPolicyRet, writeOnlyPolicy) != 0 {
logError(testName, function, args, startTime, "", "policy should be set to writeonly", err)
if writeOnlyPolicyRet == "" {
logError(testName, function, args, startTime, "", "policy should be set", err)
return
}
@@ -3543,7 +3538,7 @@ func testFunctional() {
function = "SetBucketPolicy(bucketName, readWritePolicy)"
functionAll += ", " + function
readWritePolicy := `{"Version":"2012-10-17","Statement":[{"Action":["s3:ListBucket","s3:ListBucketMultipartUploads"],"Effect":"Allow","Principal":{"AWS":["*"]},"Resource":["arn:aws:s3:::` + bucketName + `"],"Sid":""}]}`
readWritePolicy := `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Principal":{"AWS":["*"]},"Action":["s3:ListBucket","s3:ListBucketMultipartUploads"],"Resource":["arn:aws:s3:::` + bucketName + `"]}]}`
args = map[string]interface{}{
"bucketName": bucketName,
@@ -3567,8 +3562,8 @@ func testFunctional() {
return
}
if strings.Compare(readWritePolicyRet, readWritePolicy) != 0 {
logError(testName, function, args, startTime, "", "policy should be set to readwrite", err)
if readWritePolicyRet == "" {
logError(testName, function, args, startTime, "", "policy should be set", err)
return
}
@@ -5995,7 +5990,11 @@ func testStorageClassMetadataPutObject() {
logError(testName, function, args, startTime, "", "Metadata verification failed, STANDARD storage class should not be a part of response metadata", err)
return
}
// Delete all objects and buckets
if err = cleanupBucket(bucketName, c); err != nil {
logError(testName, function, args, startTime, "", "Cleanup failed", err)
return
}
successLogger(testName, function, args, startTime).Info()
}
@@ -6036,7 +6035,11 @@ func testStorageClassInvalidMetadataPutObject() {
logError(testName, function, args, startTime, "", "PutObject with invalid storage class passed, was expected to fail", err)
return
}
// Delete all objects and buckets
if err = cleanupBucket(bucketName, c); err != nil {
logError(testName, function, args, startTime, "", "Cleanup failed", err)
return
}
successLogger(testName, function, args, startTime).Info()
}
@@ -6136,7 +6139,11 @@ func testStorageClassMetadataCopyObject() {
logError(testName, function, args, startTime, "", "Metadata verification failed, STANDARD storage class should not be a part of response metadata", err)
return
}
// Delete all objects and buckets
if err = cleanupBucket(bucketName, c); err != nil {
logError(testName, function, args, startTime, "", "Cleanup failed", err)
return
}
successLogger(testName, function, args, startTime).Info()
}
@@ -6497,7 +6504,7 @@ func testFunctionalV2() {
function = "SetBucketPolicy(bucketName, bucketPolicy)"
functionAll += ", " + function
readWritePolicy := `{"Version": "2012-10-17","Statement": [{"Action": ["s3:ListBucketMultipartUploads,s3:ListBucket"],"Effect": "Allow","Principal": {"AWS": ["*"]},"Resource": ["arn:aws:s3:::` + bucketName + `/*"],"Sid": ""}]}`
readWritePolicy := `{"Version": "2012-10-17","Statement": [{"Action": ["s3:ListBucketMultipartUploads", "s3:ListBucket"],"Effect": "Allow","Principal": {"AWS": ["*"]},"Resource": ["arn:aws:s3:::` + bucketName + `"],"Sid": ""}]}`
args = map[string]interface{}{
"bucketName": bucketName,
@@ -6558,9 +6565,9 @@ func testFunctionalV2() {
return
}
objectName_noLength := objectName + "-nolength"
args["objectName"] = objectName_noLength
n, err = c.PutObject(bucketName, objectName_noLength, bytes.NewReader(buf), int64(len(buf)), minio.PutObjectOptions{ContentType: "binary/octet-stream"})
objectNameNoLength := objectName + "-nolength"
args["objectName"] = objectNameNoLength
n, err = c.PutObject(bucketName, objectNameNoLength, bytes.NewReader(buf), int64(len(buf)), minio.PutObjectOptions{ContentType: "binary/octet-stream"})
if err != nil {
logError(testName, function, args, startTime, "", "PutObject failed", err)
return

View File

@@ -131,6 +131,7 @@ var retryableS3Codes = map[string]struct{}{
"InternalError": {},
"ExpiredToken": {},
"ExpiredTokenException": {},
"SlowDown": {},
// Add more AWS S3 codes here.
}

View File

@@ -222,6 +222,7 @@ var supportedHeaders = []string{
"content-encoding",
"content-disposition",
"content-language",
"x-amz-website-redirect-location",
// Add more supported headers here.
}

View File

@@ -644,16 +644,28 @@ func (d *Decoder) decodeMapFromStruct(name string, dataVal reflect.Value, val re
return fmt.Errorf("cannot assign type '%s' to map value field of type '%s'", v.Type(), valMap.Type().Elem())
}
tagValue := f.Tag.Get(d.config.TagName)
tagParts := strings.Split(tagValue, ",")
// Determine the name of the key in the map
keyName := f.Name
tagValue := f.Tag.Get(d.config.TagName)
tagValue = strings.SplitN(tagValue, ",", 2)[0]
if tagValue != "" {
if tagValue == "-" {
if tagParts[0] != "" {
if tagParts[0] == "-" {
continue
}
keyName = tagParts[0]
}
keyName = tagValue
// If "squash" is specified in the tag, we squash the field down.
squash := false
for _, tag := range tagParts[1:] {
if tag == "squash" {
squash = true
break
}
}
if squash && v.Kind() != reflect.Struct {
return fmt.Errorf("cannot squash non-struct type '%s'", v.Type())
}
switch v.Kind() {
@@ -673,7 +685,13 @@ func (d *Decoder) decodeMapFromStruct(name string, dataVal reflect.Value, val re
return err
}
valMap.SetMapIndex(reflect.ValueOf(keyName), vMap)
if squash {
for _, k := range vMap.MapKeys() {
valMap.SetMapIndex(k, vMap.MapIndex(k))
}
} else {
valMap.SetMapIndex(reflect.ValueOf(keyName), vMap)
}
default:
valMap.SetMapIndex(reflect.ValueOf(keyName), v)

View File

@@ -251,6 +251,14 @@ __%[1]s_handle_word()
__%[1]s_handle_command
elif [[ $c -eq 0 ]]; then
__%[1]s_handle_command
elif __%[1]s_contains_word "${words[c]}" "${command_aliases[@]}"; then
# aliashash variable is an associative array which is only supported in bash > 3.
if [[ -z "${BASH_VERSION}" || "${BASH_VERSINFO[0]}" -gt 3 ]]; then
words[c]=${aliashash[${words[c]}]}
__%[1]s_handle_command
else
__%[1]s_handle_noun
fi
else
__%[1]s_handle_noun
fi
@@ -266,6 +274,7 @@ func writePostscript(buf *bytes.Buffer, name string) {
buf.WriteString(fmt.Sprintf(`{
local cur prev words cword
declare -A flaghash 2>/dev/null || :
declare -A aliashash 2>/dev/null || :
if declare -F _init_completion >/dev/null 2>&1; then
_init_completion -s || return
else
@@ -305,6 +314,7 @@ func writeCommands(buf *bytes.Buffer, cmd *Command) {
continue
}
buf.WriteString(fmt.Sprintf(" commands+=(%q)\n", c.Name()))
writeCmdAliases(buf, c)
}
buf.WriteString("\n")
}
@@ -443,6 +453,21 @@ func writeRequiredNouns(buf *bytes.Buffer, cmd *Command) {
}
}
func writeCmdAliases(buf *bytes.Buffer, cmd *Command) {
if len(cmd.Aliases) == 0 {
return
}
sort.Sort(sort.StringSlice(cmd.Aliases))
buf.WriteString(fmt.Sprint(` if [[ -z "${BASH_VERSION}" || "${BASH_VERSINFO[0]}" -gt 3 ]]; then`, "\n"))
for _, value := range cmd.Aliases {
buf.WriteString(fmt.Sprintf(" command_aliases+=(%q)\n", value))
buf.WriteString(fmt.Sprintf(" aliashash[%q]=%q\n", value, cmd.Name()))
}
buf.WriteString(` fi`)
buf.WriteString("\n")
}
func writeArgAliases(buf *bytes.Buffer, cmd *Command) {
buf.WriteString(" noun_aliases=()\n")
sort.Sort(sort.StringSlice(cmd.ArgAliases))
@@ -469,6 +494,10 @@ func gen(buf *bytes.Buffer, cmd *Command) {
}
buf.WriteString(fmt.Sprintf(" last_command=%q\n", commandName))
buf.WriteString("\n")
buf.WriteString(" command_aliases=()\n")
buf.WriteString("\n")
writeCommands(buf, cmd)
writeFlags(buf, cmd)
writeRequiredFlag(buf, cmd)

View File

@@ -181,7 +181,7 @@ a custom flag completion function with cobra.BashCompCustom:
```go
annotation := make(map[string][]string)
annotation[cobra.BashCompFilenameExt] = []string{"__kubectl_get_namespaces"}
annotation[cobra.BashCompCustom] = []string{"__kubectl_get_namespaces"}
flag := &pflag.Flag{
Name: "namespace",

View File

@@ -27,6 +27,9 @@ import (
flag "github.com/spf13/pflag"
)
// FParseErrWhitelist configures Flag parse errors to be ignored
type FParseErrWhitelist flag.ParseErrorsWhitelist
// Command is just that, a command for your application.
// E.g. 'go run ...' - 'run' is the command. Cobra requires
// you to define the usage and description as part of your command
@@ -137,6 +140,9 @@ type Command struct {
// TraverseChildren parses flags on all parents before executing child command.
TraverseChildren bool
//FParseErrWhitelist flag parse errors to be ignored
FParseErrWhitelist FParseErrWhitelist
// commands is the list of commands supported by this program.
commands []*Command
// parent is a parent command for this command.
@@ -1463,6 +1469,10 @@ func (c *Command) ParseFlags(args []string) error {
}
beforeErrorBufLen := c.flagErrorBuf.Len()
c.mergePersistentFlags()
//do it here after merging all flags and just before parse
c.Flags().ParseErrorsWhitelist = flag.ParseErrorsWhitelist(c.FParseErrWhitelist)
err := c.Flags().Parse(args)
// Print warnings if they occurred (e.g. deprecated flag messages).
if c.flagErrorBuf.Len()-beforeErrorBufLen > 0 && err == nil {