derp: change the protocol framing to always include a length
Addresses one of crawshaw's TODOs. Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:

committed by
Brad Fitzpatrick

parent
c47f907a27
commit
f029c4c82d
@ -8,6 +8,7 @@ import (
|
||||
"bufio"
|
||||
crand "crypto/rand"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@ -54,34 +55,40 @@ func NewClient(privateKey key.Private, nc net.Conn, brw *bufio.ReadWriter, logf
|
||||
}
|
||||
|
||||
func (c *Client) recvServerKey() error {
|
||||
gotMagic, err := readUint32(c.br, 0xffffffff)
|
||||
var buf [40]byte
|
||||
t, flen, err := readFrame(c.br, 1<<10, buf[:])
|
||||
if err == io.ErrShortBuffer {
|
||||
// For future-proofing, allow server to send more in its greeting.
|
||||
err = nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if gotMagic != magic {
|
||||
return fmt.Errorf("bad magic %x, want %x", gotMagic, magic)
|
||||
}
|
||||
if err := readType(c.br, typeServerKey); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := io.ReadFull(c.br, c.serverKey[:]); err != nil {
|
||||
return err
|
||||
if flen < uint32(len(buf)) || t != frameServerKey || string(buf[:len(magic)]) != magic {
|
||||
return errors.New("invalid server greeting")
|
||||
}
|
||||
copy(c.serverKey[:], buf[len(magic):])
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) recvServerInfo() (*serverInfo, error) {
|
||||
if err := readType(c.br, typeServerInfo); err != nil {
|
||||
fl, err := readFrameTypeHeader(c.br, frameServerInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var nonce [24]byte
|
||||
const maxLength = nonceLen + maxInfoLen
|
||||
if fl < nonceLen {
|
||||
return nil, fmt.Errorf("short serverInfo frame")
|
||||
}
|
||||
if fl > maxLength {
|
||||
return nil, fmt.Errorf("long serverInfo frame")
|
||||
}
|
||||
// TODO: add a read-nonce-and-box helper
|
||||
var nonce [nonceLen]byte
|
||||
if _, err := io.ReadFull(c.br, nonce[:]); err != nil {
|
||||
return nil, fmt.Errorf("nonce: %v", err)
|
||||
}
|
||||
msgLen, err := readUint32(c.br, oneMB)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("msglen: %v", err)
|
||||
}
|
||||
msgLen := fl - nonceLen
|
||||
msgbox := make([]byte, msgLen)
|
||||
if _, err := io.ReadFull(c.br, msgbox); err != nil {
|
||||
return nil, fmt.Errorf("msgbox: %v", err)
|
||||
@ -98,49 +105,43 @@ func (c *Client) recvServerInfo() (*serverInfo, error) {
|
||||
}
|
||||
|
||||
func (c *Client) sendClientKey() error {
|
||||
var nonce [24]byte
|
||||
var nonce [nonceLen]byte
|
||||
if _, err := crand.Read(nonce[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
msg := []byte("{}") // no clientInfo for now
|
||||
msgbox := box.Seal(nil, msg, &nonce, c.serverKey.B32(), c.privateKey.B32())
|
||||
|
||||
if _, err := c.bw.Write(c.publicKey[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := c.bw.Write(nonce[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := putUint32(c.bw, uint32(len(msgbox))); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := c.bw.Write(msgbox); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.bw.Flush()
|
||||
buf := make([]byte, 0, nonceLen+keyLen+len(msgbox))
|
||||
buf = append(buf, c.publicKey[:]...)
|
||||
buf = append(buf, nonce[:]...)
|
||||
buf = append(buf, msgbox...)
|
||||
return writeFrame(c.bw, frameClientInfo, buf)
|
||||
}
|
||||
|
||||
func (c *Client) Send(dstKey key.Public, msg []byte) (err error) {
|
||||
// Send sends a packet to the Tailscale node identified by dstKey.
|
||||
//
|
||||
// It is an error if the packet is larger than 64KB.
|
||||
func (c *Client) Send(dstKey key.Public, pkt []byte) error { return c.send(dstKey, pkt) }
|
||||
|
||||
func (c *Client) send(dstKey key.Public, pkt []byte) (ret error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = fmt.Errorf("derp.Send: %v", err)
|
||||
if ret != nil {
|
||||
ret = fmt.Errorf("derp.Send: %v", ret)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := typeSendPacket.Write(c.bw); err != nil {
|
||||
if len(pkt) > 64<<10 {
|
||||
return fmt.Errorf("packet too big: %d", len(pkt))
|
||||
}
|
||||
|
||||
if err := writeFrameHeader(c.bw, frameSendPacket, uint32(len(dstKey)+len(pkt))); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := c.bw.Write(dstKey[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
msgLen := uint32(len(msg))
|
||||
if int(msgLen) != len(msg) {
|
||||
return fmt.Errorf("packet too big: %d", len(msg))
|
||||
}
|
||||
if err := putUint32(c.bw, msgLen); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := c.bw.Write(msg); err != nil {
|
||||
if _, err := c.bw.Write(pkt); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.bw.Flush()
|
||||
@ -160,34 +161,21 @@ func (c *Client) Recv(b []byte) (n int, err error) {
|
||||
}
|
||||
}()
|
||||
|
||||
loop:
|
||||
for {
|
||||
c.nc.SetReadDeadline(time.Now().Add(120 * time.Second))
|
||||
typ, err := c.br.ReadByte()
|
||||
t, n, err := readFrame(c.br, 1<<20, b)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
switch frameType(typ) {
|
||||
case typeKeepAlive:
|
||||
continue
|
||||
case typeRecvPacket:
|
||||
break loop
|
||||
switch t {
|
||||
default:
|
||||
return 0, fmt.Errorf("derp.Recv: unknown packet type 0x%X", typ)
|
||||
continue
|
||||
case frameKeepAlive:
|
||||
// TODO: eventually we'll have server->client pings that
|
||||
// require ack pongs.
|
||||
continue
|
||||
case frameRecvPacket:
|
||||
return int(n), nil
|
||||
}
|
||||
}
|
||||
|
||||
packetLen, err := readUint32(c.br, oneMB)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if int(packetLen) > len(b) {
|
||||
// TODO(crawshaw): discard the packet
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
b = b[:packetLen]
|
||||
if _, err := io.ReadFull(c.br, b); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(packetLen), nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user