net/nettest: new package with net-like testing primitives
This is a lot like wiring up a local UDP socket, read and write deadlines work. The big difference is the Block feature, which lets you stop the packet flow without breaking the connection. This lets you emulate broken sockets and test timeouts actually work. Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
This commit is contained in:

committed by
David Crawshaw

parent
52c0cb12fb
commit
41ac4a79d6
261
net/nettest/pipe.go
Normal file
261
net/nettest/pipe.go
Normal file
@ -0,0 +1,261 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package nettest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const debugPipe = false
|
||||
|
||||
// Pipe implements an in-memory FIFO with timeouts.
|
||||
type Pipe struct {
|
||||
name string
|
||||
maxBuf int
|
||||
rCh chan struct{}
|
||||
wCh chan struct{}
|
||||
|
||||
mu sync.Mutex
|
||||
closed bool
|
||||
blocked bool
|
||||
buf []byte
|
||||
readTimeout time.Time
|
||||
writeTimeout time.Time
|
||||
cancelReadTimer func()
|
||||
cancelWriteTimer func()
|
||||
}
|
||||
|
||||
// NewPipe creates a Pipe with a buffer size fixed at maxBuf.
|
||||
func NewPipe(name string, maxBuf int) *Pipe {
|
||||
return &Pipe{
|
||||
name: name,
|
||||
maxBuf: maxBuf,
|
||||
rCh: make(chan struct{}, 1),
|
||||
wCh: make(chan struct{}, 1),
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
ErrTimeout = errors.New("timeout")
|
||||
ErrReadTimeout = fmt.Errorf("read %w", ErrTimeout)
|
||||
ErrWriteTimeout = fmt.Errorf("write %w", ErrTimeout)
|
||||
)
|
||||
|
||||
// Read implements io.Reader.
|
||||
func (p *Pipe) Read(b []byte) (n int, err error) {
|
||||
if debugPipe {
|
||||
orig := b
|
||||
defer func() {
|
||||
log.Printf("Pipe(%q).Read( %q) n=%d, err=%v", p.name, string(orig[:n]), n, err)
|
||||
}()
|
||||
}
|
||||
for {
|
||||
p.mu.Lock()
|
||||
closed := p.closed
|
||||
timedout := !p.readTimeout.IsZero() && time.Now().After(p.readTimeout)
|
||||
blocked := p.blocked
|
||||
if !closed && !timedout && len(p.buf) > 0 {
|
||||
n2 := copy(b, p.buf)
|
||||
p.buf = p.buf[n2:]
|
||||
b = b[n2:]
|
||||
n += n2
|
||||
}
|
||||
p.mu.Unlock()
|
||||
|
||||
if closed {
|
||||
return 0, fmt.Errorf("nettest.Pipe(%q): closed: %w", p.name, io.EOF)
|
||||
}
|
||||
if timedout {
|
||||
return 0, fmt.Errorf("nettest.Pipe(%q): %w", p.name, ErrReadTimeout)
|
||||
}
|
||||
if blocked {
|
||||
<-p.rCh
|
||||
continue
|
||||
}
|
||||
if n > 0 {
|
||||
p.signalWrite()
|
||||
return n, nil
|
||||
}
|
||||
<-p.rCh
|
||||
}
|
||||
}
|
||||
|
||||
// Write implements io.Writer.
|
||||
func (p *Pipe) Write(b []byte) (n int, err error) {
|
||||
if debugPipe {
|
||||
orig := b
|
||||
defer func() {
|
||||
log.Printf("Pipe(%q).Write(%q) n=%d, err=%v", p.name, string(orig), n, err)
|
||||
}()
|
||||
}
|
||||
for {
|
||||
p.mu.Lock()
|
||||
closed := p.closed
|
||||
timedout := !p.writeTimeout.IsZero() && time.Now().After(p.writeTimeout)
|
||||
blocked := p.blocked
|
||||
if !closed && !timedout {
|
||||
n2 := len(b)
|
||||
if limit := p.maxBuf - len(p.buf); limit < n2 {
|
||||
n2 = limit
|
||||
}
|
||||
p.buf = append(p.buf, b[:n2]...)
|
||||
b = b[n2:]
|
||||
n += n2
|
||||
}
|
||||
p.mu.Unlock()
|
||||
|
||||
if closed {
|
||||
return n, fmt.Errorf("nettest.Pipe(%q): closed: %w", p.name, io.EOF)
|
||||
}
|
||||
if timedout {
|
||||
return n, fmt.Errorf("nettest.Pipe(%q): %w", p.name, ErrWriteTimeout)
|
||||
}
|
||||
if blocked {
|
||||
<-p.wCh
|
||||
continue
|
||||
}
|
||||
if n > 0 {
|
||||
p.signalRead()
|
||||
}
|
||||
if len(b) == 0 {
|
||||
return n, nil
|
||||
}
|
||||
<-p.wCh
|
||||
}
|
||||
}
|
||||
|
||||
// Close implements io.Closer.
|
||||
func (p *Pipe) Close() error {
|
||||
p.mu.Lock()
|
||||
closed := p.closed
|
||||
p.closed = true
|
||||
if p.cancelWriteTimer != nil {
|
||||
p.cancelWriteTimer()
|
||||
p.cancelWriteTimer = nil
|
||||
}
|
||||
if p.cancelReadTimer != nil {
|
||||
p.cancelReadTimer()
|
||||
p.cancelReadTimer = nil
|
||||
}
|
||||
p.mu.Unlock()
|
||||
|
||||
if closed {
|
||||
return fmt.Errorf("nettest.Pipe(%q).Close: already closed", p.name)
|
||||
}
|
||||
|
||||
p.signalRead()
|
||||
p.signalWrite()
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetReadDeadline sets the deadline for future Read calls.
|
||||
func (p *Pipe) SetReadDeadline(t time.Time) error {
|
||||
p.mu.Lock()
|
||||
p.readTimeout = t
|
||||
if p.cancelReadTimer != nil {
|
||||
p.cancelReadTimer()
|
||||
p.cancelReadTimer = nil
|
||||
}
|
||||
if d := time.Until(t); !t.IsZero() && d > 0 {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
p.cancelReadTimer = cancel
|
||||
go func() {
|
||||
t := time.NewTimer(d)
|
||||
defer t.Stop()
|
||||
select {
|
||||
case <-t.C:
|
||||
p.signalRead()
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}()
|
||||
}
|
||||
p.mu.Unlock()
|
||||
|
||||
p.signalRead()
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetWriteDeadline sets the deadline for future Write calls.
|
||||
func (p *Pipe) SetWriteDeadline(t time.Time) error {
|
||||
p.mu.Lock()
|
||||
p.writeTimeout = t
|
||||
if p.cancelWriteTimer != nil {
|
||||
p.cancelWriteTimer()
|
||||
p.cancelWriteTimer = nil
|
||||
}
|
||||
if d := time.Until(t); !t.IsZero() && d > 0 {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
p.cancelWriteTimer = cancel
|
||||
go func() {
|
||||
t := time.NewTimer(d)
|
||||
defer t.Stop()
|
||||
select {
|
||||
case <-t.C:
|
||||
p.signalWrite()
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}()
|
||||
}
|
||||
p.mu.Unlock()
|
||||
|
||||
p.signalWrite()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Pipe) Block() error {
|
||||
p.mu.Lock()
|
||||
closed := p.closed
|
||||
blocked := p.blocked
|
||||
p.blocked = true
|
||||
p.mu.Unlock()
|
||||
|
||||
if closed {
|
||||
return fmt.Errorf("nettest.Pipe(%q).Block: closed", p.name)
|
||||
}
|
||||
if blocked {
|
||||
return fmt.Errorf("nettest.Pipe(%q).Block: already blocked", p.name)
|
||||
}
|
||||
p.signalRead()
|
||||
p.signalWrite()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Pipe) Unblock() error {
|
||||
p.mu.Lock()
|
||||
closed := p.closed
|
||||
blocked := p.blocked
|
||||
p.blocked = false
|
||||
p.mu.Unlock()
|
||||
|
||||
if closed {
|
||||
return fmt.Errorf("nettest.Pipe(%q).Block: closed", p.name)
|
||||
}
|
||||
if !blocked {
|
||||
return fmt.Errorf("nettest.Pipe(%q).Block: already unblocked", p.name)
|
||||
}
|
||||
p.signalRead()
|
||||
p.signalWrite()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Pipe) signalRead() {
|
||||
select {
|
||||
case p.rCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Pipe) signalWrite() {
|
||||
select {
|
||||
case p.wCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user