derp: change the protocol framing to always include a length

Addresses one of crawshaw's TODOs.

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick
2020-02-20 12:27:12 -08:00
committed by Brad Fitzpatrick
parent c47f907a27
commit f029c4c82d
3 changed files with 197 additions and 155 deletions

View File

@ -8,6 +8,7 @@ import (
"bufio"
crand "crypto/rand"
"encoding/json"
"errors"
"fmt"
"io"
"net"
@ -54,34 +55,40 @@ func NewClient(privateKey key.Private, nc net.Conn, brw *bufio.ReadWriter, logf
}
func (c *Client) recvServerKey() error {
gotMagic, err := readUint32(c.br, 0xffffffff)
var buf [40]byte
t, flen, err := readFrame(c.br, 1<<10, buf[:])
if err == io.ErrShortBuffer {
// For future-proofing, allow server to send more in its greeting.
err = nil
}
if err != nil {
return err
}
if gotMagic != magic {
return fmt.Errorf("bad magic %x, want %x", gotMagic, magic)
}
if err := readType(c.br, typeServerKey); err != nil {
return err
}
if _, err := io.ReadFull(c.br, c.serverKey[:]); err != nil {
return err
if flen < uint32(len(buf)) || t != frameServerKey || string(buf[:len(magic)]) != magic {
return errors.New("invalid server greeting")
}
copy(c.serverKey[:], buf[len(magic):])
return nil
}
func (c *Client) recvServerInfo() (*serverInfo, error) {
if err := readType(c.br, typeServerInfo); err != nil {
fl, err := readFrameTypeHeader(c.br, frameServerInfo)
if err != nil {
return nil, err
}
var nonce [24]byte
const maxLength = nonceLen + maxInfoLen
if fl < nonceLen {
return nil, fmt.Errorf("short serverInfo frame")
}
if fl > maxLength {
return nil, fmt.Errorf("long serverInfo frame")
}
// TODO: add a read-nonce-and-box helper
var nonce [nonceLen]byte
if _, err := io.ReadFull(c.br, nonce[:]); err != nil {
return nil, fmt.Errorf("nonce: %v", err)
}
msgLen, err := readUint32(c.br, oneMB)
if err != nil {
return nil, fmt.Errorf("msglen: %v", err)
}
msgLen := fl - nonceLen
msgbox := make([]byte, msgLen)
if _, err := io.ReadFull(c.br, msgbox); err != nil {
return nil, fmt.Errorf("msgbox: %v", err)
@ -98,49 +105,43 @@ func (c *Client) recvServerInfo() (*serverInfo, error) {
}
func (c *Client) sendClientKey() error {
var nonce [24]byte
var nonce [nonceLen]byte
if _, err := crand.Read(nonce[:]); err != nil {
return err
}
msg := []byte("{}") // no clientInfo for now
msgbox := box.Seal(nil, msg, &nonce, c.serverKey.B32(), c.privateKey.B32())
if _, err := c.bw.Write(c.publicKey[:]); err != nil {
return err
}
if _, err := c.bw.Write(nonce[:]); err != nil {
return err
}
if err := putUint32(c.bw, uint32(len(msgbox))); err != nil {
return err
}
if _, err := c.bw.Write(msgbox); err != nil {
return err
}
return c.bw.Flush()
buf := make([]byte, 0, nonceLen+keyLen+len(msgbox))
buf = append(buf, c.publicKey[:]...)
buf = append(buf, nonce[:]...)
buf = append(buf, msgbox...)
return writeFrame(c.bw, frameClientInfo, buf)
}
func (c *Client) Send(dstKey key.Public, msg []byte) (err error) {
// Send sends a packet to the Tailscale node identified by dstKey.
//
// It is an error if the packet is larger than 64KB.
func (c *Client) Send(dstKey key.Public, pkt []byte) error { return c.send(dstKey, pkt) }
func (c *Client) send(dstKey key.Public, pkt []byte) (ret error) {
defer func() {
if err != nil {
err = fmt.Errorf("derp.Send: %v", err)
if ret != nil {
ret = fmt.Errorf("derp.Send: %v", ret)
}
}()
if err := typeSendPacket.Write(c.bw); err != nil {
if len(pkt) > 64<<10 {
return fmt.Errorf("packet too big: %d", len(pkt))
}
if err := writeFrameHeader(c.bw, frameSendPacket, uint32(len(dstKey)+len(pkt))); err != nil {
return err
}
if _, err := c.bw.Write(dstKey[:]); err != nil {
return err
}
msgLen := uint32(len(msg))
if int(msgLen) != len(msg) {
return fmt.Errorf("packet too big: %d", len(msg))
}
if err := putUint32(c.bw, msgLen); err != nil {
return err
}
if _, err := c.bw.Write(msg); err != nil {
if _, err := c.bw.Write(pkt); err != nil {
return err
}
return c.bw.Flush()
@ -160,34 +161,21 @@ func (c *Client) Recv(b []byte) (n int, err error) {
}
}()
loop:
for {
c.nc.SetReadDeadline(time.Now().Add(120 * time.Second))
typ, err := c.br.ReadByte()
t, n, err := readFrame(c.br, 1<<20, b)
if err != nil {
return 0, err
}
switch frameType(typ) {
case typeKeepAlive:
continue
case typeRecvPacket:
break loop
switch t {
default:
return 0, fmt.Errorf("derp.Recv: unknown packet type 0x%X", typ)
continue
case frameKeepAlive:
// TODO: eventually we'll have server->client pings that
// require ack pongs.
continue
case frameRecvPacket:
return int(n), nil
}
}
packetLen, err := readUint32(c.br, oneMB)
if err != nil {
return 0, err
}
if int(packetLen) > len(b) {
// TODO(crawshaw): discard the packet
return 0, io.ErrShortBuffer
}
b = b[:packetLen]
if _, err := io.ReadFull(c.br, b); err != nil {
return 0, err
}
return int(packetLen), nil
}