Compare commits

...

18 Commits

Author SHA1 Message Date
43f7c94ac8 version: bump to v3.0.5 2016-08-19 10:20:37 -07:00
93d13fb5b4 integration: NewClusterV3 should launch cluster before creating clients 2016-08-18 14:54:45 -07:00
6a1e3e73dd vendor: boltdb/bolt v1.3.0 for Go 1.7
In case somebody wants to build this branch with Go 1.7
2016-08-18 14:41:34 -07:00
ec576ee5ac mvcc: fix count 2016-08-16 12:13:33 -07:00
606d79afc4 clientv3: use failfast and retry wrappers for at-most-once rpcs 2016-08-16 12:12:44 -07:00
f4d15a430c integration: treat client TLS connecting to insecure server as timeout 2016-08-16 12:09:42 -07:00
4a841459f1 clientv3: respect up/down notifications from grpc
Fixes #5842
2016-08-16 12:09:38 -07:00
ee8c577fc0 vendor: update grpc 2016-08-16 12:09:16 -07:00
8ae0f94cd7 clientv3: only block on New() when DialTimeout > 0
Fixes #6162
2016-08-12 12:03:33 -07:00
69a97863a9 clientv3: handle watchGrpcStream shutdown if prior to goroutine start
Fixes #6141
2016-08-09 20:59:09 -07:00
12c7e4a9f8 clientv3: close watcher stream once all watchers detach
Fixes #6134
2016-08-09 10:44:21 -07:00
23cced240b transport: add ServerName to TLSConfig and add ValidateSecureEndpoints
ServerName prevents accepting forged SRV records with cross-domain
credentials. ValidateSecureEndpoints prevents downgrade attacks from SRV
records.
2016-08-04 11:00:28 -07:00
e73c928d85 etcdctl: set ServerName for TLS when using --discovery-srv 2016-08-04 11:00:25 -07:00
779ad90f9a Documentation: update clustering guide about PKI SRV record forging 2016-08-04 11:00:22 -07:00
dca1740be5 etcdmain: check TLS on gateway SRV records 2016-08-04 11:00:15 -07:00
487b34d857 embed: use ServerName on TLS DNS discovery w/o CA file 2016-08-04 10:56:11 -07:00
a31283cf51 v2http: use guest access in non-TLS mode
Fix https://github.com/coreos/etcd/issues/6075.
2016-08-04 10:52:42 -07:00
b722bedf8a version: bump to v3.0.4+git 2016-07-27 15:30:31 -07:00
48 changed files with 1824 additions and 436 deletions

View File

@ -357,6 +357,8 @@ To help clients discover the etcd cluster, the following DNS SRV records are loo
If `_etcd-client-ssl._tcp.example.com` is found, clients will attempt to communicate with the etcd cluster over SSL/TLS. If `_etcd-client-ssl._tcp.example.com` is found, clients will attempt to communicate with the etcd cluster over SSL/TLS.
If etcd is using TLS without a custom certificate authority, the discovery domain (e.g., example.com) must match the SRV record domain (e.g., infra1.example.com). This is to mitigate attacks that forge SRV records to point to a different domain; the domain would have a valid certificate under PKI but be controlled by an unknown third party.
#### Create DNS SRV records #### Create DNS SRV records
``` ```

View File

@ -115,32 +115,32 @@ func NewAuth(c *Client) Auth {
} }
func (auth *auth) AuthEnable(ctx context.Context) (*AuthEnableResponse, error) { func (auth *auth) AuthEnable(ctx context.Context) (*AuthEnableResponse, error) {
resp, err := auth.remote.AuthEnable(ctx, &pb.AuthEnableRequest{}, grpc.FailFast(false)) resp, err := auth.remote.AuthEnable(ctx, &pb.AuthEnableRequest{})
return (*AuthEnableResponse)(resp), toErr(ctx, err) return (*AuthEnableResponse)(resp), toErr(ctx, err)
} }
func (auth *auth) AuthDisable(ctx context.Context) (*AuthDisableResponse, error) { func (auth *auth) AuthDisable(ctx context.Context) (*AuthDisableResponse, error) {
resp, err := auth.remote.AuthDisable(ctx, &pb.AuthDisableRequest{}, grpc.FailFast(false)) resp, err := auth.remote.AuthDisable(ctx, &pb.AuthDisableRequest{})
return (*AuthDisableResponse)(resp), toErr(ctx, err) return (*AuthDisableResponse)(resp), toErr(ctx, err)
} }
func (auth *auth) UserAdd(ctx context.Context, name string, password string) (*AuthUserAddResponse, error) { func (auth *auth) UserAdd(ctx context.Context, name string, password string) (*AuthUserAddResponse, error) {
resp, err := auth.remote.UserAdd(ctx, &pb.AuthUserAddRequest{Name: name, Password: password}, grpc.FailFast(false)) resp, err := auth.remote.UserAdd(ctx, &pb.AuthUserAddRequest{Name: name, Password: password})
return (*AuthUserAddResponse)(resp), toErr(ctx, err) return (*AuthUserAddResponse)(resp), toErr(ctx, err)
} }
func (auth *auth) UserDelete(ctx context.Context, name string) (*AuthUserDeleteResponse, error) { func (auth *auth) UserDelete(ctx context.Context, name string) (*AuthUserDeleteResponse, error) {
resp, err := auth.remote.UserDelete(ctx, &pb.AuthUserDeleteRequest{Name: name}, grpc.FailFast(false)) resp, err := auth.remote.UserDelete(ctx, &pb.AuthUserDeleteRequest{Name: name})
return (*AuthUserDeleteResponse)(resp), toErr(ctx, err) return (*AuthUserDeleteResponse)(resp), toErr(ctx, err)
} }
func (auth *auth) UserChangePassword(ctx context.Context, name string, password string) (*AuthUserChangePasswordResponse, error) { func (auth *auth) UserChangePassword(ctx context.Context, name string, password string) (*AuthUserChangePasswordResponse, error) {
resp, err := auth.remote.UserChangePassword(ctx, &pb.AuthUserChangePasswordRequest{Name: name, Password: password}, grpc.FailFast(false)) resp, err := auth.remote.UserChangePassword(ctx, &pb.AuthUserChangePasswordRequest{Name: name, Password: password})
return (*AuthUserChangePasswordResponse)(resp), toErr(ctx, err) return (*AuthUserChangePasswordResponse)(resp), toErr(ctx, err)
} }
func (auth *auth) UserGrantRole(ctx context.Context, user string, role string) (*AuthUserGrantRoleResponse, error) { func (auth *auth) UserGrantRole(ctx context.Context, user string, role string) (*AuthUserGrantRoleResponse, error) {
resp, err := auth.remote.UserGrantRole(ctx, &pb.AuthUserGrantRoleRequest{User: user, Role: role}, grpc.FailFast(false)) resp, err := auth.remote.UserGrantRole(ctx, &pb.AuthUserGrantRoleRequest{User: user, Role: role})
return (*AuthUserGrantRoleResponse)(resp), toErr(ctx, err) return (*AuthUserGrantRoleResponse)(resp), toErr(ctx, err)
} }
@ -155,12 +155,12 @@ func (auth *auth) UserList(ctx context.Context) (*AuthUserListResponse, error) {
} }
func (auth *auth) UserRevokeRole(ctx context.Context, name string, role string) (*AuthUserRevokeRoleResponse, error) { func (auth *auth) UserRevokeRole(ctx context.Context, name string, role string) (*AuthUserRevokeRoleResponse, error) {
resp, err := auth.remote.UserRevokeRole(ctx, &pb.AuthUserRevokeRoleRequest{Name: name, Role: role}, grpc.FailFast(false)) resp, err := auth.remote.UserRevokeRole(ctx, &pb.AuthUserRevokeRoleRequest{Name: name, Role: role})
return (*AuthUserRevokeRoleResponse)(resp), toErr(ctx, err) return (*AuthUserRevokeRoleResponse)(resp), toErr(ctx, err)
} }
func (auth *auth) RoleAdd(ctx context.Context, name string) (*AuthRoleAddResponse, error) { func (auth *auth) RoleAdd(ctx context.Context, name string) (*AuthRoleAddResponse, error) {
resp, err := auth.remote.RoleAdd(ctx, &pb.AuthRoleAddRequest{Name: name}, grpc.FailFast(false)) resp, err := auth.remote.RoleAdd(ctx, &pb.AuthRoleAddRequest{Name: name})
return (*AuthRoleAddResponse)(resp), toErr(ctx, err) return (*AuthRoleAddResponse)(resp), toErr(ctx, err)
} }
@ -170,7 +170,7 @@ func (auth *auth) RoleGrantPermission(ctx context.Context, name string, key, ran
RangeEnd: []byte(rangeEnd), RangeEnd: []byte(rangeEnd),
PermType: authpb.Permission_Type(permType), PermType: authpb.Permission_Type(permType),
} }
resp, err := auth.remote.RoleGrantPermission(ctx, &pb.AuthRoleGrantPermissionRequest{Name: name, Perm: perm}, grpc.FailFast(false)) resp, err := auth.remote.RoleGrantPermission(ctx, &pb.AuthRoleGrantPermissionRequest{Name: name, Perm: perm})
return (*AuthRoleGrantPermissionResponse)(resp), toErr(ctx, err) return (*AuthRoleGrantPermissionResponse)(resp), toErr(ctx, err)
} }
@ -185,12 +185,12 @@ func (auth *auth) RoleList(ctx context.Context) (*AuthRoleListResponse, error) {
} }
func (auth *auth) RoleRevokePermission(ctx context.Context, role string, key, rangeEnd string) (*AuthRoleRevokePermissionResponse, error) { func (auth *auth) RoleRevokePermission(ctx context.Context, role string, key, rangeEnd string) (*AuthRoleRevokePermissionResponse, error) {
resp, err := auth.remote.RoleRevokePermission(ctx, &pb.AuthRoleRevokePermissionRequest{Role: role, Key: key, RangeEnd: rangeEnd}, grpc.FailFast(false)) resp, err := auth.remote.RoleRevokePermission(ctx, &pb.AuthRoleRevokePermissionRequest{Role: role, Key: key, RangeEnd: rangeEnd})
return (*AuthRoleRevokePermissionResponse)(resp), toErr(ctx, err) return (*AuthRoleRevokePermissionResponse)(resp), toErr(ctx, err)
} }
func (auth *auth) RoleDelete(ctx context.Context, role string) (*AuthRoleDeleteResponse, error) { func (auth *auth) RoleDelete(ctx context.Context, role string) (*AuthRoleDeleteResponse, error) {
resp, err := auth.remote.RoleDelete(ctx, &pb.AuthRoleDeleteRequest{Role: role}, grpc.FailFast(false)) resp, err := auth.remote.RoleDelete(ctx, &pb.AuthRoleDeleteRequest{Role: role})
return (*AuthRoleDeleteResponse)(resp), toErr(ctx, err) return (*AuthRoleDeleteResponse)(resp), toErr(ctx, err)
} }

View File

@ -17,7 +17,7 @@ package clientv3
import ( import (
"net/url" "net/url"
"strings" "strings"
"sync/atomic" "sync"
"golang.org/x/net/context" "golang.org/x/net/context"
"google.golang.org/grpc" "google.golang.org/grpc"
@ -26,32 +26,115 @@ import (
// simpleBalancer does the bare minimum to expose multiple eps // simpleBalancer does the bare minimum to expose multiple eps
// to the grpc reconnection code path // to the grpc reconnection code path
type simpleBalancer struct { type simpleBalancer struct {
// eps are the client's endpoints stripped of any URL scheme // addrs are the client's endpoints for grpc
eps []string addrs []grpc.Address
ch chan []grpc.Address // notifyCh notifies grpc of the set of addresses for connecting
numGets uint32 notifyCh chan []grpc.Address
// readyc closes once the first connection is up
readyc chan struct{}
readyOnce sync.Once
// mu protects upEps, pinAddr, and connectingAddr
mu sync.RWMutex
// upEps holds the current endpoints that have an active connection
upEps map[string]struct{}
// upc closes when upEps transitions from empty to non-zero or the balancer closes.
upc chan struct{}
// pinAddr is the currently pinned address; set to the empty string on
// intialization and shutdown.
pinAddr string
} }
func newSimpleBalancer(eps []string) grpc.Balancer { func newSimpleBalancer(eps []string) *simpleBalancer {
ch := make(chan []grpc.Address, 1) notifyCh := make(chan []grpc.Address, 1)
addrs := make([]grpc.Address, len(eps)) addrs := make([]grpc.Address, len(eps))
for i := range eps { for i := range eps {
addrs[i].Addr = getHost(eps[i]) addrs[i].Addr = getHost(eps[i])
} }
ch <- addrs notifyCh <- addrs
return &simpleBalancer{eps: eps, ch: ch} sb := &simpleBalancer{
addrs: addrs,
notifyCh: notifyCh,
readyc: make(chan struct{}),
upEps: make(map[string]struct{}),
upc: make(chan struct{}),
}
return sb
} }
func (b *simpleBalancer) Start(target string) error { return nil } func (b *simpleBalancer) Start(target string) error { return nil }
func (b *simpleBalancer) Up(addr grpc.Address) func(error) { return func(error) {} }
func (b *simpleBalancer) Get(ctx context.Context, opts grpc.BalancerGetOptions) (grpc.Address, func(), error) { func (b *simpleBalancer) ConnectNotify() <-chan struct{} {
v := atomic.AddUint32(&b.numGets, 1) b.mu.Lock()
ep := b.eps[v%uint32(len(b.eps))] defer b.mu.Unlock()
return grpc.Address{Addr: getHost(ep)}, func() {}, nil return b.upc
} }
func (b *simpleBalancer) Notify() <-chan []grpc.Address { return b.ch }
func (b *simpleBalancer) Up(addr grpc.Address) func(error) {
b.mu.Lock()
if len(b.upEps) == 0 {
// notify waiting Get()s and pin first connected address
close(b.upc)
b.pinAddr = addr.Addr
}
b.upEps[addr.Addr] = struct{}{}
b.mu.Unlock()
// notify client that a connection is up
b.readyOnce.Do(func() { close(b.readyc) })
return func(err error) {
b.mu.Lock()
delete(b.upEps, addr.Addr)
if len(b.upEps) == 0 && b.pinAddr != "" {
b.upc = make(chan struct{})
} else if b.pinAddr == addr.Addr {
// choose new random up endpoint
for k := range b.upEps {
b.pinAddr = k
break
}
}
b.mu.Unlock()
}
}
func (b *simpleBalancer) Get(ctx context.Context, opts grpc.BalancerGetOptions) (grpc.Address, func(), error) {
var addr string
for {
b.mu.RLock()
ch := b.upc
b.mu.RUnlock()
select {
case <-ch:
case <-ctx.Done():
return grpc.Address{Addr: ""}, nil, ctx.Err()
}
b.mu.RLock()
addr = b.pinAddr
upEps := len(b.upEps)
b.mu.RUnlock()
if addr == "" {
return grpc.Address{Addr: ""}, nil, grpc.ErrClientConnClosing
}
if upEps > 0 {
break
}
}
return grpc.Address{Addr: addr}, func() {}, nil
}
func (b *simpleBalancer) Notify() <-chan []grpc.Address { return b.notifyCh }
func (b *simpleBalancer) Close() error { func (b *simpleBalancer) Close() error {
close(b.ch) b.mu.Lock()
close(b.notifyCh)
// terminate all waiting Get()s
b.pinAddr = ""
if len(b.upEps) == 0 {
close(b.upc)
}
b.mu.Unlock()
return nil return nil
} }

View File

@ -46,9 +46,11 @@ type Client struct {
Auth Auth
Maintenance Maintenance
conn *grpc.ClientConn conn *grpc.ClientConn
cfg Config cfg Config
creds *credentials.TransportCredentials creds *credentials.TransportCredentials
balancer *simpleBalancer
retryWrapper retryRpcFunc
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
@ -138,11 +140,10 @@ func (c *Client) dialTarget(endpoint string) (proto string, host string, creds *
return return
} }
// dialSetupOpts gives the dial opts prioer to any authentication // dialSetupOpts gives the dial opts prior to any authentication
func (c *Client) dialSetupOpts(endpoint string, dopts ...grpc.DialOption) []grpc.DialOption { func (c *Client) dialSetupOpts(endpoint string, dopts ...grpc.DialOption) (opts []grpc.DialOption) {
opts := []grpc.DialOption{ if c.cfg.DialTimeout > 0 {
grpc.WithBlock(), opts = []grpc.DialOption{grpc.WithTimeout(c.cfg.DialTimeout)}
grpc.WithTimeout(c.cfg.DialTimeout),
} }
opts = append(opts, dopts...) opts = append(opts, dopts...)
@ -240,12 +241,30 @@ func newClient(cfg *Config) (*Client, error) {
client.Password = cfg.Password client.Password = cfg.Password
} }
b := newSimpleBalancer(cfg.Endpoints) client.balancer = newSimpleBalancer(cfg.Endpoints)
conn, err := client.dial(cfg.Endpoints[0], grpc.WithBalancer(b)) conn, err := client.dial(cfg.Endpoints[0], grpc.WithBalancer(client.balancer))
if err != nil { if err != nil {
return nil, err return nil, err
} }
client.conn = conn client.conn = conn
client.retryWrapper = client.newRetryWrapper()
// wait for a connection
if cfg.DialTimeout > 0 {
hasConn := false
waitc := time.After(cfg.DialTimeout)
select {
case <-client.balancer.readyc:
hasConn = true
case <-ctx.Done():
case <-waitc:
}
if !hasConn {
client.cancel()
conn.Close()
return nil, grpc.ErrClientConnTimeout
}
}
client.Cluster = NewCluster(client) client.Cluster = NewCluster(client)
client.KV = NewKV(client) client.KV = NewKV(client)
@ -280,8 +299,12 @@ func isHaltErr(ctx context.Context, err error) bool {
return eErr != rpctypes.ErrStopped && eErr != rpctypes.ErrNoLeader return eErr != rpctypes.ErrStopped && eErr != rpctypes.ErrNoLeader
} }
// treat etcdserver errors not recognized by the client as halting // treat etcdserver errors not recognized by the client as halting
return strings.Contains(err.Error(), grpc.ErrClientConnClosing.Error()) || return isConnClosing(err) || strings.Contains(err.Error(), "etcdserver:")
strings.Contains(err.Error(), "etcdserver:") }
// isConnClosing returns true if the error matches a grpc client closing error
func isConnClosing(err error) bool {
return strings.Contains(err.Error(), grpc.ErrClientConnClosing.Error())
} }
func toErr(ctx context.Context, err error) error { func toErr(ctx context.Context, err error) error {
@ -289,9 +312,12 @@ func toErr(ctx context.Context, err error) error {
return nil return nil
} }
err = rpctypes.Error(err) err = rpctypes.Error(err)
if ctx.Err() != nil && strings.Contains(err.Error(), "context") { switch {
case ctx.Err() != nil && strings.Contains(err.Error(), "context"):
err = ctx.Err() err = ctx.Err()
} else if strings.Contains(err.Error(), grpc.ErrClientConnClosing.Error()) { case strings.Contains(err.Error(), ErrNoAvailableEndpoints.Error()):
err = ErrNoAvailableEndpoints
case strings.Contains(err.Error(), grpc.ErrClientConnClosing.Error()):
err = grpc.ErrClientConnClosing err = grpc.ErrClientConnClosing
} }
return err return err

View File

@ -20,11 +20,14 @@ import (
"time" "time"
"github.com/coreos/etcd/etcdserver" "github.com/coreos/etcd/etcdserver"
"github.com/coreos/etcd/pkg/testutil"
"golang.org/x/net/context" "golang.org/x/net/context"
"google.golang.org/grpc" "google.golang.org/grpc"
) )
func TestDialTimeout(t *testing.T) { func TestDialTimeout(t *testing.T) {
defer testutil.AfterTest(t)
donec := make(chan error) donec := make(chan error)
go func() { go func() {
// without timeout, grpc keeps redialing if connection refused // without timeout, grpc keeps redialing if connection refused
@ -56,6 +59,15 @@ func TestDialTimeout(t *testing.T) {
} }
} }
func TestDialNoTimeout(t *testing.T) {
cfg := Config{Endpoints: []string{"127.0.0.1:12345"}}
c, err := New(cfg)
if c == nil || err != nil {
t.Fatalf("new client with DialNoWait should succeed, got %v", err)
}
c.Close()
}
func TestIsHaltErr(t *testing.T) { func TestIsHaltErr(t *testing.T) {
if !isHaltErr(nil, fmt.Errorf("etcdserver: some etcdserver error")) { if !isHaltErr(nil, fmt.Errorf("etcdserver: some etcdserver error")) {
t.Errorf(`error prefixed with "etcdserver: " should be Halted by default`) t.Errorf(`error prefixed with "etcdserver: " should be Halted by default`)

View File

@ -47,12 +47,12 @@ type cluster struct {
} }
func NewCluster(c *Client) Cluster { func NewCluster(c *Client) Cluster {
return &cluster{remote: pb.NewClusterClient(c.conn)} return &cluster{remote: RetryClusterClient(c)}
} }
func (c *cluster) MemberAdd(ctx context.Context, peerAddrs []string) (*MemberAddResponse, error) { func (c *cluster) MemberAdd(ctx context.Context, peerAddrs []string) (*MemberAddResponse, error) {
r := &pb.MemberAddRequest{PeerURLs: peerAddrs} r := &pb.MemberAddRequest{PeerURLs: peerAddrs}
resp, err := c.remote.MemberAdd(ctx, r, grpc.FailFast(false)) resp, err := c.remote.MemberAdd(ctx, r)
if err == nil { if err == nil {
return (*MemberAddResponse)(resp), nil return (*MemberAddResponse)(resp), nil
} }
@ -64,7 +64,7 @@ func (c *cluster) MemberAdd(ctx context.Context, peerAddrs []string) (*MemberAdd
func (c *cluster) MemberRemove(ctx context.Context, id uint64) (*MemberRemoveResponse, error) { func (c *cluster) MemberRemove(ctx context.Context, id uint64) (*MemberRemoveResponse, error) {
r := &pb.MemberRemoveRequest{ID: id} r := &pb.MemberRemoveRequest{ID: id}
resp, err := c.remote.MemberRemove(ctx, r, grpc.FailFast(false)) resp, err := c.remote.MemberRemove(ctx, r)
if err == nil { if err == nil {
return (*MemberRemoveResponse)(resp), nil return (*MemberRemoveResponse)(resp), nil
} }
@ -78,7 +78,7 @@ func (c *cluster) MemberUpdate(ctx context.Context, id uint64, peerAddrs []strin
// it is safe to retry on update. // it is safe to retry on update.
for { for {
r := &pb.MemberUpdateRequest{ID: id, PeerURLs: peerAddrs} r := &pb.MemberUpdateRequest{ID: id, PeerURLs: peerAddrs}
resp, err := c.remote.MemberUpdate(ctx, r, grpc.FailFast(false)) resp, err := c.remote.MemberUpdate(ctx, r)
if err == nil { if err == nil {
return (*MemberUpdateResponse)(resp), nil return (*MemberUpdateResponse)(resp), nil
} }

View File

@ -16,6 +16,7 @@ package integration
import ( import (
"bytes" "bytes"
"math/rand"
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
@ -662,3 +663,75 @@ func TestKVPutStoppedServerAndClose(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
} }
// TestKVGetOneEndpointDown ensures a client can connect and get if one endpoint is down
func TestKVPutOneEndpointDown(t *testing.T) {
defer testutil.AfterTest(t)
clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 3})
defer clus.Terminate(t)
// get endpoint list
eps := make([]string, 3)
for i := range eps {
eps[i] = clus.Members[i].GRPCAddr()
}
// make a dead node
clus.Members[rand.Intn(len(eps))].Stop(t)
// try to connect with dead node in the endpoint list
cfg := clientv3.Config{Endpoints: eps, DialTimeout: 1 * time.Second}
cli, err := clientv3.New(cfg)
if err != nil {
t.Fatal(err)
}
defer cli.Close()
ctx, cancel := context.WithTimeout(context.TODO(), 3*time.Second)
if _, err := cli.Get(ctx, "abc", clientv3.WithSerializable()); err != nil {
t.Fatal(err)
}
cancel()
}
// TestKVGetResetLoneEndpoint ensures that if an endpoint resets and all other
// endpoints are down, then it will reconnect.
func TestKVGetResetLoneEndpoint(t *testing.T) {
defer testutil.AfterTest(t)
clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 2})
defer clus.Terminate(t)
// get endpoint list
eps := make([]string, 2)
for i := range eps {
eps[i] = clus.Members[i].GRPCAddr()
}
cfg := clientv3.Config{Endpoints: eps, DialTimeout: 500 * time.Millisecond}
cli, err := clientv3.New(cfg)
if err != nil {
t.Fatal(err)
}
defer cli.Close()
// disconnect everything
clus.Members[0].Stop(t)
clus.Members[1].Stop(t)
// have Get try to reconnect
donec := make(chan struct{})
go func() {
ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second)
if _, err := cli.Get(ctx, "abc", clientv3.WithSerializable()); err != nil {
t.Fatal(err)
}
cancel()
close(donec)
}()
time.Sleep(500 * time.Millisecond)
clus.Members[0].Restart(t)
select {
case <-time.After(10 * time.Second):
t.Fatalf("timed out waiting for Get")
case <-donec:
}
}

View File

@ -82,7 +82,7 @@ type kv struct {
} }
func NewKV(c *Client) KV { func NewKV(c *Client) KV {
return &kv{remote: pb.NewKVClient(c.conn)} return &kv{remote: RetryKVClient(c)}
} }
func (kv *kv) Put(ctx context.Context, key, val string, opts ...OpOption) (*PutResponse, error) { func (kv *kv) Put(ctx context.Context, key, val string, opts ...OpOption) (*PutResponse, error) {
@ -158,14 +158,14 @@ func (kv *kv) do(ctx context.Context, op Op) (OpResponse, error) {
case tPut: case tPut:
var resp *pb.PutResponse var resp *pb.PutResponse
r := &pb.PutRequest{Key: op.key, Value: op.val, Lease: int64(op.leaseID)} r := &pb.PutRequest{Key: op.key, Value: op.val, Lease: int64(op.leaseID)}
resp, err = kv.remote.Put(ctx, r, grpc.FailFast(false)) resp, err = kv.remote.Put(ctx, r)
if err == nil { if err == nil {
return OpResponse{put: (*PutResponse)(resp)}, nil return OpResponse{put: (*PutResponse)(resp)}, nil
} }
case tDeleteRange: case tDeleteRange:
var resp *pb.DeleteRangeResponse var resp *pb.DeleteRangeResponse
r := &pb.DeleteRangeRequest{Key: op.key, RangeEnd: op.end} r := &pb.DeleteRangeRequest{Key: op.key, RangeEnd: op.end}
resp, err = kv.remote.DeleteRange(ctx, r, grpc.FailFast(false)) resp, err = kv.remote.DeleteRange(ctx, r)
if err == nil { if err == nil {
return OpResponse{del: (*DeleteResponse)(resp)}, nil return OpResponse{del: (*DeleteResponse)(resp)}, nil
} }

View File

@ -110,7 +110,7 @@ func NewLease(c *Client) Lease {
l := &lessor{ l := &lessor{
donec: make(chan struct{}), donec: make(chan struct{}),
keepAlives: make(map[LeaseID]*keepAlive), keepAlives: make(map[LeaseID]*keepAlive),
remote: pb.NewLeaseClient(c.conn), remote: RetryLeaseClient(c),
firstKeepAliveTimeout: c.cfg.DialTimeout + time.Second, firstKeepAliveTimeout: c.cfg.DialTimeout + time.Second,
} }
if l.firstKeepAliveTimeout == time.Second { if l.firstKeepAliveTimeout == time.Second {
@ -130,7 +130,7 @@ func (l *lessor) Grant(ctx context.Context, ttl int64) (*LeaseGrantResponse, err
for { for {
r := &pb.LeaseGrantRequest{TTL: ttl} r := &pb.LeaseGrantRequest{TTL: ttl}
resp, err := l.remote.LeaseGrant(cctx, r, grpc.FailFast(false)) resp, err := l.remote.LeaseGrant(cctx, r)
if err == nil { if err == nil {
gresp := &LeaseGrantResponse{ gresp := &LeaseGrantResponse{
ResponseHeader: resp.GetHeader(), ResponseHeader: resp.GetHeader(),
@ -156,7 +156,7 @@ func (l *lessor) Revoke(ctx context.Context, id LeaseID) (*LeaseRevokeResponse,
for { for {
r := &pb.LeaseRevokeRequest{ID: int64(id)} r := &pb.LeaseRevokeRequest{ID: int64(id)}
resp, err := l.remote.LeaseRevoke(cctx, r, grpc.FailFast(false)) resp, err := l.remote.LeaseRevoke(cctx, r)
if err == nil { if err == nil {
return (*LeaseRevokeResponse)(resp), nil return (*LeaseRevokeResponse)(resp), nil

243
clientv3/retry.go Normal file
View File

@ -0,0 +1,243 @@
// Copyright 2016 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package clientv3
import (
pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
"golang.org/x/net/context"
"google.golang.org/grpc"
)
type rpcFunc func(ctx context.Context) error
type retryRpcFunc func(context.Context, rpcFunc)
func (c *Client) newRetryWrapper() retryRpcFunc {
return func(rpcCtx context.Context, f rpcFunc) {
for {
err := f(rpcCtx)
// ignore grpc conn closing on fail-fast calls; they are transient errors
if err == nil || !isConnClosing(err) {
return
}
select {
case <-c.balancer.ConnectNotify():
case <-rpcCtx.Done():
case <-c.ctx.Done():
return
}
}
}
}
type retryKVClient struct {
pb.KVClient
retryf retryRpcFunc
}
// RetryKVClient implements a KVClient that uses the client's FailFast retry policy.
func RetryKVClient(c *Client) pb.KVClient {
return &retryKVClient{pb.NewKVClient(c.conn), c.retryWrapper}
}
func (rkv *retryKVClient) Put(ctx context.Context, in *pb.PutRequest, opts ...grpc.CallOption) (resp *pb.PutResponse, err error) {
rkv.retryf(ctx, func(rctx context.Context) error {
resp, err = rkv.KVClient.Put(rctx, in, opts...)
return err
})
return resp, err
}
func (rkv *retryKVClient) DeleteRange(ctx context.Context, in *pb.DeleteRangeRequest, opts ...grpc.CallOption) (resp *pb.DeleteRangeResponse, err error) {
rkv.retryf(ctx, func(rctx context.Context) error {
resp, err = rkv.KVClient.DeleteRange(rctx, in, opts...)
return err
})
return resp, err
}
func (rkv *retryKVClient) Txn(ctx context.Context, in *pb.TxnRequest, opts ...grpc.CallOption) (resp *pb.TxnResponse, err error) {
rkv.retryf(ctx, func(rctx context.Context) error {
resp, err = rkv.KVClient.Txn(rctx, in, opts...)
return err
})
return resp, err
}
func (rkv *retryKVClient) Compact(ctx context.Context, in *pb.CompactionRequest, opts ...grpc.CallOption) (resp *pb.CompactionResponse, err error) {
rkv.retryf(ctx, func(rctx context.Context) error {
resp, err = rkv.KVClient.Compact(rctx, in, opts...)
return err
})
return resp, err
}
type retryLeaseClient struct {
pb.LeaseClient
retryf retryRpcFunc
}
// RetryLeaseClient implements a LeaseClient that uses the client's FailFast retry policy.
func RetryLeaseClient(c *Client) pb.LeaseClient {
return &retryLeaseClient{pb.NewLeaseClient(c.conn), c.retryWrapper}
}
func (rlc *retryLeaseClient) LeaseGrant(ctx context.Context, in *pb.LeaseGrantRequest, opts ...grpc.CallOption) (resp *pb.LeaseGrantResponse, err error) {
rlc.retryf(ctx, func(rctx context.Context) error {
resp, err = rlc.LeaseClient.LeaseGrant(rctx, in, opts...)
return err
})
return resp, err
}
func (rlc *retryLeaseClient) LeaseRevoke(ctx context.Context, in *pb.LeaseRevokeRequest, opts ...grpc.CallOption) (resp *pb.LeaseRevokeResponse, err error) {
rlc.retryf(ctx, func(rctx context.Context) error {
resp, err = rlc.LeaseClient.LeaseRevoke(rctx, in, opts...)
return err
})
return resp, err
}
type retryClusterClient struct {
pb.ClusterClient
retryf retryRpcFunc
}
// RetryClusterClient implements a ClusterClient that uses the client's FailFast retry policy.
func RetryClusterClient(c *Client) pb.ClusterClient {
return &retryClusterClient{pb.NewClusterClient(c.conn), c.retryWrapper}
}
func (rcc *retryClusterClient) MemberAdd(ctx context.Context, in *pb.MemberAddRequest, opts ...grpc.CallOption) (resp *pb.MemberAddResponse, err error) {
rcc.retryf(ctx, func(rctx context.Context) error {
resp, err = rcc.ClusterClient.MemberAdd(rctx, in, opts...)
return err
})
return resp, err
}
func (rcc *retryClusterClient) MemberRemove(ctx context.Context, in *pb.MemberRemoveRequest, opts ...grpc.CallOption) (resp *pb.MemberRemoveResponse, err error) {
rcc.retryf(ctx, func(rctx context.Context) error {
resp, err = rcc.ClusterClient.MemberRemove(rctx, in, opts...)
return err
})
return resp, err
}
func (rcc *retryClusterClient) MemberUpdate(ctx context.Context, in *pb.MemberUpdateRequest, opts ...grpc.CallOption) (resp *pb.MemberUpdateResponse, err error) {
rcc.retryf(ctx, func(rctx context.Context) error {
resp, err = rcc.ClusterClient.MemberUpdate(rctx, in, opts...)
return err
})
return resp, err
}
type retryAuthClient struct {
pb.AuthClient
retryf retryRpcFunc
}
// RetryAuthClient implements a AuthClient that uses the client's FailFast retry policy.
func RetryAuthClient(c *Client) pb.AuthClient {
return &retryAuthClient{pb.NewAuthClient(c.conn), c.retryWrapper}
}
func (rac *retryAuthClient) AuthEnable(ctx context.Context, in *pb.AuthEnableRequest, opts ...grpc.CallOption) (resp *pb.AuthEnableResponse, err error) {
rac.retryf(ctx, func(rctx context.Context) error {
resp, err = rac.AuthClient.AuthEnable(rctx, in, opts...)
return err
})
return resp, err
}
func (rac *retryAuthClient) AuthDisable(ctx context.Context, in *pb.AuthDisableRequest, opts ...grpc.CallOption) (resp *pb.AuthDisableResponse, err error) {
rac.retryf(ctx, func(rctx context.Context) error {
resp, err = rac.AuthClient.AuthDisable(rctx, in, opts...)
return err
})
return resp, err
}
func (rac *retryAuthClient) UserAdd(ctx context.Context, in *pb.AuthUserAddRequest, opts ...grpc.CallOption) (resp *pb.AuthUserAddResponse, err error) {
rac.retryf(ctx, func(rctx context.Context) error {
resp, err = rac.AuthClient.UserAdd(rctx, in, opts...)
return err
})
return resp, err
}
func (rac *retryAuthClient) UserDelete(ctx context.Context, in *pb.AuthUserDeleteRequest, opts ...grpc.CallOption) (resp *pb.AuthUserDeleteResponse, err error) {
rac.retryf(ctx, func(rctx context.Context) error {
resp, err = rac.AuthClient.UserDelete(rctx, in, opts...)
return err
})
return resp, err
}
func (rac *retryAuthClient) UserChangePassword(ctx context.Context, in *pb.AuthUserChangePasswordRequest, opts ...grpc.CallOption) (resp *pb.AuthUserChangePasswordResponse, err error) {
rac.retryf(ctx, func(rctx context.Context) error {
resp, err = rac.AuthClient.UserChangePassword(rctx, in, opts...)
return err
})
return resp, err
}
func (rac *retryAuthClient) UserGrantRole(ctx context.Context, in *pb.AuthUserGrantRoleRequest, opts ...grpc.CallOption) (resp *pb.AuthUserGrantRoleResponse, err error) {
rac.retryf(ctx, func(rctx context.Context) error {
resp, err = rac.AuthClient.UserGrantRole(rctx, in, opts...)
return err
})
return resp, err
}
func (rac *retryAuthClient) UserRevokeRole(ctx context.Context, in *pb.AuthUserRevokeRoleRequest, opts ...grpc.CallOption) (resp *pb.AuthUserRevokeRoleResponse, err error) {
rac.retryf(ctx, func(rctx context.Context) error {
resp, err = rac.AuthClient.UserRevokeRole(rctx, in, opts...)
return err
})
return resp, err
}
func (rac *retryAuthClient) RoleAdd(ctx context.Context, in *pb.AuthRoleAddRequest, opts ...grpc.CallOption) (resp *pb.AuthRoleAddResponse, err error) {
rac.retryf(ctx, func(rctx context.Context) error {
resp, err = rac.AuthClient.RoleAdd(rctx, in, opts...)
return err
})
return resp, err
}
func (rac *retryAuthClient) RoleDelete(ctx context.Context, in *pb.AuthRoleDeleteRequest, opts ...grpc.CallOption) (resp *pb.AuthRoleDeleteResponse, err error) {
rac.retryf(ctx, func(rctx context.Context) error {
resp, err = rac.AuthClient.RoleDelete(rctx, in, opts...)
return err
})
return resp, err
}
func (rac *retryAuthClient) RoleGrantPermission(ctx context.Context, in *pb.AuthRoleGrantPermissionRequest, opts ...grpc.CallOption) (resp *pb.AuthRoleGrantPermissionResponse, err error) {
rac.retryf(ctx, func(rctx context.Context) error {
resp, err = rac.AuthClient.RoleGrantPermission(rctx, in, opts...)
return err
})
return resp, err
}
func (rac *retryAuthClient) RoleRevokePermission(ctx context.Context, in *pb.AuthRoleRevokePermissionRequest, opts ...grpc.CallOption) (resp *pb.AuthRoleRevokePermissionResponse, err error) {
rac.retryf(ctx, func(rctx context.Context) error {
resp, err = rac.AuthClient.RoleRevokePermission(rctx, in, opts...)
return err
})
return resp, err
}

View File

@ -19,7 +19,6 @@ import (
pb "github.com/coreos/etcd/etcdserver/etcdserverpb" pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
"golang.org/x/net/context" "golang.org/x/net/context"
"google.golang.org/grpc"
) )
// Txn is the interface that wraps mini-transactions. // Txn is the interface that wraps mini-transactions.
@ -153,7 +152,7 @@ func (txn *txn) Commit() (*TxnResponse, error) {
func (txn *txn) commit() (*TxnResponse, error) { func (txn *txn) commit() (*TxnResponse, error) {
r := &pb.TxnRequest{Compare: txn.cmps, Success: txn.sus, Failure: txn.fas} r := &pb.TxnRequest{Compare: txn.cmps, Success: txn.sus, Failure: txn.fas}
resp, err := txn.kv.remote.Txn(txn.ctx, r, grpc.FailFast(false)) resp, err := txn.kv.remote.Txn(txn.ctx, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -242,6 +242,7 @@ func (w *watcher) Watch(ctx context.Context, key string, opts ...OpOption) Watch
case reqc <- wr: case reqc <- wr:
ok = true ok = true
case <-wr.ctx.Done(): case <-wr.ctx.Done():
wgs.stopIfEmpty()
case <-donec: case <-donec:
if wgs.closeErr != nil { if wgs.closeErr != nil {
closeCh <- WatchResponse{closeErr: wgs.closeErr} closeCh <- WatchResponse{closeErr: wgs.closeErr}
@ -285,7 +286,12 @@ func (w *watcher) Close() (err error) {
} }
func (w *watchGrpcStream) Close() (err error) { func (w *watchGrpcStream) Close() (err error) {
close(w.stopc) w.mu.Lock()
if w.stopc != nil {
close(w.stopc)
w.stopc = nil
}
w.mu.Unlock()
<-w.donec <-w.donec
select { select {
case err = <-w.errc: case err = <-w.errc:
@ -348,11 +354,13 @@ func (w *watchGrpcStream) addStream(resp *pb.WatchResponse, pendingReq *watchReq
// closeStream closes the watcher resources and removes it // closeStream closes the watcher resources and removes it
func (w *watchGrpcStream) closeStream(ws *watcherStream) { func (w *watchGrpcStream) closeStream(ws *watcherStream) {
w.mu.Lock()
// cancels request stream; subscriber receives nil channel // cancels request stream; subscriber receives nil channel
close(ws.initReq.retc) close(ws.initReq.retc)
// close subscriber's channel // close subscriber's channel
close(ws.outc) close(ws.outc)
delete(w.streams, ws.id) delete(w.streams, ws.id)
w.mu.Unlock()
} }
// run is the root of the goroutines for managing a watcher client // run is the root of the goroutines for managing a watcher client
@ -371,6 +379,14 @@ func (w *watchGrpcStream) run() {
w.cancel() w.cancel()
}() }()
// already stopped?
w.mu.RLock()
stopc := w.stopc
w.mu.RUnlock()
if stopc == nil {
return
}
// start a stream with the etcd grpc server // start a stream with the etcd grpc server
if wc, closeErr = w.newWatchClient(); closeErr != nil { if wc, closeErr = w.newWatchClient(); closeErr != nil {
return return
@ -446,7 +462,7 @@ func (w *watchGrpcStream) run() {
failedReq = pendingReq failedReq = pendingReq
} }
cancelSet = make(map[int64]struct{}) cancelSet = make(map[int64]struct{})
case <-w.stopc: case <-stopc:
return return
} }
@ -586,12 +602,20 @@ func (w *watchGrpcStream) serveStream(ws *watcherStream) {
} }
} }
w.mu.Lock()
w.closeStream(ws) w.closeStream(ws)
w.mu.Unlock() w.stopIfEmpty()
// lazily send cancel message if events on missing id // lazily send cancel message if events on missing id
} }
func (wgs *watchGrpcStream) stopIfEmpty() {
wgs.mu.Lock()
if len(wgs.streams) == 0 && wgs.stopc != nil {
close(wgs.stopc)
wgs.stopc = nil
}
wgs.mu.Unlock()
}
func (w *watchGrpcStream) newWatchClient() (pb.Watch_WatchClient, error) { func (w *watchGrpcStream) newWatchClient() (pb.Watch_WatchClient, error) {
ws, rerr := w.resume() ws, rerr := w.resume()
if rerr != nil { if rerr != nil {
@ -616,13 +640,14 @@ func (w *watchGrpcStream) resume() (ws pb.Watch_WatchClient, err error) {
// openWatchClient retries opening a watchclient until retryConnection fails // openWatchClient retries opening a watchclient until retryConnection fails
func (w *watchGrpcStream) openWatchClient() (ws pb.Watch_WatchClient, err error) { func (w *watchGrpcStream) openWatchClient() (ws pb.Watch_WatchClient, err error) {
for { for {
select { w.mu.Lock()
case <-w.stopc: stopc := w.stopc
w.mu.Unlock()
if stopc == nil {
if err == nil { if err == nil {
err = context.Canceled err = context.Canceled
} }
return nil, err return nil, err
default:
} }
if ws, err = w.remote.Watch(w.ctx, grpc.FailFast(false)); ws != nil && err == nil { if ws, err = w.remote.Watch(w.ctx, grpc.FailFast(false)); ws != nil && err == nil {
break break

42
cmd/Godeps/Godeps.json generated
View File

@ -1,6 +1,6 @@
{ {
"ImportPath": "github.com/coreos/etcd", "ImportPath": "github.com/coreos/etcd",
"GoVersion": "devel-a6dbfc1", "GoVersion": "go1.7",
"GodepVersion": "v74", "GodepVersion": "v74",
"Packages": [ "Packages": [
"./..." "./..."
@ -25,8 +25,8 @@
}, },
{ {
"ImportPath": "github.com/boltdb/bolt", "ImportPath": "github.com/boltdb/bolt",
"Comment": "v1.2.1", "Comment": "v1.3.0",
"Rev": "dfb21201d9270c1082d5fb0f07f500311ff72f18" "Rev": "583e8937c61f1af6513608ccc75c97b6abdf4ff9"
}, },
{ {
"ImportPath": "github.com/cockroachdb/cmux", "ImportPath": "github.com/cockroachdb/cmux",
@ -237,48 +237,48 @@
}, },
{ {
"ImportPath": "google.golang.org/grpc", "ImportPath": "google.golang.org/grpc",
"Comment": "v1.0.0-6-g02fca89", "Comment": "v1.0.0-174-gc278196",
"Rev": "02fca896ff5f50c6bbbee0860345a49344b37a03" "Rev": "c2781963b3af261a37e0f14fdcb7c1fa13259e1f"
}, },
{ {
"ImportPath": "google.golang.org/grpc/codes", "ImportPath": "google.golang.org/grpc/codes",
"Comment": "v1.0.0-6-g02fca89", "Comment": "v1.0.0-174-gc278196",
"Rev": "02fca896ff5f50c6bbbee0860345a49344b37a03" "Rev": "c2781963b3af261a37e0f14fdcb7c1fa13259e1f"
}, },
{ {
"ImportPath": "google.golang.org/grpc/credentials", "ImportPath": "google.golang.org/grpc/credentials",
"Comment": "v1.0.0-6-g02fca89", "Comment": "v1.0.0-174-gc278196",
"Rev": "02fca896ff5f50c6bbbee0860345a49344b37a03" "Rev": "c2781963b3af261a37e0f14fdcb7c1fa13259e1f"
}, },
{ {
"ImportPath": "google.golang.org/grpc/grpclog", "ImportPath": "google.golang.org/grpc/grpclog",
"Comment": "v1.0.0-6-g02fca89", "Comment": "v1.0.0-174-gc278196",
"Rev": "02fca896ff5f50c6bbbee0860345a49344b37a03" "Rev": "c2781963b3af261a37e0f14fdcb7c1fa13259e1f"
}, },
{ {
"ImportPath": "google.golang.org/grpc/internal", "ImportPath": "google.golang.org/grpc/internal",
"Comment": "v1.0.0-6-g02fca89", "Comment": "v1.0.0-174-gc278196",
"Rev": "02fca896ff5f50c6bbbee0860345a49344b37a03" "Rev": "c2781963b3af261a37e0f14fdcb7c1fa13259e1f"
}, },
{ {
"ImportPath": "google.golang.org/grpc/metadata", "ImportPath": "google.golang.org/grpc/metadata",
"Comment": "v1.0.0-6-g02fca89", "Comment": "v1.0.0-174-gc278196",
"Rev": "02fca896ff5f50c6bbbee0860345a49344b37a03" "Rev": "c2781963b3af261a37e0f14fdcb7c1fa13259e1f"
}, },
{ {
"ImportPath": "google.golang.org/grpc/naming", "ImportPath": "google.golang.org/grpc/naming",
"Comment": "v1.0.0-6-g02fca89", "Comment": "v1.0.0-174-gc278196",
"Rev": "02fca896ff5f50c6bbbee0860345a49344b37a03" "Rev": "c2781963b3af261a37e0f14fdcb7c1fa13259e1f"
}, },
{ {
"ImportPath": "google.golang.org/grpc/peer", "ImportPath": "google.golang.org/grpc/peer",
"Comment": "v1.0.0-6-g02fca89", "Comment": "v1.0.0-174-gc278196",
"Rev": "02fca896ff5f50c6bbbee0860345a49344b37a03" "Rev": "c2781963b3af261a37e0f14fdcb7c1fa13259e1f"
}, },
{ {
"ImportPath": "google.golang.org/grpc/transport", "ImportPath": "google.golang.org/grpc/transport",
"Comment": "v1.0.0-6-g02fca89", "Comment": "v1.0.0-174-gc278196",
"Rev": "02fca896ff5f50c6bbbee0860345a49344b37a03" "Rev": "c2781963b3af261a37e0f14fdcb7c1fa13259e1f"
}, },
{ {
"ImportPath": "gopkg.in/cheggaaa/pb.v1", "ImportPath": "gopkg.in/cheggaaa/pb.v1",

View File

@ -1,4 +1,4 @@
Bolt [![Coverage Status](https://coveralls.io/repos/boltdb/bolt/badge.svg?branch=master)](https://coveralls.io/r/boltdb/bolt?branch=master) [![GoDoc](https://godoc.org/github.com/boltdb/bolt?status.svg)](https://godoc.org/github.com/boltdb/bolt) ![Version](https://img.shields.io/badge/version-1.0-green.svg) Bolt [![Coverage Status](https://coveralls.io/repos/boltdb/bolt/badge.svg?branch=master)](https://coveralls.io/r/boltdb/bolt?branch=master) [![GoDoc](https://godoc.org/github.com/boltdb/bolt?status.svg)](https://godoc.org/github.com/boltdb/bolt) ![Version](https://img.shields.io/badge/version-1.2.1-green.svg)
==== ====
Bolt is a pure Go key/value store inspired by [Howard Chu's][hyc_symas] Bolt is a pure Go key/value store inspired by [Howard Chu's][hyc_symas]
@ -313,7 +313,7 @@ func (s *Store) CreateUser(u *User) error {
// Generate ID for the user. // Generate ID for the user.
// This returns an error only if the Tx is closed or not writeable. // This returns an error only if the Tx is closed or not writeable.
// That can't happen in an Update() call so I ignore the error check. // That can't happen in an Update() call so I ignore the error check.
id, _ = b.NextSequence() id, _ := b.NextSequence()
u.ID = int(id) u.ID = int(id)
// Marshal user data into bytes. // Marshal user data into bytes.
@ -557,7 +557,7 @@ if err != nil {
Bolt is able to run on mobile devices by leveraging the binding feature of the Bolt is able to run on mobile devices by leveraging the binding feature of the
[gomobile](https://github.com/golang/mobile) tool. Create a struct that will [gomobile](https://github.com/golang/mobile) tool. Create a struct that will
contain your database logic and a reference to a `*bolt.DB` with a initializing contain your database logic and a reference to a `*bolt.DB` with a initializing
contstructor that takes in a filepath where the database file will be stored. constructor that takes in a filepath where the database file will be stored.
Neither Android nor iOS require extra permissions or cleanup from using this method. Neither Android nor iOS require extra permissions or cleanup from using this method.
```go ```go
@ -807,6 +807,7 @@ them via pull request.
Below is a list of public, open source projects that use Bolt: Below is a list of public, open source projects that use Bolt:
* [BoltDbWeb](https://github.com/evnix/boltdbweb) - A web based GUI for BoltDB files.
* [Operation Go: A Routine Mission](http://gocode.io) - An online programming game for Golang using Bolt for user accounts and a leaderboard. * [Operation Go: A Routine Mission](http://gocode.io) - An online programming game for Golang using Bolt for user accounts and a leaderboard.
* [Bazil](https://bazil.org/) - A file system that lets your data reside where it is most convenient for it to reside. * [Bazil](https://bazil.org/) - A file system that lets your data reside where it is most convenient for it to reside.
* [DVID](https://github.com/janelia-flyem/dvid) - Added Bolt as optional storage engine and testing it against Basho-tuned leveldb. * [DVID](https://github.com/janelia-flyem/dvid) - Added Bolt as optional storage engine and testing it against Basho-tuned leveldb.
@ -825,7 +826,6 @@ Below is a list of public, open source projects that use Bolt:
* [cayley](https://github.com/google/cayley) - Cayley is an open-source graph database using Bolt as optional backend. * [cayley](https://github.com/google/cayley) - Cayley is an open-source graph database using Bolt as optional backend.
* [bleve](http://www.blevesearch.com/) - A pure Go search engine similar to ElasticSearch that uses Bolt as the default storage backend. * [bleve](http://www.blevesearch.com/) - A pure Go search engine similar to ElasticSearch that uses Bolt as the default storage backend.
* [tentacool](https://github.com/optiflows/tentacool) - REST api server to manage system stuff (IP, DNS, Gateway...) on a linux server. * [tentacool](https://github.com/optiflows/tentacool) - REST api server to manage system stuff (IP, DNS, Gateway...) on a linux server.
* [SkyDB](https://github.com/skydb/sky) - Behavioral analytics database.
* [Seaweed File System](https://github.com/chrislusf/seaweedfs) - Highly scalable distributed key~file system with O(1) disk read. * [Seaweed File System](https://github.com/chrislusf/seaweedfs) - Highly scalable distributed key~file system with O(1) disk read.
* [InfluxDB](https://influxdata.com) - Scalable datastore for metrics, events, and real-time analytics. * [InfluxDB](https://influxdata.com) - Scalable datastore for metrics, events, and real-time analytics.
* [Freehold](http://tshannon.bitbucket.org/freehold/) - An open, secure, and lightweight platform for your files and data. * [Freehold](http://tshannon.bitbucket.org/freehold/) - An open, secure, and lightweight platform for your files and data.
@ -842,9 +842,11 @@ Below is a list of public, open source projects that use Bolt:
* [Go Report Card](https://goreportcard.com/) - Go code quality report cards as a (free and open source) service. * [Go Report Card](https://goreportcard.com/) - Go code quality report cards as a (free and open source) service.
* [Boltdb Boilerplate](https://github.com/bobintornado/boltdb-boilerplate) - Boilerplate wrapper around bolt aiming to make simple calls one-liners. * [Boltdb Boilerplate](https://github.com/bobintornado/boltdb-boilerplate) - Boilerplate wrapper around bolt aiming to make simple calls one-liners.
* [lru](https://github.com/crowdriff/lru) - Easy to use Bolt-backed Least-Recently-Used (LRU) read-through cache with chainable remote stores. * [lru](https://github.com/crowdriff/lru) - Easy to use Bolt-backed Least-Recently-Used (LRU) read-through cache with chainable remote stores.
* [Storm](https://github.com/asdine/storm) - A simple ORM around BoltDB. * [Storm](https://github.com/asdine/storm) - Simple and powerful ORM for BoltDB.
* [GoWebApp](https://github.com/josephspurrier/gowebapp) - A basic MVC web application in Go using BoltDB. * [GoWebApp](https://github.com/josephspurrier/gowebapp) - A basic MVC web application in Go using BoltDB.
* [SimpleBolt](https://github.com/xyproto/simplebolt) - A simple way to use BoltDB. Deals mainly with strings. * [SimpleBolt](https://github.com/xyproto/simplebolt) - A simple way to use BoltDB. Deals mainly with strings.
* [Algernon](https://github.com/xyproto/algernon) - A HTTP/2 web server with built-in support for Lua. Uses BoltDB as the default database backend. * [Algernon](https://github.com/xyproto/algernon) - A HTTP/2 web server with built-in support for Lua. Uses BoltDB as the default database backend.
* [MuLiFS](https://github.com/dankomiocevic/mulifs) - Music Library Filesystem creates a filesystem to organise your music files.
* [GoShort](https://github.com/pankajkhairnar/goShort) - GoShort is a URL shortener written in Golang and BoltDB for persistent key/value storage and for routing it's using high performent HTTPRouter.
If you are using Bolt in a project please send a pull request to add it to the list. If you are using Bolt in a project please send a pull request to add it to the list.

View File

@ -166,12 +166,16 @@ func (f *freelist) read(p *page) {
} }
// Copy the list of page ids from the freelist. // Copy the list of page ids from the freelist.
ids := ((*[maxAllocSize]pgid)(unsafe.Pointer(&p.ptr)))[idx:count] if count == 0 {
f.ids = make([]pgid, len(ids)) f.ids = nil
copy(f.ids, ids) } else {
ids := ((*[maxAllocSize]pgid)(unsafe.Pointer(&p.ptr)))[idx:count]
f.ids = make([]pgid, len(ids))
copy(f.ids, ids)
// Make sure they're sorted. // Make sure they're sorted.
sort.Sort(pgids(f.ids)) sort.Sort(pgids(f.ids))
}
// Rebuild the page cache. // Rebuild the page cache.
f.reindex() f.reindex()
@ -189,7 +193,9 @@ func (f *freelist) write(p *page) error {
// The page.count can only hold up to 64k elements so if we overflow that // The page.count can only hold up to 64k elements so if we overflow that
// number then we handle it by putting the size in the first element. // number then we handle it by putting the size in the first element.
if len(ids) < 0xFFFF { if len(ids) == 0 {
p.count = uint16(len(ids))
} else if len(ids) < 0xFFFF {
p.count = uint16(len(ids)) p.count = uint16(len(ids))
copy(((*[maxAllocSize]pgid)(unsafe.Pointer(&p.ptr)))[:], ids) copy(((*[maxAllocSize]pgid)(unsafe.Pointer(&p.ptr)))[:], ids)
} else { } else {

View File

@ -201,6 +201,11 @@ func (n *node) write(p *page) {
} }
p.count = uint16(len(n.inodes)) p.count = uint16(len(n.inodes))
// Stop here if there are no items to write.
if p.count == 0 {
return
}
// Loop over each item and write it to the page. // Loop over each item and write it to the page.
b := (*[maxAllocSize]byte)(unsafe.Pointer(&p.ptr))[n.pageElementSize()*len(n.inodes):] b := (*[maxAllocSize]byte)(unsafe.Pointer(&p.ptr))[n.pageElementSize()*len(n.inodes):]
for i, item := range n.inodes { for i, item := range n.inodes {

View File

@ -62,6 +62,9 @@ func (p *page) leafPageElement(index uint16) *leafPageElement {
// leafPageElements retrieves a list of leaf nodes. // leafPageElements retrieves a list of leaf nodes.
func (p *page) leafPageElements() []leafPageElement { func (p *page) leafPageElements() []leafPageElement {
if p.count == 0 {
return nil
}
return ((*[0x7FFFFFF]leafPageElement)(unsafe.Pointer(&p.ptr)))[:] return ((*[0x7FFFFFF]leafPageElement)(unsafe.Pointer(&p.ptr)))[:]
} }
@ -72,6 +75,9 @@ func (p *page) branchPageElement(index uint16) *branchPageElement {
// branchPageElements retrieves a list of branch nodes. // branchPageElements retrieves a list of branch nodes.
func (p *page) branchPageElements() []branchPageElement { func (p *page) branchPageElements() []branchPageElement {
if p.count == 0 {
return nil
}
return ((*[0x7FFFFFF]branchPageElement)(unsafe.Pointer(&p.ptr)))[:] return ((*[0x7FFFFFF]branchPageElement)(unsafe.Pointer(&p.ptr)))[:]
} }

View File

@ -1,17 +1,21 @@
language: go language: go
go: go:
- 1.5.3 - 1.5.4
- 1.6 - 1.6.3
go_import_path: google.golang.org/grpc
before_install: before_install:
- go get golang.org/x/tools/cmd/goimports
- go get github.com/golang/lint/golint
- go get github.com/axw/gocov/gocov - go get github.com/axw/gocov/gocov
- go get github.com/mattn/goveralls - go get github.com/mattn/goveralls
- go get golang.org/x/tools/cmd/cover - go get golang.org/x/tools/cmd/cover
install:
- mkdir -p "$GOPATH/src/google.golang.org"
- mv "$TRAVIS_BUILD_DIR" "$GOPATH/src/google.golang.org/grpc"
script: script:
- '! gofmt -s -d -l . 2>&1 | read'
- '! goimports -l . | read'
- '! golint ./... | grep -vE "(_string|\.pb)\.go:"'
- '! go tool vet -all . 2>&1 | grep -vE "constant [0-9]+ not a string in call to Errorf"'
- make test testrace - make test testrace

View File

@ -36,6 +36,7 @@ package grpc
import ( import (
"bytes" "bytes"
"io" "io"
"math"
"time" "time"
"golang.org/x/net/context" "golang.org/x/net/context"
@ -51,13 +52,20 @@ import (
func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error { func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error {
// Try to acquire header metadata from the server if there is any. // Try to acquire header metadata from the server if there is any.
var err error var err error
defer func() {
if err != nil {
if _, ok := err.(transport.ConnectionError); !ok {
t.CloseStream(stream, err)
}
}
}()
c.headerMD, err = stream.Header() c.headerMD, err = stream.Header()
if err != nil { if err != nil {
return err return err
} }
p := &parser{r: stream} p := &parser{r: stream}
for { for {
if err = recv(p, dopts.codec, stream, dopts.dc, reply); err != nil { if err = recv(p, dopts.codec, stream, dopts.dc, reply, math.MaxInt32); err != nil {
if err == io.EOF { if err == io.EOF {
break break
} }
@ -76,6 +84,7 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd
} }
defer func() { defer func() {
if err != nil { if err != nil {
// If err is connection error, t will be closed, no need to close stream here.
if _, ok := err.(transport.ConnectionError); !ok { if _, ok := err.(transport.ConnectionError); !ok {
t.CloseStream(stream, err) t.CloseStream(stream, err)
} }
@ -90,7 +99,10 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd
return nil, transport.StreamErrorf(codes.Internal, "grpc: %v", err) return nil, transport.StreamErrorf(codes.Internal, "grpc: %v", err)
} }
err = t.Write(stream, outBuf, opts) err = t.Write(stream, outBuf, opts)
if err != nil { // t.NewStream(...) could lead to an early rejection of the RPC (e.g., the service/method
// does not exist.) so that t.Write could get io.EOF from wait(...). Leave the following
// recvResponse to get the final status.
if err != nil && err != io.EOF {
return nil, err return nil, err
} }
// Sent successfully. // Sent successfully.
@ -176,7 +188,10 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
put() put()
put = nil put = nil
} }
if _, ok := err.(transport.ConnectionError); ok { // Retry a non-failfast RPC when
// i) there is a connection error; or
// ii) the server started to drain before this RPC was initiated.
if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
if c.failFast { if c.failFast {
return toRPCErr(err) return toRPCErr(err)
} }
@ -184,20 +199,18 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
} }
return toRPCErr(err) return toRPCErr(err)
} }
// Receive the response
err = recvResponse(cc.dopts, t, &c, stream, reply) err = recvResponse(cc.dopts, t, &c, stream, reply)
if err != nil { if err != nil {
if put != nil { if put != nil {
put() put()
put = nil put = nil
} }
if _, ok := err.(transport.ConnectionError); ok { if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
if c.failFast { if c.failFast {
return toRPCErr(err) return toRPCErr(err)
} }
continue continue
} }
t.CloseStream(stream, err)
return toRPCErr(err) return toRPCErr(err)
} }
if c.traceInfo.tr != nil { if c.traceInfo.tr != nil {

View File

@ -43,7 +43,6 @@ import (
"golang.org/x/net/context" "golang.org/x/net/context"
"golang.org/x/net/trace" "golang.org/x/net/trace"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog" "google.golang.org/grpc/grpclog"
"google.golang.org/grpc/transport" "google.golang.org/grpc/transport"
@ -68,7 +67,7 @@ var (
// errCredentialsConflict indicates that grpc.WithTransportCredentials() // errCredentialsConflict indicates that grpc.WithTransportCredentials()
// and grpc.WithInsecure() are both called for a connection. // and grpc.WithInsecure() are both called for a connection.
errCredentialsConflict = errors.New("grpc: transport credentials are set for an insecure connection (grpc.WithTransportCredentials() and grpc.WithInsecure() are both called)") errCredentialsConflict = errors.New("grpc: transport credentials are set for an insecure connection (grpc.WithTransportCredentials() and grpc.WithInsecure() are both called)")
// errNetworkIP indicates that the connection is down due to some network I/O error. // errNetworkIO indicates that the connection is down due to some network I/O error.
errNetworkIO = errors.New("grpc: failed with network I/O error") errNetworkIO = errors.New("grpc: failed with network I/O error")
// errConnDrain indicates that the connection starts to be drained and does not accept any new RPCs. // errConnDrain indicates that the connection starts to be drained and does not accept any new RPCs.
errConnDrain = errors.New("grpc: the connection is drained") errConnDrain = errors.New("grpc: the connection is drained")
@ -196,9 +195,14 @@ func WithTimeout(d time.Duration) DialOption {
} }
// WithDialer returns a DialOption that specifies a function to use for dialing network addresses. // WithDialer returns a DialOption that specifies a function to use for dialing network addresses.
func WithDialer(f func(addr string, timeout time.Duration) (net.Conn, error)) DialOption { func WithDialer(f func(string, time.Duration) (net.Conn, error)) DialOption {
return func(o *dialOptions) { return func(o *dialOptions) {
o.copts.Dialer = f o.copts.Dialer = func(ctx context.Context, addr string) (net.Conn, error) {
if deadline, ok := ctx.Deadline(); ok {
return f(addr, deadline.Sub(time.Now()))
}
return f(addr, 0)
}
} }
} }
@ -211,10 +215,12 @@ func WithUserAgent(s string) DialOption {
// Dial creates a client connection the given target. // Dial creates a client connection the given target.
func Dial(target string, opts ...DialOption) (*ClientConn, error) { func Dial(target string, opts ...DialOption) (*ClientConn, error) {
ctx := context.Background()
cc := &ClientConn{ cc := &ClientConn{
target: target, target: target,
conns: make(map[Address]*addrConn), conns: make(map[Address]*addrConn),
} }
cc.ctx, cc.cancel = context.WithCancel(ctx)
for _, opt := range opts { for _, opt := range opts {
opt(&cc.dopts) opt(&cc.dopts)
} }
@ -226,31 +232,33 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
if cc.dopts.bs == nil { if cc.dopts.bs == nil {
cc.dopts.bs = DefaultBackoffConfig cc.dopts.bs = DefaultBackoffConfig
} }
if cc.dopts.balancer == nil {
cc.dopts.balancer = RoundRobin(nil)
}
if err := cc.dopts.balancer.Start(target); err != nil {
return nil, err
}
var ( var (
ok bool ok bool
addrs []Address addrs []Address
) )
ch := cc.dopts.balancer.Notify() if cc.dopts.balancer == nil {
if ch == nil { // Connect to target directly if balancer is nil.
// There is no name resolver installed.
addrs = append(addrs, Address{Addr: target}) addrs = append(addrs, Address{Addr: target})
} else { } else {
addrs, ok = <-ch if err := cc.dopts.balancer.Start(target); err != nil {
if !ok || len(addrs) == 0 { return nil, err
return nil, errNoAddr }
ch := cc.dopts.balancer.Notify()
if ch == nil {
// There is no name resolver installed.
addrs = append(addrs, Address{Addr: target})
} else {
addrs, ok = <-ch
if !ok || len(addrs) == 0 {
return nil, errNoAddr
}
} }
} }
waitC := make(chan error, 1) waitC := make(chan error, 1)
go func() { go func() {
for _, a := range addrs { for _, a := range addrs {
if err := cc.newAddrConn(a, false); err != nil { if err := cc.resetAddrConn(a, false, nil); err != nil {
waitC <- err waitC <- err
return return
} }
@ -267,10 +275,15 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
cc.Close() cc.Close()
return nil, err return nil, err
} }
case <-cc.ctx.Done():
cc.Close()
return nil, cc.ctx.Err()
case <-timeoutCh: case <-timeoutCh:
cc.Close() cc.Close()
return nil, ErrClientConnTimeout return nil, ErrClientConnTimeout
} }
// If balancer is nil or balancer.Notify() is nil, ok will be false here.
// The lbWatcher goroutine will not be created.
if ok { if ok {
go cc.lbWatcher() go cc.lbWatcher()
} }
@ -317,6 +330,9 @@ func (s ConnectivityState) String() string {
// ClientConn represents a client connection to an RPC server. // ClientConn represents a client connection to an RPC server.
type ClientConn struct { type ClientConn struct {
ctx context.Context
cancel context.CancelFunc
target string target string
authority string authority string
dopts dialOptions dopts dialOptions
@ -347,11 +363,12 @@ func (cc *ClientConn) lbWatcher() {
} }
if !keep { if !keep {
del = append(del, c) del = append(del, c)
delete(cc.conns, c.addr)
} }
} }
cc.mu.Unlock() cc.mu.Unlock()
for _, a := range add { for _, a := range add {
cc.newAddrConn(a, true) cc.resetAddrConn(a, true, nil)
} }
for _, c := range del { for _, c := range del {
c.tearDown(errConnDrain) c.tearDown(errConnDrain)
@ -359,13 +376,17 @@ func (cc *ClientConn) lbWatcher() {
} }
} }
func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error { // resetAddrConn creates an addrConn for addr and adds it to cc.conns.
// If there is an old addrConn for addr, it will be torn down, using tearDownErr as the reason.
// If tearDownErr is nil, errConnDrain will be used instead.
func (cc *ClientConn) resetAddrConn(addr Address, skipWait bool, tearDownErr error) error {
ac := &addrConn{ ac := &addrConn{
cc: cc, cc: cc,
addr: addr, addr: addr,
dopts: cc.dopts, dopts: cc.dopts,
shutdownChan: make(chan struct{}),
} }
ac.ctx, ac.cancel = context.WithCancel(cc.ctx)
ac.stateCV = sync.NewCond(&ac.mu)
if EnableTracing { if EnableTracing {
ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr) ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr)
} }
@ -383,26 +404,44 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
} }
} }
} }
// Insert ac into ac.cc.conns. This needs to be done before any getTransport(...) is called. // Track ac in cc. This needs to be done before any getTransport(...) is called.
ac.cc.mu.Lock() cc.mu.Lock()
if ac.cc.conns == nil { if cc.conns == nil {
ac.cc.mu.Unlock() cc.mu.Unlock()
return ErrClientConnClosing return ErrClientConnClosing
} }
stale := ac.cc.conns[ac.addr] stale := cc.conns[ac.addr]
ac.cc.conns[ac.addr] = ac cc.conns[ac.addr] = ac
ac.cc.mu.Unlock() cc.mu.Unlock()
if stale != nil { if stale != nil {
// There is an addrConn alive on ac.addr already. This could be due to // There is an addrConn alive on ac.addr already. This could be due to
// i) stale's Close is undergoing; // 1) a buggy Balancer notifies duplicated Addresses;
// ii) a buggy Balancer notifies duplicated Addresses. // 2) goaway was received, a new ac will replace the old ac.
stale.tearDown(errConnDrain) // The old ac should be deleted from cc.conns, but the
// underlying transport should drain rather than close.
if tearDownErr == nil {
// tearDownErr is nil if resetAddrConn is called by
// 1) Dial
// 2) lbWatcher
// In both cases, the stale ac should drain, not close.
stale.tearDown(errConnDrain)
} else {
stale.tearDown(tearDownErr)
}
} }
ac.stateCV = sync.NewCond(&ac.mu)
// skipWait may overwrite the decision in ac.dopts.block. // skipWait may overwrite the decision in ac.dopts.block.
if ac.dopts.block && !skipWait { if ac.dopts.block && !skipWait {
if err := ac.resetTransport(false); err != nil { if err := ac.resetTransport(false); err != nil {
ac.tearDown(err) if err != errConnClosing {
// Tear down ac and delete it from cc.conns.
cc.mu.Lock()
delete(cc.conns, ac.addr)
cc.mu.Unlock()
ac.tearDown(err)
}
if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() {
return e.Origin()
}
return err return err
} }
// Start to monitor the error status of transport. // Start to monitor the error status of transport.
@ -412,7 +451,10 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
go func() { go func() {
if err := ac.resetTransport(false); err != nil { if err := ac.resetTransport(false); err != nil {
grpclog.Printf("Failed to dial %s: %v; please retry.", ac.addr.Addr, err) grpclog.Printf("Failed to dial %s: %v; please retry.", ac.addr.Addr, err)
ac.tearDown(err) if err != errConnClosing {
// Keep this ac in cc.conns, to get the reason it's torn down.
ac.tearDown(err)
}
return return
} }
ac.transportMonitor() ac.transportMonitor()
@ -422,24 +464,48 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
} }
func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) (transport.ClientTransport, func(), error) { func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) (transport.ClientTransport, func(), error) {
addr, put, err := cc.dopts.balancer.Get(ctx, opts) var (
if err != nil { ac *addrConn
return nil, nil, toRPCErr(err) ok bool
} put func()
cc.mu.RLock() )
if cc.conns == nil { if cc.dopts.balancer == nil {
// If balancer is nil, there should be only one addrConn available.
cc.mu.RLock()
for _, ac = range cc.conns {
// Break after the first iteration to get the first addrConn.
ok = true
break
}
cc.mu.RUnlock()
} else {
var (
addr Address
err error
)
addr, put, err = cc.dopts.balancer.Get(ctx, opts)
if err != nil {
return nil, nil, toRPCErr(err)
}
cc.mu.RLock()
if cc.conns == nil {
cc.mu.RUnlock()
return nil, nil, toRPCErr(ErrClientConnClosing)
}
ac, ok = cc.conns[addr]
cc.mu.RUnlock() cc.mu.RUnlock()
return nil, nil, toRPCErr(ErrClientConnClosing)
} }
ac, ok := cc.conns[addr]
cc.mu.RUnlock()
if !ok { if !ok {
if put != nil { if put != nil {
put() put()
} }
return nil, nil, Errorf(codes.Internal, "grpc: failed to find the transport to send the rpc") return nil, nil, errConnClosing
} }
t, err := ac.wait(ctx, !opts.BlockingWait) // ac.wait should block on transient failure only if balancer is nil and RPC is non-failfast.
// - If RPC is failfast, ac.wait should not block.
// - If balancer is not nil, ac.wait should return errConnClosing on transient failure
// so that non-failfast RPCs will try to get a new transport instead of waiting on ac.
t, err := ac.wait(ctx, cc.dopts.balancer == nil && opts.BlockingWait)
if err != nil { if err != nil {
if put != nil { if put != nil {
put() put()
@ -451,6 +517,8 @@ func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions)
// Close tears down the ClientConn and all underlying connections. // Close tears down the ClientConn and all underlying connections.
func (cc *ClientConn) Close() error { func (cc *ClientConn) Close() error {
cc.cancel()
cc.mu.Lock() cc.mu.Lock()
if cc.conns == nil { if cc.conns == nil {
cc.mu.Unlock() cc.mu.Unlock()
@ -459,7 +527,9 @@ func (cc *ClientConn) Close() error {
conns := cc.conns conns := cc.conns
cc.conns = nil cc.conns = nil
cc.mu.Unlock() cc.mu.Unlock()
cc.dopts.balancer.Close() if cc.dopts.balancer != nil {
cc.dopts.balancer.Close()
}
for _, ac := range conns { for _, ac := range conns {
ac.tearDown(ErrClientConnClosing) ac.tearDown(ErrClientConnClosing)
} }
@ -468,11 +538,13 @@ func (cc *ClientConn) Close() error {
// addrConn is a network connection to a given address. // addrConn is a network connection to a given address.
type addrConn struct { type addrConn struct {
cc *ClientConn ctx context.Context
addr Address cancel context.CancelFunc
dopts dialOptions
shutdownChan chan struct{} cc *ClientConn
events trace.EventLog addr Address
dopts dialOptions
events trace.EventLog
mu sync.Mutex mu sync.Mutex
state ConnectivityState state ConnectivityState
@ -482,6 +554,9 @@ type addrConn struct {
// due to timeout. // due to timeout.
ready chan struct{} ready chan struct{}
transport transport.ClientTransport transport transport.ClientTransport
// The reason this addrConn is torn down.
tearDownErr error
} }
// printf records an event in ac's event log, unless ac has been closed. // printf records an event in ac's event log, unless ac has been closed.
@ -537,8 +612,7 @@ func (ac *addrConn) waitForStateChange(ctx context.Context, sourceState Connecti
} }
func (ac *addrConn) resetTransport(closeTransport bool) error { func (ac *addrConn) resetTransport(closeTransport bool) error {
var retries int for retries := 0; ; retries++ {
for {
ac.mu.Lock() ac.mu.Lock()
ac.printf("connecting") ac.printf("connecting")
if ac.state == Shutdown { if ac.state == Shutdown {
@ -558,13 +632,20 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
t.Close() t.Close()
} }
sleepTime := ac.dopts.bs.backoff(retries) sleepTime := ac.dopts.bs.backoff(retries)
ac.dopts.copts.Timeout = sleepTime timeout := minConnectTimeout
if sleepTime < minConnectTimeout { if timeout < sleepTime {
ac.dopts.copts.Timeout = minConnectTimeout timeout = sleepTime
} }
ctx, cancel := context.WithTimeout(ac.ctx, timeout)
connectTime := time.Now() connectTime := time.Now()
newTransport, err := transport.NewClientTransport(ac.addr.Addr, &ac.dopts.copts) newTransport, err := transport.NewClientTransport(ctx, ac.addr.Addr, ac.dopts.copts)
if err != nil { if err != nil {
cancel()
if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() {
return err
}
grpclog.Printf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, ac.addr)
ac.mu.Lock() ac.mu.Lock()
if ac.state == Shutdown { if ac.state == Shutdown {
// ac.tearDown(...) has been invoked. // ac.tearDown(...) has been invoked.
@ -579,17 +660,12 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
ac.ready = nil ac.ready = nil
} }
ac.mu.Unlock() ac.mu.Unlock()
sleepTime -= time.Since(connectTime)
if sleepTime < 0 {
sleepTime = 0
}
closeTransport = false closeTransport = false
select { select {
case <-time.After(sleepTime): case <-time.After(sleepTime - time.Since(connectTime)):
case <-ac.shutdownChan: case <-ac.ctx.Done():
return ac.ctx.Err()
} }
retries++
grpclog.Printf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, ac.addr)
continue continue
} }
ac.mu.Lock() ac.mu.Lock()
@ -607,7 +683,9 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
close(ac.ready) close(ac.ready)
ac.ready = nil ac.ready = nil
} }
ac.down = ac.cc.dopts.balancer.Up(ac.addr) if ac.cc.dopts.balancer != nil {
ac.down = ac.cc.dopts.balancer.Up(ac.addr)
}
ac.mu.Unlock() ac.mu.Unlock()
return nil return nil
} }
@ -621,14 +699,42 @@ func (ac *addrConn) transportMonitor() {
t := ac.transport t := ac.transport
ac.mu.Unlock() ac.mu.Unlock()
select { select {
// shutdownChan is needed to detect the teardown when // This is needed to detect the teardown when
// the addrConn is idle (i.e., no RPC in flight). // the addrConn is idle (i.e., no RPC in flight).
case <-ac.shutdownChan: case <-ac.ctx.Done():
select {
case <-t.Error():
t.Close()
default:
}
return
case <-t.GoAway():
// If GoAway happens without any network I/O error, ac is closed without shutting down the
// underlying transport (the transport will be closed when all the pending RPCs finished or
// failed.).
// If GoAway and some network I/O error happen concurrently, ac and its underlying transport
// are closed.
// In both cases, a new ac is created.
select {
case <-t.Error():
ac.cc.resetAddrConn(ac.addr, true, errNetworkIO)
default:
ac.cc.resetAddrConn(ac.addr, true, errConnDrain)
}
return return
case <-t.Error(): case <-t.Error():
select {
case <-ac.ctx.Done():
t.Close()
return
case <-t.GoAway():
ac.cc.resetAddrConn(ac.addr, true, errNetworkIO)
return
default:
}
ac.mu.Lock() ac.mu.Lock()
if ac.state == Shutdown { if ac.state == Shutdown {
// ac.tearDown(...) has been invoked. // ac has been shutdown.
ac.mu.Unlock() ac.mu.Unlock()
return return
} }
@ -640,6 +746,10 @@ func (ac *addrConn) transportMonitor() {
ac.printf("transport exiting: %v", err) ac.printf("transport exiting: %v", err)
ac.mu.Unlock() ac.mu.Unlock()
grpclog.Printf("grpc: addrConn.transportMonitor exits due to: %v", err) grpclog.Printf("grpc: addrConn.transportMonitor exits due to: %v", err)
if err != errConnClosing {
// Keep this ac in cc.conns, to get the reason it's torn down.
ac.tearDown(err)
}
return return
} }
} }
@ -647,21 +757,22 @@ func (ac *addrConn) transportMonitor() {
} }
// wait blocks until i) the new transport is up or ii) ctx is done or iii) ac is closed or // wait blocks until i) the new transport is up or ii) ctx is done or iii) ac is closed or
// iv) transport is in TransientFailure and the RPC is fail-fast. // iv) transport is in TransientFailure and blocking is false.
func (ac *addrConn) wait(ctx context.Context, failFast bool) (transport.ClientTransport, error) { func (ac *addrConn) wait(ctx context.Context, blocking bool) (transport.ClientTransport, error) {
for { for {
ac.mu.Lock() ac.mu.Lock()
switch { switch {
case ac.state == Shutdown: case ac.state == Shutdown:
err := ac.tearDownErr
ac.mu.Unlock() ac.mu.Unlock()
return nil, errConnClosing return nil, err
case ac.state == Ready: case ac.state == Ready:
ct := ac.transport ct := ac.transport
ac.mu.Unlock() ac.mu.Unlock()
return ct, nil return ct, nil
case ac.state == TransientFailure && failFast: case ac.state == TransientFailure && !blocking:
ac.mu.Unlock() ac.mu.Unlock()
return nil, Errorf(codes.Unavailable, "grpc: RPC failed fast due to transport failure") return nil, errConnClosing
default: default:
ready := ac.ready ready := ac.ready
if ready == nil { if ready == nil {
@ -683,24 +794,28 @@ func (ac *addrConn) wait(ctx context.Context, failFast bool) (transport.ClientTr
// TODO(zhaoq): Make this synchronous to avoid unbounded memory consumption in // TODO(zhaoq): Make this synchronous to avoid unbounded memory consumption in
// some edge cases (e.g., the caller opens and closes many addrConn's in a // some edge cases (e.g., the caller opens and closes many addrConn's in a
// tight loop. // tight loop.
// tearDown doesn't remove ac from ac.cc.conns.
func (ac *addrConn) tearDown(err error) { func (ac *addrConn) tearDown(err error) {
ac.cancel()
ac.mu.Lock() ac.mu.Lock()
defer func() { defer ac.mu.Unlock()
ac.mu.Unlock()
ac.cc.mu.Lock()
if ac.cc.conns != nil {
delete(ac.cc.conns, ac.addr)
}
ac.cc.mu.Unlock()
}()
if ac.state == Shutdown {
return
}
ac.state = Shutdown
if ac.down != nil { if ac.down != nil {
ac.down(downErrorf(false, false, "%v", err)) ac.down(downErrorf(false, false, "%v", err))
ac.down = nil ac.down = nil
} }
if err == errConnDrain && ac.transport != nil {
// GracefulClose(...) may be executed multiple times when
// i) receiving multiple GoAway frames from the server; or
// ii) there are concurrent name resolver/Balancer triggered
// address removal and GoAway.
ac.transport.GracefulClose()
}
if ac.state == Shutdown {
return
}
ac.state = Shutdown
ac.tearDownErr = err
ac.stateCV.Broadcast() ac.stateCV.Broadcast()
if ac.events != nil { if ac.events != nil {
ac.events.Finish() ac.events.Finish()
@ -710,15 +825,8 @@ func (ac *addrConn) tearDown(err error) {
close(ac.ready) close(ac.ready)
ac.ready = nil ac.ready = nil
} }
if ac.transport != nil { if ac.transport != nil && err != errConnDrain {
if err == errConnDrain { ac.transport.Close()
ac.transport.GracefulClose()
} else {
ac.transport.Close()
}
}
if ac.shutdownChan != nil {
close(ac.shutdownChan)
} }
return return
} }

View File

@ -44,7 +44,6 @@ import (
"io/ioutil" "io/ioutil"
"net" "net"
"strings" "strings"
"time"
"golang.org/x/net/context" "golang.org/x/net/context"
) )
@ -93,11 +92,12 @@ type TransportCredentials interface {
// ClientHandshake does the authentication handshake specified by the corresponding // ClientHandshake does the authentication handshake specified by the corresponding
// authentication protocol on rawConn for clients. It returns the authenticated // authentication protocol on rawConn for clients. It returns the authenticated
// connection and the corresponding auth information about the connection. // connection and the corresponding auth information about the connection.
ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (net.Conn, AuthInfo, error) // Implementations must use the provided context to implement timely cancellation.
ClientHandshake(context.Context, string, net.Conn) (net.Conn, AuthInfo, error)
// ServerHandshake does the authentication handshake for servers. It returns // ServerHandshake does the authentication handshake for servers. It returns
// the authenticated connection and the corresponding auth information about // the authenticated connection and the corresponding auth information about
// the connection. // the connection.
ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) ServerHandshake(net.Conn) (net.Conn, AuthInfo, error)
// Info provides the ProtocolInfo of this TransportCredentials. // Info provides the ProtocolInfo of this TransportCredentials.
Info() ProtocolInfo Info() ProtocolInfo
} }
@ -136,42 +136,28 @@ func (c *tlsCreds) RequireTransportSecurity() bool {
return true return true
} }
type timeoutError struct{} func (c *tlsCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ AuthInfo, err error) {
func (timeoutError) Error() string { return "credentials: Dial timed out" }
func (timeoutError) Timeout() bool { return true }
func (timeoutError) Temporary() bool { return true }
func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, _ AuthInfo, err error) {
// borrow some code from tls.DialWithDialer
var errChannel chan error
if timeout != 0 {
errChannel = make(chan error, 2)
time.AfterFunc(timeout, func() {
errChannel <- timeoutError{}
})
}
// use local cfg to avoid clobbering ServerName if using multiple endpoints // use local cfg to avoid clobbering ServerName if using multiple endpoints
cfg := *c.config cfg := cloneTLSConfig(c.config)
if c.config.ServerName == "" { if cfg.ServerName == "" {
colonPos := strings.LastIndex(addr, ":") colonPos := strings.LastIndex(addr, ":")
if colonPos == -1 { if colonPos == -1 {
colonPos = len(addr) colonPos = len(addr)
} }
cfg.ServerName = addr[:colonPos] cfg.ServerName = addr[:colonPos]
} }
conn := tls.Client(rawConn, &cfg) conn := tls.Client(rawConn, cfg)
if timeout == 0 { errChannel := make(chan error, 1)
err = conn.Handshake() go func() {
} else { errChannel <- conn.Handshake()
go func() { }()
errChannel <- conn.Handshake() select {
}() case err := <-errChannel:
err = <-errChannel if err != nil {
} return nil, nil, err
if err != nil { }
rawConn.Close() case <-ctx.Done():
return nil, nil, err return nil, nil, ctx.Err()
} }
// TODO(zhaoq): Omit the auth info for client now. It is more for // TODO(zhaoq): Omit the auth info for client now. It is more for
// information than anything else. // information than anything else.
@ -181,7 +167,6 @@ func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.D
func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) { func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) {
conn := tls.Server(rawConn, c.config) conn := tls.Server(rawConn, c.config)
if err := conn.Handshake(); err != nil { if err := conn.Handshake(); err != nil {
rawConn.Close()
return nil, nil, err return nil, nil, err
} }
return conn, TLSInfo{conn.ConnectionState()}, nil return conn, TLSInfo{conn.ConnectionState()}, nil
@ -189,7 +174,7 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
// NewTLS uses c to construct a TransportCredentials based on TLS. // NewTLS uses c to construct a TransportCredentials based on TLS.
func NewTLS(c *tls.Config) TransportCredentials { func NewTLS(c *tls.Config) TransportCredentials {
tc := &tlsCreds{c} tc := &tlsCreds{cloneTLSConfig(c)}
tc.config.NextProtos = alpnProtoStr tc.config.NextProtos = alpnProtoStr
return tc return tc
} }

View File

@ -0,0 +1,76 @@
// +build go1.7
/*
*
* Copyright 2016, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package credentials
import (
"crypto/tls"
)
// cloneTLSConfig returns a shallow clone of the exported
// fields of cfg, ignoring the unexported sync.Once, which
// contains a mutex and must not be copied.
//
// If cfg is nil, a new zero tls.Config is returned.
//
// TODO replace this function with official clone function.
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return &tls.Config{
Rand: cfg.Rand,
Time: cfg.Time,
Certificates: cfg.Certificates,
NameToCertificate: cfg.NameToCertificate,
GetCertificate: cfg.GetCertificate,
RootCAs: cfg.RootCAs,
NextProtos: cfg.NextProtos,
ServerName: cfg.ServerName,
ClientAuth: cfg.ClientAuth,
ClientCAs: cfg.ClientCAs,
InsecureSkipVerify: cfg.InsecureSkipVerify,
CipherSuites: cfg.CipherSuites,
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
SessionTicketsDisabled: cfg.SessionTicketsDisabled,
SessionTicketKey: cfg.SessionTicketKey,
ClientSessionCache: cfg.ClientSessionCache,
MinVersion: cfg.MinVersion,
MaxVersion: cfg.MaxVersion,
CurvePreferences: cfg.CurvePreferences,
DynamicRecordSizingDisabled: cfg.DynamicRecordSizingDisabled,
Renegotiation: cfg.Renegotiation,
}
}

View File

@ -0,0 +1,74 @@
// +build !go1.7
/*
*
* Copyright 2016, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package credentials
import (
"crypto/tls"
)
// cloneTLSConfig returns a shallow clone of the exported
// fields of cfg, ignoring the unexported sync.Once, which
// contains a mutex and must not be copied.
//
// If cfg is nil, a new zero tls.Config is returned.
//
// TODO replace this function with official clone function.
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return &tls.Config{
Rand: cfg.Rand,
Time: cfg.Time,
Certificates: cfg.Certificates,
NameToCertificate: cfg.NameToCertificate,
GetCertificate: cfg.GetCertificate,
RootCAs: cfg.RootCAs,
NextProtos: cfg.NextProtos,
ServerName: cfg.ServerName,
ClientAuth: cfg.ClientAuth,
ClientCAs: cfg.ClientCAs,
InsecureSkipVerify: cfg.InsecureSkipVerify,
CipherSuites: cfg.CipherSuites,
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
SessionTicketsDisabled: cfg.SessionTicketsDisabled,
SessionTicketKey: cfg.SessionTicketKey,
ClientSessionCache: cfg.ClientSessionCache,
MinVersion: cfg.MinVersion,
MaxVersion: cfg.MaxVersion,
CurvePreferences: cfg.CurvePreferences,
}
}

View File

@ -60,15 +60,21 @@ func encodeKeyValue(k, v string) (string, string) {
// DecodeKeyValue returns the original key and value corresponding to the // DecodeKeyValue returns the original key and value corresponding to the
// encoded data in k, v. // encoded data in k, v.
// If k is a binary header and v contains comma, v is split on comma before decoded,
// and the decoded v will be joined with comma before returned.
func DecodeKeyValue(k, v string) (string, string, error) { func DecodeKeyValue(k, v string) (string, string, error) {
if !strings.HasSuffix(k, binHdrSuffix) { if !strings.HasSuffix(k, binHdrSuffix) {
return k, v, nil return k, v, nil
} }
val, err := base64.StdEncoding.DecodeString(v) vvs := strings.Split(v, ",")
if err != nil { for i, vv := range vvs {
return "", "", err val, err := base64.StdEncoding.DecodeString(vv)
if err != nil {
return "", "", err
}
vvs[i] = string(val)
} }
return k, string(val), nil return k, strings.Join(vvs, ","), nil
} }
// MD is a mapping from metadata keys to values. Users should use the following // MD is a mapping from metadata keys to values. Users should use the following

View File

@ -227,7 +227,7 @@ type parser struct {
// No other error values or types must be returned, which also means // No other error values or types must be returned, which also means
// that the underlying io.Reader must not return an incompatible // that the underlying io.Reader must not return an incompatible
// error. // error.
func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) { func (p *parser) recvMsg(maxMsgSize int) (pf payloadFormat, msg []byte, err error) {
if _, err := io.ReadFull(p.r, p.header[:]); err != nil { if _, err := io.ReadFull(p.r, p.header[:]); err != nil {
return 0, nil, err return 0, nil, err
} }
@ -238,6 +238,9 @@ func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) {
if length == 0 { if length == 0 {
return pf, nil, nil return pf, nil, nil
} }
if length > uint32(maxMsgSize) {
return 0, nil, Errorf(codes.Internal, "grpc: received message length %d exceeding the max size %d", length, maxMsgSize)
}
// TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead // TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead
// of making it for each message: // of making it for each message:
msg = make([]byte, int(length)) msg = make([]byte, int(length))
@ -308,8 +311,8 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er
return nil return nil
} }
func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}) error { func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxMsgSize int) error {
pf, d, err := p.recvMsg() pf, d, err := p.recvMsg(maxMsgSize)
if err != nil { if err != nil {
return err return err
} }
@ -319,11 +322,16 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{
if pf == compressionMade { if pf == compressionMade {
d, err = dc.Do(bytes.NewReader(d)) d, err = dc.Do(bytes.NewReader(d))
if err != nil { if err != nil {
return transport.StreamErrorf(codes.Internal, "grpc: failed to decompress the received message %v", err) return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
} }
} }
if len(d) > maxMsgSize {
// TODO: Revisit the error code. Currently keep it consistent with java
// implementation.
return Errorf(codes.Internal, "grpc: received a message of %d bytes exceeding %d limit", len(d), maxMsgSize)
}
if err := c.Unmarshal(d, m); err != nil { if err := c.Unmarshal(d, m); err != nil {
return transport.StreamErrorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err) return Errorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err)
} }
return nil return nil
} }

View File

@ -89,9 +89,13 @@ type service struct {
type Server struct { type Server struct {
opts options opts options
mu sync.Mutex // guards following mu sync.Mutex // guards following
lis map[net.Listener]bool lis map[net.Listener]bool
conns map[io.Closer]bool conns map[io.Closer]bool
drain bool
// A CondVar to let GracefulStop() blocks until all the pending RPCs are finished
// and all the transport goes away.
cv *sync.Cond
m map[string]*service // service name -> service info m map[string]*service // service name -> service info
events trace.EventLog events trace.EventLog
} }
@ -101,12 +105,15 @@ type options struct {
codec Codec codec Codec
cp Compressor cp Compressor
dc Decompressor dc Decompressor
maxMsgSize int
unaryInt UnaryServerInterceptor unaryInt UnaryServerInterceptor
streamInt StreamServerInterceptor streamInt StreamServerInterceptor
maxConcurrentStreams uint32 maxConcurrentStreams uint32
useHandlerImpl bool // use http.Handler-based server useHandlerImpl bool // use http.Handler-based server
} }
var defaultMaxMsgSize = 1024 * 1024 * 4 // use 4MB as the default message size limit
// A ServerOption sets options. // A ServerOption sets options.
type ServerOption func(*options) type ServerOption func(*options)
@ -117,20 +124,28 @@ func CustomCodec(codec Codec) ServerOption {
} }
} }
// RPCCompressor returns a ServerOption that sets a compressor for outbound message. // RPCCompressor returns a ServerOption that sets a compressor for outbound messages.
func RPCCompressor(cp Compressor) ServerOption { func RPCCompressor(cp Compressor) ServerOption {
return func(o *options) { return func(o *options) {
o.cp = cp o.cp = cp
} }
} }
// RPCDecompressor returns a ServerOption that sets a decompressor for inbound message. // RPCDecompressor returns a ServerOption that sets a decompressor for inbound messages.
func RPCDecompressor(dc Decompressor) ServerOption { func RPCDecompressor(dc Decompressor) ServerOption {
return func(o *options) { return func(o *options) {
o.dc = dc o.dc = dc
} }
} }
// MaxMsgSize returns a ServerOption to set the max message size in bytes for inbound mesages.
// If this is not set, gRPC uses the default 4MB.
func MaxMsgSize(m int) ServerOption {
return func(o *options) {
o.maxMsgSize = m
}
}
// MaxConcurrentStreams returns a ServerOption that will apply a limit on the number // MaxConcurrentStreams returns a ServerOption that will apply a limit on the number
// of concurrent streams to each ServerTransport. // of concurrent streams to each ServerTransport.
func MaxConcurrentStreams(n uint32) ServerOption { func MaxConcurrentStreams(n uint32) ServerOption {
@ -173,6 +188,7 @@ func StreamInterceptor(i StreamServerInterceptor) ServerOption {
// started to accept requests yet. // started to accept requests yet.
func NewServer(opt ...ServerOption) *Server { func NewServer(opt ...ServerOption) *Server {
var opts options var opts options
opts.maxMsgSize = defaultMaxMsgSize
for _, o := range opt { for _, o := range opt {
o(&opts) o(&opts)
} }
@ -186,6 +202,7 @@ func NewServer(opt ...ServerOption) *Server {
conns: make(map[io.Closer]bool), conns: make(map[io.Closer]bool),
m: make(map[string]*service), m: make(map[string]*service),
} }
s.cv = sync.NewCond(&s.mu)
if EnableTracing { if EnableTracing {
_, file, line, _ := runtime.Caller(1) _, file, line, _ := runtime.Caller(1)
s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line)) s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line))
@ -264,8 +281,8 @@ type ServiceInfo struct {
// GetServiceInfo returns a map from service names to ServiceInfo. // GetServiceInfo returns a map from service names to ServiceInfo.
// Service names include the package names, in the form of <package>.<service>. // Service names include the package names, in the form of <package>.<service>.
func (s *Server) GetServiceInfo() map[string]*ServiceInfo { func (s *Server) GetServiceInfo() map[string]ServiceInfo {
ret := make(map[string]*ServiceInfo) ret := make(map[string]ServiceInfo)
for n, srv := range s.m { for n, srv := range s.m {
methods := make([]MethodInfo, 0, len(srv.md)+len(srv.sd)) methods := make([]MethodInfo, 0, len(srv.md)+len(srv.sd))
for m := range srv.md { for m := range srv.md {
@ -283,7 +300,7 @@ func (s *Server) GetServiceInfo() map[string]*ServiceInfo {
}) })
} }
ret[n] = &ServiceInfo{ ret[n] = ServiceInfo{
Methods: methods, Methods: methods,
Metadata: srv.mdata, Metadata: srv.mdata,
} }
@ -468,7 +485,7 @@ func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Strea
func (s *Server) addConn(c io.Closer) bool { func (s *Server) addConn(c io.Closer) bool {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if s.conns == nil { if s.conns == nil || s.drain {
return false return false
} }
s.conns[c] = true s.conns[c] = true
@ -480,6 +497,7 @@ func (s *Server) removeConn(c io.Closer) {
defer s.mu.Unlock() defer s.mu.Unlock()
if s.conns != nil { if s.conns != nil {
delete(s.conns, c) delete(s.conns, c)
s.cv.Signal()
} }
} }
@ -520,7 +538,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
} }
p := &parser{r: stream} p := &parser{r: stream}
for { for {
pf, req, err := p.recvMsg() pf, req, err := p.recvMsg(s.opts.maxMsgSize)
if err == io.EOF { if err == io.EOF {
// The entire stream is done (for unary RPC only). // The entire stream is done (for unary RPC only).
return err return err
@ -530,6 +548,10 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
} }
if err != nil { if err != nil {
switch err := err.(type) { switch err := err.(type) {
case *rpcError:
if err := t.WriteStatus(stream, err.code, err.desc); err != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
}
case transport.ConnectionError: case transport.ConnectionError:
// Nothing to do here. // Nothing to do here.
case transport.StreamError: case transport.StreamError:
@ -569,6 +591,12 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
return err return err
} }
} }
if len(req) > s.opts.maxMsgSize {
// TODO: Revisit the error code. Currently keep it consistent with
// java implementation.
statusCode = codes.Internal
statusDesc = fmt.Sprintf("grpc: server received a message of %d bytes exceeding %d limit", len(req), s.opts.maxMsgSize)
}
if err := s.opts.codec.Unmarshal(req, v); err != nil { if err := s.opts.codec.Unmarshal(req, v); err != nil {
return err return err
} }
@ -628,13 +656,14 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
stream.SetSendCompress(s.opts.cp.Type()) stream.SetSendCompress(s.opts.cp.Type())
} }
ss := &serverStream{ ss := &serverStream{
t: t, t: t,
s: stream, s: stream,
p: &parser{r: stream}, p: &parser{r: stream},
codec: s.opts.codec, codec: s.opts.codec,
cp: s.opts.cp, cp: s.opts.cp,
dc: s.opts.dc, dc: s.opts.dc,
trInfo: trInfo, maxMsgSize: s.opts.maxMsgSize,
trInfo: trInfo,
} }
if ss.cp != nil { if ss.cp != nil {
ss.cbuf = new(bytes.Buffer) ss.cbuf = new(bytes.Buffer)
@ -766,14 +795,16 @@ func (s *Server) Stop() {
s.mu.Lock() s.mu.Lock()
listeners := s.lis listeners := s.lis
s.lis = nil s.lis = nil
cs := s.conns st := s.conns
s.conns = nil s.conns = nil
// interrupt GracefulStop if Stop and GracefulStop are called concurrently.
s.cv.Signal()
s.mu.Unlock() s.mu.Unlock()
for lis := range listeners { for lis := range listeners {
lis.Close() lis.Close()
} }
for c := range cs { for c := range st {
c.Close() c.Close()
} }
@ -785,6 +816,32 @@ func (s *Server) Stop() {
s.mu.Unlock() s.mu.Unlock()
} }
// GracefulStop stops the gRPC server gracefully. It stops the server to accept new
// connections and RPCs and blocks until all the pending RPCs are finished.
func (s *Server) GracefulStop() {
s.mu.Lock()
defer s.mu.Unlock()
if s.drain == true || s.conns == nil {
return
}
s.drain = true
for lis := range s.lis {
lis.Close()
}
s.lis = nil
for c := range s.conns {
c.(transport.ServerTransport).Drain()
}
for len(s.conns) != 0 {
s.cv.Wait()
}
s.conns = nil
if s.events != nil {
s.events.Finish()
s.events = nil
}
}
func init() { func init() {
internal.TestingCloseConns = func(arg interface{}) { internal.TestingCloseConns = func(arg interface{}) {
arg.(*Server).testingCloseConns() arg.(*Server).testingCloseConns()

View File

@ -37,6 +37,7 @@ import (
"bytes" "bytes"
"errors" "errors"
"io" "io"
"math"
"sync" "sync"
"time" "time"
@ -84,12 +85,9 @@ type ClientStream interface {
// Header returns the header metadata received from the server if there // Header returns the header metadata received from the server if there
// is any. It blocks if the metadata is not ready to read. // is any. It blocks if the metadata is not ready to read.
Header() (metadata.MD, error) Header() (metadata.MD, error)
// Trailer returns the trailer metadata from the server. It must be called // Trailer returns the trailer metadata from the server, if there is any.
// after stream.Recv() returns non-nil error (including io.EOF) for // It must only be called after stream.CloseAndRecv has returned, or
// bi-directional streaming and server streaming or stream.CloseAndRecv() // stream.Recv has returned a non-nil error (including io.EOF).
// returns for client streaming in order to receive trailer metadata if
// present. Otherwise, it could returns an empty MD even though trailer
// is present.
Trailer() metadata.MD Trailer() metadata.MD
// CloseSend closes the send direction of the stream. It closes the stream // CloseSend closes the send direction of the stream. It closes the stream
// when non-nil error is met. // when non-nil error is met.
@ -99,11 +97,10 @@ type ClientStream interface {
// NewClientStream creates a new Stream for the client side. This is called // NewClientStream creates a new Stream for the client side. This is called
// by generated code. // by generated code.
func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) { func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
var ( var (
t transport.ClientTransport t transport.ClientTransport
s *transport.Stream s *transport.Stream
err error
put func() put func()
) )
c := defaultCallInfo c := defaultCallInfo
@ -120,27 +117,24 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
if cc.dopts.cp != nil { if cc.dopts.cp != nil {
callHdr.SendCompress = cc.dopts.cp.Type() callHdr.SendCompress = cc.dopts.cp.Type()
} }
cs := &clientStream{ var trInfo traceInfo
opts: opts, if EnableTracing {
c: c, trInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method)
desc: desc, trInfo.firstLine.client = true
codec: cc.dopts.codec,
cp: cc.dopts.cp,
dc: cc.dopts.dc,
tracing: EnableTracing,
}
if cc.dopts.cp != nil {
callHdr.SendCompress = cc.dopts.cp.Type()
cs.cbuf = new(bytes.Buffer)
}
if cs.tracing {
cs.trInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method)
cs.trInfo.firstLine.client = true
if deadline, ok := ctx.Deadline(); ok { if deadline, ok := ctx.Deadline(); ok {
cs.trInfo.firstLine.deadline = deadline.Sub(time.Now()) trInfo.firstLine.deadline = deadline.Sub(time.Now())
} }
cs.trInfo.tr.LazyLog(&cs.trInfo.firstLine, false) trInfo.tr.LazyLog(&trInfo.firstLine, false)
ctx = trace.NewContext(ctx, cs.trInfo.tr) ctx = trace.NewContext(ctx, trInfo.tr)
defer func() {
if err != nil {
// Need to call tr.finish() if error is returned.
// Because tr will not be returned to caller.
trInfo.tr.LazyPrintf("RPC: [%v]", err)
trInfo.tr.SetError()
trInfo.tr.Finish()
}
}()
} }
gopts := BalancerGetOptions{ gopts := BalancerGetOptions{
BlockingWait: !c.failFast, BlockingWait: !c.failFast,
@ -168,9 +162,8 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
put() put()
put = nil put = nil
} }
if _, ok := err.(transport.ConnectionError); ok { if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
if c.failFast { if c.failFast {
cs.finish(err)
return nil, toRPCErr(err) return nil, toRPCErr(err)
} }
continue continue
@ -179,16 +172,43 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
} }
break break
} }
cs.put = put cs := &clientStream{
cs.t = t opts: opts,
cs.s = s c: c,
cs.p = &parser{r: s} desc: desc,
// Listen on ctx.Done() to detect cancellation when there is no pending codec: cc.dopts.codec,
// I/O operations on this stream. cp: cc.dopts.cp,
dc: cc.dopts.dc,
put: put,
t: t,
s: s,
p: &parser{r: s},
tracing: EnableTracing,
trInfo: trInfo,
}
if cc.dopts.cp != nil {
cs.cbuf = new(bytes.Buffer)
}
// Listen on ctx.Done() to detect cancellation and s.Done() to detect normal termination
// when there is no pending I/O operations on this stream.
go func() { go func() {
select { select {
case <-t.Error(): case <-t.Error():
// Incur transport error, simply exit. // Incur transport error, simply exit.
case <-s.Done():
// TODO: The trace of the RPC is terminated here when there is no pending
// I/O, which is probably not the optimal solution.
if s.StatusCode() == codes.OK {
cs.finish(nil)
} else {
cs.finish(Errorf(s.StatusCode(), "%s", s.StatusDesc()))
}
cs.closeTransportStream(nil)
case <-s.GoAway():
cs.finish(errConnDrain)
cs.closeTransportStream(errConnDrain)
case <-s.Context().Done(): case <-s.Context().Done():
err := s.Context().Err() err := s.Context().Err()
cs.finish(err) cs.finish(err)
@ -251,7 +271,17 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
if err != nil { if err != nil {
cs.finish(err) cs.finish(err)
} }
if err == nil || err == io.EOF { if err == nil {
return
}
if err == io.EOF {
// Specialize the process for server streaming. SendMesg is only called
// once when creating the stream object. io.EOF needs to be skipped when
// the rpc is early finished (before the stream object is created.).
// TODO: It is probably better to move this into the generated code.
if !cs.desc.ClientStreams && cs.desc.ServerStreams {
err = nil
}
return return
} }
if _, ok := err.(transport.ConnectionError); !ok { if _, ok := err.(transport.ConnectionError); !ok {
@ -272,7 +302,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
} }
func (cs *clientStream) RecvMsg(m interface{}) (err error) { func (cs *clientStream) RecvMsg(m interface{}) (err error) {
err = recv(cs.p, cs.codec, cs.s, cs.dc, m) err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32)
defer func() { defer func() {
// err != nil indicates the termination of the stream. // err != nil indicates the termination of the stream.
if err != nil { if err != nil {
@ -291,7 +321,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
return return
} }
// Special handling for client streaming rpc. // Special handling for client streaming rpc.
err = recv(cs.p, cs.codec, cs.s, cs.dc, m) err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32)
cs.closeTransportStream(err) cs.closeTransportStream(err)
if err == nil { if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>")) return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
@ -326,7 +356,7 @@ func (cs *clientStream) CloseSend() (err error) {
} }
}() }()
if err == nil || err == io.EOF { if err == nil || err == io.EOF {
return return nil
} }
if _, ok := err.(transport.ConnectionError); !ok { if _, ok := err.(transport.ConnectionError); !ok {
cs.closeTransportStream(err) cs.closeTransportStream(err)
@ -392,6 +422,7 @@ type serverStream struct {
cp Compressor cp Compressor
dc Decompressor dc Decompressor
cbuf *bytes.Buffer cbuf *bytes.Buffer
maxMsgSize int
statusCode codes.Code statusCode codes.Code
statusDesc string statusDesc string
trInfo *traceInfo trInfo *traceInfo
@ -458,5 +489,5 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
ss.mu.Unlock() ss.mu.Unlock()
} }
}() }()
return recv(ss.p, ss.codec, ss.s, ss.dc, m) return recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxMsgSize)
} }

View File

@ -72,6 +72,11 @@ type resetStream struct {
func (*resetStream) item() {} func (*resetStream) item() {}
type goAway struct {
}
func (*goAway) item() {}
type flushIO struct { type flushIO struct {
} }

46
cmd/vendor/google.golang.org/grpc/transport/go16.go generated vendored Normal file
View File

@ -0,0 +1,46 @@
// +build go1.6,!go1.7
/*
* Copyright 2016, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package transport
import (
"net"
"golang.org/x/net/context"
)
// dialContext connects to the address on the named network.
func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
return (&net.Dialer{Cancel: ctx.Done()}).Dial(network, address)
}

46
cmd/vendor/google.golang.org/grpc/transport/go17.go generated vendored Normal file
View File

@ -0,0 +1,46 @@
// +build go1.7
/*
* Copyright 2016, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package transport
import (
"net"
"golang.org/x/net/context"
)
// dialContext connects to the address on the named network.
func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
return (&net.Dialer{}).DialContext(ctx, network, address)
}

View File

@ -83,7 +83,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr
} }
if v := r.Header.Get("grpc-timeout"); v != "" { if v := r.Header.Get("grpc-timeout"); v != "" {
to, err := timeoutDecode(v) to, err := decodeTimeout(v)
if err != nil { if err != nil {
return nil, StreamErrorf(codes.Internal, "malformed time-out: %v", err) return nil, StreamErrorf(codes.Internal, "malformed time-out: %v", err)
} }
@ -194,7 +194,7 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, statusCode codes.Code,
h := ht.rw.Header() h := ht.rw.Header()
h.Set("Grpc-Status", fmt.Sprintf("%d", statusCode)) h.Set("Grpc-Status", fmt.Sprintf("%d", statusCode))
if statusDesc != "" { if statusDesc != "" {
h.Set("Grpc-Message", statusDesc) h.Set("Grpc-Message", encodeGrpcMessage(statusDesc))
} }
if md := s.Trailer(); len(md) > 0 { if md := s.Trailer(); len(md) > 0 {
for k, vv := range md { for k, vv := range md {
@ -370,6 +370,10 @@ func (ht *serverHandlerTransport) runStream() {
} }
} }
func (ht *serverHandlerTransport) Drain() {
panic("Drain() is not implemented")
}
// mapRecvMsgError returns the non-nil err into the appropriate // mapRecvMsgError returns the non-nil err into the appropriate
// error value as expected by callers of *grpc.parser.recvMsg. // error value as expected by callers of *grpc.parser.recvMsg.
// In particular, in can only be: // In particular, in can only be:

View File

@ -35,6 +35,7 @@ package transport
import ( import (
"bytes" "bytes"
"fmt"
"io" "io"
"math" "math"
"net" "net"
@ -71,6 +72,9 @@ type http2Client struct {
shutdownChan chan struct{} shutdownChan chan struct{}
// errorChan is closed to notify the I/O error to the caller. // errorChan is closed to notify the I/O error to the caller.
errorChan chan struct{} errorChan chan struct{}
// goAway is closed to notify the upper layer (i.e., addrConn.transportMonitor)
// that the server sent GoAway on this transport.
goAway chan struct{}
framer *framer framer *framer
hBuf *bytes.Buffer // the buffer for HPACK encoding hBuf *bytes.Buffer // the buffer for HPACK encoding
@ -97,41 +101,44 @@ type http2Client struct {
maxStreams int maxStreams int
// the per-stream outbound flow control window size set by the peer. // the per-stream outbound flow control window size set by the peer.
streamSendQuota uint32 streamSendQuota uint32
// goAwayID records the Last-Stream-ID in the GoAway frame from the server.
goAwayID uint32
// prevGoAway ID records the Last-Stream-ID in the previous GOAway frame.
prevGoAwayID uint32
}
func dial(fn func(context.Context, string) (net.Conn, error), ctx context.Context, addr string) (net.Conn, error) {
if fn != nil {
return fn(ctx, addr)
}
return dialContext(ctx, "tcp", addr)
} }
// newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 // newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2
// and starts to receive messages on it. Non-nil error returns if construction // and starts to receive messages on it. Non-nil error returns if construction
// fails. // fails.
func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err error) { func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ ClientTransport, err error) {
if opts.Dialer == nil {
// Set the default Dialer.
opts.Dialer = func(addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout("tcp", addr, timeout)
}
}
scheme := "http" scheme := "http"
startT := time.Now() conn, connErr := dial(opts.Dialer, ctx, addr)
timeout := opts.Timeout
conn, connErr := opts.Dialer(addr, timeout)
if connErr != nil { if connErr != nil {
return nil, ConnectionErrorf("transport: %v", connErr) return nil, ConnectionErrorf(true, connErr, "transport: %v", connErr)
} }
var authInfo credentials.AuthInfo // Any further errors will close the underlying connection
if opts.TransportCredentials != nil { defer func(conn net.Conn) {
scheme = "https"
if timeout > 0 {
timeout -= time.Since(startT)
}
conn, authInfo, connErr = opts.TransportCredentials.ClientHandshake(addr, conn, timeout)
}
if connErr != nil {
return nil, ConnectionErrorf("transport: %v", connErr)
}
defer func() {
if err != nil { if err != nil {
conn.Close() conn.Close()
} }
}() }(conn)
var authInfo credentials.AuthInfo
if creds := opts.TransportCredentials; creds != nil {
scheme = "https"
conn, authInfo, connErr = creds.ClientHandshake(ctx, addr, conn)
}
if connErr != nil {
// Credentials handshake error is not a temporary error (unless the error
// was the connection closing).
return nil, ConnectionErrorf(connErr == io.EOF, connErr, "transport: %v", connErr)
}
ua := primaryUA ua := primaryUA
if opts.UserAgent != "" { if opts.UserAgent != "" {
ua = opts.UserAgent + " " + ua ua = opts.UserAgent + " " + ua
@ -147,6 +154,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
writableChan: make(chan int, 1), writableChan: make(chan int, 1),
shutdownChan: make(chan struct{}), shutdownChan: make(chan struct{}),
errorChan: make(chan struct{}), errorChan: make(chan struct{}),
goAway: make(chan struct{}),
framer: newFramer(conn), framer: newFramer(conn),
hBuf: &buf, hBuf: &buf,
hEnc: hpack.NewEncoder(&buf), hEnc: hpack.NewEncoder(&buf),
@ -168,11 +176,11 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
n, err := t.conn.Write(clientPreface) n, err := t.conn.Write(clientPreface)
if err != nil { if err != nil {
t.Close() t.Close()
return nil, ConnectionErrorf("transport: %v", err) return nil, ConnectionErrorf(true, err, "transport: %v", err)
} }
if n != len(clientPreface) { if n != len(clientPreface) {
t.Close() t.Close()
return nil, ConnectionErrorf("transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface)) return nil, ConnectionErrorf(true, err, "transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface))
} }
if initialWindowSize != defaultWindowSize { if initialWindowSize != defaultWindowSize {
err = t.framer.writeSettings(true, http2.Setting{ err = t.framer.writeSettings(true, http2.Setting{
@ -184,13 +192,13 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
} }
if err != nil { if err != nil {
t.Close() t.Close()
return nil, ConnectionErrorf("transport: %v", err) return nil, ConnectionErrorf(true, err, "transport: %v", err)
} }
// Adjust the connection flow control window if needed. // Adjust the connection flow control window if needed.
if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 { if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 {
if err := t.framer.writeWindowUpdate(true, 0, delta); err != nil { if err := t.framer.writeWindowUpdate(true, 0, delta); err != nil {
t.Close() t.Close()
return nil, ConnectionErrorf("transport: %v", err) return nil, ConnectionErrorf(true, err, "transport: %v", err)
} }
} }
go t.controller() go t.controller()
@ -202,6 +210,8 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
// TODO(zhaoq): Handle uint32 overflow of Stream.id. // TODO(zhaoq): Handle uint32 overflow of Stream.id.
s := &Stream{ s := &Stream{
id: t.nextID, id: t.nextID,
done: make(chan struct{}),
goAway: make(chan struct{}),
method: callHdr.Method, method: callHdr.Method,
sendCompress: callHdr.SendCompress, sendCompress: callHdr.SendCompress,
buf: newRecvBuffer(), buf: newRecvBuffer(),
@ -216,8 +226,9 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
// Make a stream be able to cancel the pending operations by itself. // Make a stream be able to cancel the pending operations by itself.
s.ctx, s.cancel = context.WithCancel(ctx) s.ctx, s.cancel = context.WithCancel(ctx)
s.dec = &recvBufferReader{ s.dec = &recvBufferReader{
ctx: s.ctx, ctx: s.ctx,
recv: s.buf, goAway: s.goAway,
recv: s.buf,
} }
return s return s
} }
@ -271,6 +282,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
t.mu.Unlock() t.mu.Unlock()
return nil, ErrConnClosing return nil, ErrConnClosing
} }
if t.state == draining {
t.mu.Unlock()
return nil, ErrStreamDrain
}
if t.state != reachable { if t.state != reachable {
t.mu.Unlock() t.mu.Unlock()
return nil, ErrConnClosing return nil, ErrConnClosing
@ -278,7 +293,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
checkStreamsQuota := t.streamsQuota != nil checkStreamsQuota := t.streamsQuota != nil
t.mu.Unlock() t.mu.Unlock()
if checkStreamsQuota { if checkStreamsQuota {
sq, err := wait(ctx, t.shutdownChan, t.streamsQuota.acquire()) sq, err := wait(ctx, nil, nil, t.shutdownChan, t.streamsQuota.acquire())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -287,7 +302,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
t.streamsQuota.add(sq - 1) t.streamsQuota.add(sq - 1)
} }
} }
if _, err := wait(ctx, t.shutdownChan, t.writableChan); err != nil { if _, err := wait(ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
// Return the quota back now because there is no stream returned to the caller. // Return the quota back now because there is no stream returned to the caller.
if _, ok := err.(StreamError); ok && checkStreamsQuota { if _, ok := err.(StreamError); ok && checkStreamsQuota {
t.streamsQuota.add(1) t.streamsQuota.add(1)
@ -295,6 +310,15 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
return nil, err return nil, err
} }
t.mu.Lock() t.mu.Lock()
if t.state == draining {
t.mu.Unlock()
if checkStreamsQuota {
t.streamsQuota.add(1)
}
// Need to make t writable again so that the rpc in flight can still proceed.
t.writableChan <- 0
return nil, ErrStreamDrain
}
if t.state != reachable { if t.state != reachable {
t.mu.Unlock() t.mu.Unlock()
return nil, ErrConnClosing return nil, ErrConnClosing
@ -329,7 +353,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress}) t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress})
} }
if timeout > 0 { if timeout > 0 {
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: timeoutEncode(timeout)}) t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: encodeTimeout(timeout)})
} }
for k, v := range authData { for k, v := range authData {
// Capital header names are illegal in HTTP/2. // Capital header names are illegal in HTTP/2.
@ -384,7 +408,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
} }
if err != nil { if err != nil {
t.notifyError(err) t.notifyError(err)
return nil, ConnectionErrorf("transport: %v", err) return nil, ConnectionErrorf(true, err, "transport: %v", err)
} }
} }
t.writableChan <- 0 t.writableChan <- 0
@ -403,22 +427,17 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
if t.streamsQuota != nil { if t.streamsQuota != nil {
updateStreams = true updateStreams = true
} }
if t.state == draining && len(t.activeStreams) == 1 { delete(t.activeStreams, s.id)
if t.state == draining && len(t.activeStreams) == 0 {
// The transport is draining and s is the last live stream on t. // The transport is draining and s is the last live stream on t.
t.mu.Unlock() t.mu.Unlock()
t.Close() t.Close()
return return
} }
delete(t.activeStreams, s.id)
t.mu.Unlock() t.mu.Unlock()
if updateStreams { if updateStreams {
t.streamsQuota.add(1) t.streamsQuota.add(1)
} }
// In case stream sending and receiving are invoked in separate
// goroutines (e.g., bi-directional streaming), the caller needs
// to call cancel on the stream to interrupt the blocking on
// other goroutines.
s.cancel()
s.mu.Lock() s.mu.Lock()
if q := s.fc.resetPendingData(); q > 0 { if q := s.fc.resetPendingData(); q > 0 {
if n := t.fc.onRead(q); n > 0 { if n := t.fc.onRead(q); n > 0 {
@ -445,13 +464,13 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
// accessed any more. // accessed any more.
func (t *http2Client) Close() (err error) { func (t *http2Client) Close() (err error) {
t.mu.Lock() t.mu.Lock()
if t.state == reachable {
close(t.errorChan)
}
if t.state == closing { if t.state == closing {
t.mu.Unlock() t.mu.Unlock()
return return
} }
if t.state == reachable || t.state == draining {
close(t.errorChan)
}
t.state = closing t.state = closing
t.mu.Unlock() t.mu.Unlock()
close(t.shutdownChan) close(t.shutdownChan)
@ -475,10 +494,35 @@ func (t *http2Client) Close() (err error) {
func (t *http2Client) GracefulClose() error { func (t *http2Client) GracefulClose() error {
t.mu.Lock() t.mu.Lock()
if t.state == closing { switch t.state {
case unreachable:
// The server may close the connection concurrently. t is not available for
// any streams. Close it now.
t.mu.Unlock()
t.Close()
return nil
case closing:
t.mu.Unlock() t.mu.Unlock()
return nil return nil
} }
// Notify the streams which were initiated after the server sent GOAWAY.
select {
case <-t.goAway:
n := t.prevGoAwayID
if n == 0 && t.nextID > 1 {
n = t.nextID - 2
}
m := t.goAwayID + 2
if m == 2 {
m = 1
}
for i := m; i <= n; i += 2 {
if s, ok := t.activeStreams[i]; ok {
close(s.goAway)
}
}
default:
}
if t.state == draining { if t.state == draining {
t.mu.Unlock() t.mu.Unlock()
return nil return nil
@ -504,15 +548,15 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
size := http2MaxFrameLen size := http2MaxFrameLen
s.sendQuotaPool.add(0) s.sendQuotaPool.add(0)
// Wait until the stream has some quota to send the data. // Wait until the stream has some quota to send the data.
sq, err := wait(s.ctx, t.shutdownChan, s.sendQuotaPool.acquire()) sq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, s.sendQuotaPool.acquire())
if err != nil { if err != nil {
return err return err
} }
t.sendQuotaPool.add(0) t.sendQuotaPool.add(0)
// Wait until the transport has some quota to send the data. // Wait until the transport has some quota to send the data.
tq, err := wait(s.ctx, t.shutdownChan, t.sendQuotaPool.acquire()) tq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, t.sendQuotaPool.acquire())
if err != nil { if err != nil {
if _, ok := err.(StreamError); ok { if _, ok := err.(StreamError); ok || err == io.EOF {
t.sendQuotaPool.cancel() t.sendQuotaPool.cancel()
} }
return err return err
@ -544,8 +588,8 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
// Indicate there is a writer who is about to write a data frame. // Indicate there is a writer who is about to write a data frame.
t.framer.adjustNumWriters(1) t.framer.adjustNumWriters(1)
// Got some quota. Try to acquire writing privilege on the transport. // Got some quota. Try to acquire writing privilege on the transport.
if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { if _, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, t.writableChan); err != nil {
if _, ok := err.(StreamError); ok { if _, ok := err.(StreamError); ok || err == io.EOF {
// Return the connection quota back. // Return the connection quota back.
t.sendQuotaPool.add(len(p)) t.sendQuotaPool.add(len(p))
} }
@ -578,7 +622,7 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
// invoked. // invoked.
if err := t.framer.writeData(forceFlush, s.id, endStream, p); err != nil { if err := t.framer.writeData(forceFlush, s.id, endStream, p); err != nil {
t.notifyError(err) t.notifyError(err)
return ConnectionErrorf("transport: %v", err) return ConnectionErrorf(true, err, "transport: %v", err)
} }
if t.framer.adjustNumWriters(-1) == 0 { if t.framer.adjustNumWriters(-1) == 0 {
t.framer.flushWrite() t.framer.flushWrite()
@ -593,11 +637,7 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
} }
s.mu.Lock() s.mu.Lock()
if s.state != streamDone { if s.state != streamDone {
if s.state == streamReadDone { s.state = streamWriteDone
s.state = streamDone
} else {
s.state = streamWriteDone
}
} }
s.mu.Unlock() s.mu.Unlock()
return nil return nil
@ -630,7 +670,7 @@ func (t *http2Client) updateWindow(s *Stream, n uint32) {
func (t *http2Client) handleData(f *http2.DataFrame) { func (t *http2Client) handleData(f *http2.DataFrame) {
size := len(f.Data()) size := len(f.Data())
if err := t.fc.onData(uint32(size)); err != nil { if err := t.fc.onData(uint32(size)); err != nil {
t.notifyError(ConnectionErrorf("%v", err)) t.notifyError(ConnectionErrorf(true, err, "%v", err))
return return
} }
// Select the right stream to dispatch. // Select the right stream to dispatch.
@ -655,6 +695,7 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
s.state = streamDone s.state = streamDone
s.statusCode = codes.Internal s.statusCode = codes.Internal
s.statusDesc = err.Error() s.statusDesc = err.Error()
close(s.done)
s.mu.Unlock() s.mu.Unlock()
s.write(recvMsg{err: io.EOF}) s.write(recvMsg{err: io.EOF})
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeFlowControl}) t.controlBuf.put(&resetStream{s.id, http2.ErrCodeFlowControl})
@ -672,13 +713,14 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
// the read direction is closed, and set the status appropriately. // the read direction is closed, and set the status appropriately.
if f.FrameHeader.Flags.Has(http2.FlagDataEndStream) { if f.FrameHeader.Flags.Has(http2.FlagDataEndStream) {
s.mu.Lock() s.mu.Lock()
if s.state == streamWriteDone { if s.state == streamDone {
s.state = streamDone s.mu.Unlock()
} else { return
s.state = streamReadDone
} }
s.state = streamDone
s.statusCode = codes.Internal s.statusCode = codes.Internal
s.statusDesc = "server closed the stream without sending trailers" s.statusDesc = "server closed the stream without sending trailers"
close(s.done)
s.mu.Unlock() s.mu.Unlock()
s.write(recvMsg{err: io.EOF}) s.write(recvMsg{err: io.EOF})
} }
@ -704,6 +746,8 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) {
grpclog.Println("transport: http2Client.handleRSTStream found no mapped gRPC status for the received http2 error ", f.ErrCode) grpclog.Println("transport: http2Client.handleRSTStream found no mapped gRPC status for the received http2 error ", f.ErrCode)
s.statusCode = codes.Unknown s.statusCode = codes.Unknown
} }
s.statusDesc = fmt.Sprintf("stream terminated by RST_STREAM with error code: %d", f.ErrCode)
close(s.done)
s.mu.Unlock() s.mu.Unlock()
s.write(recvMsg{err: io.EOF}) s.write(recvMsg{err: io.EOF})
} }
@ -728,7 +772,32 @@ func (t *http2Client) handlePing(f *http2.PingFrame) {
} }
func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
// TODO(zhaoq): GoAwayFrame handler to be implemented t.mu.Lock()
if t.state == reachable || t.state == draining {
if f.LastStreamID > 0 && f.LastStreamID%2 != 1 {
t.mu.Unlock()
t.notifyError(ConnectionErrorf(true, nil, "received illegal http2 GOAWAY frame: stream ID %d is even", f.LastStreamID))
return
}
select {
case <-t.goAway:
id := t.goAwayID
// t.goAway has been closed (i.e.,multiple GoAways).
if id < f.LastStreamID {
t.mu.Unlock()
t.notifyError(ConnectionErrorf(true, nil, "received illegal http2 GOAWAY frame: previously recv GOAWAY frame with LastStramID %d, currently recv %d", id, f.LastStreamID))
return
}
t.prevGoAwayID = id
t.goAwayID = f.LastStreamID
t.mu.Unlock()
return
default:
}
t.goAwayID = f.LastStreamID
close(t.goAway)
}
t.mu.Unlock()
} }
func (t *http2Client) handleWindowUpdate(f *http2.WindowUpdateFrame) { func (t *http2Client) handleWindowUpdate(f *http2.WindowUpdateFrame) {
@ -780,11 +849,11 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
if len(state.mdata) > 0 { if len(state.mdata) > 0 {
s.trailer = state.mdata s.trailer = state.mdata
} }
s.state = streamDone
s.statusCode = state.statusCode s.statusCode = state.statusCode
s.statusDesc = state.statusDesc s.statusDesc = state.statusDesc
close(s.done)
s.state = streamDone
s.mu.Unlock() s.mu.Unlock()
s.write(recvMsg{err: io.EOF}) s.write(recvMsg{err: io.EOF})
} }
@ -937,13 +1006,22 @@ func (t *http2Client) Error() <-chan struct{} {
return t.errorChan return t.errorChan
} }
func (t *http2Client) GoAway() <-chan struct{} {
return t.goAway
}
func (t *http2Client) notifyError(err error) { func (t *http2Client) notifyError(err error) {
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock()
// make sure t.errorChan is closed only once. // make sure t.errorChan is closed only once.
if t.state == draining {
t.mu.Unlock()
t.Close()
return
}
if t.state == reachable { if t.state == reachable {
t.state = unreachable t.state = unreachable
close(t.errorChan) close(t.errorChan)
grpclog.Printf("transport: http2Client.notifyError got notified that the client transport was broken %v.", err) grpclog.Printf("transport: http2Client.notifyError got notified that the client transport was broken %v.", err)
} }
t.mu.Unlock()
} }

View File

@ -111,12 +111,12 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthI
Val: uint32(initialWindowSize)}) Val: uint32(initialWindowSize)})
} }
if err := framer.writeSettings(true, settings...); err != nil { if err := framer.writeSettings(true, settings...); err != nil {
return nil, ConnectionErrorf("transport: %v", err) return nil, ConnectionErrorf(true, err, "transport: %v", err)
} }
// Adjust the connection flow control window if needed. // Adjust the connection flow control window if needed.
if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 { if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 {
if err := framer.writeWindowUpdate(true, 0, delta); err != nil { if err := framer.writeWindowUpdate(true, 0, delta); err != nil {
return nil, ConnectionErrorf("transport: %v", err) return nil, ConnectionErrorf(true, err, "transport: %v", err)
} }
} }
var buf bytes.Buffer var buf bytes.Buffer
@ -142,7 +142,7 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthI
} }
// operateHeader takes action on the decoded headers. // operateHeader takes action on the decoded headers.
func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) { func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) (close bool) {
buf := newRecvBuffer() buf := newRecvBuffer()
s := &Stream{ s := &Stream{
id: frame.Header().StreamID, id: frame.Header().StreamID,
@ -205,6 +205,13 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream}) t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream})
return return
} }
if s.id%2 != 1 || s.id <= t.maxStreamID {
t.mu.Unlock()
// illegal gRPC stream id.
grpclog.Println("transport: http2Server.HandleStreams received an illegal stream id: ", s.id)
return true
}
t.maxStreamID = s.id
s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota)) s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota))
t.activeStreams[s.id] = s t.activeStreams[s.id] = s
t.mu.Unlock() t.mu.Unlock()
@ -212,6 +219,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
t.updateWindow(s, uint32(n)) t.updateWindow(s, uint32(n))
} }
handle(s) handle(s)
return
} }
// HandleStreams receives incoming streams using the given handler. This is // HandleStreams receives incoming streams using the given handler. This is
@ -231,6 +239,10 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
} }
frame, err := t.framer.readFrame() frame, err := t.framer.readFrame()
if err == io.EOF || err == io.ErrUnexpectedEOF {
t.Close()
return
}
if err != nil { if err != nil {
grpclog.Printf("transport: http2Server.HandleStreams failed to read frame: %v", err) grpclog.Printf("transport: http2Server.HandleStreams failed to read frame: %v", err)
t.Close() t.Close()
@ -257,20 +269,20 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
t.controlBuf.put(&resetStream{se.StreamID, se.Code}) t.controlBuf.put(&resetStream{se.StreamID, se.Code})
continue continue
} }
if err == io.EOF || err == io.ErrUnexpectedEOF {
t.Close()
return
}
grpclog.Printf("transport: http2Server.HandleStreams failed to read frame: %v", err)
t.Close() t.Close()
return return
} }
switch frame := frame.(type) { switch frame := frame.(type) {
case *http2.MetaHeadersFrame: case *http2.MetaHeadersFrame:
id := frame.Header().StreamID if t.operateHeaders(frame, handle) {
if id%2 != 1 || id <= t.maxStreamID {
// illegal gRPC stream id.
grpclog.Println("transport: http2Server.HandleStreams received an illegal stream id: ", id)
t.Close() t.Close()
break break
} }
t.maxStreamID = id
t.operateHeaders(frame, handle)
case *http2.DataFrame: case *http2.DataFrame:
t.handleData(frame) t.handleData(frame)
case *http2.RSTStreamFrame: case *http2.RSTStreamFrame:
@ -282,7 +294,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
case *http2.WindowUpdateFrame: case *http2.WindowUpdateFrame:
t.handleWindowUpdate(frame) t.handleWindowUpdate(frame)
case *http2.GoAwayFrame: case *http2.GoAwayFrame:
break // TODO: Handle GoAway from the client appropriately.
default: default:
grpclog.Printf("transport: http2Server.HandleStreams found unhandled frame type %v.", frame) grpclog.Printf("transport: http2Server.HandleStreams found unhandled frame type %v.", frame)
} }
@ -364,11 +376,7 @@ func (t *http2Server) handleData(f *http2.DataFrame) {
// Received the end of stream from the client. // Received the end of stream from the client.
s.mu.Lock() s.mu.Lock()
if s.state != streamDone { if s.state != streamDone {
if s.state == streamWriteDone { s.state = streamReadDone
s.state = streamDone
} else {
s.state = streamReadDone
}
} }
s.mu.Unlock() s.mu.Unlock()
s.write(recvMsg{err: io.EOF}) s.write(recvMsg{err: io.EOF})
@ -440,7 +448,7 @@ func (t *http2Server) writeHeaders(s *Stream, b *bytes.Buffer, endStream bool) e
} }
if err != nil { if err != nil {
t.Close() t.Close()
return ConnectionErrorf("transport: %v", err) return ConnectionErrorf(true, err, "transport: %v", err)
} }
} }
return nil return nil
@ -455,7 +463,7 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
} }
s.headerOk = true s.headerOk = true
s.mu.Unlock() s.mu.Unlock()
if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
return err return err
} }
t.hBuf.Reset() t.hBuf.Reset()
@ -495,7 +503,7 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s
headersSent = true headersSent = true
} }
s.mu.Unlock() s.mu.Unlock()
if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
return err return err
} }
t.hBuf.Reset() t.hBuf.Reset()
@ -508,7 +516,7 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s
Name: "grpc-status", Name: "grpc-status",
Value: strconv.Itoa(int(statusCode)), Value: strconv.Itoa(int(statusCode)),
}) })
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: statusDesc}) t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(statusDesc)})
// Attach the trailer metadata. // Attach the trailer metadata.
for k, v := range s.trailer { for k, v := range s.trailer {
// Clients don't tolerate reading restricted headers after some non restricted ones were sent. // Clients don't tolerate reading restricted headers after some non restricted ones were sent.
@ -544,7 +552,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
} }
s.mu.Unlock() s.mu.Unlock()
if writeHeaderFrame { if writeHeaderFrame {
if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
return err return err
} }
t.hBuf.Reset() t.hBuf.Reset()
@ -560,7 +568,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
} }
if err := t.framer.writeHeaders(false, p); err != nil { if err := t.framer.writeHeaders(false, p); err != nil {
t.Close() t.Close()
return ConnectionErrorf("transport: %v", err) return ConnectionErrorf(true, err, "transport: %v", err)
} }
t.writableChan <- 0 t.writableChan <- 0
} }
@ -572,13 +580,13 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
size := http2MaxFrameLen size := http2MaxFrameLen
s.sendQuotaPool.add(0) s.sendQuotaPool.add(0)
// Wait until the stream has some quota to send the data. // Wait until the stream has some quota to send the data.
sq, err := wait(s.ctx, t.shutdownChan, s.sendQuotaPool.acquire()) sq, err := wait(s.ctx, nil, nil, t.shutdownChan, s.sendQuotaPool.acquire())
if err != nil { if err != nil {
return err return err
} }
t.sendQuotaPool.add(0) t.sendQuotaPool.add(0)
// Wait until the transport has some quota to send the data. // Wait until the transport has some quota to send the data.
tq, err := wait(s.ctx, t.shutdownChan, t.sendQuotaPool.acquire()) tq, err := wait(s.ctx, nil, nil, t.shutdownChan, t.sendQuotaPool.acquire())
if err != nil { if err != nil {
if _, ok := err.(StreamError); ok { if _, ok := err.(StreamError); ok {
t.sendQuotaPool.cancel() t.sendQuotaPool.cancel()
@ -604,7 +612,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
t.framer.adjustNumWriters(1) t.framer.adjustNumWriters(1)
// Got some quota. Try to acquire writing privilege on the // Got some quota. Try to acquire writing privilege on the
// transport. // transport.
if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
if _, ok := err.(StreamError); ok { if _, ok := err.(StreamError); ok {
// Return the connection quota back. // Return the connection quota back.
t.sendQuotaPool.add(ps) t.sendQuotaPool.add(ps)
@ -634,7 +642,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
} }
if err := t.framer.writeData(forceFlush, s.id, false, p); err != nil { if err := t.framer.writeData(forceFlush, s.id, false, p); err != nil {
t.Close() t.Close()
return ConnectionErrorf("transport: %v", err) return ConnectionErrorf(true, err, "transport: %v", err)
} }
if t.framer.adjustNumWriters(-1) == 0 { if t.framer.adjustNumWriters(-1) == 0 {
t.framer.flushWrite() t.framer.flushWrite()
@ -679,6 +687,17 @@ func (t *http2Server) controller() {
} }
case *resetStream: case *resetStream:
t.framer.writeRSTStream(true, i.streamID, i.code) t.framer.writeRSTStream(true, i.streamID, i.code)
case *goAway:
t.mu.Lock()
if t.state == closing {
t.mu.Unlock()
// The transport is closing.
return
}
sid := t.maxStreamID
t.state = draining
t.mu.Unlock()
t.framer.writeGoAway(true, sid, http2.ErrCodeNo, nil)
case *flushIO: case *flushIO:
t.framer.flushWrite() t.framer.flushWrite()
case *ping: case *ping:
@ -724,6 +743,9 @@ func (t *http2Server) Close() (err error) {
func (t *http2Server) closeStream(s *Stream) { func (t *http2Server) closeStream(s *Stream) {
t.mu.Lock() t.mu.Lock()
delete(t.activeStreams, s.id) delete(t.activeStreams, s.id)
if t.state == draining && len(t.activeStreams) == 0 {
defer t.Close()
}
t.mu.Unlock() t.mu.Unlock()
// In case stream sending and receiving are invoked in separate // In case stream sending and receiving are invoked in separate
// goroutines (e.g., bi-directional streaming), cancel needs to be // goroutines (e.g., bi-directional streaming), cancel needs to be
@ -746,3 +768,7 @@ func (t *http2Server) closeStream(s *Stream) {
func (t *http2Server) RemoteAddr() net.Addr { func (t *http2Server) RemoteAddr() net.Addr {
return t.conn.RemoteAddr() return t.conn.RemoteAddr()
} }
func (t *http2Server) Drain() {
t.controlBuf.put(&goAway{})
}

View File

@ -35,6 +35,7 @@ package transport
import ( import (
"bufio" "bufio"
"bytes"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -174,11 +175,11 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) {
} }
d.statusCode = codes.Code(code) d.statusCode = codes.Code(code)
case "grpc-message": case "grpc-message":
d.statusDesc = f.Value d.statusDesc = decodeGrpcMessage(f.Value)
case "grpc-timeout": case "grpc-timeout":
d.timeoutSet = true d.timeoutSet = true
var err error var err error
d.timeout, err = timeoutDecode(f.Value) d.timeout, err = decodeTimeout(f.Value)
if err != nil { if err != nil {
d.setErr(StreamErrorf(codes.Internal, "transport: malformed time-out: %v", err)) d.setErr(StreamErrorf(codes.Internal, "transport: malformed time-out: %v", err))
return return
@ -251,7 +252,7 @@ func div(d, r time.Duration) int64 {
} }
// TODO(zhaoq): It is the simplistic and not bandwidth efficient. Improve it. // TODO(zhaoq): It is the simplistic and not bandwidth efficient. Improve it.
func timeoutEncode(t time.Duration) string { func encodeTimeout(t time.Duration) string {
if d := div(t, time.Nanosecond); d <= maxTimeoutValue { if d := div(t, time.Nanosecond); d <= maxTimeoutValue {
return strconv.FormatInt(d, 10) + "n" return strconv.FormatInt(d, 10) + "n"
} }
@ -271,7 +272,7 @@ func timeoutEncode(t time.Duration) string {
return strconv.FormatInt(div(t, time.Hour), 10) + "H" return strconv.FormatInt(div(t, time.Hour), 10) + "H"
} }
func timeoutDecode(s string) (time.Duration, error) { func decodeTimeout(s string) (time.Duration, error) {
size := len(s) size := len(s)
if size < 2 { if size < 2 {
return 0, fmt.Errorf("transport: timeout string is too short: %q", s) return 0, fmt.Errorf("transport: timeout string is too short: %q", s)
@ -288,6 +289,80 @@ func timeoutDecode(s string) (time.Duration, error) {
return d * time.Duration(t), nil return d * time.Duration(t), nil
} }
const (
spaceByte = ' '
tildaByte = '~'
percentByte = '%'
)
// encodeGrpcMessage is used to encode status code in header field
// "grpc-message".
// It checks to see if each individual byte in msg is an
// allowable byte, and then either percent encoding or passing it through.
// When percent encoding, the byte is converted into hexadecimal notation
// with a '%' prepended.
func encodeGrpcMessage(msg string) string {
if msg == "" {
return ""
}
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
c := msg[i]
if !(c >= spaceByte && c < tildaByte && c != percentByte) {
return encodeGrpcMessageUnchecked(msg)
}
}
return msg
}
func encodeGrpcMessageUnchecked(msg string) string {
var buf bytes.Buffer
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
c := msg[i]
if c >= spaceByte && c < tildaByte && c != percentByte {
buf.WriteByte(c)
} else {
buf.WriteString(fmt.Sprintf("%%%02X", c))
}
}
return buf.String()
}
// decodeGrpcMessage decodes the msg encoded by encodeGrpcMessage.
func decodeGrpcMessage(msg string) string {
if msg == "" {
return ""
}
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
if msg[i] == percentByte && i+2 < lenMsg {
return decodeGrpcMessageUnchecked(msg)
}
}
return msg
}
func decodeGrpcMessageUnchecked(msg string) string {
var buf bytes.Buffer
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
c := msg[i]
if c == percentByte && i+2 < lenMsg {
parsed, err := strconv.ParseInt(msg[i+1:i+3], 16, 8)
if err != nil {
buf.WriteByte(c)
} else {
buf.WriteByte(byte(parsed))
i += 2
}
} else {
buf.WriteByte(c)
}
}
return buf.String()
}
type framer struct { type framer struct {
numWriters int32 numWriters int32
reader io.Reader reader io.Reader

View File

@ -0,0 +1,51 @@
// +build !go1.6
/*
* Copyright 2016, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package transport
import (
"net"
"time"
"golang.org/x/net/context"
)
// dialContext connects to the address on the named network.
func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
var dialer net.Dialer
if deadline, ok := ctx.Deadline(); ok {
dialer.Timeout = deadline.Sub(time.Now())
}
return dialer.Dial(network, address)
}

View File

@ -44,7 +44,6 @@ import (
"io" "io"
"net" "net"
"sync" "sync"
"time"
"golang.org/x/net/context" "golang.org/x/net/context"
"golang.org/x/net/trace" "golang.org/x/net/trace"
@ -120,10 +119,11 @@ func (b *recvBuffer) get() <-chan item {
// recvBufferReader implements io.Reader interface to read the data from // recvBufferReader implements io.Reader interface to read the data from
// recvBuffer. // recvBuffer.
type recvBufferReader struct { type recvBufferReader struct {
ctx context.Context ctx context.Context
recv *recvBuffer goAway chan struct{}
last *bytes.Reader // Stores the remaining data in the previous calls. recv *recvBuffer
err error last *bytes.Reader // Stores the remaining data in the previous calls.
err error
} }
// Read reads the next len(p) bytes from last. If last is drained, it tries to // Read reads the next len(p) bytes from last. If last is drained, it tries to
@ -141,6 +141,8 @@ func (r *recvBufferReader) Read(p []byte) (n int, err error) {
select { select {
case <-r.ctx.Done(): case <-r.ctx.Done():
return 0, ContextErr(r.ctx.Err()) return 0, ContextErr(r.ctx.Err())
case <-r.goAway:
return 0, ErrStreamDrain
case i := <-r.recv.get(): case i := <-r.recv.get():
r.recv.load() r.recv.load()
m := i.(*recvMsg) m := i.(*recvMsg)
@ -158,7 +160,7 @@ const (
streamActive streamState = iota streamActive streamState = iota
streamWriteDone // EndStream sent streamWriteDone // EndStream sent
streamReadDone // EndStream received streamReadDone // EndStream received
streamDone // sendDone and recvDone or RSTStreamFrame is sent or received. streamDone // the entire stream is finished.
) )
// Stream represents an RPC in the transport layer. // Stream represents an RPC in the transport layer.
@ -169,6 +171,10 @@ type Stream struct {
// ctx is the associated context of the stream. // ctx is the associated context of the stream.
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
// done is closed when the final status arrives.
done chan struct{}
// goAway is closed when the server sent GoAways signal before this stream was initiated.
goAway chan struct{}
// method records the associated RPC method of the stream. // method records the associated RPC method of the stream.
method string method string
recvCompress string recvCompress string
@ -214,6 +220,18 @@ func (s *Stream) SetSendCompress(str string) {
s.sendCompress = str s.sendCompress = str
} }
// Done returns a chanel which is closed when it receives the final status
// from the server.
func (s *Stream) Done() <-chan struct{} {
return s.done
}
// GoAway returns a channel which is closed when the server sent GoAways signal
// before this stream was initiated.
func (s *Stream) GoAway() <-chan struct{} {
return s.goAway
}
// Header acquires the key-value pairs of header metadata once it // Header acquires the key-value pairs of header metadata once it
// is available. It blocks until i) the metadata is ready or ii) there is no // is available. It blocks until i) the metadata is ready or ii) there is no
// header metadata or iii) the stream is cancelled/expired. // header metadata or iii) the stream is cancelled/expired.
@ -221,6 +239,8 @@ func (s *Stream) Header() (metadata.MD, error) {
select { select {
case <-s.ctx.Done(): case <-s.ctx.Done():
return nil, ContextErr(s.ctx.Err()) return nil, ContextErr(s.ctx.Err())
case <-s.goAway:
return nil, ErrStreamDrain
case <-s.headerChan: case <-s.headerChan:
return s.header.Copy(), nil return s.header.Copy(), nil
} }
@ -335,19 +355,17 @@ type ConnectOptions struct {
// UserAgent is the application user agent. // UserAgent is the application user agent.
UserAgent string UserAgent string
// Dialer specifies how to dial a network address. // Dialer specifies how to dial a network address.
Dialer func(string, time.Duration) (net.Conn, error) Dialer func(context.Context, string) (net.Conn, error)
// PerRPCCredentials stores the PerRPCCredentials required to issue RPCs. // PerRPCCredentials stores the PerRPCCredentials required to issue RPCs.
PerRPCCredentials []credentials.PerRPCCredentials PerRPCCredentials []credentials.PerRPCCredentials
// TransportCredentials stores the Authenticator required to setup a client connection. // TransportCredentials stores the Authenticator required to setup a client connection.
TransportCredentials credentials.TransportCredentials TransportCredentials credentials.TransportCredentials
// Timeout specifies the timeout for dialing a ClientTransport.
Timeout time.Duration
} }
// NewClientTransport establishes the transport with the required ConnectOptions // NewClientTransport establishes the transport with the required ConnectOptions
// and returns it to the caller. // and returns it to the caller.
func NewClientTransport(target string, opts *ConnectOptions) (ClientTransport, error) { func NewClientTransport(ctx context.Context, target string, opts ConnectOptions) (ClientTransport, error) {
return newHTTP2Client(target, opts) return newHTTP2Client(ctx, target, opts)
} }
// Options provides additional hints and information for message // Options provides additional hints and information for message
@ -417,6 +435,11 @@ type ClientTransport interface {
// and create a new one) in error case. It should not return nil // and create a new one) in error case. It should not return nil
// once the transport is initiated. // once the transport is initiated.
Error() <-chan struct{} Error() <-chan struct{}
// GoAway returns a channel that is closed when ClientTranspor
// receives the draining signal from the server (e.g., GOAWAY frame in
// HTTP/2).
GoAway() <-chan struct{}
} }
// ServerTransport is the common interface for all gRPC server-side transport // ServerTransport is the common interface for all gRPC server-side transport
@ -448,6 +471,9 @@ type ServerTransport interface {
// RemoteAddr returns the remote network address. // RemoteAddr returns the remote network address.
RemoteAddr() net.Addr RemoteAddr() net.Addr
// Drain notifies the client this ServerTransport stops accepting new RPCs.
Drain()
} }
// StreamErrorf creates an StreamError with the specified error code and description. // StreamErrorf creates an StreamError with the specified error code and description.
@ -459,9 +485,11 @@ func StreamErrorf(c codes.Code, format string, a ...interface{}) StreamError {
} }
// ConnectionErrorf creates an ConnectionError with the specified error description. // ConnectionErrorf creates an ConnectionError with the specified error description.
func ConnectionErrorf(format string, a ...interface{}) ConnectionError { func ConnectionErrorf(temp bool, e error, format string, a ...interface{}) ConnectionError {
return ConnectionError{ return ConnectionError{
Desc: fmt.Sprintf(format, a...), Desc: fmt.Sprintf(format, a...),
temp: temp,
err: e,
} }
} }
@ -469,14 +497,36 @@ func ConnectionErrorf(format string, a ...interface{}) ConnectionError {
// entire connection and the retry of all the active streams. // entire connection and the retry of all the active streams.
type ConnectionError struct { type ConnectionError struct {
Desc string Desc string
temp bool
err error
} }
func (e ConnectionError) Error() string { func (e ConnectionError) Error() string {
return fmt.Sprintf("connection error: desc = %q", e.Desc) return fmt.Sprintf("connection error: desc = %q", e.Desc)
} }
// ErrConnClosing indicates that the transport is closing. // Temporary indicates if this connection error is temporary or fatal.
var ErrConnClosing = ConnectionError{Desc: "transport is closing"} func (e ConnectionError) Temporary() bool {
return e.temp
}
// Origin returns the original error of this connection error.
func (e ConnectionError) Origin() error {
// Never return nil error here.
// If the original error is nil, return itself.
if e.err == nil {
return e
}
return e.err
}
var (
// ErrConnClosing indicates that the transport is closing.
ErrConnClosing = ConnectionError{Desc: "transport is closing", temp: true}
// ErrStreamDrain indicates that the stream is rejected by the server because
// the server stops accepting new RPCs.
ErrStreamDrain = StreamErrorf(codes.Unavailable, "the server stops accepting new RPCs")
)
// StreamError is an error that only affects one stream within a connection. // StreamError is an error that only affects one stream within a connection.
type StreamError struct { type StreamError struct {
@ -501,12 +551,25 @@ func ContextErr(err error) StreamError {
// wait blocks until it can receive from ctx.Done, closing, or proceed. // wait blocks until it can receive from ctx.Done, closing, or proceed.
// If it receives from ctx.Done, it returns 0, the StreamError for ctx.Err. // If it receives from ctx.Done, it returns 0, the StreamError for ctx.Err.
// If it receives from done, it returns 0, io.EOF if ctx is not done; otherwise
// it return the StreamError for ctx.Err.
// If it receives from goAway, it returns 0, ErrStreamDrain.
// If it receives from closing, it returns 0, ErrConnClosing. // If it receives from closing, it returns 0, ErrConnClosing.
// If it receives from proceed, it returns the received integer, nil. // If it receives from proceed, it returns the received integer, nil.
func wait(ctx context.Context, closing <-chan struct{}, proceed <-chan int) (int, error) { func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <-chan int) (int, error) {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return 0, ContextErr(ctx.Err()) return 0, ContextErr(ctx.Err())
case <-done:
// User cancellation has precedence.
select {
case <-ctx.Done():
return 0, ContextErr(ctx.Err())
default:
}
return 0, io.EOF
case <-goAway:
return 0, ErrStreamDrain
case <-closing: case <-closing:
return 0, ErrConnClosing return 0, ErrConnClosing
case i := <-proceed: case i := <-proceed:

View File

@ -85,13 +85,7 @@ func getPeersFlagValue(c *cli.Context) []string {
} }
func getDomainDiscoveryFlagValue(c *cli.Context) ([]string, error) { func getDomainDiscoveryFlagValue(c *cli.Context) ([]string, error) {
domainstr := c.GlobalString("discovery-srv") domainstr, insecure := getDiscoveryDomain(c)
// Use an environment variable if nothing was supplied on the
// command line
if domainstr == "" {
domainstr = os.Getenv("ETCDCTL_DISCOVERY_SRV")
}
// If we still don't have domain discovery, return nothing // If we still don't have domain discovery, return nothing
if domainstr == "" { if domainstr == "" {
@ -103,8 +97,30 @@ func getDomainDiscoveryFlagValue(c *cli.Context) ([]string, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if insecure {
return eps, err
}
// strip insecure connections
ret := []string{}
for _, ep := range eps {
if strings.HasPrefix("http://", ep) {
fmt.Fprintf(os.Stderr, "ignoring discovered insecure endpoint %q\n", ep)
continue
}
ret = append(ret, ep)
}
return ret, err
}
return eps, err func getDiscoveryDomain(c *cli.Context) (domainstr string, insecure bool) {
domainstr = c.GlobalString("discovery-srv")
// Use an environment variable if nothing was supplied on the
// command line
if domainstr == "" {
domainstr = os.Getenv("ETCDCTL_DISCOVERY_SRV")
}
insecure = c.GlobalBool("insecure-discovery") || (os.Getenv("ETCDCTL_INSECURE_DISCOVERY") != "")
return domainstr, insecure
} }
func getEndpoints(c *cli.Context) ([]string, error) { func getEndpoints(c *cli.Context) ([]string, error) {
@ -151,10 +167,15 @@ func getTransport(c *cli.Context) (*http.Transport, error) {
keyfile = os.Getenv("ETCDCTL_KEY_FILE") keyfile = os.Getenv("ETCDCTL_KEY_FILE")
} }
discoveryDomain, insecure := getDiscoveryDomain(c)
if insecure {
discoveryDomain = ""
}
tls := transport.TLSInfo{ tls := transport.TLSInfo{
CAFile: cafile, CAFile: cafile,
CertFile: certfile, CertFile: certfile,
KeyFile: keyfile, KeyFile: keyfile,
ServerName: discoveryDomain,
} }
dialTimeout := defaultDialTimeout dialTimeout := defaultDialTimeout

View File

@ -39,6 +39,7 @@ func Start() {
cli.BoolFlag{Name: "no-sync", Usage: "don't synchronize cluster information before sending request"}, cli.BoolFlag{Name: "no-sync", Usage: "don't synchronize cluster information before sending request"},
cli.StringFlag{Name: "output, o", Value: "simple", Usage: "output response in the given format (`simple`, `extended` or `json`)"}, cli.StringFlag{Name: "output, o", Value: "simple", Usage: "output response in the given format (`simple`, `extended` or `json`)"},
cli.StringFlag{Name: "discovery-srv, D", Usage: "domain name to query for SRV records describing cluster endpoints"}, cli.StringFlag{Name: "discovery-srv, D", Usage: "domain name to query for SRV records describing cluster endpoints"},
cli.BoolFlag{Name: "insecure-discovery", Usage: "accept insecure SRV records describing cluster endpoints"},
cli.StringFlag{Name: "peers, C", Value: "", Usage: "DEPRECATED - \"--endpoints\" should be used instead"}, cli.StringFlag{Name: "peers, C", Value: "", Usage: "DEPRECATED - \"--endpoints\" should be used instead"},
cli.StringFlag{Name: "endpoint", Value: "", Usage: "DEPRECATED - \"--endpoints\" should be used instead"}, cli.StringFlag{Name: "endpoint", Value: "", Usage: "DEPRECATED - \"--endpoints\" should be used instead"},
cli.StringFlag{Name: "endpoints", Value: "", Usage: "a comma-delimited list of machine addresses in the cluster (default: \"http://127.0.0.1:2379,http://127.0.0.1:4001\")"}, cli.StringFlag{Name: "endpoints", Value: "", Usage: "a comma-delimited list of machine addresses in the cluster (default: \"http://127.0.0.1:2379,http://127.0.0.1:4001\")"},

View File

@ -561,6 +561,9 @@ func getPeerURLsMapAndToken(cfg *config, which string) (urlsmap types.URLsMap, t
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
if strings.Contains(clusterStr, "https://") && cfg.peerTLSInfo.CAFile == "" {
cfg.peerTLSInfo.ServerName = cfg.DnsCluster
}
urlsmap, err = types.NewURLsMap(clusterStr) urlsmap, err = types.NewURLsMap(clusterStr)
// only etcd member must belong to the discovered cluster. // only etcd member must belong to the discovered cluster.
// proxy does not need to belong to the discovered cluster. // proxy does not need to belong to the discovered cluster.

View File

@ -20,14 +20,19 @@ import (
"os" "os"
"time" "time"
"github.com/coreos/etcd/client"
"github.com/coreos/etcd/pkg/transport"
"github.com/coreos/etcd/proxy/tcpproxy" "github.com/coreos/etcd/proxy/tcpproxy"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
var ( var (
gatewayListenAddr string gatewayListenAddr string
gatewayEndpoints []string gatewayEndpoints []string
getewayRetryDelay time.Duration gatewayDNSCluster string
gatewayInsecureDiscovery bool
getewayRetryDelay time.Duration
gatewayCA string
) )
var ( var (
@ -61,6 +66,10 @@ func newGatewayStartCommand() *cobra.Command {
} }
cmd.Flags().StringVar(&gatewayListenAddr, "listen-addr", "127.0.0.1:23790", "listen address") cmd.Flags().StringVar(&gatewayListenAddr, "listen-addr", "127.0.0.1:23790", "listen address")
cmd.Flags().StringVar(&gatewayDNSCluster, "discovery-srv", "", "DNS domain used to bootstrap initial cluster")
cmd.Flags().BoolVar(&gatewayInsecureDiscovery, "insecure-discovery", false, "accept insecure SRV records")
cmd.Flags().StringVar(&gatewayCA, "trusted-ca-file", "", "path to the client server TLS CA file.")
cmd.Flags().StringSliceVar(&gatewayEndpoints, "endpoints", []string{"127.0.0.1:2379"}, "comma separated etcd cluster endpoints") cmd.Flags().StringSliceVar(&gatewayEndpoints, "endpoints", []string{"127.0.0.1:2379"}, "comma separated etcd cluster endpoints")
cmd.Flags().DurationVar(&getewayRetryDelay, "retry-delay", time.Minute, "duration of delay before retrying failed endpoints") cmd.Flags().DurationVar(&getewayRetryDelay, "retry-delay", time.Minute, "duration of delay before retrying failed endpoints")
@ -68,6 +77,33 @@ func newGatewayStartCommand() *cobra.Command {
} }
func startGateway(cmd *cobra.Command, args []string) { func startGateway(cmd *cobra.Command, args []string) {
endpoints := gatewayEndpoints
if gatewayDNSCluster != "" {
eps, err := client.NewSRVDiscover().Discover(gatewayDNSCluster)
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
plog.Infof("discovered the cluster %s from %s", eps, gatewayDNSCluster)
// confirm TLS connections are good
if !gatewayInsecureDiscovery {
tlsInfo := transport.TLSInfo{
TrustedCAFile: gatewayCA,
ServerName: gatewayDNSCluster,
}
plog.Infof("validating discovered endpoints %v", eps)
endpoints, err = transport.ValidateSecureEndpoints(tlsInfo, eps)
if err != nil {
plog.Warningf("%v", err)
}
plog.Infof("using discovered endpoints %v", endpoints)
}
}
if len(endpoints) == 0 {
plog.Fatalf("no endpoints found")
}
l, err := net.Listen("tcp", gatewayListenAddr) l, err := net.Listen("tcp", gatewayListenAddr)
if err != nil { if err != nil {
fmt.Fprintln(os.Stderr, err) fmt.Fprintln(os.Stderr, err)

View File

@ -116,10 +116,11 @@ func hasKeyPrefixAccess(sec auth.Store, r *http.Request, key string, recursive,
} }
var user *auth.User var user *auth.User
if r.Header.Get("Authorization") == "" && clientCertAuthEnabled { if r.Header.Get("Authorization") == "" {
user = userFromClientCertificate(sec, r) if clientCertAuthEnabled {
user = userFromClientCertificate(sec, r)
}
if user == nil { if user == nil {
plog.Warningf("auth: no authorization provided, checking guest access")
return hasGuestAccess(sec, r, key) return hasGuestAccess(sec, r, key)
} }
} else { } else {

View File

@ -717,6 +717,36 @@ func TestPrefixAccess(t *testing.T) {
hasKeyPrefixAccess: false, hasKeyPrefixAccess: false,
hasRecursiveAccess: false, hasRecursiveAccess: false,
}, },
{ // guest access in non-TLS mode
key: "/foo",
req: (func() *http.Request {
return mustJSONRequest(t, "GET", "somepath", "")
})(),
store: &mockAuthStore{
enabled: true,
users: map[string]*auth.User{
"root": {
User: "root",
Password: goodPassword,
Roles: []string{"root"},
},
},
roles: map[string]*auth.Role{
"guest": {
Role: "guest",
Permissions: auth.Permissions{
KV: auth.RWPermission{
Read: []string{"/foo*"},
Write: []string{"/foo*"},
},
},
},
},
},
hasRoot: false,
hasKeyPrefixAccess: true,
hasRecursiveAccess: true,
},
} }
for i, tt := range table { for i, tt := range table {

View File

@ -752,6 +752,7 @@ func NewClusterV3(t *testing.T, cfg *ClusterConfig) *ClusterV3 {
clus := &ClusterV3{ clus := &ClusterV3{
cluster: NewClusterByConfig(t, cfg), cluster: NewClusterByConfig(t, cfg),
} }
clus.Launch(t)
for _, m := range clus.Members { for _, m := range clus.Members {
client, err := NewClientV3(m) client, err := NewClientV3(m)
if err != nil { if err != nil {
@ -759,7 +760,6 @@ func NewClusterV3(t *testing.T, cfg *ClusterConfig) *ClusterV3 {
} }
clus.clients = append(clus.clients, client) clus.clients = append(clus.clients, client)
} }
clus.Launch(t)
return clus return clus
} }

View File

@ -245,6 +245,9 @@ func testKVRangeLimit(t *testing.T, f rangeFunc) {
if r.Rev != wrev { if r.Rev != wrev {
t.Errorf("#%d: rev = %d, want %d", i, r.Rev, wrev) t.Errorf("#%d: rev = %d, want %d", i, r.Rev, wrev)
} }
if r.Count != len(kvs) {
t.Errorf("#%d: count = %d, want %d", i, r.Count, len(kvs))
}
} }
} }

View File

@ -497,7 +497,7 @@ func (s *store) rangeKeys(key, end []byte, limit, rangeRev int64, countOnly bool
break break
} }
} }
return kvs, len(kvs), curRev, nil return kvs, len(revpairs), curRev, nil
} }
func (s *store) put(key, value []byte, leaseID lease.LeaseID) { func (s *store) put(key, value []byte, leaseID lease.LeaseID) {

View File

@ -64,6 +64,9 @@ type TLSInfo struct {
TrustedCAFile string TrustedCAFile string
ClientCertAuth bool ClientCertAuth bool
// ServerName ensures the cert matches the given host in case of discovery / virtual hosting
ServerName string
selfCert bool selfCert bool
// parseFunc exists to simplify testing. Typically, parseFunc // parseFunc exists to simplify testing. Typically, parseFunc
@ -164,6 +167,7 @@ func (info TLSInfo) baseConfig() (*tls.Config, error) {
cfg := &tls.Config{ cfg := &tls.Config{
Certificates: []tls.Certificate{*tlsCert}, Certificates: []tls.Certificate{*tlsCert},
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
ServerName: info.ServerName,
} }
return cfg, nil return cfg, nil
} }
@ -215,7 +219,7 @@ func (info TLSInfo) ClientConfig() (*tls.Config, error) {
return nil, err return nil, err
} }
} else { } else {
cfg = &tls.Config{} cfg = &tls.Config{ServerName: info.ServerName}
} }
CAFiles := info.cafiles() CAFiles := info.cafiles()
@ -224,6 +228,8 @@ func (info TLSInfo) ClientConfig() (*tls.Config, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
// if given a CA, trust any host with a cert signed by the CA
cfg.ServerName = ""
} }
if info.selfCert { if info.selfCert {

49
pkg/transport/tls.go Normal file
View File

@ -0,0 +1,49 @@
// Copyright 2016 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package transport
import (
"fmt"
"strings"
"time"
)
// ValidateSecureEndpoints scans the given endpoints against tls info, returning only those
// endpoints that could be validated as secure.
func ValidateSecureEndpoints(tlsInfo TLSInfo, eps []string) ([]string, error) {
t, err := NewTransport(tlsInfo, 5*time.Second)
if err != nil {
return nil, err
}
var errs []string
var endpoints []string
for _, ep := range eps {
if !strings.HasPrefix(ep, "https://") {
errs = append(errs, fmt.Sprintf("%q is insecure", ep))
continue
}
conn, cerr := t.Dial("tcp", ep[len("https://"):])
if cerr != nil {
errs = append(errs, fmt.Sprintf("%q failed to dial (%v)", ep, cerr))
continue
}
conn.Close()
endpoints = append(endpoints, ep)
}
if len(errs) != 0 {
err = fmt.Errorf("%s", strings.Join(errs, ","))
}
return endpoints, err
}

View File

@ -29,7 +29,7 @@ import (
var ( var (
// MinClusterVersion is the min cluster version this etcd binary is compatible with. // MinClusterVersion is the min cluster version this etcd binary is compatible with.
MinClusterVersion = "2.3.0" MinClusterVersion = "2.3.0"
Version = "3.0.4" Version = "3.0.5"
// Git SHA Value will be set during build // Git SHA Value will be set during build
GitSHA = "Not provided (use ./build instead of go build)" GitSHA = "Not provided (use ./build instead of go build)"