net/tstun: merge in wgengine/tstun.
Signed-off-by: David Anderson <danderson@tailscale.com>
This commit is contained in:

committed by
Dave Anderson

parent
018200aeba
commit
588b70f468
387
net/tstun/wrap_test.go
Normal file
387
net/tstun/wrap_test.go
Normal file
@ -0,0 +1,387 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package tstun
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"unsafe"
|
||||
|
||||
"github.com/tailscale/wireguard-go/tun/tuntest"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/net/packet"
|
||||
"tailscale.com/types/ipproto"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/wgengine/filter"
|
||||
)
|
||||
|
||||
func udp4(src, dst string, sport, dport uint16) []byte {
|
||||
sip, err := netaddr.ParseIP(src)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
dip, err := netaddr.ParseIP(dst)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
header := &packet.UDP4Header{
|
||||
IP4Header: packet.IP4Header{
|
||||
Src: sip,
|
||||
Dst: dip,
|
||||
IPID: 0,
|
||||
},
|
||||
SrcPort: sport,
|
||||
DstPort: dport,
|
||||
}
|
||||
return packet.Generate(header, []byte("udp_payload"))
|
||||
}
|
||||
|
||||
func nets(nets ...string) (ret []netaddr.IPPrefix) {
|
||||
for _, s := range nets {
|
||||
if i := strings.IndexByte(s, '/'); i == -1 {
|
||||
ip, err := netaddr.ParseIP(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
bits := uint8(32)
|
||||
if ip.Is6() {
|
||||
bits = 128
|
||||
}
|
||||
ret = append(ret, netaddr.IPPrefix{IP: ip, Bits: bits})
|
||||
} else {
|
||||
pfx, err := netaddr.ParseIPPrefix(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ret = append(ret, pfx)
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func ports(s string) filter.PortRange {
|
||||
if s == "*" {
|
||||
return filter.PortRange{First: 0, Last: 65535}
|
||||
}
|
||||
|
||||
var fs, ls string
|
||||
i := strings.IndexByte(s, '-')
|
||||
if i == -1 {
|
||||
fs = s
|
||||
ls = fs
|
||||
} else {
|
||||
fs = s[:i]
|
||||
ls = s[i+1:]
|
||||
}
|
||||
first, err := strconv.ParseInt(fs, 10, 16)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("invalid NetPortRange %q", s))
|
||||
}
|
||||
last, err := strconv.ParseInt(ls, 10, 16)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("invalid NetPortRange %q", s))
|
||||
}
|
||||
return filter.PortRange{First: uint16(first), Last: uint16(last)}
|
||||
}
|
||||
|
||||
func netports(netPorts ...string) (ret []filter.NetPortRange) {
|
||||
for _, s := range netPorts {
|
||||
i := strings.LastIndexByte(s, ':')
|
||||
if i == -1 {
|
||||
panic(fmt.Sprintf("invalid NetPortRange %q", s))
|
||||
}
|
||||
|
||||
npr := filter.NetPortRange{
|
||||
Net: nets(s[:i])[0],
|
||||
Ports: ports(s[i+1:]),
|
||||
}
|
||||
ret = append(ret, npr)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func setfilter(logf logger.Logf, tun *TUN) {
|
||||
protos := []ipproto.Proto{
|
||||
ipproto.TCP,
|
||||
ipproto.UDP,
|
||||
}
|
||||
matches := []filter.Match{
|
||||
{IPProto: protos, Srcs: nets("5.6.7.8"), Dsts: netports("1.2.3.4:89-90")},
|
||||
{IPProto: protos, Srcs: nets("1.2.3.4"), Dsts: netports("5.6.7.8:98")},
|
||||
}
|
||||
var sb netaddr.IPSetBuilder
|
||||
sb.AddPrefix(netaddr.MustParseIPPrefix("1.2.0.0/16"))
|
||||
tun.SetFilter(filter.New(matches, sb.IPSet(), sb.IPSet(), nil, logf))
|
||||
}
|
||||
|
||||
func newChannelTUN(logf logger.Logf, secure bool) (*tuntest.ChannelTUN, *TUN) {
|
||||
chtun := tuntest.NewChannelTUN()
|
||||
tun := WrapTUN(logf, chtun.TUN())
|
||||
if secure {
|
||||
setfilter(logf, tun)
|
||||
} else {
|
||||
tun.disableFilter = true
|
||||
}
|
||||
return chtun, tun
|
||||
}
|
||||
|
||||
func newFakeTUN(logf logger.Logf, secure bool) (*fakeTUN, *TUN) {
|
||||
ftun := NewFakeTUN()
|
||||
tun := WrapTUN(logf, ftun)
|
||||
if secure {
|
||||
setfilter(logf, tun)
|
||||
} else {
|
||||
tun.disableFilter = true
|
||||
}
|
||||
return ftun.(*fakeTUN), tun
|
||||
}
|
||||
|
||||
func TestReadAndInject(t *testing.T) {
|
||||
chtun, tun := newChannelTUN(t.Logf, false)
|
||||
defer tun.Close()
|
||||
|
||||
const size = 2 // all payloads have this size
|
||||
written := []string{"w0", "w1"}
|
||||
injected := []string{"i0", "i1"}
|
||||
|
||||
go func() {
|
||||
for _, packet := range written {
|
||||
payload := []byte(packet)
|
||||
chtun.Outbound <- payload
|
||||
}
|
||||
}()
|
||||
|
||||
for _, packet := range injected {
|
||||
go func(packet string) {
|
||||
payload := []byte(packet)
|
||||
err := tun.InjectOutbound(payload)
|
||||
if err != nil {
|
||||
t.Errorf("%s: error: %v", packet, err)
|
||||
}
|
||||
}(packet)
|
||||
}
|
||||
|
||||
var buf [MaxPacketSize]byte
|
||||
var seen = make(map[string]bool)
|
||||
// We expect the same packets back, in no particular order.
|
||||
for i := 0; i < len(written)+len(injected); i++ {
|
||||
n, err := tun.Read(buf[:], 0)
|
||||
if err != nil {
|
||||
t.Errorf("read %d: error: %v", i, err)
|
||||
}
|
||||
if n != size {
|
||||
t.Errorf("read %d: got size %d; want %d", i, n, size)
|
||||
}
|
||||
got := string(buf[:n])
|
||||
t.Logf("read %d: got %s", i, got)
|
||||
seen[got] = true
|
||||
}
|
||||
|
||||
for _, packet := range written {
|
||||
if !seen[packet] {
|
||||
t.Errorf("%s not received", packet)
|
||||
}
|
||||
}
|
||||
for _, packet := range injected {
|
||||
if !seen[packet] {
|
||||
t.Errorf("%s not received", packet)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteAndInject(t *testing.T) {
|
||||
chtun, tun := newChannelTUN(t.Logf, false)
|
||||
defer tun.Close()
|
||||
|
||||
const size = 2 // all payloads have this size
|
||||
written := []string{"w0", "w1"}
|
||||
injected := []string{"i0", "i1"}
|
||||
|
||||
go func() {
|
||||
for _, packet := range written {
|
||||
payload := []byte(packet)
|
||||
n, err := tun.Write(payload, 0)
|
||||
if err != nil {
|
||||
t.Errorf("%s: error: %v", packet, err)
|
||||
}
|
||||
if n != size {
|
||||
t.Errorf("%s: got size %d; want %d", packet, n, size)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for _, packet := range injected {
|
||||
go func(packet string) {
|
||||
payload := []byte(packet)
|
||||
err := tun.InjectInboundCopy(payload)
|
||||
if err != nil {
|
||||
t.Errorf("%s: error: %v", packet, err)
|
||||
}
|
||||
}(packet)
|
||||
}
|
||||
|
||||
seen := make(map[string]bool)
|
||||
// We expect the same packets back, in no particular order.
|
||||
for i := 0; i < len(written)+len(injected); i++ {
|
||||
packet := <-chtun.Inbound
|
||||
got := string(packet)
|
||||
t.Logf("read %d: got %s", i, got)
|
||||
seen[got] = true
|
||||
}
|
||||
|
||||
for _, packet := range written {
|
||||
if !seen[packet] {
|
||||
t.Errorf("%s not received", packet)
|
||||
}
|
||||
}
|
||||
for _, packet := range injected {
|
||||
if !seen[packet] {
|
||||
t.Errorf("%s not received", packet)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilter(t *testing.T) {
|
||||
chtun, tun := newChannelTUN(t.Logf, true)
|
||||
defer tun.Close()
|
||||
|
||||
type direction int
|
||||
|
||||
const (
|
||||
in direction = iota
|
||||
out
|
||||
)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
dir direction
|
||||
drop bool
|
||||
data []byte
|
||||
}{
|
||||
{"junk_in", in, true, []byte("\x45not a valid IPv4 packet")},
|
||||
{"junk_out", out, true, []byte("\x45not a valid IPv4 packet")},
|
||||
{"bad_port_in", in, true, udp4("5.6.7.8", "1.2.3.4", 22, 22)},
|
||||
{"bad_port_out", out, false, udp4("1.2.3.4", "5.6.7.8", 22, 22)},
|
||||
{"bad_ip_in", in, true, udp4("8.1.1.1", "1.2.3.4", 89, 89)},
|
||||
{"bad_ip_out", out, false, udp4("1.2.3.4", "8.1.1.1", 98, 98)},
|
||||
{"good_packet_in", in, false, udp4("5.6.7.8", "1.2.3.4", 89, 89)},
|
||||
{"good_packet_out", out, false, udp4("1.2.3.4", "5.6.7.8", 98, 98)},
|
||||
}
|
||||
|
||||
// A reader on the other end of the TUN.
|
||||
go func() {
|
||||
var recvbuf []byte
|
||||
for {
|
||||
select {
|
||||
case <-tun.closed:
|
||||
return
|
||||
case recvbuf = <-chtun.Inbound:
|
||||
// continue
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if tt.drop && bytes.Equal(recvbuf, tt.data) {
|
||||
t.Errorf("did not drop %s", tt.name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
var buf [MaxPacketSize]byte
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var n int
|
||||
var err error
|
||||
var filtered bool
|
||||
|
||||
if tt.dir == in {
|
||||
_, err = tun.Write(tt.data, 0)
|
||||
if err == ErrFiltered {
|
||||
filtered = true
|
||||
err = nil
|
||||
}
|
||||
} else {
|
||||
chtun.Outbound <- tt.data
|
||||
n, err = tun.Read(buf[:], 0)
|
||||
// In the read direction, errors are fatal, so we return n = 0 instead.
|
||||
filtered = (n == 0)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("got err %v; want nil", err)
|
||||
}
|
||||
|
||||
if filtered {
|
||||
if !tt.drop {
|
||||
t.Errorf("got drop; want accept")
|
||||
}
|
||||
} else {
|
||||
if tt.drop {
|
||||
t.Errorf("got accept; want drop")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllocs(t *testing.T) {
|
||||
ftun, tun := newFakeTUN(t.Logf, false)
|
||||
defer tun.Close()
|
||||
|
||||
buf := []byte{0x00}
|
||||
allocs := testing.AllocsPerRun(100, func() {
|
||||
_, err := ftun.Write(buf, 0)
|
||||
if err != nil {
|
||||
t.Errorf("write: error: %v", err)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
if allocs > 0 {
|
||||
t.Errorf("read allocs = %v; want 0", allocs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClose(t *testing.T) {
|
||||
ftun, tun := newFakeTUN(t.Logf, false)
|
||||
|
||||
data := udp4("1.2.3.4", "5.6.7.8", 98, 98)
|
||||
_, err := ftun.Write(data, 0)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
tun.Close()
|
||||
_, err = ftun.Write(data, 0)
|
||||
if err == nil {
|
||||
t.Error("Expected error from ftun.Write() after Close()")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkWrite(b *testing.B) {
|
||||
ftun, tun := newFakeTUN(b.Logf, true)
|
||||
defer tun.Close()
|
||||
|
||||
packet := udp4("5.6.7.8", "1.2.3.4", 89, 89)
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := ftun.Write(packet, 0)
|
||||
if err != nil {
|
||||
b.Errorf("err = %v; want nil", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAtomic64Alignment(t *testing.T) {
|
||||
off := unsafe.Offsetof(TUN{}.lastActivityAtomic)
|
||||
if off%8 != 0 {
|
||||
t.Errorf("offset %v not 8-byte aligned", off)
|
||||
}
|
||||
|
||||
c := new(TUN)
|
||||
atomic.StoreInt64(&c.lastActivityAtomic, 123)
|
||||
}
|
Reference in New Issue
Block a user