clientv3: use Endpoints(), fix context creation

If overwritten, the previous context should be canceled first.

Signed-off-by: Gyuho Lee <leegyuho@amazon.com>
This commit is contained in:
Gyuho Lee
2019-08-12 18:29:47 -07:00
parent 49c6e87f74
commit 2a8d09b83b

View File

@ -274,8 +274,8 @@ func (c *Client) getToken(ctx context.Context) error {
var err error // return last error in a case of fail var err error // return last error in a case of fail
var auth *authenticator var auth *authenticator
for i := 0; i < len(c.cfg.Endpoints); i++ { eps := c.Endpoints()
ep := c.cfg.Endpoints[i] for _, ep := range eps {
// use dial options without dopts to avoid reusing the client balancer // use dial options without dopts to avoid reusing the client balancer
var dOpts []grpc.DialOption var dOpts []grpc.DialOption
_, host, _ := endpoint.ParseEndpoint(ep) _, host, _ := endpoint.ParseEndpoint(ep)
@ -519,13 +519,17 @@ func (c *Client) roundRobinQuorumBackoff(waitBetween time.Duration, jitterFracti
func (c *Client) checkVersion() (err error) { func (c *Client) checkVersion() (err error) {
var wg sync.WaitGroup var wg sync.WaitGroup
errc := make(chan error, len(c.cfg.Endpoints))
eps := c.Endpoints()
errc := make(chan error, len(eps))
ctx, cancel := context.WithCancel(c.ctx) ctx, cancel := context.WithCancel(c.ctx)
if c.cfg.DialTimeout > 0 { if c.cfg.DialTimeout > 0 {
ctx, cancel = context.WithTimeout(ctx, c.cfg.DialTimeout) cancel()
ctx, cancel = context.WithTimeout(c.ctx, c.cfg.DialTimeout)
} }
wg.Add(len(c.cfg.Endpoints))
for _, ep := range c.cfg.Endpoints { wg.Add(len(eps))
for _, ep := range eps {
// if cluster is current, any endpoint gives a recent version // if cluster is current, any endpoint gives a recent version
go func(e string) { go func(e string) {
defer wg.Done() defer wg.Done()
@ -537,8 +541,15 @@ func (c *Client) checkVersion() (err error) {
vs := strings.Split(resp.Version, ".") vs := strings.Split(resp.Version, ".")
maj, min := 0, 0 maj, min := 0, 0
if len(vs) >= 2 { if len(vs) >= 2 {
maj, _ = strconv.Atoi(vs[0]) var serr error
min, rerr = strconv.Atoi(vs[1]) if maj, serr = strconv.Atoi(vs[0]); serr != nil {
errc <- serr
return
}
if min, serr = strconv.Atoi(vs[1]); serr != nil {
errc <- serr
return
}
} }
if maj < 3 || (maj == 3 && min < 2) { if maj < 3 || (maj == 3 && min < 2) {
rerr = ErrOldCluster rerr = ErrOldCluster
@ -547,7 +558,7 @@ func (c *Client) checkVersion() (err error) {
}(ep) }(ep)
} }
// wait for success // wait for success
for i := 0; i < len(c.cfg.Endpoints); i++ { for range eps {
if err = <-errc; err == nil { if err = <-errc; err == nil {
break break
} }