control/controlbase: rename from control/noise.

Updates #3488

Signed-off-by: David Anderson <danderson@tailscale.com>
This commit is contained in:
David Anderson
2021-12-02 18:04:48 -08:00
committed by Dave Anderson
parent 02461ea459
commit 6cd180746f
8 changed files with 8 additions and 8 deletions

359
control/controlbase/conn.go Normal file
View File

@ -0,0 +1,359 @@
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package noise implements the base transport of the Tailscale 2021
// control protocol.
//
// The base transport implements Noise IK, instantiated with
// Curve25519, ChaCha20Poly1305 and BLAKE2s.
package controlbase
import (
"crypto/cipher"
"encoding/binary"
"fmt"
"net"
"sync"
"time"
"golang.org/x/crypto/blake2s"
chp "golang.org/x/crypto/chacha20poly1305"
"tailscale.com/types/key"
)
const (
// maxMessageSize is the maximum size of a protocol frame on the
// wire, including header and payload.
maxMessageSize = 4096
// maxCiphertextSize is the maximum amount of ciphertext bytes
// that one protocol frame can carry, after framing.
maxCiphertextSize = maxMessageSize - 3
// maxPlaintextSize is the maximum amount of plaintext bytes that
// one protocol frame can carry, after encryption and framing.
maxPlaintextSize = maxCiphertextSize - chp.Overhead
)
// A Conn is a secured Noise connection. It implements the net.Conn
// interface, with the unusual trait that any write error (including a
// SetWriteDeadline induced i/o timeout) causes all future writes to
// fail.
type Conn struct {
conn net.Conn
version uint16
peer key.MachinePublic
handshakeHash [blake2s.Size]byte
rx rxState
tx txState
}
// rxState is all the Conn state that Read uses.
type rxState struct {
sync.Mutex
cipher cipher.AEAD
nonce nonce
buf [maxMessageSize]byte
n int // number of valid bytes in buf
next int // offset of next undecrypted packet
plaintext []byte // slice into buf of decrypted bytes
}
// txState is all the Conn state that Write uses.
type txState struct {
sync.Mutex
cipher cipher.AEAD
nonce nonce
buf [maxMessageSize]byte
err error // records the first partial write error for all future calls
}
// ProtocolVersion returns the protocol version that was used to
// establish this Conn.
func (c *Conn) ProtocolVersion() int {
return int(c.version)
}
// HandshakeHash returns the Noise handshake hash for the connection,
// which can be used to bind other messages to this connection
// (i.e. to ensure that the message wasn't replayed from a different
// connection).
func (c *Conn) HandshakeHash() [blake2s.Size]byte {
return c.handshakeHash
}
// Peer returns the peer's long-term public key.
func (c *Conn) Peer() key.MachinePublic {
return c.peer
}
// readNLocked reads into c.rx.buf until buf contains at least total
// bytes. Returns a slice of the total bytes in rxBuf, or an
// error if fewer than total bytes are available.
func (c *Conn) readNLocked(total int) ([]byte, error) {
if total > maxMessageSize {
return nil, errReadTooBig{total}
}
for {
if total <= c.rx.n {
return c.rx.buf[:total], nil
}
n, err := c.conn.Read(c.rx.buf[c.rx.n:])
c.rx.n += n
if err != nil {
return nil, err
}
}
}
// decryptLocked decrypts msg (which is header+ciphertext) in-place
// and sets c.rx.plaintext to the decrypted bytes.
func (c *Conn) decryptLocked(msg []byte) (err error) {
if msgType := msg[0]; msgType != msgTypeRecord {
return fmt.Errorf("received message with unexpected type %d, want %d", msgType, msgTypeRecord)
}
// We don't check the length field here, because the caller
// already did in order to figure out how big the msg slice should
// be.
ciphertext := msg[headerLen:]
if !c.rx.nonce.Valid() {
return errCipherExhausted{}
}
c.rx.plaintext, err = c.rx.cipher.Open(ciphertext[:0], c.rx.nonce[:], ciphertext, nil)
c.rx.nonce.Increment()
if err != nil {
// Once a decryption has failed, our Conn is no longer
// synchronized with our peer. Nuke the cipher state to be
// safe, so that no further decryptions are attempted. Future
// read attempts will return net.ErrClosed.
c.rx.cipher = nil
}
return err
}
// encryptLocked encrypts plaintext into c.tx.buf (including the
// packet header) and returns a slice of the ciphertext, or an error
// if the cipher is exhausted (i.e. can no longer be used safely).
func (c *Conn) encryptLocked(plaintext []byte) ([]byte, error) {
if !c.tx.nonce.Valid() {
// Received 2^64-1 messages on this cipher state. Connection
// is no longer usable.
return nil, errCipherExhausted{}
}
c.tx.buf[0] = msgTypeRecord
binary.BigEndian.PutUint16(c.tx.buf[1:headerLen], uint16(len(plaintext)+chp.Overhead))
ret := c.tx.cipher.Seal(c.tx.buf[:headerLen], c.tx.nonce[:], plaintext, nil)
c.tx.nonce.Increment()
return ret, nil
}
// wholeMessageLocked returns a slice of one whole Noise transport
// message from c.rx.buf, if one whole message is available, and
// advances the read state to the next Noise message in the
// buffer. Returns nil without advancing read state if there isn't one
// whole message in c.rx.buf.
func (c *Conn) wholeMessageLocked() []byte {
available := c.rx.n - c.rx.next
if available < headerLen {
return nil
}
bs := c.rx.buf[c.rx.next:c.rx.n]
totalSize := headerLen + int(binary.BigEndian.Uint16(bs[1:3]))
if len(bs) < totalSize {
return nil
}
c.rx.next += totalSize
return bs[:totalSize]
}
// decryptOneLocked decrypts one Noise transport message, reading from
// c.conn as needed, and sets c.rx.plaintext to point to the decrypted
// bytes. c.rx.plaintext is only valid if err == nil.
func (c *Conn) decryptOneLocked() error {
c.rx.plaintext = nil
// Fast path: do we have one whole ciphertext frame buffered
// already?
if bs := c.wholeMessageLocked(); bs != nil {
return c.decryptLocked(bs)
}
if c.rx.next != 0 {
// To simplify the read logic, move the remainder of the
// buffered bytes back to the head of the buffer, so we can
// grow it without worrying about wraparound.
c.rx.n = copy(c.rx.buf[:], c.rx.buf[c.rx.next:c.rx.n])
c.rx.next = 0
}
bs, err := c.readNLocked(headerLen)
if err != nil {
return err
}
// The rest of the header (besides the length field) gets verified
// in decryptLocked, not here.
messageLen := headerLen + int(binary.BigEndian.Uint16(bs[1:3]))
bs, err = c.readNLocked(messageLen)
if err != nil {
return err
}
c.rx.next = len(bs)
return c.decryptLocked(bs)
}
// Read implements io.Reader.
func (c *Conn) Read(bs []byte) (int, error) {
c.rx.Lock()
defer c.rx.Unlock()
if c.rx.cipher == nil {
return 0, net.ErrClosed
}
// If no plaintext is buffered, decrypt incoming frames until we
// have some plaintext. Zero-byte Noise frames are allowed in this
// protocol, which is why we have to loop here rather than decrypt
// a single additional frame.
for len(c.rx.plaintext) == 0 {
if err := c.decryptOneLocked(); err != nil {
return 0, err
}
}
n := copy(bs, c.rx.plaintext)
c.rx.plaintext = c.rx.plaintext[n:]
return n, nil
}
// Write implements io.Writer.
func (c *Conn) Write(bs []byte) (n int, err error) {
c.tx.Lock()
defer c.tx.Unlock()
if c.tx.err != nil {
return 0, c.tx.err
}
defer func() {
if err != nil {
// All write errors are fatal for this conn, so clear the
// cipher state whenever an error happens.
c.tx.cipher = nil
}
if c.tx.err == nil {
// Only set c.tx.err if not nil so that we can return one
// error on the first failure, and a different one for
// subsequent calls. See the error handling around Write
// below for why.
c.tx.err = err
}
}()
if c.tx.cipher == nil {
return 0, net.ErrClosed
}
var sent int
for len(bs) > 0 {
toSend := bs
if len(toSend) > maxPlaintextSize {
toSend = bs[:maxPlaintextSize]
}
bs = bs[len(toSend):]
ciphertext, err := c.encryptLocked(toSend)
if err != nil {
return 0, err
}
n, err := c.conn.Write(ciphertext)
sent += n
if err != nil {
// Return the raw error on the Write that actually
// failed. For future writes, return that error wrapped in
// a desync error.
c.tx.err = errPartialWrite{err}
return sent, err
}
}
return sent, nil
}
// Close implements io.Closer.
func (c *Conn) Close() error {
closeErr := c.conn.Close() // unblocks any waiting reads or writes
// Remove references to live cipher state. Strictly speaking this
// is unnecessary, but we want to try and hand the active cipher
// state to the garbage collector promptly, to preserve perfect
// forward secrecy as much as we can.
c.rx.Lock()
c.rx.cipher = nil
c.rx.Unlock()
c.tx.Lock()
c.tx.cipher = nil
c.tx.Unlock()
return closeErr
}
func (c *Conn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
func (c *Conn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() }
func (c *Conn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) }
func (c *Conn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) }
func (c *Conn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }
// errCipherExhausted is the error returned when we run out of nonces
// on a cipher.
type errCipherExhausted struct{}
func (errCipherExhausted) Error() string {
return "cipher exhausted, no more nonces available for current key"
}
func (errCipherExhausted) Timeout() bool { return false }
func (errCipherExhausted) Temporary() bool { return false }
// errPartialWrite is the error returned when the cipher state has
// become unusable due to a past partial write.
type errPartialWrite struct {
err error
}
func (e errPartialWrite) Error() string {
return fmt.Sprintf("cipher state desynchronized due to partial write (%v)", e.err)
}
func (e errPartialWrite) Unwrap() error { return e.err }
func (e errPartialWrite) Temporary() bool { return false }
func (e errPartialWrite) Timeout() bool { return false }
// errReadTooBig is the error returned when the peer sent an
// unacceptably large Noise frame.
type errReadTooBig struct {
requested int
}
func (e errReadTooBig) Error() string {
return fmt.Sprintf("requested read of %d bytes exceeds max allowed Noise frame size", e.requested)
}
func (e errReadTooBig) Temporary() bool {
// permanent error because this error only occurs when our peer
// sends us a frame so large we're unwilling to ever decode it.
return false
}
func (e errReadTooBig) Timeout() bool { return false }
type nonce [chp.NonceSize]byte
func (n *nonce) Valid() bool {
return binary.BigEndian.Uint32(n[:4]) == 0 && binary.BigEndian.Uint64(n[4:]) != invalidNonce
}
func (n *nonce) Increment() {
if !n.Valid() {
panic("increment of invalid nonce")
}
binary.BigEndian.PutUint64(n[4:], 1+binary.BigEndian.Uint64(n[4:]))
}

View File

@ -0,0 +1,339 @@
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package controlbase
import (
"bufio"
"bytes"
"context"
"crypto/rand"
"encoding/binary"
"fmt"
"io"
"net"
"strings"
"sync"
"testing"
"testing/iotest"
chp "golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/nettest"
tsnettest "tailscale.com/net/nettest"
"tailscale.com/types/key"
)
func TestMessageSize(t *testing.T) {
// This test is a regression guard against someone looking at
// maxCiphertextSize, going "huh, we could be more efficient if it
// were larger, and accidentally violating the Noise spec. Do not
// change this max value, it's a deliberate limitation of the
// cryptographic protocol we use (see Section 3 "Message Format"
// of the Noise spec).
const max = 65535
if maxCiphertextSize > max {
t.Fatalf("max ciphertext size is %d, which is larger than the maximum noise message size %d", maxCiphertextSize, max)
}
}
func TestConnBasic(t *testing.T) {
client, server := pair(t)
sb := sinkReads(server)
want := "test"
if _, err := io.WriteString(client, want); err != nil {
t.Fatalf("client write failed: %v", err)
}
client.Close()
if got := sb.String(4); got != want {
t.Fatalf("wrong content received: got %q, want %q", got, want)
}
if err := sb.Error(); err != io.EOF {
t.Fatal("client close wasn't seen by server")
}
if sb.Total() != 4 {
t.Fatalf("wrong amount of bytes received: got %d, want 4", sb.Total())
}
}
// bufferedWriteConn wraps a net.Conn and gives control over how
// Writes get batched out.
type bufferedWriteConn struct {
net.Conn
w *bufio.Writer
manualFlush bool
}
func (c *bufferedWriteConn) Write(bs []byte) (int, error) {
n, err := c.w.Write(bs)
if err == nil && !c.manualFlush {
err = c.w.Flush()
}
return n, err
}
// TestFastPath exercises the Read codepath that can receive multiple
// Noise frames at once and decode each in turn without making another
// syscall.
func TestFastPath(t *testing.T) {
s1, s2 := tsnettest.NewConn("noise", 128000)
b := &bufferedWriteConn{s1, bufio.NewWriterSize(s1, 10000), false}
client, server := pairWithConns(t, b, s2)
b.manualFlush = true
sb := sinkReads(server)
const packets = 10
s := "test"
for i := 0; i < packets; i++ {
// Many separate writes, to force separate Noise frames that
// all get buffered up and then all sent as a single slice to
// the server.
if _, err := io.WriteString(client, s); err != nil {
t.Fatalf("client write1 failed: %v", err)
}
}
if err := b.w.Flush(); err != nil {
t.Fatalf("client flush failed: %v", err)
}
client.Close()
want := strings.Repeat(s, packets)
if got := sb.String(len(want)); got != want {
t.Fatalf("wrong content received: got %q, want %q", got, want)
}
if err := sb.Error(); err != io.EOF {
t.Fatalf("client close wasn't seen by server")
}
}
// Writes things larger than a single Noise frame, to check the
// chunking on the encoder and decoder.
func TestBigData(t *testing.T) {
client, server := pair(t)
serverReads := sinkReads(server)
clientReads := sinkReads(client)
const sz = 15 * 1024 // 15KiB
clientStr := strings.Repeat("abcde", sz/5)
serverStr := strings.Repeat("fghij", sz/5*2)
if _, err := io.WriteString(client, clientStr); err != nil {
t.Fatalf("writing client>server: %v", err)
}
if _, err := io.WriteString(server, serverStr); err != nil {
t.Fatalf("writing server>client: %v", err)
}
if serverGot := serverReads.String(sz); serverGot != clientStr {
t.Error("server didn't receive what client sent")
}
if clientGot := clientReads.String(2 * sz); clientGot != serverStr {
t.Error("client didn't receive what server sent")
}
getNonce := func(n [chp.NonceSize]byte) uint64 {
if binary.BigEndian.Uint32(n[:4]) != 0 {
panic("unexpected nonce")
}
return binary.BigEndian.Uint64(n[4:])
}
// Reach into the Conns and verify the cipher nonces advanced as
// expected.
if getNonce(client.tx.nonce) != getNonce(server.rx.nonce) {
t.Error("desynchronized client tx nonce")
}
if getNonce(server.tx.nonce) != getNonce(client.rx.nonce) {
t.Error("desynchronized server tx nonce")
}
if n := getNonce(client.tx.nonce); n != 4 {
t.Errorf("wrong client tx nonce, got %d want 4", n)
}
if n := getNonce(server.tx.nonce); n != 8 {
t.Errorf("wrong client tx nonce, got %d want 8", n)
}
}
// readerConn wraps a net.Conn and routes its Reads through a separate
// io.Reader.
type readerConn struct {
net.Conn
r io.Reader
}
func (c readerConn) Read(bs []byte) (int, error) { return c.r.Read(bs) }
// Check that the receiver can handle not being able to read an entire
// frame in a single syscall.
func TestDataTrickle(t *testing.T) {
s1, s2 := tsnettest.NewConn("noise", 128000)
client, server := pairWithConns(t, s1, readerConn{s2, iotest.OneByteReader(s2)})
serverReads := sinkReads(server)
const sz = 10000
clientStr := strings.Repeat("abcde", sz/5)
if _, err := io.WriteString(client, clientStr); err != nil {
t.Fatalf("writing client>server: %v", err)
}
serverGot := serverReads.String(sz)
if serverGot != clientStr {
t.Error("server didn't receive what client sent")
}
}
func TestConnStd(t *testing.T) {
// You can run this test manually, and noise.Conn should pass all
// of them except for TestConn/PastTimeout,
// TestConn/FutureTimeout, TestConn/ConcurrentMethods, because
// those tests assume that write errors are recoverable, and
// they're not on our Conn due to cipher security.
t.Skip("not all tests can pass on this Conn, see https://github.com/golang/go/issues/46977")
nettest.TestConn(t, func() (c1 net.Conn, c2 net.Conn, stop func(), err error) {
s1, s2 := tsnettest.NewConn("noise", 4096)
controlKey := key.NewMachine()
machineKey := key.NewMachine()
serverErr := make(chan error, 1)
go func() {
var err error
c2, err = Server(context.Background(), s2, controlKey)
serverErr <- err
}()
c1, err = Client(context.Background(), s1, machineKey, controlKey.Public())
if err != nil {
s1.Close()
s2.Close()
return nil, nil, nil, fmt.Errorf("connecting client: %w", err)
}
if err := <-serverErr; err != nil {
c1.Close()
s1.Close()
s2.Close()
return nil, nil, nil, fmt.Errorf("connecting server: %w", err)
}
return c1, c2, func() {
c1.Close()
c2.Close()
}, nil
})
}
// mkConns creates synthetic Noise Conns wrapping the given net.Conns.
// This function is for testing just the Conn transport logic without
// having to muck about with Noise handshakes.
func mkConns(s1, s2 net.Conn) (*Conn, *Conn) {
var k1, k2 [chp.KeySize]byte
if _, err := rand.Read(k1[:]); err != nil {
panic(err)
}
if _, err := rand.Read(k2[:]); err != nil {
panic(err)
}
ret1 := &Conn{
conn: s1,
tx: txState{cipher: newCHP(k1)},
rx: rxState{cipher: newCHP(k2)},
}
ret2 := &Conn{
conn: s2,
tx: txState{cipher: newCHP(k2)},
rx: rxState{cipher: newCHP(k1)},
}
return ret1, ret2
}
type readSink struct {
r io.Reader
cond *sync.Cond
sync.Mutex
bs bytes.Buffer
err error
}
func sinkReads(r io.Reader) *readSink {
ret := &readSink{
r: r,
}
ret.cond = sync.NewCond(&ret.Mutex)
go func() {
var buf [4096]byte
for {
n, err := r.Read(buf[:])
ret.Lock()
ret.bs.Write(buf[:n])
if err != nil {
ret.err = err
}
ret.cond.Broadcast()
ret.Unlock()
if err != nil {
return
}
}
}()
return ret
}
func (s *readSink) String(total int) string {
s.Lock()
defer s.Unlock()
for s.bs.Len() < total && s.err == nil {
s.cond.Wait()
}
if s.err != nil {
total = s.bs.Len()
}
return string(s.bs.Bytes()[:total])
}
func (s *readSink) Error() error {
s.Lock()
defer s.Unlock()
for s.err == nil {
s.cond.Wait()
}
return s.err
}
func (s *readSink) Total() int {
s.Lock()
defer s.Unlock()
return s.bs.Len()
}
func pairWithConns(t *testing.T, clientConn, serverConn net.Conn) (*Conn, *Conn) {
var (
controlKey = key.NewMachine()
machineKey = key.NewMachine()
server *Conn
serverErr = make(chan error, 1)
)
go func() {
var err error
server, err = Server(context.Background(), serverConn, controlKey)
serverErr <- err
}()
client, err := Client(context.Background(), clientConn, machineKey, controlKey.Public())
if err != nil {
t.Fatalf("client connection failed: %v", err)
}
if err := <-serverErr; err != nil {
t.Fatalf("server connection failed: %v", err)
}
return client, server
}
func pair(t *testing.T) (*Conn, *Conn) {
s1, s2 := tsnettest.NewConn("noise", 128000)
return pairWithConns(t, s1, s2)
}

View File

@ -0,0 +1,443 @@
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package controlbase
import (
"context"
"crypto/cipher"
"encoding/binary"
"errors"
"fmt"
"hash"
"io"
"net"
"strconv"
"time"
"go4.org/mem"
"golang.org/x/crypto/blake2s"
chp "golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/curve25519"
"golang.org/x/crypto/hkdf"
"tailscale.com/types/key"
)
const (
// protocolName is the name of the specific instantiation of Noise
// that the control protocol uses. This string's value is fixed by
// the Noise spec, and shouldn't be changed unless we're updating
// the control protocol to use a different Noise instance.
protocolName = "Noise_IK_25519_ChaChaPoly_BLAKE2s"
// protocolVersion is the version of the control protocol that
// Client will use when initiating a handshake.
protocolVersion uint16 = 1
// protocolVersionPrefix is the name portion of the protocol
// name+version string that gets mixed into the handshake as a
// prologue.
//
// This mixing verifies that both clients agree that they're
// executing the control protocol at a specific version that
// matches the advertised version in the cleartext packet header.
protocolVersionPrefix = "Tailscale Control Protocol v"
invalidNonce = ^uint64(0)
)
func protocolVersionPrologue(version uint16) []byte {
ret := make([]byte, 0, len(protocolVersionPrefix)+5) // 5 bytes is enough to encode all possible version numbers.
ret = append(ret, protocolVersionPrefix...)
return strconv.AppendUint(ret, uint64(version), 10)
}
// Client initiates a control client handshake, returning the resulting
// control connection.
//
// The context deadline, if any, covers the entire handshaking
// process. Any preexisting Conn deadline is removed.
func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, controlKey key.MachinePublic) (*Conn, error) {
if deadline, ok := ctx.Deadline(); ok {
if err := conn.SetDeadline(deadline); err != nil {
return nil, fmt.Errorf("setting conn deadline: %w", err)
}
defer func() {
conn.SetDeadline(time.Time{})
}()
}
var s symmetricState
s.Initialize()
// prologue
s.MixHash(protocolVersionPrologue(protocolVersion))
// <- s
// ...
s.MixHash(controlKey.UntypedBytes())
// -> e, es, s, ss
init := mkInitiationMessage()
machineEphemeral := key.NewMachine()
machineEphemeralPub := machineEphemeral.Public()
copy(init.EphemeralPub(), machineEphemeralPub.UntypedBytes())
s.MixHash(machineEphemeralPub.UntypedBytes())
cipher, err := s.MixDH(machineEphemeral, controlKey)
if err != nil {
return nil, fmt.Errorf("computing es: %w", err)
}
machineKeyPub := machineKey.Public()
s.EncryptAndHash(cipher, init.MachinePub(), machineKeyPub.UntypedBytes())
cipher, err = s.MixDH(machineKey, controlKey)
if err != nil {
return nil, fmt.Errorf("computing ss: %w", err)
}
s.EncryptAndHash(cipher, init.Tag(), nil) // empty message payload
if _, err := conn.Write(init[:]); err != nil {
return nil, fmt.Errorf("writing initiation: %w", err)
}
// Read in the payload and look for errors/protocol violations from the server.
var resp responseMessage
if _, err := io.ReadFull(conn, resp.Header()); err != nil {
return nil, fmt.Errorf("reading response header: %w", err)
}
if resp.Type() != msgTypeResponse {
if resp.Type() != msgTypeError {
return nil, fmt.Errorf("unexpected response message type %d", resp.Type())
}
msg := make([]byte, resp.Length())
if _, err := io.ReadFull(conn, msg); err != nil {
return nil, err
}
return nil, fmt.Errorf("server error: %q", msg)
}
if resp.Length() != len(resp.Payload()) {
return nil, fmt.Errorf("wrong length %d received for handshake response", resp.Length())
}
if _, err := io.ReadFull(conn, resp.Payload()); err != nil {
return nil, err
}
// <- e, ee, se
controlEphemeralPub := key.MachinePublicFromRaw32(mem.B(resp.EphemeralPub()))
s.MixHash(controlEphemeralPub.UntypedBytes())
if _, err = s.MixDH(machineEphemeral, controlEphemeralPub); err != nil {
return nil, fmt.Errorf("computing ee: %w", err)
}
cipher, err = s.MixDH(machineKey, controlEphemeralPub)
if err != nil {
return nil, fmt.Errorf("computing se: %w", err)
}
if err := s.DecryptAndHash(cipher, nil, resp.Tag()); err != nil {
return nil, fmt.Errorf("decrypting payload: %w", err)
}
c1, c2, err := s.Split()
if err != nil {
return nil, fmt.Errorf("finalizing handshake: %w", err)
}
c := &Conn{
conn: conn,
version: protocolVersion,
peer: controlKey,
handshakeHash: s.h,
tx: txState{
cipher: c1,
},
rx: rxState{
cipher: c2,
},
}
return c, nil
}
// Server initiates a control server handshake, returning the resulting
// control connection.
//
// The context deadline, if any, covers the entire handshaking
// process.
func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate) (*Conn, error) {
if deadline, ok := ctx.Deadline(); ok {
if err := conn.SetDeadline(deadline); err != nil {
return nil, fmt.Errorf("setting conn deadline: %w", err)
}
defer func() {
conn.SetDeadline(time.Time{})
}()
}
// Deliberately does not support formatting, so that we don't echo
// attacker-controlled input back to them.
sendErr := func(msg string) error {
if len(msg) >= 1<<16 {
msg = msg[:1<<16]
}
var hdr [headerLen]byte
hdr[0] = msgTypeError
binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg)))
if _, err := conn.Write(hdr[:]); err != nil {
return fmt.Errorf("sending %q error to client: %w", msg, err)
}
if _, err := io.WriteString(conn, msg); err != nil {
return fmt.Errorf("sending %q error to client: %w", msg, err)
}
return fmt.Errorf("refused client handshake: %q", msg)
}
var s symmetricState
s.Initialize()
var init initiationMessage
if _, err := io.ReadFull(conn, init.Header()); err != nil {
return nil, err
}
if init.Version() != protocolVersion {
return nil, sendErr("unsupported protocol version")
}
if init.Type() != msgTypeInitiation {
return nil, sendErr("unexpected handshake message type")
}
if init.Length() != len(init.Payload()) {
return nil, sendErr("wrong handshake initiation length")
}
if _, err := io.ReadFull(conn, init.Payload()); err != nil {
return nil, err
}
// prologue. Can only do this once we at least think the client is
// handshaking using a supported version.
s.MixHash(protocolVersionPrologue(protocolVersion))
// <- s
// ...
controlKeyPub := controlKey.Public()
s.MixHash(controlKeyPub.UntypedBytes())
// -> e, es, s, ss
machineEphemeralPub := key.MachinePublicFromRaw32(mem.B(init.EphemeralPub()))
s.MixHash(machineEphemeralPub.UntypedBytes())
cipher, err := s.MixDH(controlKey, machineEphemeralPub)
if err != nil {
return nil, fmt.Errorf("computing es: %w", err)
}
var machineKeyBytes [32]byte
if err := s.DecryptAndHash(cipher, machineKeyBytes[:], init.MachinePub()); err != nil {
return nil, fmt.Errorf("decrypting machine key: %w", err)
}
machineKey := key.MachinePublicFromRaw32(mem.B(machineKeyBytes[:]))
cipher, err = s.MixDH(controlKey, machineKey)
if err != nil {
return nil, fmt.Errorf("computing ss: %w", err)
}
if err := s.DecryptAndHash(cipher, nil, init.Tag()); err != nil {
return nil, fmt.Errorf("decrypting initiation tag: %w", err)
}
// <- e, ee, se
resp := mkResponseMessage()
controlEphemeral := key.NewMachine()
controlEphemeralPub := controlEphemeral.Public()
copy(resp.EphemeralPub(), controlEphemeralPub.UntypedBytes())
s.MixHash(controlEphemeralPub.UntypedBytes())
if _, err := s.MixDH(controlEphemeral, machineEphemeralPub); err != nil {
return nil, fmt.Errorf("computing ee: %w", err)
}
cipher, err = s.MixDH(controlEphemeral, machineKey)
if err != nil {
return nil, fmt.Errorf("computing se: %w", err)
}
s.EncryptAndHash(cipher, resp.Tag(), nil) // empty message payload
c1, c2, err := s.Split()
if err != nil {
return nil, fmt.Errorf("finalizing handshake: %w", err)
}
if _, err := conn.Write(resp[:]); err != nil {
return nil, err
}
c := &Conn{
conn: conn,
version: protocolVersion,
peer: machineKey,
handshakeHash: s.h,
tx: txState{
cipher: c2,
},
rx: rxState{
cipher: c1,
},
}
return c, nil
}
// symmetricState contains the state of an in-flight handshake.
type symmetricState struct {
finished bool
h [blake2s.Size]byte // hash of currently-processed handshake state
ck [blake2s.Size]byte // chaining key used to construct session keys at the end of the handshake
}
func (s *symmetricState) checkFinished() {
if s.finished {
panic("attempted to use symmetricState after Split was called")
}
}
// Initialize sets s to the initial handshake state, prior to
// processing any handshake messages.
func (s *symmetricState) Initialize() {
s.checkFinished()
s.h = blake2s.Sum256([]byte(protocolName))
s.ck = s.h
}
// MixHash updates s.h to be BLAKE2s(s.h || data), where || is
// concatenation.
func (s *symmetricState) MixHash(data []byte) {
s.checkFinished()
h := newBLAKE2s()
h.Write(s.h[:])
h.Write(data)
h.Sum(s.h[:0])
}
// MixDH updates s.ck with the result of X25519(priv, pub) and returns
// a singleUseCHP that can be used to encrypt or decrypt handshake
// data.
//
// MixDH corresponds to MixKey(X25519(...))) in the spec. Implementing
// it as a single function allows for strongly-typed arguments that
// reduce the risk of error in the caller (e.g. invoking X25519 with
// two private keys, or two public keys), and thus producing the wrong
// calculation.
func (s *symmetricState) MixDH(priv key.MachinePrivate, pub key.MachinePublic) (*singleUseCHP, error) {
s.checkFinished()
keyData, err := curve25519.X25519(priv.UntypedBytes(), pub.UntypedBytes())
if err != nil {
return nil, fmt.Errorf("computing X25519: %w", err)
}
r := hkdf.New(newBLAKE2s, keyData, s.ck[:], nil)
if _, err := io.ReadFull(r, s.ck[:]); err != nil {
return nil, fmt.Errorf("extracting ck: %w", err)
}
var k [chp.KeySize]byte
if _, err := io.ReadFull(r, k[:]); err != nil {
return nil, fmt.Errorf("extracting k: %w", err)
}
return newSingleUseCHP(k), nil
}
// EncryptAndHash encrypts plaintext into ciphertext (which must be
// the correct size to hold the encrypted plaintext) using cipher,
// mixes the ciphertext into s.h, and returns the ciphertext.
func (s *symmetricState) EncryptAndHash(cipher *singleUseCHP, ciphertext, plaintext []byte) {
s.checkFinished()
if len(ciphertext) != len(plaintext)+chp.Overhead {
panic("ciphertext is wrong size for given plaintext")
}
ret := cipher.Seal(ciphertext[:0], plaintext, s.h[:])
s.MixHash(ret)
}
// DecryptAndHash decrypts the given ciphertext into plaintext (which
// must be the correct size to hold the decrypted ciphertext) using
// cipher. If decryption is successful, it mixes the ciphertext into
// s.h.
func (s *symmetricState) DecryptAndHash(cipher *singleUseCHP, plaintext, ciphertext []byte) error {
s.checkFinished()
if len(ciphertext) != len(plaintext)+chp.Overhead {
return errors.New("plaintext is wrong size for given ciphertext")
}
if _, err := cipher.Open(plaintext[:0], ciphertext, s.h[:]); err != nil {
return err
}
s.MixHash(ciphertext)
return nil
}
// Split returns two ChaCha20Poly1305 ciphers with keys derived from
// the current handshake state. Methods on s cannot be used again
// after calling Split.
func (s *symmetricState) Split() (c1, c2 cipher.AEAD, err error) {
s.finished = true
var k1, k2 [chp.KeySize]byte
r := hkdf.New(newBLAKE2s, nil, s.ck[:], nil)
if _, err := io.ReadFull(r, k1[:]); err != nil {
return nil, nil, fmt.Errorf("extracting k1: %w", err)
}
if _, err := io.ReadFull(r, k2[:]); err != nil {
return nil, nil, fmt.Errorf("extracting k2: %w", err)
}
c1, err = chp.New(k1[:])
if err != nil {
return nil, nil, fmt.Errorf("constructing AEAD c1: %w", err)
}
c2, err = chp.New(k2[:])
if err != nil {
return nil, nil, fmt.Errorf("constructing AEAD c2: %w", err)
}
return c1, c2, nil
}
// newBLAKE2s returns a hash.Hash implementing BLAKE2s, or panics on
// error.
func newBLAKE2s() hash.Hash {
h, err := blake2s.New256(nil)
if err != nil {
// Should never happen, errors only happen when using BLAKE2s
// in MAC mode with a key.
panic(err)
}
return h
}
// newCHP returns a cipher.AEAD implementing ChaCha20Poly1305, or
// panics on error.
func newCHP(key [chp.KeySize]byte) cipher.AEAD {
aead, err := chp.New(key[:])
if err != nil {
// Can only happen if we passed a key of the wrong length. The
// function signature prevents that.
panic(err)
}
return aead
}
// singleUseCHP is an instance of ChaCha20Poly1305 that can be used
// only once, either for encrypting or decrypting, but not both. The
// chosen operation is always executed with an all-zeros
// nonce. Subsequent calls to either Seal or Open panic.
type singleUseCHP struct {
c cipher.AEAD
}
func newSingleUseCHP(key [chp.KeySize]byte) *singleUseCHP {
return &singleUseCHP{newCHP(key)}
}
func (c *singleUseCHP) Seal(dst, plaintext, additionalData []byte) []byte {
if c.c == nil {
panic("Attempted reuse of singleUseAEAD")
}
cipher := c.c
c.c = nil
var nonce [chp.NonceSize]byte
return cipher.Seal(dst, nonce[:], plaintext, additionalData)
}
func (c *singleUseCHP) Open(dst, ciphertext, additionalData []byte) ([]byte, error) {
if c.c == nil {
panic("Attempted reuse of singleUseAEAD")
}
cipher := c.c
c.c = nil
var nonce [chp.NonceSize]byte
return cipher.Open(dst, nonce[:], ciphertext, additionalData)
}

View File

@ -0,0 +1,299 @@
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package controlbase
import (
"bytes"
"context"
"io"
"strings"
"testing"
"time"
tsnettest "tailscale.com/net/nettest"
"tailscale.com/types/key"
)
func TestHandshake(t *testing.T) {
var (
clientConn, serverConn = tsnettest.NewConn("noise", 128000)
serverKey = key.NewMachine()
clientKey = key.NewMachine()
server *Conn
serverErr = make(chan error, 1)
)
go func() {
var err error
server, err = Server(context.Background(), serverConn, serverKey)
serverErr <- err
}()
client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public())
if err != nil {
t.Fatalf("client connection failed: %v", err)
}
if err := <-serverErr; err != nil {
t.Fatalf("server connection failed: %v", err)
}
if client.HandshakeHash() != server.HandshakeHash() {
t.Fatal("client and server disagree on handshake hash")
}
if client.ProtocolVersion() != int(protocolVersion) {
t.Fatalf("client reporting wrong protocol version %d, want %d", client.ProtocolVersion(), protocolVersion)
}
if client.ProtocolVersion() != server.ProtocolVersion() {
t.Fatalf("peers disagree on protocol version, client=%d server=%d", client.ProtocolVersion(), server.ProtocolVersion())
}
if client.Peer() != serverKey.Public() {
t.Fatal("client peer key isn't serverKey")
}
if server.Peer() != clientKey.Public() {
t.Fatal("client peer key isn't serverKey")
}
}
// Check that handshaking repeatedly with the same long-term keys
// result in different handshake hashes and wire traffic.
func TestNoReuse(t *testing.T) {
var (
hashes = map[[32]byte]bool{}
clientHandshakes = map[[96]byte]bool{}
serverHandshakes = map[[48]byte]bool{}
packets = map[[32]byte]bool{}
)
for i := 0; i < 10; i++ {
var (
clientRaw, serverRaw = tsnettest.NewConn("noise", 128000)
clientBuf, serverBuf bytes.Buffer
clientConn = &readerConn{clientRaw, io.TeeReader(clientRaw, &clientBuf)}
serverConn = &readerConn{serverRaw, io.TeeReader(serverRaw, &serverBuf)}
serverKey = key.NewMachine()
clientKey = key.NewMachine()
server *Conn
serverErr = make(chan error, 1)
)
go func() {
var err error
server, err = Server(context.Background(), serverConn, serverKey)
serverErr <- err
}()
client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public())
if err != nil {
t.Fatalf("client connection failed: %v", err)
}
if err := <-serverErr; err != nil {
t.Fatalf("server connection failed: %v", err)
}
var clientHS [96]byte
copy(clientHS[:], serverBuf.Bytes())
if clientHandshakes[clientHS] {
t.Fatal("client handshake seen twice")
}
clientHandshakes[clientHS] = true
var serverHS [48]byte
copy(serverHS[:], clientBuf.Bytes())
if serverHandshakes[serverHS] {
t.Fatal("server handshake seen twice")
}
serverHandshakes[serverHS] = true
clientBuf.Reset()
serverBuf.Reset()
cb := sinkReads(client)
sb := sinkReads(server)
if hashes[client.HandshakeHash()] {
t.Fatalf("handshake hash %v seen twice", client.HandshakeHash())
}
hashes[client.HandshakeHash()] = true
// Sending 14 bytes turns into 32 bytes on the wire (+16 for
// the chacha20poly1305 overhead, +2 length header)
if _, err := io.WriteString(client, strings.Repeat("a", 14)); err != nil {
t.Fatalf("client>server write failed: %v", err)
}
if _, err := io.WriteString(server, strings.Repeat("b", 14)); err != nil {
t.Fatalf("server>client write failed: %v", err)
}
// Wait for the bytes to be read, so we know they've traveled end to end
cb.String(14)
sb.String(14)
var clientWire, serverWire [32]byte
copy(clientWire[:], clientBuf.Bytes())
copy(serverWire[:], serverBuf.Bytes())
if packets[clientWire] {
t.Fatalf("client wire traffic seen twice")
}
packets[clientWire] = true
if packets[serverWire] {
t.Fatalf("server wire traffic seen twice")
}
packets[serverWire] = true
server.Close()
client.Close()
}
}
// tamperReader wraps a reader and mutates the Nth byte.
type tamperReader struct {
r io.Reader
n int
total int
}
func (r *tamperReader) Read(bs []byte) (int, error) {
n, err := r.r.Read(bs)
if off := r.n - r.total; off >= 0 && off < n {
bs[off] += 1
}
r.total += n
return n, err
}
func TestTampering(t *testing.T) {
// Tamper with every byte of the client initiation message.
for i := 0; i < 101; i++ {
var (
clientConn, serverRaw = tsnettest.NewConn("noise", 128000)
serverConn = &readerConn{serverRaw, &tamperReader{serverRaw, i, 0}}
serverKey = key.NewMachine()
clientKey = key.NewMachine()
serverErr = make(chan error, 1)
)
go func() {
_, err := Server(context.Background(), serverConn, serverKey)
// If the server failed, we have to close the Conn to
// unblock the client.
if err != nil {
serverConn.Close()
}
serverErr <- err
}()
_, err := Client(context.Background(), clientConn, clientKey, serverKey.Public())
if err == nil {
t.Fatal("client connection succeeded despite tampering")
}
if err := <-serverErr; err == nil {
t.Fatalf("server connection succeeded despite tampering")
}
}
// Tamper with every byte of the server response message.
for i := 0; i < 51; i++ {
var (
clientRaw, serverConn = tsnettest.NewConn("noise", 128000)
clientConn = &readerConn{clientRaw, &tamperReader{clientRaw, i, 0}}
serverKey = key.NewMachine()
clientKey = key.NewMachine()
serverErr = make(chan error, 1)
)
go func() {
_, err := Server(context.Background(), serverConn, serverKey)
serverErr <- err
}()
_, err := Client(context.Background(), clientConn, clientKey, serverKey.Public())
if err == nil {
t.Fatal("client connection succeeded despite tampering")
}
// The server shouldn't fail, because the tampering took place
// in its response.
if err := <-serverErr; err != nil {
t.Fatalf("server connection failed despite no tampering: %v", err)
}
}
// Tamper with every byte of the first server>client transport message.
for i := 0; i < 30; i++ {
var (
clientRaw, serverConn = tsnettest.NewConn("noise", 128000)
clientConn = &readerConn{clientRaw, &tamperReader{clientRaw, 51 + i, 0}}
serverKey = key.NewMachine()
clientKey = key.NewMachine()
serverErr = make(chan error, 1)
)
go func() {
server, err := Server(context.Background(), serverConn, serverKey)
serverErr <- err
_, err = io.WriteString(server, strings.Repeat("a", 14))
serverErr <- err
}()
client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public())
if err != nil {
t.Fatalf("client handshake failed: %v", err)
}
// The server shouldn't fail, because the tampering took place
// in its response.
if err := <-serverErr; err != nil {
t.Fatalf("server handshake failed: %v", err)
}
// The client needs a timeout if the tampering is hitting the length header.
if i == 1 || i == 2 {
client.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
}
var bs [100]byte
n, err := client.Read(bs[:])
if err == nil {
t.Fatal("read succeeded despite tampering")
}
if n != 0 {
t.Fatal("conn yielded some bytes despite tampering")
}
}
// Tamper with every byte of the first client>server transport message.
for i := 0; i < 30; i++ {
var (
clientConn, serverRaw = tsnettest.NewConn("noise", 128000)
serverConn = &readerConn{serverRaw, &tamperReader{serverRaw, 101 + i, 0}}
serverKey = key.NewMachine()
clientKey = key.NewMachine()
serverErr = make(chan error, 1)
)
go func() {
server, err := Server(context.Background(), serverConn, serverKey)
serverErr <- err
var bs [100]byte
// The server needs a timeout if the tampering is hitting the length header.
if i == 1 || i == 2 {
server.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
}
n, err := server.Read(bs[:])
if n != 0 {
panic("server got bytes despite tampering")
} else {
serverErr <- err
}
}()
client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public())
if err != nil {
t.Fatalf("client handshake failed: %v", err)
}
if err := <-serverErr; err != nil {
t.Fatalf("server handshake failed: %v", err)
}
if _, err := io.WriteString(client, strings.Repeat("a", 14)); err != nil {
t.Fatalf("client>server write failed: %v", err)
}
if err := <-serverErr; err == nil {
t.Fatal("server successfully received bytes despite tampering")
}
}
}

View File

@ -0,0 +1,257 @@
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package controlbase
import (
"context"
"encoding/binary"
"errors"
"io"
"net"
"testing"
tsnettest "tailscale.com/net/nettest"
"tailscale.com/types/key"
)
// Can a reference Noise IK client talk to our server?
func TestInteropClient(t *testing.T) {
var (
s1, s2 = tsnettest.NewConn("noise", 128000)
controlKey = key.NewMachine()
machineKey = key.NewMachine()
serverErr = make(chan error, 2)
serverBytes = make(chan []byte, 1)
c2s = "client>server"
s2c = "server>client"
)
go func() {
server, err := Server(context.Background(), s2, controlKey)
serverErr <- err
if err != nil {
return
}
var buf [1024]byte
_, err = io.ReadFull(server, buf[:len(c2s)])
serverBytes <- buf[:len(c2s)]
if err != nil {
serverErr <- err
return
}
_, err = server.Write([]byte(s2c))
serverErr <- err
}()
gotS2C, err := noiseExplorerClient(s1, controlKey.Public(), machineKey, []byte(c2s))
if err != nil {
t.Fatalf("failed client interop: %v", err)
}
if string(gotS2C) != s2c {
t.Fatalf("server sent unexpected data %q, want %q", string(gotS2C), s2c)
}
if err := <-serverErr; err != nil {
t.Fatalf("server handshake failed: %v", err)
}
if err := <-serverErr; err != nil {
t.Fatalf("server read/write failed: %v", err)
}
if got := string(<-serverBytes); got != c2s {
t.Fatalf("server received %q, want %q", got, c2s)
}
}
// Can our client talk to a reference Noise IK server?
func TestInteropServer(t *testing.T) {
var (
s1, s2 = tsnettest.NewConn("noise", 128000)
controlKey = key.NewMachine()
machineKey = key.NewMachine()
clientErr = make(chan error, 2)
clientBytes = make(chan []byte, 1)
c2s = "client>server"
s2c = "server>client"
)
go func() {
client, err := Client(context.Background(), s1, machineKey, controlKey.Public())
clientErr <- err
if err != nil {
return
}
_, err = client.Write([]byte(c2s))
if err != nil {
clientErr <- err
return
}
var buf [1024]byte
_, err = io.ReadFull(client, buf[:len(s2c)])
clientBytes <- buf[:len(s2c)]
clientErr <- err
}()
gotC2S, err := noiseExplorerServer(s2, controlKey, machineKey.Public(), []byte(s2c))
if err != nil {
t.Fatalf("failed server interop: %v", err)
}
if string(gotC2S) != c2s {
t.Fatalf("server sent unexpected data %q, want %q", string(gotC2S), c2s)
}
if err := <-clientErr; err != nil {
t.Fatalf("client handshake failed: %v", err)
}
if err := <-clientErr; err != nil {
t.Fatalf("client read/write failed: %v", err)
}
if got := string(<-clientBytes); got != s2c {
t.Fatalf("client received %q, want %q", got, s2c)
}
}
// noiseExplorerClient uses the Noise Explorer implementation of Noise
// IK to handshake as a Noise client on conn, transmit payload, and
// read+return a payload from the peer.
func noiseExplorerClient(conn net.Conn, controlKey key.MachinePublic, machineKey key.MachinePrivate, payload []byte) ([]byte, error) {
var mk keypair
copy(mk.private_key[:], machineKey.UntypedBytes())
copy(mk.public_key[:], machineKey.Public().UntypedBytes())
var peerKey [32]byte
copy(peerKey[:], controlKey.UntypedBytes())
session := InitSession(true, protocolVersionPrologue(protocolVersion), mk, peerKey)
_, msg1 := SendMessage(&session, nil)
var hdr [initiationHeaderLen]byte
binary.BigEndian.PutUint16(hdr[:2], protocolVersion)
hdr[2] = msgTypeInitiation
binary.BigEndian.PutUint16(hdr[3:5], 96)
if _, err := conn.Write(hdr[:]); err != nil {
return nil, err
}
if _, err := conn.Write(msg1.ne[:]); err != nil {
return nil, err
}
if _, err := conn.Write(msg1.ns); err != nil {
return nil, err
}
if _, err := conn.Write(msg1.ciphertext); err != nil {
return nil, err
}
var buf [1024]byte
if _, err := io.ReadFull(conn, buf[:51]); err != nil {
return nil, err
}
// ignore the header for this test, we're only checking the noise
// implementation.
msg2 := messagebuffer{
ciphertext: buf[35:51],
}
copy(msg2.ne[:], buf[3:35])
_, p, valid := RecvMessage(&session, &msg2)
if !valid {
return nil, errors.New("handshake failed")
}
if len(p) != 0 {
return nil, errors.New("non-empty payload")
}
_, msg3 := SendMessage(&session, payload)
hdr[0] = msgTypeRecord
binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg3.ciphertext)))
if _, err := conn.Write(hdr[:3]); err != nil {
return nil, err
}
if _, err := conn.Write(msg3.ciphertext); err != nil {
return nil, err
}
if _, err := io.ReadFull(conn, buf[:3]); err != nil {
return nil, err
}
// Ignore all of the header except the payload length
plen := int(binary.BigEndian.Uint16(buf[1:3]))
if _, err := io.ReadFull(conn, buf[:plen]); err != nil {
return nil, err
}
msg4 := messagebuffer{
ciphertext: buf[:plen],
}
_, p, valid = RecvMessage(&session, &msg4)
if !valid {
return nil, errors.New("transport message decryption failed")
}
return p, nil
}
func noiseExplorerServer(conn net.Conn, controlKey key.MachinePrivate, wantMachineKey key.MachinePublic, payload []byte) ([]byte, error) {
var mk keypair
copy(mk.private_key[:], controlKey.UntypedBytes())
copy(mk.public_key[:], controlKey.Public().UntypedBytes())
session := InitSession(false, protocolVersionPrologue(protocolVersion), mk, [32]byte{})
var buf [1024]byte
if _, err := io.ReadFull(conn, buf[:101]); err != nil {
return nil, err
}
// Ignore the header, we're just checking the noise implementation.
msg1 := messagebuffer{
ns: buf[37:85],
ciphertext: buf[85:101],
}
copy(msg1.ne[:], buf[5:37])
_, p, valid := RecvMessage(&session, &msg1)
if !valid {
return nil, errors.New("handshake failed")
}
if len(p) != 0 {
return nil, errors.New("non-empty payload")
}
_, msg2 := SendMessage(&session, nil)
var hdr [headerLen]byte
hdr[0] = msgTypeResponse
binary.BigEndian.PutUint16(hdr[1:3], 48)
if _, err := conn.Write(hdr[:]); err != nil {
return nil, err
}
if _, err := conn.Write(msg2.ne[:]); err != nil {
return nil, err
}
if _, err := conn.Write(msg2.ciphertext[:]); err != nil {
return nil, err
}
if _, err := io.ReadFull(conn, buf[:3]); err != nil {
return nil, err
}
plen := int(binary.BigEndian.Uint16(buf[1:3]))
if _, err := io.ReadFull(conn, buf[:plen]); err != nil {
return nil, err
}
msg3 := messagebuffer{
ciphertext: buf[:plen],
}
_, p, valid = RecvMessage(&session, &msg3)
if !valid {
return nil, errors.New("transport message decryption failed")
}
_, msg4 := SendMessage(&session, payload)
hdr[0] = msgTypeRecord
binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg4.ciphertext)))
if _, err := conn.Write(hdr[:]); err != nil {
return nil, err
}
if _, err := conn.Write(msg4.ciphertext); err != nil {
return nil, err
}
return p, nil
}

View File

@ -0,0 +1,88 @@
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package controlbase
import "encoding/binary"
const (
// msgTypeInitiation frames carry a Noise IK handshake initiation message.
msgTypeInitiation = 1
// msgTypeResponse frames carry a Noise IK handshake response message.
msgTypeResponse = 2
// msgTypeError frames carry an unauthenticated human-readable
// error message.
//
// Errors reported in this message type must be treated as public
// hints only. They are not encrypted or authenticated, and so can
// be seen and tampered with on the wire.
msgTypeError = 3
// msgTypeRecord frames carry session data bytes.
msgTypeRecord = 4
// headerLen is the size of the header on all messages except msgTypeInitiation.
headerLen = 3
// initiationHeaderLen is the size of the header on all msgTypeInitiation messages.
initiationHeaderLen = 5
)
// initiationMessage is the protocol message sent from a client
// machine to a control server.
//
// 2b: protocol version
// 1b: message type (0x01)
// 2b: payload length (96)
// 5b: header (see headerLen for fields)
// 32b: client ephemeral public key (cleartext)
// 48b: client machine public key (encrypted)
// 16b: message tag (authenticates the whole message)
type initiationMessage [101]byte
func mkInitiationMessage() initiationMessage {
var ret initiationMessage
binary.BigEndian.PutUint16(ret[:2], uint16(protocolVersion))
ret[2] = msgTypeInitiation
binary.BigEndian.PutUint16(ret[3:5], uint16(len(ret.Payload())))
return ret
}
func (m *initiationMessage) Header() []byte { return m[:initiationHeaderLen] }
func (m *initiationMessage) Payload() []byte { return m[initiationHeaderLen:] }
func (m *initiationMessage) Version() uint16 { return binary.BigEndian.Uint16(m[:2]) }
func (m *initiationMessage) Type() byte { return m[2] }
func (m *initiationMessage) Length() int { return int(binary.BigEndian.Uint16(m[3:5])) }
func (m *initiationMessage) EphemeralPub() []byte {
return m[initiationHeaderLen : initiationHeaderLen+32]
}
func (m *initiationMessage) MachinePub() []byte {
return m[initiationHeaderLen+32 : initiationHeaderLen+32+48]
}
func (m *initiationMessage) Tag() []byte { return m[initiationHeaderLen+32+48:] }
// responseMessage is the protocol message sent from a control server
// to a client machine.
//
// 1b: message type (0x02)
// 2b: payload length (48)
// 32b: control ephemeral public key (cleartext)
// 16b: message tag (authenticates the whole message)
type responseMessage [51]byte
func mkResponseMessage() responseMessage {
var ret responseMessage
ret[0] = msgTypeResponse
binary.BigEndian.PutUint16(ret[1:], uint16(len(ret.Payload())))
return ret
}
func (m *responseMessage) Header() []byte { return m[:headerLen] }
func (m *responseMessage) Payload() []byte { return m[headerLen:] }
func (m *responseMessage) Type() byte { return m[0] }
func (m *responseMessage) Length() int { return int(binary.BigEndian.Uint16(m[1:3])) }
func (m *responseMessage) EphemeralPub() []byte { return m[headerLen : headerLen+32] }
func (m *responseMessage) Tag() []byte { return m[headerLen+32:] }

View File

@ -0,0 +1,475 @@
// This file contains the implementation of Noise IK from
// https://noiseexplorer.com/ . Unlike the rest of this repository,
// this file is licensed under the terms of the GNU GPL v3. See
// https://source.symbolic.software/noiseexplorer/noiseexplorer for
// more information.
//
// This file is used here to verify that Tailscale's implementation of
// Noise IK is interoperable with another implementation.
//lint:file-ignore SA4006 not our code.
/*
IK:
<- s
...
-> e, es, s, ss
<- e, ee, se
->
<-
*/
// Implementation Version: 1.0.2
/* ---------------------------------------------------------------- *
* PARAMETERS *
* ---------------------------------------------------------------- */
package controlbase
import (
"crypto/rand"
"crypto/subtle"
"encoding/binary"
"hash"
"io"
"math"
"golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/curve25519"
"golang.org/x/crypto/hkdf"
)
/* ---------------------------------------------------------------- *
* TYPES *
* ---------------------------------------------------------------- */
type keypair struct {
public_key [32]byte
private_key [32]byte
}
type messagebuffer struct {
ne [32]byte
ns []byte
ciphertext []byte
}
type cipherstate struct {
k [32]byte
n uint32
}
type symmetricstate struct {
cs cipherstate
ck [32]byte
h [32]byte
}
type handshakestate struct {
ss symmetricstate
s keypair
e keypair
rs [32]byte
re [32]byte
psk [32]byte
}
type noisesession struct {
hs handshakestate
h [32]byte
cs1 cipherstate
cs2 cipherstate
mc uint64
i bool
}
/* ---------------------------------------------------------------- *
* CONSTANTS *
* ---------------------------------------------------------------- */
var emptyKey = [32]byte{
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
}
var minNonce = uint32(0)
/* ---------------------------------------------------------------- *
* UTILITY FUNCTIONS *
* ---------------------------------------------------------------- */
func getPublicKey(kp *keypair) [32]byte {
return kp.public_key
}
func isEmptyKey(k [32]byte) bool {
return subtle.ConstantTimeCompare(k[:], emptyKey[:]) == 1
}
func validatePublicKey(k []byte) bool {
forbiddenCurveValues := [12][]byte{
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{224, 235, 122, 124, 59, 65, 184, 174, 22, 86, 227, 250, 241, 159, 196, 106, 218, 9, 141, 235, 156, 50, 177, 253, 134, 98, 5, 22, 95, 73, 184, 0},
{95, 156, 149, 188, 163, 80, 140, 36, 177, 208, 177, 85, 156, 131, 239, 91, 4, 68, 92, 196, 88, 28, 142, 134, 216, 34, 78, 221, 208, 159, 17, 87},
{236, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 127},
{237, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 127},
{238, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 127},
{205, 235, 122, 124, 59, 65, 184, 174, 22, 86, 227, 250, 241, 159, 196, 106, 218, 9, 141, 235, 156, 50, 177, 253, 134, 98, 5, 22, 95, 73, 184, 128},
{76, 156, 149, 188, 163, 80, 140, 36, 177, 208, 177, 85, 156, 131, 239, 91, 4, 68, 92, 196, 88, 28, 142, 134, 216, 34, 78, 221, 208, 159, 17, 215},
{217, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255},
{218, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255},
{219, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 25},
}
for _, testValue := range forbiddenCurveValues {
if subtle.ConstantTimeCompare(k[:], testValue[:]) == 1 {
panic("Invalid public key")
}
}
return true
}
/* ---------------------------------------------------------------- *
* PRIMITIVES *
* ---------------------------------------------------------------- */
func incrementNonce(n uint32) uint32 {
return n + 1
}
func dh(private_key [32]byte, public_key [32]byte) [32]byte {
var ss [32]byte
curve25519.ScalarMult(&ss, &private_key, &public_key)
return ss
}
func generateKeypair() keypair {
var public_key [32]byte
var private_key [32]byte
_, _ = rand.Read(private_key[:])
curve25519.ScalarBaseMult(&public_key, &private_key)
if validatePublicKey(public_key[:]) {
return keypair{public_key, private_key}
}
return generateKeypair()
}
func generatePublicKey(private_key [32]byte) [32]byte {
var public_key [32]byte
curve25519.ScalarBaseMult(&public_key, &private_key)
return public_key
}
func encrypt(k [32]byte, n uint32, ad []byte, plaintext []byte) []byte {
var nonce [12]byte
var ciphertext []byte
enc, _ := chacha20poly1305.New(k[:])
binary.LittleEndian.PutUint32(nonce[4:], n)
ciphertext = enc.Seal(nil, nonce[:], plaintext, ad)
return ciphertext
}
func decrypt(k [32]byte, n uint32, ad []byte, ciphertext []byte) (bool, []byte, []byte) {
var nonce [12]byte
var plaintext []byte
enc, err := chacha20poly1305.New(k[:])
binary.LittleEndian.PutUint32(nonce[4:], n)
plaintext, err = enc.Open(nil, nonce[:], ciphertext, ad)
return (err == nil), ad, plaintext
}
func getHash(a []byte, b []byte) [32]byte {
return blake2s.Sum256(append(a, b...))
}
func hashProtocolName(protocolName []byte) [32]byte {
var h [32]byte
if len(protocolName) <= 32 {
copy(h[:], protocolName)
} else {
h = getHash(protocolName, []byte{})
}
return h
}
func blake2HkdfInterface() hash.Hash {
h, _ := blake2s.New256([]byte{})
return h
}
func getHkdf(ck [32]byte, ikm []byte) ([32]byte, [32]byte, [32]byte) {
var k1 [32]byte
var k2 [32]byte
var k3 [32]byte
output := hkdf.New(blake2HkdfInterface, ikm[:], ck[:], []byte{})
io.ReadFull(output, k1[:])
io.ReadFull(output, k2[:])
io.ReadFull(output, k3[:])
return k1, k2, k3
}
/* ---------------------------------------------------------------- *
* STATE MANAGEMENT *
* ---------------------------------------------------------------- */
/* CipherState */
func initializeKey(k [32]byte) cipherstate {
return cipherstate{k, minNonce}
}
func hasKey(cs *cipherstate) bool {
return !isEmptyKey(cs.k)
}
func setNonce(cs *cipherstate, newNonce uint32) *cipherstate {
cs.n = newNonce
return cs
}
func encryptWithAd(cs *cipherstate, ad []byte, plaintext []byte) (*cipherstate, []byte) {
e := encrypt(cs.k, cs.n, ad, plaintext)
cs = setNonce(cs, incrementNonce(cs.n))
return cs, e
}
func decryptWithAd(cs *cipherstate, ad []byte, ciphertext []byte) (*cipherstate, []byte, bool) {
valid, ad, plaintext := decrypt(cs.k, cs.n, ad, ciphertext)
cs = setNonce(cs, incrementNonce(cs.n))
return cs, plaintext, valid
}
func reKey(cs *cipherstate) *cipherstate {
e := encrypt(cs.k, math.MaxUint32, []byte{}, emptyKey[:])
copy(cs.k[:], e)
return cs
}
/* SymmetricState */
func initializeSymmetric(protocolName []byte) symmetricstate {
h := hashProtocolName(protocolName)
ck := h
cs := initializeKey(emptyKey)
return symmetricstate{cs, ck, h}
}
func mixKey(ss *symmetricstate, ikm [32]byte) *symmetricstate {
ck, tempK, _ := getHkdf(ss.ck, ikm[:])
ss.cs = initializeKey(tempK)
ss.ck = ck
return ss
}
func mixHash(ss *symmetricstate, data []byte) *symmetricstate {
ss.h = getHash(ss.h[:], data)
return ss
}
func mixKeyAndHash(ss *symmetricstate, ikm [32]byte) *symmetricstate {
var tempH [32]byte
var tempK [32]byte
ss.ck, tempH, tempK = getHkdf(ss.ck, ikm[:])
ss = mixHash(ss, tempH[:])
ss.cs = initializeKey(tempK)
return ss
}
func getHandshakeHash(ss *symmetricstate) [32]byte {
return ss.h
}
func encryptAndHash(ss *symmetricstate, plaintext []byte) (*symmetricstate, []byte) {
var ciphertext []byte
if hasKey(&ss.cs) {
_, ciphertext = encryptWithAd(&ss.cs, ss.h[:], plaintext)
} else {
ciphertext = plaintext
}
ss = mixHash(ss, ciphertext)
return ss, ciphertext
}
func decryptAndHash(ss *symmetricstate, ciphertext []byte) (*symmetricstate, []byte, bool) {
var plaintext []byte
var valid bool
if hasKey(&ss.cs) {
_, plaintext, valid = decryptWithAd(&ss.cs, ss.h[:], ciphertext)
} else {
plaintext, valid = ciphertext, true
}
ss = mixHash(ss, ciphertext)
return ss, plaintext, valid
}
func split(ss *symmetricstate) (cipherstate, cipherstate) {
tempK1, tempK2, _ := getHkdf(ss.ck, []byte{})
cs1 := initializeKey(tempK1)
cs2 := initializeKey(tempK2)
return cs1, cs2
}
/* HandshakeState */
func initializeInitiator(prologue []byte, s keypair, rs [32]byte, psk [32]byte) handshakestate {
var ss symmetricstate
var e keypair
var re [32]byte
name := []byte("Noise_IK_25519_ChaChaPoly_BLAKE2s")
ss = initializeSymmetric(name)
mixHash(&ss, prologue)
mixHash(&ss, rs[:])
return handshakestate{ss, s, e, rs, re, psk}
}
func initializeResponder(prologue []byte, s keypair, rs [32]byte, psk [32]byte) handshakestate {
var ss symmetricstate
var e keypair
var re [32]byte
name := []byte("Noise_IK_25519_ChaChaPoly_BLAKE2s")
ss = initializeSymmetric(name)
mixHash(&ss, prologue)
mixHash(&ss, s.public_key[:])
return handshakestate{ss, s, e, rs, re, psk}
}
func writeMessageA(hs *handshakestate, payload []byte) (*handshakestate, messagebuffer) {
ne, ns, ciphertext := emptyKey, []byte{}, []byte{}
hs.e = generateKeypair()
ne = hs.e.public_key
mixHash(&hs.ss, ne[:])
/* No PSK, so skipping mixKey */
mixKey(&hs.ss, dh(hs.e.private_key, hs.rs))
spk := make([]byte, len(hs.s.public_key))
copy(spk[:], hs.s.public_key[:])
_, ns = encryptAndHash(&hs.ss, spk)
mixKey(&hs.ss, dh(hs.s.private_key, hs.rs))
_, ciphertext = encryptAndHash(&hs.ss, payload)
messageBuffer := messagebuffer{ne, ns, ciphertext}
return hs, messageBuffer
}
func writeMessageB(hs *handshakestate, payload []byte) ([32]byte, messagebuffer, cipherstate, cipherstate) {
ne, ns, ciphertext := emptyKey, []byte{}, []byte{}
hs.e = generateKeypair()
ne = hs.e.public_key
mixHash(&hs.ss, ne[:])
/* No PSK, so skipping mixKey */
mixKey(&hs.ss, dh(hs.e.private_key, hs.re))
mixKey(&hs.ss, dh(hs.e.private_key, hs.rs))
_, ciphertext = encryptAndHash(&hs.ss, payload)
messageBuffer := messagebuffer{ne, ns, ciphertext}
cs1, cs2 := split(&hs.ss)
return hs.ss.h, messageBuffer, cs1, cs2
}
func writeMessageRegular(cs *cipherstate, payload []byte) (*cipherstate, messagebuffer) {
ne, ns, ciphertext := emptyKey, []byte{}, []byte{}
cs, ciphertext = encryptWithAd(cs, []byte{}, payload)
messageBuffer := messagebuffer{ne, ns, ciphertext}
return cs, messageBuffer
}
func readMessageA(hs *handshakestate, message *messagebuffer) (*handshakestate, []byte, bool) {
valid1 := true
if validatePublicKey(message.ne[:]) {
hs.re = message.ne
}
mixHash(&hs.ss, hs.re[:])
/* No PSK, so skipping mixKey */
mixKey(&hs.ss, dh(hs.s.private_key, hs.re))
_, ns, valid1 := decryptAndHash(&hs.ss, message.ns)
if valid1 && len(ns) == 32 && validatePublicKey(message.ns[:]) {
copy(hs.rs[:], ns)
}
mixKey(&hs.ss, dh(hs.s.private_key, hs.rs))
_, plaintext, valid2 := decryptAndHash(&hs.ss, message.ciphertext)
return hs, plaintext, (valid1 && valid2)
}
func readMessageB(hs *handshakestate, message *messagebuffer) ([32]byte, []byte, bool, cipherstate, cipherstate) {
valid1 := true
if validatePublicKey(message.ne[:]) {
hs.re = message.ne
}
mixHash(&hs.ss, hs.re[:])
/* No PSK, so skipping mixKey */
mixKey(&hs.ss, dh(hs.e.private_key, hs.re))
mixKey(&hs.ss, dh(hs.s.private_key, hs.re))
_, plaintext, valid2 := decryptAndHash(&hs.ss, message.ciphertext)
cs1, cs2 := split(&hs.ss)
return hs.ss.h, plaintext, (valid1 && valid2), cs1, cs2
}
func readMessageRegular(cs *cipherstate, message *messagebuffer) (*cipherstate, []byte, bool) {
/* No encrypted keys */
_, plaintext, valid2 := decryptWithAd(cs, []byte{}, message.ciphertext)
return cs, plaintext, valid2
}
/* ---------------------------------------------------------------- *
* PROCESSES *
* ---------------------------------------------------------------- */
func InitSession(initiator bool, prologue []byte, s keypair, rs [32]byte) noisesession {
var session noisesession
psk := emptyKey
if initiator {
session.hs = initializeInitiator(prologue, s, rs, psk)
} else {
session.hs = initializeResponder(prologue, s, rs, psk)
}
session.i = initiator
session.mc = 0
return session
}
func SendMessage(session *noisesession, message []byte) (*noisesession, messagebuffer) {
var messageBuffer messagebuffer
if session.mc == 0 {
_, messageBuffer = writeMessageA(&session.hs, message)
}
if session.mc == 1 {
session.h, messageBuffer, session.cs1, session.cs2 = writeMessageB(&session.hs, message)
session.hs = handshakestate{}
}
if session.mc > 1 {
if session.i {
_, messageBuffer = writeMessageRegular(&session.cs1, message)
} else {
_, messageBuffer = writeMessageRegular(&session.cs2, message)
}
}
session.mc = session.mc + 1
return session, messageBuffer
}
func RecvMessage(session *noisesession, message *messagebuffer) (*noisesession, []byte, bool) {
var plaintext []byte
var valid bool
if session.mc == 0 {
_, plaintext, valid = readMessageA(&session.hs, message)
}
if session.mc == 1 {
session.h, plaintext, valid, session.cs1, session.cs2 = readMessageB(&session.hs, message)
session.hs = handshakestate{}
}
if session.mc > 1 {
if session.i {
_, plaintext, valid = readMessageRegular(&session.cs2, message)
} else {
_, plaintext, valid = readMessageRegular(&session.cs1, message)
}
}
session.mc = session.mc + 1
return session, plaintext, valid
}
func main() {}