diff --git a/derp/derp_server.go b/derp/derp_server.go index 35d452773..f9350ab1c 100644 --- a/derp/derp_server.go +++ b/derp/derp_server.go @@ -24,6 +24,7 @@ import ( "time" "golang.org/x/crypto/nacl/box" + "golang.org/x/sync/errgroup" "tailscale.com/metrics" "tailscale.com/types/key" "tailscale.com/types/logger" @@ -247,16 +248,35 @@ func (s *Server) accept(nc Conn, brw *bufio.ReadWriter, remoteAddr string) error } defer s.unregisterClient(c) - return c.run() + return c.run(ctx) } -func (c *sclient) run() error { - go c.sender() +// run serves the client until there's an error. +// If the client hangs up or the server is closed, run returns nil, otherwise run returns an error. +func (c *sclient) run(ctx context.Context) error { + // Launch sender, but don't return from run until sender goroutine is done. + var grp errgroup.Group + sendCtx, cancelSender := context.WithCancel(ctx) + grp.Go(func() error { return c.sendLoop(sendCtx) }) + defer func() { + cancelSender() + if err := grp.Wait(); err != nil && !c.s.isClosed() { + c.logf("sender failed: %v", err) + } + }() for { ft, fl, err := readFrameHeader(c.br) if err != nil { - return fmt.Errorf("client %x: readFrameHeader: %v", c.key, err) + if errors.Is(err, io.EOF) { + c.logf("read EOF") + return nil + } + if c.s.isClosed() { + c.logf("closing; server closed") + return nil + } + return fmt.Errorf("client %x: readFrameHeader: %w", c.key, err) } switch ft { case frameNotePreferred: @@ -518,17 +538,12 @@ func (c *sclient) setPreferred(v bool) { } } -func (c *sclient) sender() { - // If the sender shuts down unilaterally due to an error, close so - // that the receive loop unblocks and cleans up the rest. - defer c.nc.Close() - if err := c.sendLoop(); err != nil { - c.logf("sender failed: %v", err) - } -} - -func (c *sclient) sendLoop() error { +func (c *sclient) sendLoop(ctx context.Context) error { defer func() { + // If the sender shuts down unilaterally due to an error, close so + // that the receive loop unblocks and cleans up the rest. + c.nc.Close() + // Drain the send queue to count dropped packets for { select { @@ -560,7 +575,7 @@ func (c *sclient) sendLoop() error { // First, a non-blocking select (with a default) that // does as many non-flushing writes as possible. select { - case <-c.done: + case <-ctx.Done(): return nil case msg := <-c.sendQueue: werr = c.sendPacket(msg.src, msg.bs) @@ -578,7 +593,7 @@ func (c *sclient) sendLoop() error { // Then a blocking select with same: select { - case <-c.done: + case <-ctx.Done(): return nil case msg := <-c.sendQueue: werr = c.sendPacket(msg.src, msg.bs)