control/noise: adjust implementation to match revised spec.

Signed-off-by: David Anderson <danderson@tailscale.com>
This commit is contained in:
David Anderson
2021-07-29 11:59:40 -07:00
committed by Dave Anderson
parent 89a68a4c22
commit 0b392dbaf7
6 changed files with 300 additions and 90 deletions

View File

@ -24,9 +24,9 @@ import (
)
const (
maxPlaintextSize = 4096
maxCiphertextSize = maxPlaintextSize + poly1305.TagSize
maxPacketSize = maxCiphertextSize + 2 // ciphertext + length header
maxMessageSize = 4096
maxCiphertextSize = maxMessageSize - headerLen
maxPlaintextSize = maxCiphertextSize - poly1305.TagSize
)
// A Conn is a secured Noise connection. It implements the net.Conn
@ -35,6 +35,7 @@ const (
// fail.
type Conn struct {
conn net.Conn
version int
peer key.Public
handshakeHash [blake2s.Size]byte
rx rxState
@ -46,7 +47,7 @@ type rxState struct {
sync.Mutex
cipher cipher.AEAD
nonce [chp.NonceSize]byte
buf [maxPacketSize]byte
buf [maxMessageSize]byte
n int // number of valid bytes in buf
next int // offset of next undecrypted packet
plaintext []byte // slice into buf of decrypted bytes
@ -57,10 +58,14 @@ type txState struct {
sync.Mutex
cipher cipher.AEAD
nonce [chp.NonceSize]byte
buf [maxPacketSize]byte
buf [maxMessageSize]byte
err error // records the first partial write error for all future calls
}
func (c *Conn) ProtocolVersion() int {
return c.version
}
// HandshakeHash returns the Noise handshake hash for the connection,
// which can be used to bind other messages to this connection
// (i.e. to ensure that the message wasn't replayed from a different
@ -84,7 +89,7 @@ func validNonce(nonce []byte) bool {
// bytes. Returns a slice of the available bytes in rxBuf, or an
// error if fewer than total bytes are available.
func (c *Conn) readNLocked(total int) ([]byte, error) {
if total > maxPacketSize {
if total > maxMessageSize {
return nil, errReadTooBig{total}
}
for {
@ -100,10 +105,20 @@ func (c *Conn) readNLocked(total int) ([]byte, error) {
}
}
// decryptLocked decrypts ciphertext in-place and sets c.rx.plaintext
// to the decrypted bytes. Returns an error if the cipher is exhausted
// (i.e. can no longer be used safely) or decryption fails.
func (c *Conn) decryptLocked(ciphertext []byte) (err error) {
// decryptLocked decrypts message (which is header+ciphertext)
// in-place and sets c.rx.plaintext to the decrypted bytes. Returns an
// error if the cipher is exhausted (i.e. can no longer be used
// safely) or decryption fails.
func (c *Conn) decryptLocked(msg []byte) (err error) {
if hdrVersion(msg) != c.version {
return fmt.Errorf("received message with unexpected protocol version %d, want %d", hdrVersion(msg), c.version)
}
if hdrType(msg) != msgTypeRecord {
return fmt.Errorf("received message with unexpected type %d, want %d", hdrType(msg), msgTypeRecord)
}
// length was already handled in caller to size msg.
ciphertext := msg[headerLen:]
if !validNonce(c.rx.nonce[:]) {
return errCipherExhausted{}
}
@ -124,8 +139,8 @@ func (c *Conn) decryptLocked(ciphertext []byte) (err error) {
}
// encryptLocked encrypts plaintext into c.tx.buf (including the
// 2-byte length header) and returns a slice of the ciphertext, or an
// error if the cipher is exhausted (i.e. can no longer be used safely).
// packet header) and returns a slice of the ciphertext, or an error
// if the cipher is exhausted (i.e. can no longer be used safely).
func (c *Conn) encryptLocked(plaintext []byte) ([]byte, error) {
if !validNonce(c.tx.nonce[:]) {
// Received 2^64-1 messages on this cipher state. Connection
@ -133,8 +148,8 @@ func (c *Conn) encryptLocked(plaintext []byte) ([]byte, error) {
return nil, errCipherExhausted{}
}
binary.BigEndian.PutUint16(c.tx.buf[:2], uint16(len(plaintext)+poly1305.TagSize))
ret := c.tx.cipher.Seal(c.tx.buf[:2], c.tx.nonce[:], plaintext, nil)
setHeader(c.tx.buf[:5], protocolVersion, msgTypeRecord, len(plaintext)+poly1305.TagSize)
ret := c.tx.cipher.Seal(c.tx.buf[:5], c.tx.nonce[:], plaintext, nil)
// Safe to increment the nonce here, because we checked for nonce
// wraparound above.
@ -143,18 +158,18 @@ func (c *Conn) encryptLocked(plaintext []byte) ([]byte, error) {
return ret, nil
}
// wholeCiphertextLocked returns a slice of one whole Noise frame from
// c.rx.buf, if one whole ciphertext is available, and advances the
// read state to the next Noise frame in the buffer. Returns nil
// without advancing read state if there's not one whole ciphertext in
// c.rx.buf.
func (c *Conn) wholeCiphertextLocked() []byte {
// wholeMessageLocked returns a slice of one whole Noise transport
// message from c.rx.buf, if one whole message is available, and
// advances the read state to the next Noise message in the
// buffer. Returns nil without advancing read state if there isn't one
// whole message in c.rx.buf.
func (c *Conn) wholeMessageLocked() []byte {
available := c.rx.n - c.rx.next
if available < 2 {
if available < headerLen {
return nil
}
bs := c.rx.buf[c.rx.next:c.rx.n]
totalSize := int(binary.BigEndian.Uint16(bs[:2])) + 2
totalSize := hdrLen(bs) + headerLen
if len(bs) < totalSize {
return nil
}
@ -162,16 +177,16 @@ func (c *Conn) wholeCiphertextLocked() []byte {
return bs[:totalSize]
}
// decryptOneLocked decrypts one Noise frame, reading from c.conn as needed,
// and sets c.rx.plaintext to point to the decrypted
// decryptOneLocked decrypts one Noise transport message, reading from
// c.conn as needed, and sets c.rx.plaintext to point to the decrypted
// bytes. c.rx.plaintext is only valid if err == nil.
func (c *Conn) decryptOneLocked() error {
c.rx.plaintext = nil
// Fast path: do we have one whole ciphertext frame buffered
// already?
if bs := c.wholeCiphertextLocked(); bs != nil {
return c.decryptLocked(bs[2:])
if bs := c.wholeMessageLocked(); bs != nil {
return c.decryptLocked(bs)
}
if c.rx.next != 0 {
@ -183,18 +198,20 @@ func (c *Conn) decryptOneLocked() error {
c.rx.next = 0
}
bs, err := c.readNLocked(2)
bs, err := c.readNLocked(headerLen)
if err != nil {
return err
}
totalLen := int(binary.BigEndian.Uint16(bs[:2])) + 2
bs, err = c.readNLocked(totalLen)
// The rest of the header (besides the length field) gets verified
// in decryptLocked, not here.
messageLen := headerLen + hdrLen(bs)
bs, err = c.readNLocked(messageLen)
if err != nil {
return err
}
bs = bs[:messageLen]
c.rx.next = totalLen
bs = bs[2:totalLen]
c.rx.next = len(bs)
return c.decryptLocked(bs)
}