vendor: update grpc, gogo/protobuf

This commit is contained in:
Gyu-Ho Lee
2016-04-25 14:10:58 -07:00
parent 4b31acf0e0
commit 12d01bb1eb
19 changed files with 384 additions and 241 deletions

26
cmd/Godeps/Godeps.json generated
View File

@ -1,7 +1,7 @@
{ {
"ImportPath": "github.com/coreos/etcd", "ImportPath": "github.com/coreos/etcd",
"GoVersion": "go1.6", "GoVersion": "go1.6",
"GodepVersion": "v60", "GodepVersion": "v62",
"Packages": [ "Packages": [
"./..." "./..."
], ],
@ -75,8 +75,8 @@
}, },
{ {
"ImportPath": "github.com/gogo/protobuf/proto", "ImportPath": "github.com/gogo/protobuf/proto",
"Comment": "v0.1-118-ge8904f5", "Comment": "v0.2-13-gc3995ae",
"Rev": "e8904f58e872a473a5b91bc9bf3377d223555263" "Rev": "c3995ae437bb78d1189f4f147dfe5f87ad3596e4"
}, },
{ {
"ImportPath": "github.com/golang/glog", "ImportPath": "github.com/golang/glog",
@ -109,7 +109,7 @@
}, },
{ {
"ImportPath": "github.com/mattn/go-runewidth", "ImportPath": "github.com/mattn/go-runewidth",
"Comment": "travisish-46-gd6bea18", "Comment": "v0.0.1",
"Rev": "d6bea18f789704b5f83375793155289da36a3c7f" "Rev": "d6bea18f789704b5f83375793155289da36a3c7f"
}, },
{ {
@ -209,39 +209,39 @@
}, },
{ {
"ImportPath": "google.golang.org/grpc", "ImportPath": "google.golang.org/grpc",
"Rev": "9ac074585f926c8506b6351bfdc396d2b19b1cb1" "Rev": "b062a3c003c22bfef58fa99d689e6a892b408f9d"
}, },
{ {
"ImportPath": "google.golang.org/grpc/codes", "ImportPath": "google.golang.org/grpc/codes",
"Rev": "9ac074585f926c8506b6351bfdc396d2b19b1cb1" "Rev": "b062a3c003c22bfef58fa99d689e6a892b408f9d"
}, },
{ {
"ImportPath": "google.golang.org/grpc/credentials", "ImportPath": "google.golang.org/grpc/credentials",
"Rev": "9ac074585f926c8506b6351bfdc396d2b19b1cb1" "Rev": "b062a3c003c22bfef58fa99d689e6a892b408f9d"
}, },
{ {
"ImportPath": "google.golang.org/grpc/grpclog", "ImportPath": "google.golang.org/grpc/grpclog",
"Rev": "9ac074585f926c8506b6351bfdc396d2b19b1cb1" "Rev": "b062a3c003c22bfef58fa99d689e6a892b408f9d"
}, },
{ {
"ImportPath": "google.golang.org/grpc/internal", "ImportPath": "google.golang.org/grpc/internal",
"Rev": "9ac074585f926c8506b6351bfdc396d2b19b1cb1" "Rev": "b062a3c003c22bfef58fa99d689e6a892b408f9d"
}, },
{ {
"ImportPath": "google.golang.org/grpc/metadata", "ImportPath": "google.golang.org/grpc/metadata",
"Rev": "9ac074585f926c8506b6351bfdc396d2b19b1cb1" "Rev": "b062a3c003c22bfef58fa99d689e6a892b408f9d"
}, },
{ {
"ImportPath": "google.golang.org/grpc/naming", "ImportPath": "google.golang.org/grpc/naming",
"Rev": "9ac074585f926c8506b6351bfdc396d2b19b1cb1" "Rev": "b062a3c003c22bfef58fa99d689e6a892b408f9d"
}, },
{ {
"ImportPath": "google.golang.org/grpc/peer", "ImportPath": "google.golang.org/grpc/peer",
"Rev": "9ac074585f926c8506b6351bfdc396d2b19b1cb1" "Rev": "b062a3c003c22bfef58fa99d689e6a892b408f9d"
}, },
{ {
"ImportPath": "google.golang.org/grpc/transport", "ImportPath": "google.golang.org/grpc/transport",
"Rev": "9ac074585f926c8506b6351bfdc396d2b19b1cb1" "Rev": "b062a3c003c22bfef58fa99d689e6a892b408f9d"
}, },
{ {
"ImportPath": "gopkg.in/cheggaaa/pb.v1", "ImportPath": "gopkg.in/cheggaaa/pb.v1",

View File

@ -115,14 +115,8 @@ func setCustomType(base structPointer, f field, value interface{}) {
oldHeader.Len = v.Len() oldHeader.Len = v.Len()
oldHeader.Cap = v.Cap() oldHeader.Cap = v.Cap()
default: default:
l := 1
size := reflect.TypeOf(value).Elem().Size() size := reflect.TypeOf(value).Elem().Size()
if kind == reflect.Array { structPointer_Copy(toStructPointer(reflect.ValueOf(value)), structPointer_Add(base, f), int(size))
l = reflect.TypeOf(value).Elem().Len()
size = reflect.TypeOf(value).Size()
}
total := int(size) * l
structPointer_Copy(toStructPointer(reflect.ValueOf(value)), structPointer_Add(base, f), total)
} }
} }

View File

@ -105,6 +105,11 @@ func (p *Buffer) EncodeVarint(x uint64) error {
return nil return nil
} }
// SizeVarint returns the varint encoding size of an integer.
func SizeVarint(x uint64) int {
return sizeVarint(x)
}
func sizeVarint(x uint64) (n int) { func sizeVarint(x uint64) (n int) {
for { for {
n++ n++
@ -1248,24 +1253,9 @@ func size_struct(prop *StructProperties, base structPointer) (n int) {
} }
// Factor in any oneof fields. // Factor in any oneof fields.
// TODO: This could be faster and use less reflection. if prop.oneofSizer != nil {
if prop.oneofMarshaler != nil { m := structPointer_Interface(base, prop.stype).(Message)
sv := reflect.ValueOf(structPointer_Interface(base, prop.stype)).Elem() n += prop.oneofSizer(m)
for i := 0; i < prop.stype.NumField(); i++ {
fv := sv.Field(i)
if fv.Kind() != reflect.Interface || fv.IsNil() {
continue
}
if prop.stype.Field(i).Tag.Get("protobuf_oneof") == "" {
continue
}
spv := fv.Elem() // interface -> *T
sv := spv.Elem() // *T -> T
sf := sv.Type().Field(0) // StructField inside T
var prop Properties
prop.Init(sf.Type, "whatever", sf.Tag.Get("protobuf"), &sf)
n += prop.size(&prop, toStructPointer(spv))
}
} }
return return

View File

@ -50,7 +50,9 @@ Equality is defined in this way:
are equal, and extensions sets are equal. are equal, and extensions sets are equal.
- Two set scalar fields are equal iff their values are equal. - Two set scalar fields are equal iff their values are equal.
If the fields are of a floating-point type, remember that If the fields are of a floating-point type, remember that
NaN != x for all x, including NaN. NaN != x for all x, including NaN. If the message is defined
in a proto3 .proto file, fields are not "set"; specifically,
zero length proto3 "bytes" fields are equal (nil == {}).
- Two repeated fields are equal iff their lengths are the same, - Two repeated fields are equal iff their lengths are the same,
and their corresponding elements are equal (a "bytes" field, and their corresponding elements are equal (a "bytes" field,
although represented by []byte, is not a repeated field) although represented by []byte, is not a repeated field)
@ -88,6 +90,7 @@ func Equal(a, b Message) bool {
// v1 and v2 are known to have the same type. // v1 and v2 are known to have the same type.
func equalStruct(v1, v2 reflect.Value) bool { func equalStruct(v1, v2 reflect.Value) bool {
sprop := GetProperties(v1.Type())
for i := 0; i < v1.NumField(); i++ { for i := 0; i < v1.NumField(); i++ {
f := v1.Type().Field(i) f := v1.Type().Field(i)
if strings.HasPrefix(f.Name, "XXX_") { if strings.HasPrefix(f.Name, "XXX_") {
@ -113,7 +116,7 @@ func equalStruct(v1, v2 reflect.Value) bool {
} }
f1, f2 = f1.Elem(), f2.Elem() f1, f2 = f1.Elem(), f2.Elem()
} }
if !equalAny(f1, f2) { if !equalAny(f1, f2, sprop.Prop[i]) {
return false return false
} }
} }
@ -140,7 +143,8 @@ func equalStruct(v1, v2 reflect.Value) bool {
} }
// v1 and v2 are known to have the same type. // v1 and v2 are known to have the same type.
func equalAny(v1, v2 reflect.Value) bool { // prop may be nil.
func equalAny(v1, v2 reflect.Value, prop *Properties) bool {
if v1.Type() == protoMessageType { if v1.Type() == protoMessageType {
m1, _ := v1.Interface().(Message) m1, _ := v1.Interface().(Message)
m2, _ := v2.Interface().(Message) m2, _ := v2.Interface().(Message)
@ -163,7 +167,7 @@ func equalAny(v1, v2 reflect.Value) bool {
if e1.Type() != e2.Type() { if e1.Type() != e2.Type() {
return false return false
} }
return equalAny(e1, e2) return equalAny(e1, e2, nil)
case reflect.Map: case reflect.Map:
if v1.Len() != v2.Len() { if v1.Len() != v2.Len() {
return false return false
@ -174,16 +178,22 @@ func equalAny(v1, v2 reflect.Value) bool {
// This key was not found in the second map. // This key was not found in the second map.
return false return false
} }
if !equalAny(v1.MapIndex(key), val2) { if !equalAny(v1.MapIndex(key), val2, nil) {
return false return false
} }
} }
return true return true
case reflect.Ptr: case reflect.Ptr:
return equalAny(v1.Elem(), v2.Elem()) return equalAny(v1.Elem(), v2.Elem(), prop)
case reflect.Slice: case reflect.Slice:
if v1.Type().Elem().Kind() == reflect.Uint8 { if v1.Type().Elem().Kind() == reflect.Uint8 {
// short circuit: []byte // short circuit: []byte
// Edge case: if this is in a proto3 message, a zero length
// bytes field is considered the zero value.
if prop != nil && prop.proto3 && v1.Len() == 0 && v2.Len() == 0 {
return true
}
if v1.IsNil() != v2.IsNil() { if v1.IsNil() != v2.IsNil() {
return false return false
} }
@ -194,7 +204,7 @@ func equalAny(v1, v2 reflect.Value) bool {
return false return false
} }
for i := 0; i < v1.Len(); i++ { for i := 0; i < v1.Len(); i++ {
if !equalAny(v1.Index(i), v2.Index(i)) { if !equalAny(v1.Index(i), v2.Index(i), prop) {
return false return false
} }
} }
@ -229,7 +239,7 @@ func equalExtensions(base reflect.Type, em1, em2 map[int32]Extension) bool {
if m1 != nil && m2 != nil { if m1 != nil && m2 != nil {
// Both are unencoded. // Both are unencoded.
if !equalAny(reflect.ValueOf(m1), reflect.ValueOf(m2)) { if !equalAny(reflect.ValueOf(m1), reflect.ValueOf(m2), nil) {
return false return false
} }
continue continue
@ -257,7 +267,7 @@ func equalExtensions(base reflect.Type, em1, em2 map[int32]Extension) bool {
log.Printf("proto: badly encoded extension %d of %v: %v", extNum, base, err) log.Printf("proto: badly encoded extension %d of %v: %v", extNum, base, err)
return false return false
} }
if !equalAny(reflect.ValueOf(m1), reflect.ValueOf(m2)) { if !equalAny(reflect.ValueOf(m1), reflect.ValueOf(m2), nil) {
return false return false
} }
} }

View File

@ -403,7 +403,6 @@ func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) {
o := NewBuffer(b) o := NewBuffer(b)
t := reflect.TypeOf(extension.ExtensionType) t := reflect.TypeOf(extension.ExtensionType)
rep := extension.repeated()
props := extensionProperties(extension) props := extensionProperties(extension)
@ -425,7 +424,7 @@ func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) {
return nil, err return nil, err
} }
if !rep || o.index >= len(o.buf) { if o.index >= len(o.buf) {
break break
} }
} }

View File

@ -185,6 +185,17 @@ func NewExtension(e []byte) Extension {
return ee return ee
} }
func AppendExtension(e extendableProto, tag int32, buf []byte) {
if ee, eok := e.(extensionsMap); eok {
ext := ee.ExtensionMap()[int32(tag)] // may be missing
ext.enc = append(ext.enc, buf...)
ee.ExtensionMap()[int32(tag)] = ext
} else if ee, eok := e.(extensionsBytes); eok {
ext := ee.GetExtensions()
*ext = append(*ext, buf...)
}
}
func (this Extension) GoString() string { func (this Extension) GoString() string {
if this.enc == nil { if this.enc == nil {
if err := encodeExtension(&this); err != nil { if err := encodeExtension(&this); err != nil {

View File

@ -70,6 +70,12 @@ for a protocol buffer variable v:
with distinguished wrapper types for each possible field value. with distinguished wrapper types for each possible field value.
- Marshal and Unmarshal are functions to encode and decode the wire format. - Marshal and Unmarshal are functions to encode and decode the wire format.
When the .proto file specifies `syntax="proto3"`, there are some differences:
- Non-repeated fields of non-message type are values instead of pointers.
- Getters are only generated for message and oneof fields.
- Enum types do not get an Enum method.
The simplest way to describe this is to see an example. The simplest way to describe this is to see an example.
Given file test.proto, containing Given file test.proto, containing
@ -229,6 +235,7 @@ To create and play with a Test object:
test := &pb.Test{ test := &pb.Test{
Label: proto.String("hello"), Label: proto.String("hello"),
Type: proto.Int32(17), Type: proto.Int32(17),
Reps: []int64{1, 2, 3},
Optionalgroup: &pb.Test_OptionalGroup{ Optionalgroup: &pb.Test_OptionalGroup{
RequiredField: proto.String("good bye"), RequiredField: proto.String("good bye"),
}, },
@ -441,7 +448,7 @@ func (p *Buffer) DebugPrint(s string, b []byte) {
var u uint64 var u uint64
obuf := p.buf obuf := p.buf
index := p.index sindex := p.index
p.buf = b p.buf = b
p.index = 0 p.index = 0
depth := 0 depth := 0
@ -536,7 +543,7 @@ out:
fmt.Printf("\n") fmt.Printf("\n")
p.buf = obuf p.buf = obuf
p.index = index p.index = sindex
} }
// SetDefaults sets unset protocol buffer fields to their default values. // SetDefaults sets unset protocol buffer fields to their default values.
@ -881,3 +888,7 @@ func isProto3Zero(v reflect.Value) bool {
} }
return false return false
} }
// ProtoPackageIsVersion1 is referenced from generated protocol buffer files
// to assert that that code is compatible with this version of the proto package.
const GoGoProtoPackageIsVersion1 = true

View File

@ -96,6 +96,9 @@ type oneofMarshaler func(Message, *Buffer) error
// A oneofUnmarshaler does the unmarshaling for a oneof field in a message. // A oneofUnmarshaler does the unmarshaling for a oneof field in a message.
type oneofUnmarshaler func(Message, int, int, *Buffer) (bool, error) type oneofUnmarshaler func(Message, int, int, *Buffer) (bool, error)
// A oneofSizer does the sizing for all oneof fields in a message.
type oneofSizer func(Message) int
// tagMap is an optimization over map[int]int for typical protocol buffer // tagMap is an optimization over map[int]int for typical protocol buffer
// use-cases. Encoded protocol buffers are often in tag order with small tag // use-cases. Encoded protocol buffers are often in tag order with small tag
// numbers. // numbers.
@ -147,6 +150,7 @@ type StructProperties struct {
oneofMarshaler oneofMarshaler oneofMarshaler oneofMarshaler
oneofUnmarshaler oneofUnmarshaler oneofUnmarshaler oneofUnmarshaler
oneofSizer oneofSizer
stype reflect.Type stype reflect.Type
// OneofTypes contains information about the oneof fields in this message. // OneofTypes contains information about the oneof fields in this message.
@ -174,6 +178,7 @@ func (sp *StructProperties) Swap(i, j int) { sp.order[i], sp.order[j] = sp.order
type Properties struct { type Properties struct {
Name string // name of the field, for error messages Name string // name of the field, for error messages
OrigName string // original name before protocol compiler (always set) OrigName string // original name before protocol compiler (always set)
JSONName string // name to use for JSON; determined by protoc
Wire string Wire string
WireType int WireType int
Tag int Tag int
@ -233,8 +238,9 @@ func (p *Properties) String() string {
if p.Packed { if p.Packed {
s += ",packed" s += ",packed"
} }
if p.OrigName != p.Name { s += ",name=" + p.OrigName
s += ",name=" + p.OrigName if p.JSONName != p.OrigName {
s += ",json=" + p.JSONName
} }
if p.proto3 { if p.proto3 {
s += ",proto3" s += ",proto3"
@ -314,6 +320,8 @@ func (p *Properties) Parse(s string) {
p.Packed = true p.Packed = true
case strings.HasPrefix(f, "name="): case strings.HasPrefix(f, "name="):
p.OrigName = f[5:] p.OrigName = f[5:]
case strings.HasPrefix(f, "json="):
p.JSONName = f[5:]
case strings.HasPrefix(f, "enum="): case strings.HasPrefix(f, "enum="):
p.Enum = f[5:] p.Enum = f[5:]
case f == "proto3": case f == "proto3":
@ -784,11 +792,11 @@ func getPropertiesLocked(t reflect.Type) *StructProperties {
sort.Sort(prop) sort.Sort(prop)
type oneofMessage interface { type oneofMessage interface {
XXX_OneofFuncs() (func(Message, *Buffer) error, func(Message, int, int, *Buffer) (bool, error), []interface{}) XXX_OneofFuncs() (func(Message, *Buffer) error, func(Message, int, int, *Buffer) (bool, error), func(Message) int, []interface{})
} }
if om, ok := reflect.Zero(reflect.PtrTo(t)).Interface().(oneofMessage); isOneofMessage && ok { if om, ok := reflect.Zero(reflect.PtrTo(t)).Interface().(oneofMessage); isOneofMessage && ok {
var oots []interface{} var oots []interface{}
prop.oneofMarshaler, prop.oneofUnmarshaler, oots = om.XXX_OneofFuncs() prop.oneofMarshaler, prop.oneofUnmarshaler, prop.oneofSizer, oots = om.XXX_OneofFuncs()
prop.stype = t prop.stype = t
// Interpret oneof metadata. // Interpret oneof metadata.

View File

@ -573,12 +573,12 @@ func writeUnknownStruct(w *textWriter, data []byte) (err error) {
return ferr return ferr
} }
if wire != WireStartGroup { if wire != WireStartGroup {
if err := w.WriteByte(':'); err != nil { if err = w.WriteByte(':'); err != nil {
return err return err
} }
} }
if !w.compact || wire == WireStartGroup { if !w.compact || wire == WireStartGroup {
if err := w.WriteByte(' '); err != nil { if err = w.WriteByte(' '); err != nil {
return err return err
} }
} }

View File

@ -124,6 +124,14 @@ func isWhitespace(c byte) bool {
return false return false
} }
func isQuote(c byte) bool {
switch c {
case '"', '\'':
return true
}
return false
}
func (p *textParser) skipWhitespace() { func (p *textParser) skipWhitespace() {
i := 0 i := 0
for i < len(p.s) && (isWhitespace(p.s[i]) || p.s[i] == '#') { for i < len(p.s) && (isWhitespace(p.s[i]) || p.s[i] == '#') {
@ -338,13 +346,13 @@ func (p *textParser) next() *token {
p.advance() p.advance()
if p.done { if p.done {
p.cur.value = "" p.cur.value = ""
} else if len(p.cur.value) > 0 && p.cur.value[0] == '"' { } else if len(p.cur.value) > 0 && isQuote(p.cur.value[0]) {
// Look for multiple quoted strings separated by whitespace, // Look for multiple quoted strings separated by whitespace,
// and concatenate them. // and concatenate them.
cat := p.cur cat := p.cur
for { for {
p.skipWhitespace() p.skipWhitespace()
if p.done || p.s[0] != '"' { if p.done || !isQuote(p.s[0]) {
break break
} }
p.advance() p.advance()
@ -724,15 +732,15 @@ func (p *textParser) readAny(v reflect.Value, props *Properties) error {
if err != nil { if err != nil {
return err return err
} }
tok := p.next() ntok := p.next()
if tok.err != nil { if ntok.err != nil {
return tok.err return ntok.err
} }
if tok.value == "]" { if ntok.value == "]" {
break break
} }
if tok.value != "," { if ntok.value != "," {
return p.errorf("Expected ']' or ',' found %q", tok.value) return p.errorf("Expected ']' or ',' found %q", ntok.value)
} }
} }
return nil return nil

View File

@ -8,7 +8,7 @@ import (
// DefaultBackoffConfig uses values specified for backoff in // DefaultBackoffConfig uses values specified for backoff in
// https://github.com/grpc/grpc/blob/master/doc/connection-backoff.md. // https://github.com/grpc/grpc/blob/master/doc/connection-backoff.md.
var ( var (
DefaultBackoffConfig = &BackoffConfig{ DefaultBackoffConfig = BackoffConfig{
MaxDelay: 120 * time.Second, MaxDelay: 120 * time.Second,
baseDelay: 1.0 * time.Second, baseDelay: 1.0 * time.Second,
factor: 1.6, factor: 1.6,
@ -33,7 +33,10 @@ type BackoffConfig struct {
// MaxDelay is the upper bound of backoff delay. // MaxDelay is the upper bound of backoff delay.
MaxDelay time.Duration MaxDelay time.Duration
// TODO(stevvooe): The following fields are not exported, as allowing changes // TODO(stevvooe): The following fields are not exported, as allowing
// changes would violate the current GRPC specification for backoff. If
// GRPC decides to allow more interesting backoff strategies, these fields
// may be opened up in the future.
// baseDelay is the amount of time to wait before retrying after the first // baseDelay is the amount of time to wait before retrying after the first
// failure. // failure.
@ -46,7 +49,16 @@ type BackoffConfig struct {
jitter float64 jitter float64
} }
func (bc *BackoffConfig) backoff(retries int) (t time.Duration) { func setDefaults(bc *BackoffConfig) {
md := bc.MaxDelay
*bc = DefaultBackoffConfig
if md > 0 {
bc.MaxDelay = md
}
}
func (bc BackoffConfig) backoff(retries int) (t time.Duration) {
if retries == 0 { if retries == 0 {
return bc.baseDelay return bc.baseDelay
} }

View File

@ -115,9 +115,21 @@ func WithPicker(p Picker) DialOption {
} }
} }
// WithBackoffMaxDelay configures the dialer to use the provided maximum delay
// when backing off after failed connection attempts.
func WithBackoffMaxDelay(md time.Duration) DialOption {
return WithBackoffConfig(BackoffConfig{MaxDelay: md})
}
// WithBackoffConfig configures the dialer to use the provided backoff // WithBackoffConfig configures the dialer to use the provided backoff
// parameters after connection failures. // parameters after connection failures.
func WithBackoffConfig(b *BackoffConfig) DialOption { //
// Use WithBackoffMaxDelay until more parameters on BackoffConfig are opened up
// for use.
func WithBackoffConfig(b BackoffConfig) DialOption {
// Set defaults to ensure that provided BackoffConfig is valid and
// unexported fields get default values.
setDefaults(&b)
return withBackoff(b) return withBackoff(b)
} }

74
cmd/vendor/google.golang.org/grpc/interceptor.go generated vendored Normal file
View File

@ -0,0 +1,74 @@
/*
*
* Copyright 2016, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package grpc
import (
"golang.org/x/net/context"
)
// UnaryServerInfo consists of various information about a unary RPC on
// server side. All per-rpc information may be mutated by the interceptor.
type UnaryServerInfo struct {
// Server is the service implementation the user provides. This is read-only.
Server interface{}
// FullMethod is the full RPC method string, i.e., /package.service/method.
FullMethod string
}
// UnaryHandler defines the handler invoked by UnaryServerInterceptor to complete the normal
// execution of a unary RPC.
type UnaryHandler func(ctx context.Context, req interface{}) (interface{}, error)
// UnaryServerInterceptor provides a hook to intercept the execution of a unary RPC on the server. info
// contains all the information of this RPC the interceptor can operate on. And handler is the wrapper
// of the service method implementation. It is the responsibility of the interceptor to invoke handler
// to complete the RPC.
type UnaryServerInterceptor func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error)
// StreamServerInfo consists of various information about a streaming RPC on
// server side. All per-rpc information may be mutated by the interceptor.
type StreamServerInfo struct {
// FullMethod is the full RPC method string, i.e., /package.service/method.
FullMethod string
// IsClientStream indicates whether the RPC is a client streaming RPC.
IsClientStream bool
// IsServerStream indicates whether the RPC is a server streaming RPC.
IsServerStream bool
}
// StreamServerInterceptor provides a hook to intercept the execution of a streaming RPC on the server.
// info contains all the information of this RPC the interceptor can operate on. And handler is the
// service method implementation. It is the responsibility of the interceptor to invoke handler to
// complete the RPC.
type StreamServerInterceptor func(srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error

View File

@ -409,10 +409,10 @@ func convertCode(err error) codes.Code {
return codes.Unknown return codes.Unknown
} }
// SupportPackageIsVersion1 is referenced from generated protocol buffer files // SupportPackageIsVersion2 is referenced from generated protocol buffer files
// to assert that that code is compatible with this version of the grpc package. // to assert that that code is compatible with this version of the grpc package.
// //
// This constant may be renamed in the future if a change in the generated code // This constant may be renamed in the future if a change in the generated code
// requires a synchronised update of grpc-go and protoc-gen-go. This constant // requires a synchronised update of grpc-go and protoc-gen-go. This constant
// should not be referenced from any other code. // should not be referenced from any other code.
const SupportPackageIsVersion1 = true const SupportPackageIsVersion2 = true

View File

@ -57,7 +57,7 @@ import (
"google.golang.org/grpc/transport" "google.golang.org/grpc/transport"
) )
type methodHandler func(srv interface{}, ctx context.Context, dec func(interface{}) error) (interface{}, error) type methodHandler func(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor UnaryServerInterceptor) (interface{}, error)
// MethodDesc represents an RPC service's method specification. // MethodDesc represents an RPC service's method specification.
type MethodDesc struct { type MethodDesc struct {
@ -99,6 +99,8 @@ type options struct {
codec Codec codec Codec
cp Compressor cp Compressor
dc Decompressor dc Decompressor
unaryInt UnaryServerInterceptor
streamInt StreamServerInterceptor
maxConcurrentStreams uint32 maxConcurrentStreams uint32
useHandlerImpl bool // use http.Handler-based server useHandlerImpl bool // use http.Handler-based server
} }
@ -140,6 +142,29 @@ func Creds(c credentials.Credentials) ServerOption {
} }
} }
// UnaryInterceptor returns a ServerOption that sets the UnaryServerInterceptor for the
// server. Only one unary interceptor can be installed. The construction of multiple
// interceptors (e.g., chaining) can be implemented at the caller.
func UnaryInterceptor(i UnaryServerInterceptor) ServerOption {
return func(o *options) {
if o.unaryInt != nil {
panic("The unary server interceptor has been set.")
}
o.unaryInt = i
}
}
// StreamInterceptor returns a ServerOption that sets the StreamServerInterceptor for the
// server. Only one stream interceptor can be installed.
func StreamInterceptor(i StreamServerInterceptor) ServerOption {
return func(o *options) {
if o.streamInt != nil {
panic("The stream server interceptor has been set.")
}
o.streamInt = i
}
}
// NewServer creates a gRPC server which has no service registered and has not // NewServer creates a gRPC server which has no service registered and has not
// started to accept requests yet. // started to accept requests yet.
func NewServer(opt ...ServerOption) *Server { func NewServer(opt ...ServerOption) *Server {
@ -494,7 +519,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
} }
return nil return nil
} }
reply, appErr := md.Handler(srv.server, stream.Context(), df) reply, appErr := md.Handler(srv.server, stream.Context(), df, s.opts.unaryInt)
if appErr != nil { if appErr != nil {
if err, ok := appErr.(rpcError); ok { if err, ok := appErr.(rpcError); ok {
statusCode = err.code statusCode = err.code
@ -572,7 +597,18 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
ss.mu.Unlock() ss.mu.Unlock()
}() }()
} }
if appErr := sd.Handler(srv.server, ss); appErr != nil { var appErr error
if s.opts.streamInt == nil {
appErr = sd.Handler(srv.server, ss)
} else {
info := &StreamServerInfo{
FullMethod: stream.Method(),
IsClientStream: sd.ClientStreams,
IsServerStream: sd.ServerStreams,
}
appErr = s.opts.streamInt(srv.server, ss, info, sd.Handler)
}
if appErr != nil {
if err, ok := appErr.(rpcError); ok { if err, ok := appErr.(rpcError); ok {
ss.statusCode = err.code ss.statusCode = err.code
ss.statusDesc = err.desc ss.statusDesc = err.desc

View File

@ -47,12 +47,14 @@ import (
"google.golang.org/grpc/transport" "google.golang.org/grpc/transport"
) )
type streamHandler func(srv interface{}, stream ServerStream) error // StreamHandler defines the handler called by gRPC server to complete the
// execution of a streaming RPC.
type StreamHandler func(srv interface{}, stream ServerStream) error
// StreamDesc represents a streaming RPC service's method specification. // StreamDesc represents a streaming RPC service's method specification.
type StreamDesc struct { type StreamDesc struct {
StreamName string StreamName string
Handler streamHandler Handler StreamHandler
// At least one of these is true. // At least one of these is true.
ServerStreams bool ServerStreams bool

View File

@ -162,10 +162,6 @@ func (qb *quotaPool) acquire() <-chan int {
type inFlow struct { type inFlow struct {
// The inbound flow control limit for pending data. // The inbound flow control limit for pending data.
limit uint32 limit uint32
// conn points to the shared connection-level inFlow that is shared
// by all streams on that conn. It is nil for the inFlow on the conn
// directly.
conn *inFlow
mu sync.Mutex mu sync.Mutex
// pendingData is the overall data which have been received but not been // pendingData is the overall data which have been received but not been
@ -176,97 +172,39 @@ type inFlow struct {
pendingUpdate uint32 pendingUpdate uint32
} }
// onData is invoked when some data frame is received. It increments not only its // onData is invoked when some data frame is received. It updates pendingData.
// own pendingData but also that of the associated connection-level flow.
func (f *inFlow) onData(n uint32) error { func (f *inFlow) onData(n uint32) error {
if n == 0 {
return nil
}
f.mu.Lock() f.mu.Lock()
defer f.mu.Unlock() defer f.mu.Unlock()
if f.pendingData+f.pendingUpdate+n > f.limit {
return fmt.Errorf("received %d-bytes data exceeding the limit %d bytes", f.pendingData+f.pendingUpdate+n, f.limit)
}
if f.conn != nil {
if err := f.conn.onData(n); err != nil {
return ConnectionErrorf("%v", err)
}
}
f.pendingData += n f.pendingData += n
if f.pendingData+f.pendingUpdate > f.limit {
return fmt.Errorf("received %d-bytes data exceeding the limit %d bytes", f.pendingData+f.pendingUpdate, f.limit)
}
return nil return nil
} }
// adjustConnPendingUpdate increments the connection level pending updates by n. // onRead is invoked when the application reads the data. It returns the window size
// This is called to make the proper connection level window updates when // to be sent to the peer.
// receiving data frame targeting the canceled RPCs. func (f *inFlow) onRead(n uint32) uint32 {
func (f *inFlow) adjustConnPendingUpdate(n uint32) (uint32, error) {
if n == 0 || f.conn != nil {
return 0, nil
}
f.mu.Lock() f.mu.Lock()
defer f.mu.Unlock() defer f.mu.Unlock()
if f.pendingData+f.pendingUpdate+n > f.limit { if f.pendingData == 0 {
return 0, ConnectionErrorf("received %d-bytes data exceeding the limit %d bytes", f.pendingData+f.pendingUpdate+n, f.limit)
}
f.pendingUpdate += n
if f.pendingUpdate >= f.limit/4 {
ret := f.pendingUpdate
f.pendingUpdate = 0
return ret, nil
}
return 0, nil
}
// connOnRead updates the connection level states when the application consumes data.
func (f *inFlow) connOnRead(n uint32) uint32 {
if n == 0 || f.conn != nil {
return 0 return 0
} }
f.mu.Lock()
defer f.mu.Unlock()
f.pendingData -= n f.pendingData -= n
f.pendingUpdate += n f.pendingUpdate += n
if f.pendingUpdate >= f.limit/4 { if f.pendingUpdate >= f.limit/4 {
ret := f.pendingUpdate wu := f.pendingUpdate
f.pendingUpdate = 0 f.pendingUpdate = 0
return ret return wu
} }
return 0 return 0
} }
// onRead is invoked when the application reads the data. It returns the window updates func (f *inFlow) resetPendingData() uint32 {
// for both stream and connection level.
func (f *inFlow) onRead(n uint32) (swu, cwu uint32) {
if n == 0 {
return
}
f.mu.Lock()
defer f.mu.Unlock()
if f.pendingData == 0 {
// pendingData has been adjusted by restoreConn.
return
}
f.pendingData -= n
f.pendingUpdate += n
if f.pendingUpdate >= f.limit/4 {
swu = f.pendingUpdate
f.pendingUpdate = 0
}
cwu = f.conn.connOnRead(n)
return
}
// restoreConn is invoked when a stream is terminated. It removes its stake in
// the connection-level flow and resets its own state.
func (f *inFlow) restoreConn() uint32 {
if f.conn == nil {
return 0
}
f.mu.Lock() f.mu.Lock()
defer f.mu.Unlock() defer f.mu.Unlock()
n := f.pendingData n := f.pendingData
f.pendingData = 0 f.pendingData = 0
f.pendingUpdate = 0 return n
return f.conn.connOnRead(n)
} }

View File

@ -140,29 +140,6 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
conn.Close() conn.Close()
} }
}() }()
// Send connection preface to server.
n, err := conn.Write(clientPreface)
if err != nil {
return nil, ConnectionErrorf("transport: %v", err)
}
if n != len(clientPreface) {
return nil, ConnectionErrorf("transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface))
}
framer := newFramer(conn)
if initialWindowSize != defaultWindowSize {
err = framer.writeSettings(true, http2.Setting{http2.SettingInitialWindowSize, uint32(initialWindowSize)})
} else {
err = framer.writeSettings(true)
}
if err != nil {
return nil, ConnectionErrorf("transport: %v", err)
}
// Adjust the connection flow control window if needed.
if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 {
if err := framer.writeWindowUpdate(true, 0, delta); err != nil {
return nil, ConnectionErrorf("transport: %v", err)
}
}
ua := primaryUA ua := primaryUA
if opts.UserAgent != "" { if opts.UserAgent != "" {
ua = opts.UserAgent + " " + ua ua = opts.UserAgent + " " + ua
@ -178,7 +155,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
writableChan: make(chan int, 1), writableChan: make(chan int, 1),
shutdownChan: make(chan struct{}), shutdownChan: make(chan struct{}),
errorChan: make(chan struct{}), errorChan: make(chan struct{}),
framer: framer, framer: newFramer(conn),
hBuf: &buf, hBuf: &buf,
hEnc: hpack.NewEncoder(&buf), hEnc: hpack.NewEncoder(&buf),
controlBuf: newRecvBuffer(), controlBuf: newRecvBuffer(),
@ -191,28 +168,49 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
maxStreams: math.MaxInt32, maxStreams: math.MaxInt32,
streamSendQuota: defaultWindowSize, streamSendQuota: defaultWindowSize,
} }
// Start the reader goroutine for incoming message. Each transport has
// a dedicated goroutine which reads HTTP2 frame from network. Then it
// dispatches the frame to the corresponding stream entity.
go t.reader()
// Send connection preface to server.
n, err := t.conn.Write(clientPreface)
if err != nil {
t.Close()
return nil, ConnectionErrorf("transport: %v", err)
}
if n != len(clientPreface) {
t.Close()
return nil, ConnectionErrorf("transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface))
}
if initialWindowSize != defaultWindowSize {
err = t.framer.writeSettings(true, http2.Setting{http2.SettingInitialWindowSize, uint32(initialWindowSize)})
} else {
err = t.framer.writeSettings(true)
}
if err != nil {
t.Close()
return nil, ConnectionErrorf("transport: %v", err)
}
// Adjust the connection flow control window if needed.
if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 {
if err := t.framer.writeWindowUpdate(true, 0, delta); err != nil {
t.Close()
return nil, ConnectionErrorf("transport: %v", err)
}
}
go t.controller() go t.controller()
t.writableChan <- 0 t.writableChan <- 0
// Start the reader goroutine for incoming message. The threading model
// on receiving is that each transport has a dedicated goroutine which
// reads HTTP2 frame from network. Then it dispatches the frame to the
// corresponding stream entity.
go t.reader()
return t, nil return t, nil
} }
func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
fc := &inFlow{
limit: initialWindowSize,
conn: t.fc,
}
// TODO(zhaoq): Handle uint32 overflow of Stream.id. // TODO(zhaoq): Handle uint32 overflow of Stream.id.
s := &Stream{ s := &Stream{
id: t.nextID, id: t.nextID,
method: callHdr.Method, method: callHdr.Method,
sendCompress: callHdr.SendCompress, sendCompress: callHdr.SendCompress,
buf: newRecvBuffer(), buf: newRecvBuffer(),
fc: fc, fc: &inFlow{limit: initialWindowSize},
sendQuotaPool: newQuotaPool(int(t.streamSendQuota)), sendQuotaPool: newQuotaPool(int(t.streamSendQuota)),
headerChan: make(chan struct{}), headerChan: make(chan struct{}),
} }
@ -237,8 +235,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
if dl, ok := ctx.Deadline(); ok { if dl, ok := ctx.Deadline(); ok {
timeout = dl.Sub(time.Now()) timeout = dl.Sub(time.Now())
} }
if err := ctx.Err(); err != nil { select {
return nil, ContextErr(err) case <-ctx.Done():
return nil, ContextErr(ctx.Err())
default:
} }
pr := &peer.Peer{ pr := &peer.Peer{
Addr: t.conn.RemoteAddr(), Addr: t.conn.RemoteAddr(),
@ -404,8 +404,10 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
// other goroutines. // other goroutines.
s.cancel() s.cancel()
s.mu.Lock() s.mu.Lock()
if q := s.fc.restoreConn(); q > 0 { if q := s.fc.resetPendingData(); q > 0 {
t.controlBuf.put(&windowUpdate{0, q}) if n := t.fc.onRead(q); n > 0 {
t.controlBuf.put(&windowUpdate{0, n})
}
} }
if s.state == streamDone { if s.state == streamDone {
s.mu.Unlock() s.mu.Unlock()
@ -427,6 +429,9 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
// accessed any more. // accessed any more.
func (t *http2Client) Close() (err error) { func (t *http2Client) Close() (err error) {
t.mu.Lock() t.mu.Lock()
if t.state == reachable {
close(t.errorChan)
}
if t.state == closing { if t.state == closing {
t.mu.Unlock() t.mu.Unlock()
return errors.New("transport: Close() was already called") return errors.New("transport: Close() was already called")
@ -505,6 +510,10 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
t.framer.adjustNumWriters(1) t.framer.adjustNumWriters(1)
// Got some quota. Try to acquire writing privilege on the transport. // Got some quota. Try to acquire writing privilege on the transport.
if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil {
if _, ok := err.(StreamError); ok {
// Return the connection quota back.
t.sendQuotaPool.add(len(p))
}
if t.framer.adjustNumWriters(-1) == 0 { if t.framer.adjustNumWriters(-1) == 0 {
// This writer is the last one in this batch and has the // This writer is the last one in this batch and has the
// responsibility to flush the buffered frames. It queues // responsibility to flush the buffered frames. It queues
@ -514,6 +523,16 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
} }
return err return err
} }
select {
case <-s.ctx.Done():
t.sendQuotaPool.add(len(p))
if t.framer.adjustNumWriters(-1) == 0 {
t.controlBuf.put(&flushIO{})
}
t.writableChan <- 0
return ContextErr(s.ctx.Err())
default:
}
if r.Len() == 0 && t.framer.adjustNumWriters(0) == 1 { if r.Len() == 0 && t.framer.adjustNumWriters(0) == 1 {
// Do a force flush iff this is last frame for the entire gRPC message // Do a force flush iff this is last frame for the entire gRPC message
// and the caller is the only writer at this moment. // and the caller is the only writer at this moment.
@ -560,41 +579,39 @@ func (t *http2Client) getStream(f http2.Frame) (*Stream, bool) {
// Window updates will deliver to the controller for sending when // Window updates will deliver to the controller for sending when
// the cumulative quota exceeds the corresponding threshold. // the cumulative quota exceeds the corresponding threshold.
func (t *http2Client) updateWindow(s *Stream, n uint32) { func (t *http2Client) updateWindow(s *Stream, n uint32) {
swu, cwu := s.fc.onRead(n) if w := t.fc.onRead(n); w > 0 {
if swu > 0 { t.controlBuf.put(&windowUpdate{0, w})
t.controlBuf.put(&windowUpdate{s.id, swu})
} }
if cwu > 0 { if w := s.fc.onRead(n); w > 0 {
t.controlBuf.put(&windowUpdate{0, cwu}) t.controlBuf.put(&windowUpdate{s.id, w})
} }
} }
func (t *http2Client) handleData(f *http2.DataFrame) { func (t *http2Client) handleData(f *http2.DataFrame) {
// Select the right stream to dispatch.
size := len(f.Data()) size := len(f.Data())
if err := t.fc.onData(uint32(size)); err != nil {
t.notifyError(ConnectionErrorf("%v", err))
return
}
// Select the right stream to dispatch.
s, ok := t.getStream(f) s, ok := t.getStream(f)
if !ok { if !ok {
cwu, err := t.fc.adjustConnPendingUpdate(uint32(size)) if w := t.fc.onRead(uint32(size)); w > 0 {
if err != nil { t.controlBuf.put(&windowUpdate{0, w})
t.notifyError(err)
return
}
if cwu > 0 {
t.controlBuf.put(&windowUpdate{0, cwu})
} }
return return
} }
if size > 0 { if size > 0 {
s.mu.Lock()
if s.state == streamDone {
s.mu.Unlock()
// The stream has been closed. Release the corresponding quota.
if w := t.fc.onRead(uint32(size)); w > 0 {
t.controlBuf.put(&windowUpdate{0, w})
}
return
}
if err := s.fc.onData(uint32(size)); err != nil { if err := s.fc.onData(uint32(size)); err != nil {
if _, ok := err.(ConnectionError); ok {
t.notifyError(err)
return
}
s.mu.Lock()
if s.state == streamDone {
s.mu.Unlock()
return
}
s.state = streamDone s.state = streamDone
s.statusCode = codes.Internal s.statusCode = codes.Internal
s.statusDesc = err.Error() s.statusDesc = err.Error()
@ -603,6 +620,7 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeFlowControl}) t.controlBuf.put(&resetStream{s.id, http2.ErrCodeFlowControl})
return return
} }
s.mu.Unlock()
// TODO(bradfitz, zhaoq): A copy is required here because there is no // TODO(bradfitz, zhaoq): A copy is required here because there is no
// guarantee f.Data() is consumed before the arrival of next frame. // guarantee f.Data() is consumed before the arrival of next frame.
// Can this copy be eliminated? // Can this copy be eliminated?

View File

@ -139,15 +139,11 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthI
// operateHeader takes action on the decoded headers. // operateHeader takes action on the decoded headers.
func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) { func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) {
buf := newRecvBuffer() buf := newRecvBuffer()
fc := &inFlow{
limit: initialWindowSize,
conn: t.fc,
}
s := &Stream{ s := &Stream{
id: frame.Header().StreamID, id: frame.Header().StreamID,
st: t, st: t,
buf: buf, buf: buf,
fc: fc, fc: &inFlow{limit: initialWindowSize},
} }
var state decodeState var state decodeState
@ -307,42 +303,46 @@ func (t *http2Server) getStream(f http2.Frame) (*Stream, bool) {
// Window updates will deliver to the controller for sending when // Window updates will deliver to the controller for sending when
// the cumulative quota exceeds the corresponding threshold. // the cumulative quota exceeds the corresponding threshold.
func (t *http2Server) updateWindow(s *Stream, n uint32) { func (t *http2Server) updateWindow(s *Stream, n uint32) {
swu, cwu := s.fc.onRead(n) if w := t.fc.onRead(n); w > 0 {
if swu > 0 { t.controlBuf.put(&windowUpdate{0, w})
t.controlBuf.put(&windowUpdate{s.id, swu})
} }
if cwu > 0 { if w := s.fc.onRead(n); w > 0 {
t.controlBuf.put(&windowUpdate{0, cwu}) t.controlBuf.put(&windowUpdate{s.id, w})
} }
} }
func (t *http2Server) handleData(f *http2.DataFrame) { func (t *http2Server) handleData(f *http2.DataFrame) {
// Select the right stream to dispatch.
size := len(f.Data()) size := len(f.Data())
if err := t.fc.onData(uint32(size)); err != nil {
grpclog.Printf("transport: http2Server %v", err)
t.Close()
return
}
// Select the right stream to dispatch.
s, ok := t.getStream(f) s, ok := t.getStream(f)
if !ok { if !ok {
cwu, err := t.fc.adjustConnPendingUpdate(uint32(size)) if w := t.fc.onRead(uint32(size)); w > 0 {
if err != nil { t.controlBuf.put(&windowUpdate{0, w})
grpclog.Printf("transport: http2Server %v", err)
t.Close()
return
}
if cwu > 0 {
t.controlBuf.put(&windowUpdate{0, cwu})
} }
return return
} }
if size > 0 { if size > 0 {
if err := s.fc.onData(uint32(size)); err != nil { s.mu.Lock()
if _, ok := err.(ConnectionError); ok { if s.state == streamDone {
grpclog.Printf("transport: http2Server %v", err) s.mu.Unlock()
t.Close() // The stream has been closed. Release the corresponding quota.
return if w := t.fc.onRead(uint32(size)); w > 0 {
t.controlBuf.put(&windowUpdate{0, w})
} }
return
}
if err := s.fc.onData(uint32(size)); err != nil {
s.mu.Unlock()
t.closeStream(s) t.closeStream(s)
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeFlowControl}) t.controlBuf.put(&resetStream{s.id, http2.ErrCodeFlowControl})
return return
} }
s.mu.Unlock()
// TODO(bradfitz, zhaoq): A copy is required here because there is no // TODO(bradfitz, zhaoq): A copy is required here because there is no
// guarantee f.Data() is consumed before the arrival of next frame. // guarantee f.Data() is consumed before the arrival of next frame.
// Can this copy be eliminated? // Can this copy be eliminated?
@ -516,6 +516,10 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
// TODO(zhaoq): Support multi-writers for a single stream. // TODO(zhaoq): Support multi-writers for a single stream.
var writeHeaderFrame bool var writeHeaderFrame bool
s.mu.Lock() s.mu.Lock()
if s.state == streamDone {
s.mu.Unlock()
return StreamErrorf(codes.Unknown, "the stream has been done")
}
if !s.headerOk { if !s.headerOk {
writeHeaderFrame = true writeHeaderFrame = true
s.headerOk = true s.headerOk = true
@ -583,6 +587,10 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
// Got some quota. Try to acquire writing privilege on the // Got some quota. Try to acquire writing privilege on the
// transport. // transport.
if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil {
if _, ok := err.(StreamError); ok {
// Return the connection quota back.
t.sendQuotaPool.add(ps)
}
if t.framer.adjustNumWriters(-1) == 0 { if t.framer.adjustNumWriters(-1) == 0 {
// This writer is the last one in this batch and has the // This writer is the last one in this batch and has the
// responsibility to flush the buffered frames. It queues // responsibility to flush the buffered frames. It queues
@ -592,6 +600,16 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
} }
return err return err
} }
select {
case <-s.ctx.Done():
t.sendQuotaPool.add(ps)
if t.framer.adjustNumWriters(-1) == 0 {
t.controlBuf.put(&flushIO{})
}
t.writableChan <- 0
return ContextErr(s.ctx.Err())
default:
}
var forceFlush bool var forceFlush bool
if r.Len() == 0 && t.framer.adjustNumWriters(0) == 1 && !opts.Last { if r.Len() == 0 && t.framer.adjustNumWriters(0) == 1 && !opts.Last {
forceFlush = true forceFlush = true
@ -689,20 +707,22 @@ func (t *http2Server) closeStream(s *Stream) {
t.mu.Lock() t.mu.Lock()
delete(t.activeStreams, s.id) delete(t.activeStreams, s.id)
t.mu.Unlock() t.mu.Unlock()
if q := s.fc.restoreConn(); q > 0 { // In case stream sending and receiving are invoked in separate
t.controlBuf.put(&windowUpdate{0, q}) // goroutines (e.g., bi-directional streaming), cancel needs to be
} // called to interrupt the potential blocking on other goroutines.
s.cancel()
s.mu.Lock() s.mu.Lock()
if q := s.fc.resetPendingData(); q > 0 {
if w := t.fc.onRead(q); w > 0 {
t.controlBuf.put(&windowUpdate{0, w})
}
}
if s.state == streamDone { if s.state == streamDone {
s.mu.Unlock() s.mu.Unlock()
return return
} }
s.state = streamDone s.state = streamDone
s.mu.Unlock() s.mu.Unlock()
// In case stream sending and receiving are invoked in separate
// goroutines (e.g., bi-directional streaming), cancel needs to be
// called to interrupt the potential blocking on other goroutines.
s.cancel()
} }
func (t *http2Server) RemoteAddr() net.Addr { func (t *http2Server) RemoteAddr() net.Addr {