diff --git a/wgengine/netstack/link_endpoint.go b/wgengine/netstack/link_endpoint.go index 485d829a3..39da64b55 100644 --- a/wgengine/netstack/link_endpoint.go +++ b/wgengine/netstack/link_endpoint.go @@ -16,19 +16,27 @@ ) type queue struct { - // TODO(jwhited): evaluate performance with mu as Mutex and/or alternative - // non-channel buffer. - c chan *stack.PacketBuffer - mu sync.RWMutex // mu guards closed + // TODO(jwhited): evaluate performance with a non-channel buffer. + c chan *stack.PacketBuffer + + closeOnce sync.Once + closedCh chan struct{} + + mu sync.RWMutex closed bool } func (q *queue) Close() { + q.closeOnce.Do(func() { + close(q.closedCh) + }) + q.mu.Lock() defer q.mu.Unlock() - if !q.closed { - close(q.c) + if q.closed { + return } + close(q.c) q.closed = true } @@ -51,26 +59,27 @@ func (q *queue) ReadContext(ctx context.Context) *stack.PacketBuffer { } func (q *queue) Write(pkt *stack.PacketBuffer) tcpip.Error { - // q holds the PacketBuffer. q.mu.RLock() defer q.mu.RUnlock() if q.closed { return &tcpip.ErrClosedForSend{} } - - wrote := false select { case q.c <- pkt.IncRef(): - wrote = true - default: - // TODO(jwhited): reconsider/count - pkt.DecRef() - } - - if wrote { return nil + case <-q.closedCh: + pkt.DecRef() + return &tcpip.ErrClosedForSend{} } - return &tcpip.ErrNoBufferSpace{} +} + +func (q *queue) Drain() int { + c := 0 + for pkt := range q.c { + pkt.DecRef() + c++ + } + return c } func (q *queue) Num() int { @@ -107,7 +116,8 @@ func newLinkEndpoint(size int, mtu uint32, linkAddr tcpip.LinkAddress, supported le := &linkEndpoint{ supportedGRO: supportedGRO, q: &queue{ - c: make(chan *stack.PacketBuffer, size), + c: make(chan *stack.PacketBuffer, size), + closedCh: make(chan struct{}), }, mtu: mtu, linkAddr: linkAddr, @@ -164,12 +174,7 @@ func (l *linkEndpoint) ReadContext(ctx context.Context) *stack.PacketBuffer { // Drain removes all outbound packets from the channel and counts them. func (l *linkEndpoint) Drain() int { - c := 0 - for pkt := l.Read(); pkt != nil; pkt = l.Read() { - pkt.DecRef() - c++ - } - return c + return l.q.Drain() } // NumQueued returns the number of packets queued for outbound.