tailcfg: add DiscoKey, unify some code, add some tests

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick
2020-06-18 19:32:55 -07:00
committed by Brad Fitzpatrick
parent d9054da86a
commit 88c305c8af
2 changed files with 83 additions and 41 deletions

View File

@ -13,7 +13,9 @@ import (
"time"
"github.com/tailscale/wireguard-go/wgcfg"
"go4.org/mem"
"golang.org/x/oauth2"
"tailscale.com/types/key"
"tailscale.com/types/opt"
"tailscale.com/types/structs"
)
@ -38,6 +40,10 @@ type MachineKey [32]byte
// NodeKey is the curve25519 public key for a node.
type NodeKey [32]byte
// DiscoKey is the curve25519 public key for path discovery key.
// It's never written to disk or reused between network start-ups.
type DiscoKey [32]byte
type Group struct {
ID GroupID
Name string
@ -127,6 +133,7 @@ type Node struct {
Key NodeKey
KeyExpiry time.Time
Machine MachineKey
DiscoKey DiscoKey
Addresses []wgcfg.CIDR // IP addresses of this Node directly
AllowedIPs []wgcfg.CIDR // range of IP addresses to route to this node
Endpoints []string `json:",omitempty"` // IP+port (public via STUN, and local LANs)
@ -519,59 +526,43 @@ type Debug struct {
LogHeapURL string `json:",omitempty"`
}
func (k MachineKey) String() string { return fmt.Sprintf("mkey:%x", k[:]) }
func (k MachineKey) String() string { return fmt.Sprintf("mkey:%x", k[:]) }
func (k MachineKey) MarshalText() ([]byte, error) { return keyMarshalText("mkey:", k), nil }
func (k *MachineKey) UnmarshalText(text []byte) error { return keyUnmarshalText(k[:], "mkey:", text) }
func (k MachineKey) MarshalText() ([]byte, error) {
buf := new(bytes.Buffer)
fmt.Fprintf(buf, "mkey:%x", k[:])
return buf.Bytes(), nil
func keyMarshalText(prefix string, k [32]byte) []byte {
buf := bytes.NewBuffer(make([]byte, 0, len(prefix)+64))
fmt.Fprintf(buf, "%s%x", prefix, k[:])
return buf.Bytes()
}
func (k *MachineKey) UnmarshalText(text []byte) error {
s := string(text)
if !strings.HasPrefix(s, "mkey:") {
return errors.New(`MachineKey.UnmarshalText: missing prefix`)
func keyUnmarshalText(dst []byte, prefix string, text []byte) error {
if len(text) < len(prefix) || string(text[:len(prefix)]) != prefix {
return fmt.Errorf("UnmarshalText: missing %q prefix", prefix)
}
s = strings.TrimPrefix(s, `mkey:`)
key, err := wgcfg.ParseHexKey(s)
pub, err := key.NewPublicFromHexMem(mem.B(text[len(prefix):]))
if err != nil {
return fmt.Errorf("MachineKey.UnmarhsalText: %v", err)
return fmt.Errorf("UnmarshalText: after %q: %v", prefix, err)
}
copy(k[:], key[:])
copy(dst[:], pub[:])
return nil
}
func (k NodeKey) String() string { return fmt.Sprintf("nodekey:%x", k[:]) }
func (k NodeKey) ShortString() string { return (key.Public(k)).ShortString() }
func (k NodeKey) ShortString() string {
pk := wgcfg.Key(k)
return pk.ShortString()
}
func (k NodeKey) String() string { return fmt.Sprintf("nodekey:%x", k[:]) }
func (k NodeKey) MarshalText() ([]byte, error) { return keyMarshalText("nodekey:", k), nil }
func (k *NodeKey) UnmarshalText(text []byte) error { return keyUnmarshalText(k[:], "nodekey:", text) }
func (k NodeKey) MarshalText() ([]byte, error) {
buf := new(bytes.Buffer)
fmt.Fprintf(buf, "nodekey:%x", k[:])
return buf.Bytes(), nil
}
// IsZero reports whether k is the zero value.
func (k NodeKey) IsZero() bool { return k == NodeKey{} }
func (k *NodeKey) UnmarshalText(text []byte) error {
s := string(text)
if !strings.HasPrefix(s, "nodekey:") {
return errors.New(`Nodekey.UnmarshalText: missing prefix`)
}
s = strings.TrimPrefix(s, "nodekey:")
key, err := wgcfg.ParseHexKey(s)
if err != nil {
return fmt.Errorf("tailcfg.Ukey.UnmarhsalText: %v", err)
}
copy(k[:], key[:])
return nil
}
func (k DiscoKey) String() string { return fmt.Sprintf("discokey:%x", k[:]) }
func (k DiscoKey) MarshalText() ([]byte, error) { return keyMarshalText("discokey:", k), nil }
func (k *DiscoKey) UnmarshalText(text []byte) error { return keyUnmarshalText(k[:], "discokey:", text) }
// IsZero reports whether k is the NodeKey zero value.
func (k NodeKey) IsZero() bool {
return k == NodeKey{}
}
// IsZero reports whether k is the zero value.
func (k DiscoKey) IsZero() bool { return k == DiscoKey{} }
func (id ID) String() string { return fmt.Sprintf("id:%x", int64(id)) }
func (id UserID) String() string { return fmt.Sprintf("userid:%x", int64(id)) }
@ -593,6 +584,7 @@ func (n *Node) Equal(n2 *Node) bool {
n.Key == n2.Key &&
n.KeyExpiry.Equal(n2.KeyExpiry) &&
n.Machine == n2.Machine &&
n.DiscoKey == n2.DiscoKey &&
reflect.DeepEqual(n.Addresses, n2.Addresses) &&
reflect.DeepEqual(n.AllowedIPs, n2.AllowedIPs) &&
reflect.DeepEqual(n.Endpoints, n2.Endpoints) &&