diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go index 435d3342f..356db72fd 100644 --- a/control/controlclient/direct.go +++ b/control/controlclient/direct.go @@ -13,6 +13,7 @@ import ( "encoding/binary" "encoding/json" "errors" + "flag" "fmt" "io" "io/ioutil" @@ -114,9 +115,10 @@ type Direct struct { tryingNewKey wgcfg.PrivateKey expiry *time.Time // hostinfo is mutated in-place while mu is held. - hostinfo *tailcfg.Hostinfo // always non-nil - endpoints []string - localPort uint16 // or zero to mean auto + hostinfo *tailcfg.Hostinfo // always non-nil + endpoints []string + everEndpoints bool // whether we've ever had non-empty endpoints + localPort uint16 // or zero to mean auto } type Options struct { @@ -476,6 +478,9 @@ func (c *Direct) newEndpoints(localPort uint16, endpoints []string) (changed boo c.logf("client.newEndpoints(%v, %v)", localPort, endpoints) c.localPort = localPort c.endpoints = append(c.endpoints[:0], endpoints...) + if len(endpoints) > 0 { + c.everEndpoints = true + } return true // changed } @@ -488,6 +493,13 @@ func (c *Direct) SetEndpoints(localPort uint16, endpoints []string) (changed boo return c.newEndpoints(localPort, endpoints) } +func inTest() bool { return flag.Lookup("test.v") != nil } + +// PollNetMap makes a /map request to download the network map, calling cb with +// each new netmap. +// +// maxPolls is how many network maps to download; common values are 1 +// or -1 (to keep a long-poll query open to the server). func (c *Direct) PollNetMap(ctx context.Context, maxPolls int, cb func(*NetworkMap)) error { c.mu.Lock() persist := c.persist @@ -497,6 +509,7 @@ func (c *Direct) PollNetMap(ctx context.Context, maxPolls int, cb func(*NetworkM backendLogID := hostinfo.BackendLogID localPort := c.localPort ep := append([]string(nil), c.endpoints...) + everEndpoints := c.everEndpoints c.mu.Unlock() if backendLogID == "" { @@ -504,7 +517,7 @@ func (c *Direct) PollNetMap(ctx context.Context, maxPolls int, cb func(*NetworkM } allowStream := maxPolls != 1 - c.logf("PollNetMap: stream=%v :%v %v", maxPolls, localPort, ep) + c.logf("PollNetMap: stream=%v :%v ep=%v", allowStream, localPort, ep) vlogf := logger.Discard if Debug.NetMap { @@ -525,6 +538,17 @@ func (c *Direct) PollNetMap(ctx context.Context, maxPolls int, cb func(*NetworkM if c.newDecompressor != nil { request.Compress = "zstd" } + // On initial startup before we know our endpoints, set the ReadOnly flag + // to tell the control server not to distribute out our (empty) endpoints to peers. + // Presumably we'll learn our endpoints in a half second and do another post + // with useful results. The first POST just gets us the DERP map which we + // need to do the STUN queries to discover our endpoints. + // TODO(bradfitz): we skip this optimization in tests, though, + // because the e2e tests are currently hyperspecific about the + // ordering of things. The e2e tests need love. + if len(ep) == 0 && !everEndpoints && !inTest() { + request.ReadOnly = true + } bodyData, err := encode(request, &serverKey, &c.machinePrivKey) if err != nil { @@ -532,16 +556,17 @@ func (c *Direct) PollNetMap(ctx context.Context, maxPolls int, cb func(*NetworkM return err } + ctx, cancel := context.WithCancel(ctx) + defer cancel() + machinePubKey := tailcfg.MachineKey(c.machinePrivKey.Public()) t0 := time.Now() u := fmt.Sprintf("%s/machine/%s/map", serverURL, machinePubKey.HexString()) - req, err := http.NewRequest("POST", u, bytes.NewReader(bodyData)) + + req, err := http.NewRequestWithContext(ctx, "POST", u, bytes.NewReader(bodyData)) if err != nil { return err } - ctx, cancel := context.WithCancel(ctx) - defer cancel() - req = req.WithContext(ctx) res, err := c.httpc.Do(req) if err != nil {