derp: use new node key type.

Update #3206

Signed-off-by: David Anderson <danderson@tailscale.com>
This commit is contained in:
David Anderson
2021-10-28 15:42:50 -07:00
parent 15376f975b
commit 37c150aee1
16 changed files with 199 additions and 269 deletions

View File

@ -8,7 +8,6 @@ import (
"bufio"
"bytes"
"context"
crand "crypto/rand"
"crypto/x509"
"encoding/json"
"errors"
@ -23,20 +22,13 @@ import (
"testing"
"time"
"go4.org/mem"
"golang.org/x/time/rate"
"tailscale.com/net/nettest"
"tailscale.com/types/key"
"tailscale.com/types/logger"
)
func newPrivateKey(tb testing.TB) (k key.Private) {
tb.Helper()
if _, err := crand.Read(k[:]); err != nil {
tb.Fatal(err)
}
return
}
func TestClientInfoUnmarshal(t *testing.T) {
for i, in := range []string{
`{"Version":5,"MeshKey":"abc"}`,
@ -54,15 +46,15 @@ func TestClientInfoUnmarshal(t *testing.T) {
}
func TestSendRecv(t *testing.T) {
serverPrivateKey := newPrivateKey(t)
serverPrivateKey := key.NewNode()
s := NewServer(serverPrivateKey, t.Logf)
defer s.Close()
const numClients = 3
var clientPrivateKeys []key.Private
var clientKeys []key.Public
var clientPrivateKeys []key.NodePrivate
var clientKeys []key.NodePublic
for i := 0; i < numClients; i++ {
priv := newPrivateKey(t)
priv := key.NewNode()
clientPrivateKeys = append(clientPrivateKeys, priv)
clientKeys = append(clientKeys, priv.Public())
}
@ -225,7 +217,7 @@ func TestSendRecv(t *testing.T) {
}
func TestSendFreeze(t *testing.T) {
serverPrivateKey := newPrivateKey(t)
serverPrivateKey := key.NewNode()
s := NewServer(serverPrivateKey, t.Logf)
defer s.Close()
s.WriteTimeout = 100 * time.Millisecond
@ -238,7 +230,7 @@ func TestSendFreeze(t *testing.T) {
// Then cathy stops processing messsages.
// That should not interfere with alice talking to bob.
newClient := func(name string, k key.Private) (c *Client, clientConn nettest.Conn) {
newClient := func(name string, k key.NodePrivate) (c *Client, clientConn nettest.Conn) {
t.Helper()
c1, c2 := nettest.NewConn(name, 1024)
go s.Accept(c1, bufio.NewReadWriter(bufio.NewReader(c1), bufio.NewWriter(c1)), name)
@ -252,13 +244,13 @@ func TestSendFreeze(t *testing.T) {
return c, c2
}
aliceKey := newPrivateKey(t)
aliceKey := key.NewNode()
aliceClient, aliceConn := newClient("alice", aliceKey)
bobKey := newPrivateKey(t)
bobKey := key.NewNode()
bobClient, bobConn := newClient("bob", bobKey)
cathyKey := newPrivateKey(t)
cathyKey := key.NewNode()
cathyClient, cathyConn := newClient("cathy", cathyKey)
var (
@ -427,7 +419,7 @@ type testServer struct {
logf logger.Logf
mu sync.Mutex
pubName map[key.Public]string
pubName map[key.NodePublic]string
clients map[*testClient]bool
}
@ -437,14 +429,14 @@ func (ts *testServer) addTestClient(c *testClient) {
ts.clients[c] = true
}
func (ts *testServer) addKeyName(k key.Public, name string) {
func (ts *testServer) addKeyName(k key.NodePublic, name string) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.pubName[k] = name
ts.logf("test adding named key %q for %x", name, k)
}
func (ts *testServer) keyName(k key.Public) string {
func (ts *testServer) keyName(k key.NodePublic) string {
ts.mu.Lock()
defer ts.mu.Unlock()
if name, ok := ts.pubName[k]; ok {
@ -465,7 +457,7 @@ func (ts *testServer) close(t *testing.T) error {
func newTestServer(t *testing.T) *testServer {
t.Helper()
logf := logger.WithPrefix(t.Logf, "derp-server: ")
s := NewServer(newPrivateKey(t), logf)
s := NewServer(key.NewNode(), logf)
s.SetMeshKey("mesh-key")
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
@ -491,7 +483,7 @@ func newTestServer(t *testing.T) *testServer {
ln: ln,
logf: logf,
clients: map[*testClient]bool{},
pubName: map[key.Public]string{},
pubName: map[key.NodePublic]string{},
}
}
@ -499,18 +491,18 @@ type testClient struct {
name string
c *Client
nc net.Conn
pub key.Public
pub key.NodePublic
ts *testServer
closed bool
}
func newTestClient(t *testing.T, ts *testServer, name string, newClient func(net.Conn, key.Private, logger.Logf) (*Client, error)) *testClient {
func newTestClient(t *testing.T, ts *testServer, name string, newClient func(net.Conn, key.NodePrivate, logger.Logf) (*Client, error)) *testClient {
t.Helper()
nc, err := net.Dial("tcp", ts.ln.Addr().String())
if err != nil {
t.Fatal(err)
}
key := newPrivateKey(t)
key := key.NewNode()
ts.addKeyName(key.Public(), name)
c, err := newClient(nc, key, logger.WithPrefix(t.Logf, "client-"+name+": "))
if err != nil {
@ -528,7 +520,7 @@ func newTestClient(t *testing.T, ts *testServer, name string, newClient func(net
}
func newRegularClient(t *testing.T, ts *testServer, name string) *testClient {
return newTestClient(t, ts, name, func(nc net.Conn, priv key.Private, logf logger.Logf) (*Client, error) {
return newTestClient(t, ts, name, func(nc net.Conn, priv key.NodePrivate, logf logger.Logf) (*Client, error) {
brw := bufio.NewReadWriter(bufio.NewReader(nc), bufio.NewWriter(nc))
c, err := NewClient(priv, nc, brw, logf)
if err != nil {
@ -541,7 +533,7 @@ func newRegularClient(t *testing.T, ts *testServer, name string) *testClient {
}
func newTestWatcher(t *testing.T, ts *testServer, name string) *testClient {
return newTestClient(t, ts, name, func(nc net.Conn, priv key.Private, logf logger.Logf) (*Client, error) {
return newTestClient(t, ts, name, func(nc net.Conn, priv key.NodePrivate, logf logger.Logf) (*Client, error) {
brw := bufio.NewReadWriter(bufio.NewReader(nc), bufio.NewWriter(nc))
c, err := NewClient(priv, nc, brw, logf, MeshKey("mesh-key"))
if err != nil {
@ -555,9 +547,9 @@ func newTestWatcher(t *testing.T, ts *testServer, name string) *testClient {
})
}
func (tc *testClient) wantPresent(t *testing.T, peers ...key.Public) {
func (tc *testClient) wantPresent(t *testing.T, peers ...key.NodePublic) {
t.Helper()
want := map[key.Public]bool{}
want := map[key.NodePublic]bool{}
for _, k := range peers {
want[k] = true
}
@ -569,7 +561,7 @@ func (tc *testClient) wantPresent(t *testing.T, peers ...key.Public) {
}
switch m := m.(type) {
case PeerPresentMessage:
got := key.Public(m)
got := key.NodePublic(m)
if !want[got] {
t.Fatalf("got peer present for %v; want present for %v", tc.ts.keyName(got), logger.ArgWriter(func(bw *bufio.Writer) {
for _, pub := range peers {
@ -587,7 +579,7 @@ func (tc *testClient) wantPresent(t *testing.T, peers ...key.Public) {
}
}
func (tc *testClient) wantGone(t *testing.T, peer key.Public) {
func (tc *testClient) wantGone(t *testing.T, peer key.NodePublic) {
t.Helper()
m, err := tc.c.recvTimeout(time.Second)
if err != nil {
@ -595,7 +587,7 @@ func (tc *testClient) wantGone(t *testing.T, peer key.Public) {
}
switch m := m.(type) {
case PeerGoneMessage:
got := key.Public(m)
got := key.NodePublic(m)
if peer != got {
t.Errorf("got gone message for %v; want gone for %v", tc.ts.keyName(got), tc.ts.keyName(peer))
}
@ -654,21 +646,24 @@ func TestWatch(t *testing.T) {
type testFwd int
func (testFwd) ForwardPacket(key.Public, key.Public, []byte) error { panic("not called in tests") }
func (testFwd) ForwardPacket(key.NodePublic, key.NodePublic, []byte) error {
panic("not called in tests")
}
func pubAll(b byte) (ret key.Public) {
for i := range ret {
ret[i] = b
func pubAll(b byte) (ret key.NodePublic) {
var bs [32]byte
for i := range bs {
bs[i] = b
}
return
return key.NodePublicFromRaw32(mem.B(bs[:]))
}
func TestForwarderRegistration(t *testing.T) {
s := &Server{
clients: make(map[key.Public]clientSet),
clientsMesh: map[key.Public]PacketForwarder{},
clients: make(map[key.NodePublic]clientSet),
clientsMesh: map[key.NodePublic]PacketForwarder{},
}
want := func(want map[key.Public]PacketForwarder) {
want := func(want map[key.NodePublic]PacketForwarder) {
t.Helper()
if got := s.clientsMesh; !reflect.DeepEqual(got, want) {
t.Fatalf("mismatch\n got: %v\nwant: %v\n", got, want)
@ -687,28 +682,28 @@ func TestForwarderRegistration(t *testing.T) {
s.AddPacketForwarder(u1, testFwd(1))
s.AddPacketForwarder(u2, testFwd(2))
want(map[key.Public]PacketForwarder{
want(map[key.NodePublic]PacketForwarder{
u1: testFwd(1),
u2: testFwd(2),
})
// Verify a remove of non-registered forwarder is no-op.
s.RemovePacketForwarder(u2, testFwd(999))
want(map[key.Public]PacketForwarder{
want(map[key.NodePublic]PacketForwarder{
u1: testFwd(1),
u2: testFwd(2),
})
// Verify a remove of non-registered user is no-op.
s.RemovePacketForwarder(u3, testFwd(1))
want(map[key.Public]PacketForwarder{
want(map[key.NodePublic]PacketForwarder{
u1: testFwd(1),
u2: testFwd(2),
})
// Actual removal.
s.RemovePacketForwarder(u2, testFwd(2))
want(map[key.Public]PacketForwarder{
want(map[key.NodePublic]PacketForwarder{
u1: testFwd(1),
})
@ -716,7 +711,7 @@ func TestForwarderRegistration(t *testing.T) {
wantCounter(&s.multiForwarderCreated, 0)
s.AddPacketForwarder(u1, testFwd(100))
s.AddPacketForwarder(u1, testFwd(100)) // dup to trigger dup path
want(map[key.Public]PacketForwarder{
want(map[key.NodePublic]PacketForwarder{
u1: multiForwarder{
testFwd(1): 1,
testFwd(100): 2,
@ -726,7 +721,7 @@ func TestForwarderRegistration(t *testing.T) {
// Removing a forwarder in a multi set that doesn't exist; does nothing.
s.RemovePacketForwarder(u1, testFwd(55))
want(map[key.Public]PacketForwarder{
want(map[key.NodePublic]PacketForwarder{
u1: multiForwarder{
testFwd(1): 1,
testFwd(100): 2,
@ -737,7 +732,7 @@ func TestForwarderRegistration(t *testing.T) {
// from being a multiForwarder.
wantCounter(&s.multiForwarderDeleted, 0)
s.RemovePacketForwarder(u1, testFwd(1))
want(map[key.Public]PacketForwarder{
want(map[key.NodePublic]PacketForwarder{
u1: testFwd(100),
})
wantCounter(&s.multiForwarderDeleted, 1)
@ -750,18 +745,18 @@ func TestForwarderRegistration(t *testing.T) {
}
s.clients[u1] = singleClient{u1c}
s.RemovePacketForwarder(u1, testFwd(100))
want(map[key.Public]PacketForwarder{
want(map[key.NodePublic]PacketForwarder{
u1: nil,
})
// But once that client disconnects, it should go away.
s.unregisterClient(u1c)
want(map[key.Public]PacketForwarder{})
want(map[key.NodePublic]PacketForwarder{})
// But if it already has a forwarder, it's not removed.
s.AddPacketForwarder(u1, testFwd(2))
s.unregisterClient(u1c)
want(map[key.Public]PacketForwarder{
want(map[key.NodePublic]PacketForwarder{
u1: testFwd(2),
})
@ -770,17 +765,17 @@ func TestForwarderRegistration(t *testing.T) {
// from nil to the new one, not a multiForwarder.
s.clients[u1] = singleClient{u1c}
s.clientsMesh[u1] = nil
want(map[key.Public]PacketForwarder{
want(map[key.NodePublic]PacketForwarder{
u1: nil,
})
s.AddPacketForwarder(u1, testFwd(3))
want(map[key.Public]PacketForwarder{
want(map[key.NodePublic]PacketForwarder{
u1: testFwd(3),
})
}
func TestMetaCert(t *testing.T) {
priv := newPrivateKey(t)
priv := key.NewNode()
pub := priv.Public()
s := NewServer(priv, t.Logf)
@ -792,7 +787,7 @@ func TestMetaCert(t *testing.T) {
if fmt.Sprint(cert.SerialNumber) != fmt.Sprint(ProtocolVersion) {
t.Errorf("serial = %v; want %v", cert.SerialNumber, ProtocolVersion)
}
if g, w := cert.Subject.CommonName, fmt.Sprintf("derpkey%x", pub[:]); g != w {
if g, w := cert.Subject.CommonName, fmt.Sprintf("derpkey%s", pub.UntypedHexString()); g != w {
t.Errorf("CommonName = %q; want %q", g, w)
}
}
@ -882,10 +877,10 @@ func TestClientSendPong(t *testing.T) {
}
func TestServerDupClients(t *testing.T) {
serverPriv := newPrivateKey(t)
serverPriv := key.NewNode()
var s *Server
clientPriv := newPrivateKey(t)
clientPriv := key.NewNode()
clientPub := clientPriv.Public()
var c1, c2, c3 *sclient
@ -1141,11 +1136,11 @@ func BenchmarkSendRecv(b *testing.B) {
}
func benchmarkSendRecvSize(b *testing.B, packetSize int) {
serverPrivateKey := newPrivateKey(b)
serverPrivateKey := key.NewNode()
s := NewServer(serverPrivateKey, logger.Discard)
defer s.Close()
key := newPrivateKey(b)
key := key.NewNode()
clientKey := key.Public()
ln, err := net.Listen("tcp", "127.0.0.1:0")
@ -1279,7 +1274,7 @@ func TestClientSendRateLimiting(t *testing.T) {
c.setSendRateLimiter(ServerInfoMessage{})
pkt := make([]byte, 1000)
if err := c.send(key.Public{}, pkt); err != nil {
if err := c.send(key.NodePublic{}, pkt); err != nil {
t.Fatal(err)
}
writes1, bytes1 := cw.Stats()
@ -1290,7 +1285,7 @@ func TestClientSendRateLimiting(t *testing.T) {
// Flood should all succeed.
cw.ResetStats()
for i := 0; i < 1000; i++ {
if err := c.send(key.Public{}, pkt); err != nil {
if err := c.send(key.NodePublic{}, pkt); err != nil {
t.Fatal(err)
}
}
@ -1309,7 +1304,7 @@ func TestClientSendRateLimiting(t *testing.T) {
TokenBucketBytesBurst: int(bytes1 * 2),
})
for i := 0; i < 1000; i++ {
if err := c.send(key.Public{}, pkt); err != nil {
if err := c.send(key.NodePublic{}, pkt); err != nil {
t.Fatal(err)
}
}