diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go index 207136659..d22dfe443 100644 --- a/ssh/tailssh/tailssh_test.go +++ b/ssh/tailssh/tailssh_test.go @@ -41,6 +41,7 @@ "tailscale.com/sessionrecording" "tailscale.com/tailcfg" "tailscale.com/tempfork/gliderlabs/ssh" + sshtest "tailscale.com/tempfork/sshtest/ssh" "tailscale.com/tsd" "tailscale.com/tstest" "tailscale.com/types/key" @@ -55,6 +56,8 @@ "tailscale.com/wgengine" ) +type _ = sshtest.Client // TODO(bradfitz,percy): sshtest; delete this line + func TestMatchRule(t *testing.T) { someAction := new(tailcfg.SSHAction) tests := []struct { diff --git a/tempfork/sshtest/README.md b/tempfork/sshtest/README.md new file mode 100644 index 000000000..30c74f525 --- /dev/null +++ b/tempfork/sshtest/README.md @@ -0,0 +1,9 @@ +# sshtest + +This contains packages that are forked & locally hacked up for use +in tests. + +Notably, `golang.org/x/crypto/ssh` was copied to +`tailscale.com/tempfork/sshtest/ssh` to permit adding behaviors specific +to testing (for testing Tailscale SSH) that aren't necessarily desirable +to have upstream. diff --git a/tempfork/sshtest/ssh/benchmark_test.go b/tempfork/sshtest/ssh/benchmark_test.go new file mode 100644 index 000000000..b356330b4 --- /dev/null +++ b/tempfork/sshtest/ssh/benchmark_test.go @@ -0,0 +1,127 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "errors" + "fmt" + "io" + "net" + "testing" +) + +type server struct { + *ServerConn + chans <-chan NewChannel +} + +func newServer(c net.Conn, conf *ServerConfig) (*server, error) { + sconn, chans, reqs, err := NewServerConn(c, conf) + if err != nil { + return nil, err + } + go DiscardRequests(reqs) + return &server{sconn, chans}, nil +} + +func (s *server) Accept() (NewChannel, error) { + n, ok := <-s.chans + if !ok { + return nil, io.EOF + } + return n, nil +} + +func sshPipe() (Conn, *server, error) { + c1, c2, err := netPipe() + if err != nil { + return nil, nil, err + } + + clientConf := ClientConfig{ + User: "user", + HostKeyCallback: InsecureIgnoreHostKey(), + } + serverConf := ServerConfig{ + NoClientAuth: true, + } + serverConf.AddHostKey(testSigners["ecdsa"]) + done := make(chan *server, 1) + go func() { + server, err := newServer(c2, &serverConf) + if err != nil { + done <- nil + } + done <- server + }() + + client, _, reqs, err := NewClientConn(c1, "", &clientConf) + if err != nil { + return nil, nil, err + } + + server := <-done + if server == nil { + return nil, nil, errors.New("server handshake failed.") + } + go DiscardRequests(reqs) + + return client, server, nil +} + +func BenchmarkEndToEnd(b *testing.B) { + b.StopTimer() + + client, server, err := sshPipe() + if err != nil { + b.Fatalf("sshPipe: %v", err) + } + + defer client.Close() + defer server.Close() + + size := (1 << 20) + input := make([]byte, size) + output := make([]byte, size) + b.SetBytes(int64(size)) + done := make(chan int, 1) + + go func() { + newCh, err := server.Accept() + if err != nil { + panic(fmt.Sprintf("Client: %v", err)) + } + ch, incoming, err := newCh.Accept() + if err != nil { + panic(fmt.Sprintf("Accept: %v", err)) + } + go DiscardRequests(incoming) + for i := 0; i < b.N; i++ { + if _, err := io.ReadFull(ch, output); err != nil { + panic(fmt.Sprintf("ReadFull: %v", err)) + } + } + ch.Close() + done <- 1 + }() + + ch, in, err := client.OpenChannel("speed", nil) + if err != nil { + b.Fatalf("OpenChannel: %v", err) + } + go DiscardRequests(in) + + b.ResetTimer() + b.StartTimer() + for i := 0; i < b.N; i++ { + if _, err := ch.Write(input); err != nil { + b.Fatalf("WriteFull: %v", err) + } + } + ch.Close() + b.StopTimer() + + <-done +} diff --git a/tempfork/sshtest/ssh/buffer.go b/tempfork/sshtest/ssh/buffer.go new file mode 100644 index 000000000..1ab07d078 --- /dev/null +++ b/tempfork/sshtest/ssh/buffer.go @@ -0,0 +1,97 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "io" + "sync" +) + +// buffer provides a linked list buffer for data exchange +// between producer and consumer. Theoretically the buffer is +// of unlimited capacity as it does no allocation of its own. +type buffer struct { + // protects concurrent access to head, tail and closed + *sync.Cond + + head *element // the buffer that will be read first + tail *element // the buffer that will be read last + + closed bool +} + +// An element represents a single link in a linked list. +type element struct { + buf []byte + next *element +} + +// newBuffer returns an empty buffer that is not closed. +func newBuffer() *buffer { + e := new(element) + b := &buffer{ + Cond: newCond(), + head: e, + tail: e, + } + return b +} + +// write makes buf available for Read to receive. +// buf must not be modified after the call to write. +func (b *buffer) write(buf []byte) { + b.Cond.L.Lock() + e := &element{buf: buf} + b.tail.next = e + b.tail = e + b.Cond.Signal() + b.Cond.L.Unlock() +} + +// eof closes the buffer. Reads from the buffer once all +// the data has been consumed will receive io.EOF. +func (b *buffer) eof() { + b.Cond.L.Lock() + b.closed = true + b.Cond.Signal() + b.Cond.L.Unlock() +} + +// Read reads data from the internal buffer in buf. Reads will block +// if no data is available, or until the buffer is closed. +func (b *buffer) Read(buf []byte) (n int, err error) { + b.Cond.L.Lock() + defer b.Cond.L.Unlock() + + for len(buf) > 0 { + // if there is data in b.head, copy it + if len(b.head.buf) > 0 { + r := copy(buf, b.head.buf) + buf, b.head.buf = buf[r:], b.head.buf[r:] + n += r + continue + } + // if there is a next buffer, make it the head + if len(b.head.buf) == 0 && b.head != b.tail { + b.head = b.head.next + continue + } + + // if at least one byte has been copied, return + if n > 0 { + break + } + + // if nothing was read, and there is nothing outstanding + // check to see if the buffer is closed. + if b.closed { + err = io.EOF + break + } + // out of buffers, wait for producer + b.Cond.Wait() + } + return +} diff --git a/tempfork/sshtest/ssh/buffer_test.go b/tempfork/sshtest/ssh/buffer_test.go new file mode 100644 index 000000000..d5781cb3d --- /dev/null +++ b/tempfork/sshtest/ssh/buffer_test.go @@ -0,0 +1,87 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "io" + "testing" +) + +var alphabet = []byte("abcdefghijklmnopqrstuvwxyz") + +func TestBufferReadwrite(t *testing.T) { + b := newBuffer() + b.write(alphabet[:10]) + r, _ := b.Read(make([]byte, 10)) + if r != 10 { + t.Fatalf("Expected written == read == 10, written: 10, read %d", r) + } + + b = newBuffer() + b.write(alphabet[:5]) + r, _ = b.Read(make([]byte, 10)) + if r != 5 { + t.Fatalf("Expected written == read == 5, written: 5, read %d", r) + } + + b = newBuffer() + b.write(alphabet[:10]) + r, _ = b.Read(make([]byte, 5)) + if r != 5 { + t.Fatalf("Expected written == 10, read == 5, written: 10, read %d", r) + } + + b = newBuffer() + b.write(alphabet[:5]) + b.write(alphabet[5:15]) + r, _ = b.Read(make([]byte, 10)) + r2, _ := b.Read(make([]byte, 10)) + if r != 10 || r2 != 5 || 15 != r+r2 { + t.Fatal("Expected written == read == 15") + } +} + +func TestBufferClose(t *testing.T) { + b := newBuffer() + b.write(alphabet[:10]) + b.eof() + _, err := b.Read(make([]byte, 5)) + if err != nil { + t.Fatal("expected read of 5 to not return EOF") + } + b = newBuffer() + b.write(alphabet[:10]) + b.eof() + r, err := b.Read(make([]byte, 5)) + r2, err2 := b.Read(make([]byte, 10)) + if r != 5 || r2 != 5 || err != nil || err2 != nil { + t.Fatal("expected reads of 5 and 5") + } + + b = newBuffer() + b.write(alphabet[:10]) + b.eof() + r, err = b.Read(make([]byte, 5)) + r2, err2 = b.Read(make([]byte, 10)) + r3, err3 := b.Read(make([]byte, 10)) + if r != 5 || r2 != 5 || r3 != 0 || err != nil || err2 != nil || err3 != io.EOF { + t.Fatal("expected reads of 5 and 5 and 0, with EOF") + } + + b = newBuffer() + b.write(make([]byte, 5)) + b.write(make([]byte, 10)) + b.eof() + r, err = b.Read(make([]byte, 9)) + r2, err2 = b.Read(make([]byte, 3)) + r3, err3 = b.Read(make([]byte, 3)) + r4, err4 := b.Read(make([]byte, 10)) + if err != nil || err2 != nil || err3 != nil || err4 != io.EOF { + t.Fatalf("Expected EOF on forth read only, err=%v, err2=%v, err3=%v, err4=%v", err, err2, err3, err4) + } + if r != 9 || r2 != 3 || r3 != 3 || r4 != 0 { + t.Fatal("Expected written == read == 15", r, r2, r3, r4) + } +} diff --git a/tempfork/sshtest/ssh/certs.go b/tempfork/sshtest/ssh/certs.go new file mode 100644 index 000000000..27d0e14aa --- /dev/null +++ b/tempfork/sshtest/ssh/certs.go @@ -0,0 +1,611 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "errors" + "fmt" + "io" + "net" + "sort" + "time" +) + +// Certificate algorithm names from [PROTOCOL.certkeys]. These values can appear +// in Certificate.Type, PublicKey.Type, and ClientConfig.HostKeyAlgorithms. +// Unlike key algorithm names, these are not passed to AlgorithmSigner nor +// returned by MultiAlgorithmSigner and don't appear in the Signature.Format +// field. +const ( + CertAlgoRSAv01 = "ssh-rsa-cert-v01@openssh.com" + CertAlgoDSAv01 = "ssh-dss-cert-v01@openssh.com" + CertAlgoECDSA256v01 = "ecdsa-sha2-nistp256-cert-v01@openssh.com" + CertAlgoECDSA384v01 = "ecdsa-sha2-nistp384-cert-v01@openssh.com" + CertAlgoECDSA521v01 = "ecdsa-sha2-nistp521-cert-v01@openssh.com" + CertAlgoSKECDSA256v01 = "sk-ecdsa-sha2-nistp256-cert-v01@openssh.com" + CertAlgoED25519v01 = "ssh-ed25519-cert-v01@openssh.com" + CertAlgoSKED25519v01 = "sk-ssh-ed25519-cert-v01@openssh.com" + + // CertAlgoRSASHA256v01 and CertAlgoRSASHA512v01 can't appear as a + // Certificate.Type (or PublicKey.Type), but only in + // ClientConfig.HostKeyAlgorithms. + CertAlgoRSASHA256v01 = "rsa-sha2-256-cert-v01@openssh.com" + CertAlgoRSASHA512v01 = "rsa-sha2-512-cert-v01@openssh.com" +) + +const ( + // Deprecated: use CertAlgoRSAv01. + CertSigAlgoRSAv01 = CertAlgoRSAv01 + // Deprecated: use CertAlgoRSASHA256v01. + CertSigAlgoRSASHA2256v01 = CertAlgoRSASHA256v01 + // Deprecated: use CertAlgoRSASHA512v01. + CertSigAlgoRSASHA2512v01 = CertAlgoRSASHA512v01 +) + +// Certificate types distinguish between host and user +// certificates. The values can be set in the CertType field of +// Certificate. +const ( + UserCert = 1 + HostCert = 2 +) + +// Signature represents a cryptographic signature. +type Signature struct { + Format string + Blob []byte + Rest []byte `ssh:"rest"` +} + +// CertTimeInfinity can be used for OpenSSHCertV01.ValidBefore to indicate that +// a certificate does not expire. +const CertTimeInfinity = 1<<64 - 1 + +// An Certificate represents an OpenSSH certificate as defined in +// [PROTOCOL.certkeys]?rev=1.8. The Certificate type implements the +// PublicKey interface, so it can be unmarshaled using +// ParsePublicKey. +type Certificate struct { + Nonce []byte + Key PublicKey + Serial uint64 + CertType uint32 + KeyId string + ValidPrincipals []string + ValidAfter uint64 + ValidBefore uint64 + Permissions + Reserved []byte + SignatureKey PublicKey + Signature *Signature +} + +// genericCertData holds the key-independent part of the certificate data. +// Overall, certificates contain an nonce, public key fields and +// key-independent fields. +type genericCertData struct { + Serial uint64 + CertType uint32 + KeyId string + ValidPrincipals []byte + ValidAfter uint64 + ValidBefore uint64 + CriticalOptions []byte + Extensions []byte + Reserved []byte + SignatureKey []byte + Signature []byte +} + +func marshalStringList(namelist []string) []byte { + var to []byte + for _, name := range namelist { + s := struct{ N string }{name} + to = append(to, Marshal(&s)...) + } + return to +} + +type optionsTuple struct { + Key string + Value []byte +} + +type optionsTupleValue struct { + Value string +} + +// serialize a map of critical options or extensions +// issue #10569 - per [PROTOCOL.certkeys] and SSH implementation, +// we need two length prefixes for a non-empty string value +func marshalTuples(tups map[string]string) []byte { + keys := make([]string, 0, len(tups)) + for key := range tups { + keys = append(keys, key) + } + sort.Strings(keys) + + var ret []byte + for _, key := range keys { + s := optionsTuple{Key: key} + if value := tups[key]; len(value) > 0 { + s.Value = Marshal(&optionsTupleValue{value}) + } + ret = append(ret, Marshal(&s)...) + } + return ret +} + +// issue #10569 - per [PROTOCOL.certkeys] and SSH implementation, +// we need two length prefixes for a non-empty option value +func parseTuples(in []byte) (map[string]string, error) { + tups := map[string]string{} + var lastKey string + var haveLastKey bool + + for len(in) > 0 { + var key, val, extra []byte + var ok bool + + if key, in, ok = parseString(in); !ok { + return nil, errShortRead + } + keyStr := string(key) + // according to [PROTOCOL.certkeys], the names must be in + // lexical order. + if haveLastKey && keyStr <= lastKey { + return nil, fmt.Errorf("ssh: certificate options are not in lexical order") + } + lastKey, haveLastKey = keyStr, true + // the next field is a data field, which if non-empty has a string embedded + if val, in, ok = parseString(in); !ok { + return nil, errShortRead + } + if len(val) > 0 { + val, extra, ok = parseString(val) + if !ok { + return nil, errShortRead + } + if len(extra) > 0 { + return nil, fmt.Errorf("ssh: unexpected trailing data after certificate option value") + } + tups[keyStr] = string(val) + } else { + tups[keyStr] = "" + } + } + return tups, nil +} + +func parseCert(in []byte, privAlgo string) (*Certificate, error) { + nonce, rest, ok := parseString(in) + if !ok { + return nil, errShortRead + } + + key, rest, err := parsePubKey(rest, privAlgo) + if err != nil { + return nil, err + } + + var g genericCertData + if err := Unmarshal(rest, &g); err != nil { + return nil, err + } + + c := &Certificate{ + Nonce: nonce, + Key: key, + Serial: g.Serial, + CertType: g.CertType, + KeyId: g.KeyId, + ValidAfter: g.ValidAfter, + ValidBefore: g.ValidBefore, + } + + for principals := g.ValidPrincipals; len(principals) > 0; { + principal, rest, ok := parseString(principals) + if !ok { + return nil, errShortRead + } + c.ValidPrincipals = append(c.ValidPrincipals, string(principal)) + principals = rest + } + + c.CriticalOptions, err = parseTuples(g.CriticalOptions) + if err != nil { + return nil, err + } + c.Extensions, err = parseTuples(g.Extensions) + if err != nil { + return nil, err + } + c.Reserved = g.Reserved + k, err := ParsePublicKey(g.SignatureKey) + if err != nil { + return nil, err + } + + c.SignatureKey = k + c.Signature, rest, ok = parseSignatureBody(g.Signature) + if !ok || len(rest) > 0 { + return nil, errors.New("ssh: signature parse error") + } + + return c, nil +} + +type openSSHCertSigner struct { + pub *Certificate + signer Signer +} + +type algorithmOpenSSHCertSigner struct { + *openSSHCertSigner + algorithmSigner AlgorithmSigner +} + +// NewCertSigner returns a Signer that signs with the given Certificate, whose +// private key is held by signer. It returns an error if the public key in cert +// doesn't match the key used by signer. +func NewCertSigner(cert *Certificate, signer Signer) (Signer, error) { + if !bytes.Equal(cert.Key.Marshal(), signer.PublicKey().Marshal()) { + return nil, errors.New("ssh: signer and cert have different public key") + } + + switch s := signer.(type) { + case MultiAlgorithmSigner: + return &multiAlgorithmSigner{ + AlgorithmSigner: &algorithmOpenSSHCertSigner{ + &openSSHCertSigner{cert, signer}, s}, + supportedAlgorithms: s.Algorithms(), + }, nil + case AlgorithmSigner: + return &algorithmOpenSSHCertSigner{ + &openSSHCertSigner{cert, signer}, s}, nil + default: + return &openSSHCertSigner{cert, signer}, nil + } +} + +func (s *openSSHCertSigner) Sign(rand io.Reader, data []byte) (*Signature, error) { + return s.signer.Sign(rand, data) +} + +func (s *openSSHCertSigner) PublicKey() PublicKey { + return s.pub +} + +func (s *algorithmOpenSSHCertSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) { + return s.algorithmSigner.SignWithAlgorithm(rand, data, algorithm) +} + +const sourceAddressCriticalOption = "source-address" + +// CertChecker does the work of verifying a certificate. Its methods +// can be plugged into ClientConfig.HostKeyCallback and +// ServerConfig.PublicKeyCallback. For the CertChecker to work, +// minimally, the IsAuthority callback should be set. +type CertChecker struct { + // SupportedCriticalOptions lists the CriticalOptions that the + // server application layer understands. These are only used + // for user certificates. + SupportedCriticalOptions []string + + // IsUserAuthority should return true if the key is recognized as an + // authority for the given user certificate. This allows for + // certificates to be signed by other certificates. This must be set + // if this CertChecker will be checking user certificates. + IsUserAuthority func(auth PublicKey) bool + + // IsHostAuthority should report whether the key is recognized as + // an authority for this host. This allows for certificates to be + // signed by other keys, and for those other keys to only be valid + // signers for particular hostnames. This must be set if this + // CertChecker will be checking host certificates. + IsHostAuthority func(auth PublicKey, address string) bool + + // Clock is used for verifying time stamps. If nil, time.Now + // is used. + Clock func() time.Time + + // UserKeyFallback is called when CertChecker.Authenticate encounters a + // public key that is not a certificate. It must implement validation + // of user keys or else, if nil, all such keys are rejected. + UserKeyFallback func(conn ConnMetadata, key PublicKey) (*Permissions, error) + + // HostKeyFallback is called when CertChecker.CheckHostKey encounters a + // public key that is not a certificate. It must implement host key + // validation or else, if nil, all such keys are rejected. + HostKeyFallback HostKeyCallback + + // IsRevoked is called for each certificate so that revocation checking + // can be implemented. It should return true if the given certificate + // is revoked and false otherwise. If nil, no certificates are + // considered to have been revoked. + IsRevoked func(cert *Certificate) bool +} + +// CheckHostKey checks a host key certificate. This method can be +// plugged into ClientConfig.HostKeyCallback. +func (c *CertChecker) CheckHostKey(addr string, remote net.Addr, key PublicKey) error { + cert, ok := key.(*Certificate) + if !ok { + if c.HostKeyFallback != nil { + return c.HostKeyFallback(addr, remote, key) + } + return errors.New("ssh: non-certificate host key") + } + if cert.CertType != HostCert { + return fmt.Errorf("ssh: certificate presented as a host key has type %d", cert.CertType) + } + if !c.IsHostAuthority(cert.SignatureKey, addr) { + return fmt.Errorf("ssh: no authorities for hostname: %v", addr) + } + + hostname, _, err := net.SplitHostPort(addr) + if err != nil { + return err + } + + // Pass hostname only as principal for host certificates (consistent with OpenSSH) + return c.CheckCert(hostname, cert) +} + +// Authenticate checks a user certificate. Authenticate can be used as +// a value for ServerConfig.PublicKeyCallback. +func (c *CertChecker) Authenticate(conn ConnMetadata, pubKey PublicKey) (*Permissions, error) { + cert, ok := pubKey.(*Certificate) + if !ok { + if c.UserKeyFallback != nil { + return c.UserKeyFallback(conn, pubKey) + } + return nil, errors.New("ssh: normal key pairs not accepted") + } + + if cert.CertType != UserCert { + return nil, fmt.Errorf("ssh: cert has type %d", cert.CertType) + } + if !c.IsUserAuthority(cert.SignatureKey) { + return nil, fmt.Errorf("ssh: certificate signed by unrecognized authority") + } + + if err := c.CheckCert(conn.User(), cert); err != nil { + return nil, err + } + + return &cert.Permissions, nil +} + +// CheckCert checks CriticalOptions, ValidPrincipals, revocation, timestamp and +// the signature of the certificate. +func (c *CertChecker) CheckCert(principal string, cert *Certificate) error { + if c.IsRevoked != nil && c.IsRevoked(cert) { + return fmt.Errorf("ssh: certificate serial %d revoked", cert.Serial) + } + + for opt := range cert.CriticalOptions { + // sourceAddressCriticalOption will be enforced by + // serverAuthenticate + if opt == sourceAddressCriticalOption { + continue + } + + found := false + for _, supp := range c.SupportedCriticalOptions { + if supp == opt { + found = true + break + } + } + if !found { + return fmt.Errorf("ssh: unsupported critical option %q in certificate", opt) + } + } + + if len(cert.ValidPrincipals) > 0 { + // By default, certs are valid for all users/hosts. + found := false + for _, p := range cert.ValidPrincipals { + if p == principal { + found = true + break + } + } + if !found { + return fmt.Errorf("ssh: principal %q not in the set of valid principals for given certificate: %q", principal, cert.ValidPrincipals) + } + } + + clock := c.Clock + if clock == nil { + clock = time.Now + } + + unixNow := clock().Unix() + if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) { + return fmt.Errorf("ssh: cert is not yet valid") + } + if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(CertTimeInfinity) && (unixNow >= before || before < 0) { + return fmt.Errorf("ssh: cert has expired") + } + if err := cert.SignatureKey.Verify(cert.bytesForSigning(), cert.Signature); err != nil { + return fmt.Errorf("ssh: certificate signature does not verify") + } + + return nil +} + +// SignCert signs the certificate with an authority, setting the Nonce, +// SignatureKey, and Signature fields. If the authority implements the +// MultiAlgorithmSigner interface the first algorithm in the list is used. This +// is useful if you want to sign with a specific algorithm. +func (c *Certificate) SignCert(rand io.Reader, authority Signer) error { + c.Nonce = make([]byte, 32) + if _, err := io.ReadFull(rand, c.Nonce); err != nil { + return err + } + c.SignatureKey = authority.PublicKey() + + if v, ok := authority.(MultiAlgorithmSigner); ok { + if len(v.Algorithms()) == 0 { + return errors.New("the provided authority has no signature algorithm") + } + // Use the first algorithm in the list. + sig, err := v.SignWithAlgorithm(rand, c.bytesForSigning(), v.Algorithms()[0]) + if err != nil { + return err + } + c.Signature = sig + return nil + } else if v, ok := authority.(AlgorithmSigner); ok && v.PublicKey().Type() == KeyAlgoRSA { + // Default to KeyAlgoRSASHA512 for ssh-rsa signers. + // TODO: consider using KeyAlgoRSASHA256 as default. + sig, err := v.SignWithAlgorithm(rand, c.bytesForSigning(), KeyAlgoRSASHA512) + if err != nil { + return err + } + c.Signature = sig + return nil + } + + sig, err := authority.Sign(rand, c.bytesForSigning()) + if err != nil { + return err + } + c.Signature = sig + return nil +} + +// certKeyAlgoNames is a mapping from known certificate algorithm names to the +// corresponding public key signature algorithm. +// +// This map must be kept in sync with the one in agent/client.go. +var certKeyAlgoNames = map[string]string{ + CertAlgoRSAv01: KeyAlgoRSA, + CertAlgoRSASHA256v01: KeyAlgoRSASHA256, + CertAlgoRSASHA512v01: KeyAlgoRSASHA512, + CertAlgoDSAv01: KeyAlgoDSA, + CertAlgoECDSA256v01: KeyAlgoECDSA256, + CertAlgoECDSA384v01: KeyAlgoECDSA384, + CertAlgoECDSA521v01: KeyAlgoECDSA521, + CertAlgoSKECDSA256v01: KeyAlgoSKECDSA256, + CertAlgoED25519v01: KeyAlgoED25519, + CertAlgoSKED25519v01: KeyAlgoSKED25519, +} + +// underlyingAlgo returns the signature algorithm associated with algo (which is +// an advertised or negotiated public key or host key algorithm). These are +// usually the same, except for certificate algorithms. +func underlyingAlgo(algo string) string { + if a, ok := certKeyAlgoNames[algo]; ok { + return a + } + return algo +} + +// certificateAlgo returns the certificate algorithms that uses the provided +// underlying signature algorithm. +func certificateAlgo(algo string) (certAlgo string, ok bool) { + for certName, algoName := range certKeyAlgoNames { + if algoName == algo { + return certName, true + } + } + return "", false +} + +func (cert *Certificate) bytesForSigning() []byte { + c2 := *cert + c2.Signature = nil + out := c2.Marshal() + // Drop trailing signature length. + return out[:len(out)-4] +} + +// Marshal serializes c into OpenSSH's wire format. It is part of the +// PublicKey interface. +func (c *Certificate) Marshal() []byte { + generic := genericCertData{ + Serial: c.Serial, + CertType: c.CertType, + KeyId: c.KeyId, + ValidPrincipals: marshalStringList(c.ValidPrincipals), + ValidAfter: uint64(c.ValidAfter), + ValidBefore: uint64(c.ValidBefore), + CriticalOptions: marshalTuples(c.CriticalOptions), + Extensions: marshalTuples(c.Extensions), + Reserved: c.Reserved, + SignatureKey: c.SignatureKey.Marshal(), + } + if c.Signature != nil { + generic.Signature = Marshal(c.Signature) + } + genericBytes := Marshal(&generic) + keyBytes := c.Key.Marshal() + _, keyBytes, _ = parseString(keyBytes) + prefix := Marshal(&struct { + Name string + Nonce []byte + Key []byte `ssh:"rest"` + }{c.Type(), c.Nonce, keyBytes}) + + result := make([]byte, 0, len(prefix)+len(genericBytes)) + result = append(result, prefix...) + result = append(result, genericBytes...) + return result +} + +// Type returns the certificate algorithm name. It is part of the PublicKey interface. +func (c *Certificate) Type() string { + certName, ok := certificateAlgo(c.Key.Type()) + if !ok { + panic("unknown certificate type for key type " + c.Key.Type()) + } + return certName +} + +// Verify verifies a signature against the certificate's public +// key. It is part of the PublicKey interface. +func (c *Certificate) Verify(data []byte, sig *Signature) error { + return c.Key.Verify(data, sig) +} + +func parseSignatureBody(in []byte) (out *Signature, rest []byte, ok bool) { + format, in, ok := parseString(in) + if !ok { + return + } + + out = &Signature{ + Format: string(format), + } + + if out.Blob, in, ok = parseString(in); !ok { + return + } + + switch out.Format { + case KeyAlgoSKECDSA256, CertAlgoSKECDSA256v01, KeyAlgoSKED25519, CertAlgoSKED25519v01: + out.Rest = in + return out, nil, ok + } + + return out, in, ok +} + +func parseSignature(in []byte) (out *Signature, rest []byte, ok bool) { + sigBytes, rest, ok := parseString(in) + if !ok { + return + } + + out, trailing, ok := parseSignatureBody(sigBytes) + if !ok || len(trailing) > 0 { + return nil, nil, false + } + return +} diff --git a/tempfork/sshtest/ssh/certs_test.go b/tempfork/sshtest/ssh/certs_test.go new file mode 100644 index 000000000..6208bb37a --- /dev/null +++ b/tempfork/sshtest/ssh/certs_test.go @@ -0,0 +1,406 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "fmt" + "io" + "net" + "reflect" + "testing" + "time" + + "golang.org/x/crypto/ssh/testdata" +) + +func TestParseCert(t *testing.T) { + authKeyBytes := bytes.TrimSuffix(testdata.SSHCertificates["rsa"], []byte(" host.example.com\n")) + + key, _, _, rest, err := ParseAuthorizedKey(authKeyBytes) + if err != nil { + t.Fatalf("ParseAuthorizedKey: %v", err) + } + if len(rest) > 0 { + t.Errorf("rest: got %q, want empty", rest) + } + + if _, ok := key.(*Certificate); !ok { + t.Fatalf("got %v (%T), want *Certificate", key, key) + } + + marshaled := MarshalAuthorizedKey(key) + // Before comparison, remove the trailing newline that + // MarshalAuthorizedKey adds. + marshaled = marshaled[:len(marshaled)-1] + if !bytes.Equal(authKeyBytes, marshaled) { + t.Errorf("marshaled certificate does not match original: got %q, want %q", marshaled, authKeyBytes) + } +} + +// Cert generated by ssh-keygen OpenSSH_6.8p1 OS X 10.10.3 +// % ssh-keygen -s ca -I testcert -O source-address=192.168.1.0/24 -O force-command=/bin/sleep user.pub +// user.pub key: ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDACh1rt2DXfV3hk6fszSQcQ/rueMId0kVD9U7nl8cfEnFxqOCrNT92g4laQIGl2mn8lsGZfTLg8ksHq3gkvgO3oo/0wHy4v32JeBOHTsN5AL4gfHNEhWeWb50ev47hnTsRIt9P4dxogeUo/hTu7j9+s9lLpEQXCvq6xocXQt0j8MV9qZBBXFLXVT3cWIkSqOdwt/5ZBg+1GSrc7WfCXVWgTk4a20uPMuJPxU4RQwZW6X3+O8Pqo8C3cW0OzZRFP6gUYUKUsTI5WntlS+LAxgw1mZNsozFGdbiOPRnEryE3SRldh9vjDR3tin1fGpA5P7+CEB/bqaXtG3V+F2OkqaMN +// Critical Options: +// +// force-command /bin/sleep +// source-address 192.168.1.0/24 +// +// Extensions: +// +// permit-X11-forwarding +// permit-agent-forwarding +// permit-port-forwarding +// permit-pty +// permit-user-rc +const exampleSSHCertWithOptions = `ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgDyysCJY0XrO1n03EeRRoITnTPdjENFmWDs9X58PP3VUAAAADAQABAAABAQDACh1rt2DXfV3hk6fszSQcQ/rueMId0kVD9U7nl8cfEnFxqOCrNT92g4laQIGl2mn8lsGZfTLg8ksHq3gkvgO3oo/0wHy4v32JeBOHTsN5AL4gfHNEhWeWb50ev47hnTsRIt9P4dxogeUo/hTu7j9+s9lLpEQXCvq6xocXQt0j8MV9qZBBXFLXVT3cWIkSqOdwt/5ZBg+1GSrc7WfCXVWgTk4a20uPMuJPxU4RQwZW6X3+O8Pqo8C3cW0OzZRFP6gUYUKUsTI5WntlS+LAxgw1mZNsozFGdbiOPRnEryE3SRldh9vjDR3tin1fGpA5P7+CEB/bqaXtG3V+F2OkqaMNAAAAAAAAAAAAAAABAAAACHRlc3RjZXJ0AAAAAAAAAAAAAAAA//////////8AAABLAAAADWZvcmNlLWNvbW1hbmQAAAAOAAAACi9iaW4vc2xlZXAAAAAOc291cmNlLWFkZHJlc3MAAAASAAAADjE5Mi4xNjguMS4wLzI0AAAAggAAABVwZXJtaXQtWDExLWZvcndhcmRpbmcAAAAAAAAAF3Blcm1pdC1hZ2VudC1mb3J3YXJkaW5nAAAAAAAAABZwZXJtaXQtcG9ydC1mb3J3YXJkaW5nAAAAAAAAAApwZXJtaXQtcHR5AAAAAAAAAA5wZXJtaXQtdXNlci1yYwAAAAAAAAAAAAABFwAAAAdzc2gtcnNhAAAAAwEAAQAAAQEAwU+c5ui5A8+J/CFpjW8wCa52bEODA808WWQDCSuTG/eMXNf59v9Y8Pk0F1E9dGCosSNyVcB/hacUrc6He+i97+HJCyKavBsE6GDxrjRyxYqAlfcOXi/IVmaUGiO8OQ39d4GHrjToInKvExSUeleQyH4Y4/e27T/pILAqPFL3fyrvMLT5qU9QyIt6zIpa7GBP5+urouNavMprV3zsfIqNBbWypinOQAw823a5wN+zwXnhZrgQiHZ/USG09Y6k98y1dTVz8YHlQVR4D3lpTAsKDKJ5hCH9WU4fdf+lU8OyNGaJ/vz0XNqxcToe1l4numLTnaoSuH89pHryjqurB7lJKwAAAQ8AAAAHc3NoLXJzYQAAAQCaHvUIoPL1zWUHIXLvu96/HU1s/i4CAW2IIEuGgxCUCiFj6vyTyYtgxQxcmbfZf6eaITlS6XJZa7Qq4iaFZh75C1DXTX8labXhRSD4E2t//AIP9MC1rtQC5xo6FmbQ+BoKcDskr+mNACcbRSxs3IL3bwCfWDnIw2WbVox9ZdcthJKk4UoCW4ix4QwdHw7zlddlz++fGEEVhmTbll1SUkycGApPFBsAYRTMupUJcYPIeReBI/m8XfkoMk99bV8ZJQTAd7OekHY2/48Ff53jLmyDjP7kNw1F8OaPtkFs6dGJXta4krmaekPy87j+35In5hFj7yoOqvSbmYUkeX70/GGQ` + +func TestParseCertWithOptions(t *testing.T) { + opts := map[string]string{ + "source-address": "192.168.1.0/24", + "force-command": "/bin/sleep", + } + exts := map[string]string{ + "permit-X11-forwarding": "", + "permit-agent-forwarding": "", + "permit-port-forwarding": "", + "permit-pty": "", + "permit-user-rc": "", + } + authKeyBytes := []byte(exampleSSHCertWithOptions) + + key, _, _, rest, err := ParseAuthorizedKey(authKeyBytes) + if err != nil { + t.Fatalf("ParseAuthorizedKey: %v", err) + } + if len(rest) > 0 { + t.Errorf("rest: got %q, want empty", rest) + } + cert, ok := key.(*Certificate) + if !ok { + t.Fatalf("got %v (%T), want *Certificate", key, key) + } + if !reflect.DeepEqual(cert.CriticalOptions, opts) { + t.Errorf("unexpected critical options - got %v, want %v", cert.CriticalOptions, opts) + } + if !reflect.DeepEqual(cert.Extensions, exts) { + t.Errorf("unexpected Extensions - got %v, want %v", cert.Extensions, exts) + } + marshaled := MarshalAuthorizedKey(key) + // Before comparison, remove the trailing newline that + // MarshalAuthorizedKey adds. + marshaled = marshaled[:len(marshaled)-1] + if !bytes.Equal(authKeyBytes, marshaled) { + t.Errorf("marshaled certificate does not match original: got %q, want %q", marshaled, authKeyBytes) + } +} + +func TestValidateCert(t *testing.T) { + key, _, _, _, err := ParseAuthorizedKey(testdata.SSHCertificates["rsa-user-testcertificate"]) + if err != nil { + t.Fatalf("ParseAuthorizedKey: %v", err) + } + validCert, ok := key.(*Certificate) + if !ok { + t.Fatalf("got %v (%T), want *Certificate", key, key) + } + checker := CertChecker{} + checker.IsUserAuthority = func(k PublicKey) bool { + return bytes.Equal(k.Marshal(), validCert.SignatureKey.Marshal()) + } + + if err := checker.CheckCert("testcertificate", validCert); err != nil { + t.Errorf("Unable to validate certificate: %v", err) + } + invalidCert := &Certificate{ + Key: testPublicKeys["rsa"], + SignatureKey: testPublicKeys["ecdsa"], + ValidBefore: CertTimeInfinity, + Signature: &Signature{}, + } + if err := checker.CheckCert("testcertificate", invalidCert); err == nil { + t.Error("Invalid cert signature passed validation") + } +} + +func TestValidateCertTime(t *testing.T) { + cert := Certificate{ + ValidPrincipals: []string{"user"}, + Key: testPublicKeys["rsa"], + ValidAfter: 50, + ValidBefore: 100, + } + + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + + for ts, ok := range map[int64]bool{ + 25: false, + 50: true, + 99: true, + 100: false, + 125: false, + } { + checker := CertChecker{ + Clock: func() time.Time { return time.Unix(ts, 0) }, + } + checker.IsUserAuthority = func(k PublicKey) bool { + return bytes.Equal(k.Marshal(), + testPublicKeys["ecdsa"].Marshal()) + } + + if v := checker.CheckCert("user", &cert); (v == nil) != ok { + t.Errorf("Authenticate(%d): %v", ts, v) + } + } +} + +// TODO(hanwen): tests for +// +// host keys: +// * fallbacks + +func TestHostKeyCert(t *testing.T) { + cert := &Certificate{ + ValidPrincipals: []string{"hostname", "hostname.domain", "otherhost"}, + Key: testPublicKeys["rsa"], + ValidBefore: CertTimeInfinity, + CertType: HostCert, + } + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + + checker := &CertChecker{ + IsHostAuthority: func(p PublicKey, addr string) bool { + return addr == "hostname:22" && bytes.Equal(testPublicKeys["ecdsa"].Marshal(), p.Marshal()) + }, + } + + certSigner, err := NewCertSigner(cert, testSigners["rsa"]) + if err != nil { + t.Errorf("NewCertSigner: %v", err) + } + + for _, test := range []struct { + addr string + succeed bool + certSignerAlgorithms []string // Empty means no algorithm restrictions. + clientHostKeyAlgorithms []string + }{ + {addr: "hostname:22", succeed: true}, + { + addr: "hostname:22", + succeed: true, + certSignerAlgorithms: []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512}, + clientHostKeyAlgorithms: []string{CertAlgoRSASHA512v01}, + }, + { + addr: "hostname:22", + succeed: false, + certSignerAlgorithms: []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512}, + clientHostKeyAlgorithms: []string{CertAlgoRSAv01}, + }, + { + addr: "hostname:22", + succeed: false, + certSignerAlgorithms: []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512}, + clientHostKeyAlgorithms: []string{KeyAlgoRSASHA512}, // Not a certificate algorithm. + }, + {addr: "otherhost:22", succeed: false}, // The certificate is valid for 'otherhost' as hostname, but we only recognize the authority of the signer for the address 'hostname:22' + {addr: "lasthost:22", succeed: false}, + } { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + errc := make(chan error) + + go func() { + conf := ServerConfig{ + NoClientAuth: true, + } + if len(test.certSignerAlgorithms) > 0 { + mas, err := NewSignerWithAlgorithms(certSigner.(AlgorithmSigner), test.certSignerAlgorithms) + if err != nil { + errc <- err + return + } + conf.AddHostKey(mas) + } else { + conf.AddHostKey(certSigner) + } + _, _, _, err := NewServerConn(c1, &conf) + errc <- err + }() + + config := &ClientConfig{ + User: "user", + HostKeyCallback: checker.CheckHostKey, + HostKeyAlgorithms: test.clientHostKeyAlgorithms, + } + _, _, _, err = NewClientConn(c2, test.addr, config) + + if (err == nil) != test.succeed { + t.Errorf("NewClientConn(%q): %v", test.addr, err) + } + + err = <-errc + if (err == nil) != test.succeed { + t.Errorf("NewServerConn(%q): %v", test.addr, err) + } + } +} + +type legacyRSASigner struct { + Signer +} + +func (s *legacyRSASigner) Sign(rand io.Reader, data []byte) (*Signature, error) { + v, ok := s.Signer.(AlgorithmSigner) + if !ok { + return nil, fmt.Errorf("invalid signer") + } + return v.SignWithAlgorithm(rand, data, KeyAlgoRSA) +} + +func TestCertTypes(t *testing.T) { + algorithmSigner, ok := testSigners["rsa"].(AlgorithmSigner) + if !ok { + t.Fatal("rsa test signer does not implement the AlgorithmSigner interface") + } + multiAlgoSignerSHA256, err := NewSignerWithAlgorithms(algorithmSigner, []string{KeyAlgoRSASHA256}) + if err != nil { + t.Fatalf("unable to create multi algorithm signer SHA256: %v", err) + } + // Algorithms are in order of preference, we expect rsa-sha2-512 to be used. + multiAlgoSignerSHA512, err := NewSignerWithAlgorithms(algorithmSigner, []string{KeyAlgoRSASHA512, KeyAlgoRSASHA256}) + if err != nil { + t.Fatalf("unable to create multi algorithm signer SHA512: %v", err) + } + + var testVars = []struct { + name string + signer Signer + algo string + }{ + {CertAlgoECDSA256v01, testSigners["ecdsap256"], ""}, + {CertAlgoECDSA384v01, testSigners["ecdsap384"], ""}, + {CertAlgoECDSA521v01, testSigners["ecdsap521"], ""}, + {CertAlgoED25519v01, testSigners["ed25519"], ""}, + {CertAlgoRSAv01, testSigners["rsa"], KeyAlgoRSASHA256}, + {"legacyRSASigner", &legacyRSASigner{testSigners["rsa"]}, KeyAlgoRSA}, + {"multiAlgoRSASignerSHA256", multiAlgoSignerSHA256, KeyAlgoRSASHA256}, + {"multiAlgoRSASignerSHA512", multiAlgoSignerSHA512, KeyAlgoRSASHA512}, + {CertAlgoDSAv01, testSigners["dsa"], ""}, + } + + k, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("error generating host key: %v", err) + } + + signer, err := NewSignerFromKey(k) + if err != nil { + t.Fatalf("error generating signer for ssh listener: %v", err) + } + + conf := &ServerConfig{ + PublicKeyCallback: func(c ConnMetadata, k PublicKey) (*Permissions, error) { + return new(Permissions), nil + }, + } + conf.AddHostKey(signer) + + for _, m := range testVars { + t.Run(m.name, func(t *testing.T) { + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + go NewServerConn(c1, conf) + + priv := m.signer + if err != nil { + t.Fatalf("error generating ssh pubkey: %v", err) + } + + cert := &Certificate{ + CertType: UserCert, + Key: priv.PublicKey(), + } + cert.SignCert(rand.Reader, priv) + + certSigner, err := NewCertSigner(cert, priv) + if err != nil { + t.Fatalf("error generating cert signer: %v", err) + } + + if m.algo != "" && cert.Signature.Format != m.algo { + t.Errorf("expected %q signature format, got %q", m.algo, cert.Signature.Format) + } + + config := &ClientConfig{ + User: "user", + HostKeyCallback: func(h string, r net.Addr, k PublicKey) error { return nil }, + Auth: []AuthMethod{PublicKeys(certSigner)}, + } + + _, _, _, err = NewClientConn(c2, "", config) + if err != nil { + t.Fatalf("error connecting: %v", err) + } + }) + } +} + +func TestCertSignWithMultiAlgorithmSigner(t *testing.T) { + type testcase struct { + sigAlgo string + algorithms []string + } + cases := []testcase{ + { + sigAlgo: KeyAlgoRSA, + algorithms: []string{KeyAlgoRSA, KeyAlgoRSASHA512}, + }, + { + sigAlgo: KeyAlgoRSASHA256, + algorithms: []string{KeyAlgoRSASHA256, KeyAlgoRSA, KeyAlgoRSASHA512}, + }, + { + sigAlgo: KeyAlgoRSASHA512, + algorithms: []string{KeyAlgoRSASHA512, KeyAlgoRSASHA256}, + }, + } + + cert := &Certificate{ + Key: testPublicKeys["rsa"], + ValidBefore: CertTimeInfinity, + CertType: UserCert, + } + + for _, c := range cases { + t.Run(c.sigAlgo, func(t *testing.T) { + signer, err := NewSignerWithAlgorithms(testSigners["rsa"].(AlgorithmSigner), c.algorithms) + if err != nil { + t.Fatalf("NewSignerWithAlgorithms error: %v", err) + } + if err := cert.SignCert(rand.Reader, signer); err != nil { + t.Fatalf("SignCert error: %v", err) + } + if cert.Signature.Format != c.sigAlgo { + t.Fatalf("got signature format %q, want %q", cert.Signature.Format, c.sigAlgo) + } + }) + } +} diff --git a/tempfork/sshtest/ssh/channel.go b/tempfork/sshtest/ssh/channel.go new file mode 100644 index 000000000..cc0bb7ab6 --- /dev/null +++ b/tempfork/sshtest/ssh/channel.go @@ -0,0 +1,645 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "log" + "sync" +) + +const ( + minPacketLength = 9 + // channelMaxPacket contains the maximum number of bytes that will be + // sent in a single packet. As per RFC 4253, section 6.1, 32k is also + // the minimum. + channelMaxPacket = 1 << 15 + // We follow OpenSSH here. + channelWindowSize = 64 * channelMaxPacket +) + +// NewChannel represents an incoming request to a channel. It must either be +// accepted for use by calling Accept, or rejected by calling Reject. +type NewChannel interface { + // Accept accepts the channel creation request. It returns the Channel + // and a Go channel containing SSH requests. The Go channel must be + // serviced otherwise the Channel will hang. + Accept() (Channel, <-chan *Request, error) + + // Reject rejects the channel creation request. After calling + // this, no other methods on the Channel may be called. + Reject(reason RejectionReason, message string) error + + // ChannelType returns the type of the channel, as supplied by the + // client. + ChannelType() string + + // ExtraData returns the arbitrary payload for this channel, as supplied + // by the client. This data is specific to the channel type. + ExtraData() []byte +} + +// A Channel is an ordered, reliable, flow-controlled, duplex stream +// that is multiplexed over an SSH connection. +type Channel interface { + // Read reads up to len(data) bytes from the channel. + Read(data []byte) (int, error) + + // Write writes len(data) bytes to the channel. + Write(data []byte) (int, error) + + // Close signals end of channel use. No data may be sent after this + // call. + Close() error + + // CloseWrite signals the end of sending in-band + // data. Requests may still be sent, and the other side may + // still send data + CloseWrite() error + + // SendRequest sends a channel request. If wantReply is true, + // it will wait for a reply and return the result as a + // boolean, otherwise the return value will be false. Channel + // requests are out-of-band messages so they may be sent even + // if the data stream is closed or blocked by flow control. + // If the channel is closed before a reply is returned, io.EOF + // is returned. + SendRequest(name string, wantReply bool, payload []byte) (bool, error) + + // Stderr returns an io.ReadWriter that writes to this channel + // with the extended data type set to stderr. Stderr may + // safely be read and written from a different goroutine than + // Read and Write respectively. + Stderr() io.ReadWriter +} + +// Request is a request sent outside of the normal stream of +// data. Requests can either be specific to an SSH channel, or they +// can be global. +type Request struct { + Type string + WantReply bool + Payload []byte + + ch *channel + mux *mux +} + +// Reply sends a response to a request. It must be called for all requests +// where WantReply is true and is a no-op otherwise. The payload argument is +// ignored for replies to channel-specific requests. +func (r *Request) Reply(ok bool, payload []byte) error { + if !r.WantReply { + return nil + } + + if r.ch == nil { + return r.mux.ackRequest(ok, payload) + } + + return r.ch.ackRequest(ok) +} + +// RejectionReason is an enumeration used when rejecting channel creation +// requests. See RFC 4254, section 5.1. +type RejectionReason uint32 + +const ( + Prohibited RejectionReason = iota + 1 + ConnectionFailed + UnknownChannelType + ResourceShortage +) + +// String converts the rejection reason to human readable form. +func (r RejectionReason) String() string { + switch r { + case Prohibited: + return "administratively prohibited" + case ConnectionFailed: + return "connect failed" + case UnknownChannelType: + return "unknown channel type" + case ResourceShortage: + return "resource shortage" + } + return fmt.Sprintf("unknown reason %d", int(r)) +} + +func min(a uint32, b int) uint32 { + if a < uint32(b) { + return a + } + return uint32(b) +} + +type channelDirection uint8 + +const ( + channelInbound channelDirection = iota + channelOutbound +) + +// channel is an implementation of the Channel interface that works +// with the mux class. +type channel struct { + // R/O after creation + chanType string + extraData []byte + localId, remoteId uint32 + + // maxIncomingPayload and maxRemotePayload are the maximum + // payload sizes of normal and extended data packets for + // receiving and sending, respectively. The wire packet will + // be 9 or 13 bytes larger (excluding encryption overhead). + maxIncomingPayload uint32 + maxRemotePayload uint32 + + mux *mux + + // decided is set to true if an accept or reject message has been sent + // (for outbound channels) or received (for inbound channels). + decided bool + + // direction contains either channelOutbound, for channels created + // locally, or channelInbound, for channels created by the peer. + direction channelDirection + + // Pending internal channel messages. + msg chan interface{} + + // Since requests have no ID, there can be only one request + // with WantReply=true outstanding. This lock is held by a + // goroutine that has such an outgoing request pending. + sentRequestMu sync.Mutex + + incomingRequests chan *Request + + sentEOF bool + + // thread-safe data + remoteWin window + pending *buffer + extPending *buffer + + // windowMu protects myWindow, the flow-control window, and myConsumed, + // the number of bytes consumed since we last increased myWindow + windowMu sync.Mutex + myWindow uint32 + myConsumed uint32 + + // writeMu serializes calls to mux.conn.writePacket() and + // protects sentClose and packetPool. This mutex must be + // different from windowMu, as writePacket can block if there + // is a key exchange pending. + writeMu sync.Mutex + sentClose bool + + // packetPool has a buffer for each extended channel ID to + // save allocations during writes. + packetPool map[uint32][]byte +} + +// writePacket sends a packet. If the packet is a channel close, it updates +// sentClose. This method takes the lock c.writeMu. +func (ch *channel) writePacket(packet []byte) error { + ch.writeMu.Lock() + if ch.sentClose { + ch.writeMu.Unlock() + return io.EOF + } + ch.sentClose = (packet[0] == msgChannelClose) + err := ch.mux.conn.writePacket(packet) + ch.writeMu.Unlock() + return err +} + +func (ch *channel) sendMessage(msg interface{}) error { + if debugMux { + log.Printf("send(%d): %#v", ch.mux.chanList.offset, msg) + } + + p := Marshal(msg) + binary.BigEndian.PutUint32(p[1:], ch.remoteId) + return ch.writePacket(p) +} + +// WriteExtended writes data to a specific extended stream. These streams are +// used, for example, for stderr. +func (ch *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) { + if ch.sentEOF { + return 0, io.EOF + } + // 1 byte message type, 4 bytes remoteId, 4 bytes data length + opCode := byte(msgChannelData) + headerLength := uint32(9) + if extendedCode > 0 { + headerLength += 4 + opCode = msgChannelExtendedData + } + + ch.writeMu.Lock() + packet := ch.packetPool[extendedCode] + // We don't remove the buffer from packetPool, so + // WriteExtended calls from different goroutines will be + // flagged as errors by the race detector. + ch.writeMu.Unlock() + + for len(data) > 0 { + space := min(ch.maxRemotePayload, len(data)) + if space, err = ch.remoteWin.reserve(space); err != nil { + return n, err + } + if want := headerLength + space; uint32(cap(packet)) < want { + packet = make([]byte, want) + } else { + packet = packet[:want] + } + + todo := data[:space] + + packet[0] = opCode + binary.BigEndian.PutUint32(packet[1:], ch.remoteId) + if extendedCode > 0 { + binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode)) + } + binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo))) + copy(packet[headerLength:], todo) + if err = ch.writePacket(packet); err != nil { + return n, err + } + + n += len(todo) + data = data[len(todo):] + } + + ch.writeMu.Lock() + ch.packetPool[extendedCode] = packet + ch.writeMu.Unlock() + + return n, err +} + +func (ch *channel) handleData(packet []byte) error { + headerLen := 9 + isExtendedData := packet[0] == msgChannelExtendedData + if isExtendedData { + headerLen = 13 + } + if len(packet) < headerLen { + // malformed data packet + return parseError(packet[0]) + } + + var extended uint32 + if isExtendedData { + extended = binary.BigEndian.Uint32(packet[5:]) + } + + length := binary.BigEndian.Uint32(packet[headerLen-4 : headerLen]) + if length == 0 { + return nil + } + if length > ch.maxIncomingPayload { + // TODO(hanwen): should send Disconnect? + return errors.New("ssh: incoming packet exceeds maximum payload size") + } + + data := packet[headerLen:] + if length != uint32(len(data)) { + return errors.New("ssh: wrong packet length") + } + + ch.windowMu.Lock() + if ch.myWindow < length { + ch.windowMu.Unlock() + // TODO(hanwen): should send Disconnect with reason? + return errors.New("ssh: remote side wrote too much") + } + ch.myWindow -= length + ch.windowMu.Unlock() + + if extended == 1 { + ch.extPending.write(data) + } else if extended > 0 { + // discard other extended data. + } else { + ch.pending.write(data) + } + return nil +} + +func (c *channel) adjustWindow(adj uint32) error { + c.windowMu.Lock() + // Since myConsumed and myWindow are managed on our side, and can never + // exceed the initial window setting, we don't worry about overflow. + c.myConsumed += adj + var sendAdj uint32 + if (channelWindowSize-c.myWindow > 3*c.maxIncomingPayload) || + (c.myWindow < channelWindowSize/2) { + sendAdj = c.myConsumed + c.myConsumed = 0 + c.myWindow += sendAdj + } + c.windowMu.Unlock() + if sendAdj == 0 { + return nil + } + return c.sendMessage(windowAdjustMsg{ + AdditionalBytes: sendAdj, + }) +} + +func (c *channel) ReadExtended(data []byte, extended uint32) (n int, err error) { + switch extended { + case 1: + n, err = c.extPending.Read(data) + case 0: + n, err = c.pending.Read(data) + default: + return 0, fmt.Errorf("ssh: extended code %d unimplemented", extended) + } + + if n > 0 { + err = c.adjustWindow(uint32(n)) + // sendWindowAdjust can return io.EOF if the remote + // peer has closed the connection, however we want to + // defer forwarding io.EOF to the caller of Read until + // the buffer has been drained. + if n > 0 && err == io.EOF { + err = nil + } + } + + return n, err +} + +func (c *channel) close() { + c.pending.eof() + c.extPending.eof() + close(c.msg) + close(c.incomingRequests) + c.writeMu.Lock() + // This is not necessary for a normal channel teardown, but if + // there was another error, it is. + c.sentClose = true + c.writeMu.Unlock() + // Unblock writers. + c.remoteWin.close() +} + +// responseMessageReceived is called when a success or failure message is +// received on a channel to check that such a message is reasonable for the +// given channel. +func (ch *channel) responseMessageReceived() error { + if ch.direction == channelInbound { + return errors.New("ssh: channel response message received on inbound channel") + } + if ch.decided { + return errors.New("ssh: duplicate response received for channel") + } + ch.decided = true + return nil +} + +func (ch *channel) handlePacket(packet []byte) error { + switch packet[0] { + case msgChannelData, msgChannelExtendedData: + return ch.handleData(packet) + case msgChannelClose: + ch.sendMessage(channelCloseMsg{PeersID: ch.remoteId}) + ch.mux.chanList.remove(ch.localId) + ch.close() + return nil + case msgChannelEOF: + // RFC 4254 is mute on how EOF affects dataExt messages but + // it is logical to signal EOF at the same time. + ch.extPending.eof() + ch.pending.eof() + return nil + } + + decoded, err := decode(packet) + if err != nil { + return err + } + + switch msg := decoded.(type) { + case *channelOpenFailureMsg: + if err := ch.responseMessageReceived(); err != nil { + return err + } + ch.mux.chanList.remove(msg.PeersID) + ch.msg <- msg + case *channelOpenConfirmMsg: + if err := ch.responseMessageReceived(); err != nil { + return err + } + if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { + return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize) + } + ch.remoteId = msg.MyID + ch.maxRemotePayload = msg.MaxPacketSize + ch.remoteWin.add(msg.MyWindow) + ch.msg <- msg + case *windowAdjustMsg: + if !ch.remoteWin.add(msg.AdditionalBytes) { + return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes) + } + case *channelRequestMsg: + req := Request{ + Type: msg.Request, + WantReply: msg.WantReply, + Payload: msg.RequestSpecificData, + ch: ch, + } + + ch.incomingRequests <- &req + default: + ch.msg <- msg + } + return nil +} + +func (m *mux) newChannel(chanType string, direction channelDirection, extraData []byte) *channel { + ch := &channel{ + remoteWin: window{Cond: newCond()}, + myWindow: channelWindowSize, + pending: newBuffer(), + extPending: newBuffer(), + direction: direction, + incomingRequests: make(chan *Request, chanSize), + msg: make(chan interface{}, chanSize), + chanType: chanType, + extraData: extraData, + mux: m, + packetPool: make(map[uint32][]byte), + } + ch.localId = m.chanList.add(ch) + return ch +} + +var errUndecided = errors.New("ssh: must Accept or Reject channel") +var errDecidedAlready = errors.New("ssh: can call Accept or Reject only once") + +type extChannel struct { + code uint32 + ch *channel +} + +func (e *extChannel) Write(data []byte) (n int, err error) { + return e.ch.WriteExtended(data, e.code) +} + +func (e *extChannel) Read(data []byte) (n int, err error) { + return e.ch.ReadExtended(data, e.code) +} + +func (ch *channel) Accept() (Channel, <-chan *Request, error) { + if ch.decided { + return nil, nil, errDecidedAlready + } + ch.maxIncomingPayload = channelMaxPacket + confirm := channelOpenConfirmMsg{ + PeersID: ch.remoteId, + MyID: ch.localId, + MyWindow: ch.myWindow, + MaxPacketSize: ch.maxIncomingPayload, + } + ch.decided = true + if err := ch.sendMessage(confirm); err != nil { + return nil, nil, err + } + + return ch, ch.incomingRequests, nil +} + +func (ch *channel) Reject(reason RejectionReason, message string) error { + if ch.decided { + return errDecidedAlready + } + reject := channelOpenFailureMsg{ + PeersID: ch.remoteId, + Reason: reason, + Message: message, + Language: "en", + } + ch.decided = true + return ch.sendMessage(reject) +} + +func (ch *channel) Read(data []byte) (int, error) { + if !ch.decided { + return 0, errUndecided + } + return ch.ReadExtended(data, 0) +} + +func (ch *channel) Write(data []byte) (int, error) { + if !ch.decided { + return 0, errUndecided + } + return ch.WriteExtended(data, 0) +} + +func (ch *channel) CloseWrite() error { + if !ch.decided { + return errUndecided + } + ch.sentEOF = true + return ch.sendMessage(channelEOFMsg{ + PeersID: ch.remoteId}) +} + +func (ch *channel) Close() error { + if !ch.decided { + return errUndecided + } + + return ch.sendMessage(channelCloseMsg{ + PeersID: ch.remoteId}) +} + +// Extended returns an io.ReadWriter that sends and receives data on the given, +// SSH extended stream. Such streams are used, for example, for stderr. +func (ch *channel) Extended(code uint32) io.ReadWriter { + if !ch.decided { + return nil + } + return &extChannel{code, ch} +} + +func (ch *channel) Stderr() io.ReadWriter { + return ch.Extended(1) +} + +func (ch *channel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { + if !ch.decided { + return false, errUndecided + } + + if wantReply { + ch.sentRequestMu.Lock() + defer ch.sentRequestMu.Unlock() + } + + msg := channelRequestMsg{ + PeersID: ch.remoteId, + Request: name, + WantReply: wantReply, + RequestSpecificData: payload, + } + + if err := ch.sendMessage(msg); err != nil { + return false, err + } + + if wantReply { + m, ok := (<-ch.msg) + if !ok { + return false, io.EOF + } + switch m.(type) { + case *channelRequestFailureMsg: + return false, nil + case *channelRequestSuccessMsg: + return true, nil + default: + return false, fmt.Errorf("ssh: unexpected response to channel request: %#v", m) + } + } + + return false, nil +} + +// ackRequest either sends an ack or nack to the channel request. +func (ch *channel) ackRequest(ok bool) error { + if !ch.decided { + return errUndecided + } + + var msg interface{} + if !ok { + msg = channelRequestFailureMsg{ + PeersID: ch.remoteId, + } + } else { + msg = channelRequestSuccessMsg{ + PeersID: ch.remoteId, + } + } + return ch.sendMessage(msg) +} + +func (ch *channel) ChannelType() string { + return ch.chanType +} + +func (ch *channel) ExtraData() []byte { + return ch.extraData +} diff --git a/tempfork/sshtest/ssh/cipher.go b/tempfork/sshtest/ssh/cipher.go new file mode 100644 index 000000000..0533786f4 --- /dev/null +++ b/tempfork/sshtest/ssh/cipher.go @@ -0,0 +1,789 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/des" + "crypto/rc4" + "crypto/subtle" + "encoding/binary" + "errors" + "fmt" + "hash" + "io" + + "golang.org/x/crypto/chacha20" + "golang.org/x/crypto/poly1305" +) + +const ( + packetSizeMultiple = 16 // TODO(huin) this should be determined by the cipher. + + // RFC 4253 section 6.1 defines a minimum packet size of 32768 that implementations + // MUST be able to process (plus a few more kilobytes for padding and mac). The RFC + // indicates implementations SHOULD be able to handle larger packet sizes, but then + // waffles on about reasonable limits. + // + // OpenSSH caps their maxPacket at 256kB so we choose to do + // the same. maxPacket is also used to ensure that uint32 + // length fields do not overflow, so it should remain well + // below 4G. + maxPacket = 256 * 1024 +) + +// noneCipher implements cipher.Stream and provides no encryption. It is used +// by the transport before the first key-exchange. +type noneCipher struct{} + +func (c noneCipher) XORKeyStream(dst, src []byte) { + copy(dst, src) +} + +func newAESCTR(key, iv []byte) (cipher.Stream, error) { + c, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + return cipher.NewCTR(c, iv), nil +} + +func newRC4(key, iv []byte) (cipher.Stream, error) { + return rc4.NewCipher(key) +} + +type cipherMode struct { + keySize int + ivSize int + create func(key, iv []byte, macKey []byte, algs directionAlgorithms) (packetCipher, error) +} + +func streamCipherMode(skip int, createFunc func(key, iv []byte) (cipher.Stream, error)) func(key, iv []byte, macKey []byte, algs directionAlgorithms) (packetCipher, error) { + return func(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) { + stream, err := createFunc(key, iv) + if err != nil { + return nil, err + } + + var streamDump []byte + if skip > 0 { + streamDump = make([]byte, 512) + } + + for remainingToDump := skip; remainingToDump > 0; { + dumpThisTime := remainingToDump + if dumpThisTime > len(streamDump) { + dumpThisTime = len(streamDump) + } + stream.XORKeyStream(streamDump[:dumpThisTime], streamDump[:dumpThisTime]) + remainingToDump -= dumpThisTime + } + + mac := macModes[algs.MAC].new(macKey) + return &streamPacketCipher{ + mac: mac, + etm: macModes[algs.MAC].etm, + macResult: make([]byte, mac.Size()), + cipher: stream, + }, nil + } +} + +// cipherModes documents properties of supported ciphers. Ciphers not included +// are not supported and will not be negotiated, even if explicitly requested in +// ClientConfig.Crypto.Ciphers. +var cipherModes = map[string]*cipherMode{ + // Ciphers from RFC 4344, which introduced many CTR-based ciphers. Algorithms + // are defined in the order specified in the RFC. + "aes128-ctr": {16, aes.BlockSize, streamCipherMode(0, newAESCTR)}, + "aes192-ctr": {24, aes.BlockSize, streamCipherMode(0, newAESCTR)}, + "aes256-ctr": {32, aes.BlockSize, streamCipherMode(0, newAESCTR)}, + + // Ciphers from RFC 4345, which introduces security-improved arcfour ciphers. + // They are defined in the order specified in the RFC. + "arcfour128": {16, 0, streamCipherMode(1536, newRC4)}, + "arcfour256": {32, 0, streamCipherMode(1536, newRC4)}, + + // Cipher defined in RFC 4253, which describes SSH Transport Layer Protocol. + // Note that this cipher is not safe, as stated in RFC 4253: "Arcfour (and + // RC4) has problems with weak keys, and should be used with caution." + // RFC 4345 introduces improved versions of Arcfour. + "arcfour": {16, 0, streamCipherMode(0, newRC4)}, + + // AEAD ciphers + gcm128CipherID: {16, 12, newGCMCipher}, + gcm256CipherID: {32, 12, newGCMCipher}, + chacha20Poly1305ID: {64, 0, newChaCha20Cipher}, + + // CBC mode is insecure and so is not included in the default config. + // (See https://www.ieee-security.org/TC/SP2013/papers/4977a526.pdf). If absolutely + // needed, it's possible to specify a custom Config to enable it. + // You should expect that an active attacker can recover plaintext if + // you do. + aes128cbcID: {16, aes.BlockSize, newAESCBCCipher}, + + // 3des-cbc is insecure and is not included in the default + // config. + tripledescbcID: {24, des.BlockSize, newTripleDESCBCCipher}, +} + +// prefixLen is the length of the packet prefix that contains the packet length +// and number of padding bytes. +const prefixLen = 5 + +// streamPacketCipher is a packetCipher using a stream cipher. +type streamPacketCipher struct { + mac hash.Hash + cipher cipher.Stream + etm bool + + // The following members are to avoid per-packet allocations. + prefix [prefixLen]byte + seqNumBytes [4]byte + padding [2 * packetSizeMultiple]byte + packetData []byte + macResult []byte +} + +// readCipherPacket reads and decrypt a single packet from the reader argument. +func (s *streamPacketCipher) readCipherPacket(seqNum uint32, r io.Reader) ([]byte, error) { + if _, err := io.ReadFull(r, s.prefix[:]); err != nil { + return nil, err + } + + var encryptedPaddingLength [1]byte + if s.mac != nil && s.etm { + copy(encryptedPaddingLength[:], s.prefix[4:5]) + s.cipher.XORKeyStream(s.prefix[4:5], s.prefix[4:5]) + } else { + s.cipher.XORKeyStream(s.prefix[:], s.prefix[:]) + } + + length := binary.BigEndian.Uint32(s.prefix[0:4]) + paddingLength := uint32(s.prefix[4]) + + var macSize uint32 + if s.mac != nil { + s.mac.Reset() + binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum) + s.mac.Write(s.seqNumBytes[:]) + if s.etm { + s.mac.Write(s.prefix[:4]) + s.mac.Write(encryptedPaddingLength[:]) + } else { + s.mac.Write(s.prefix[:]) + } + macSize = uint32(s.mac.Size()) + } + + if length <= paddingLength+1 { + return nil, errors.New("ssh: invalid packet length, packet too small") + } + + if length > maxPacket { + return nil, errors.New("ssh: invalid packet length, packet too large") + } + + // the maxPacket check above ensures that length-1+macSize + // does not overflow. + if uint32(cap(s.packetData)) < length-1+macSize { + s.packetData = make([]byte, length-1+macSize) + } else { + s.packetData = s.packetData[:length-1+macSize] + } + + if _, err := io.ReadFull(r, s.packetData); err != nil { + return nil, err + } + mac := s.packetData[length-1:] + data := s.packetData[:length-1] + + if s.mac != nil && s.etm { + s.mac.Write(data) + } + + s.cipher.XORKeyStream(data, data) + + if s.mac != nil { + if !s.etm { + s.mac.Write(data) + } + s.macResult = s.mac.Sum(s.macResult[:0]) + if subtle.ConstantTimeCompare(s.macResult, mac) != 1 { + return nil, errors.New("ssh: MAC failure") + } + } + + return s.packetData[:length-paddingLength-1], nil +} + +// writeCipherPacket encrypts and sends a packet of data to the writer argument +func (s *streamPacketCipher) writeCipherPacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { + if len(packet) > maxPacket { + return errors.New("ssh: packet too large") + } + + aadlen := 0 + if s.mac != nil && s.etm { + // packet length is not encrypted for EtM modes + aadlen = 4 + } + + paddingLength := packetSizeMultiple - (prefixLen+len(packet)-aadlen)%packetSizeMultiple + if paddingLength < 4 { + paddingLength += packetSizeMultiple + } + + length := len(packet) + 1 + paddingLength + binary.BigEndian.PutUint32(s.prefix[:], uint32(length)) + s.prefix[4] = byte(paddingLength) + padding := s.padding[:paddingLength] + if _, err := io.ReadFull(rand, padding); err != nil { + return err + } + + if s.mac != nil { + s.mac.Reset() + binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum) + s.mac.Write(s.seqNumBytes[:]) + + if s.etm { + // For EtM algorithms, the packet length must stay unencrypted, + // but the following data (padding length) must be encrypted + s.cipher.XORKeyStream(s.prefix[4:5], s.prefix[4:5]) + } + + s.mac.Write(s.prefix[:]) + + if !s.etm { + // For non-EtM algorithms, the algorithm is applied on unencrypted data + s.mac.Write(packet) + s.mac.Write(padding) + } + } + + if !(s.mac != nil && s.etm) { + // For EtM algorithms, the padding length has already been encrypted + // and the packet length must remain unencrypted + s.cipher.XORKeyStream(s.prefix[:], s.prefix[:]) + } + + s.cipher.XORKeyStream(packet, packet) + s.cipher.XORKeyStream(padding, padding) + + if s.mac != nil && s.etm { + // For EtM algorithms, packet and padding must be encrypted + s.mac.Write(packet) + s.mac.Write(padding) + } + + if _, err := w.Write(s.prefix[:]); err != nil { + return err + } + if _, err := w.Write(packet); err != nil { + return err + } + if _, err := w.Write(padding); err != nil { + return err + } + + if s.mac != nil { + s.macResult = s.mac.Sum(s.macResult[:0]) + if _, err := w.Write(s.macResult); err != nil { + return err + } + } + + return nil +} + +type gcmCipher struct { + aead cipher.AEAD + prefix [4]byte + iv []byte + buf []byte +} + +func newGCMCipher(key, iv, unusedMacKey []byte, unusedAlgs directionAlgorithms) (packetCipher, error) { + c, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + aead, err := cipher.NewGCM(c) + if err != nil { + return nil, err + } + + return &gcmCipher{ + aead: aead, + iv: iv, + }, nil +} + +const gcmTagSize = 16 + +func (c *gcmCipher) writeCipherPacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { + // Pad out to multiple of 16 bytes. This is different from the + // stream cipher because that encrypts the length too. + padding := byte(packetSizeMultiple - (1+len(packet))%packetSizeMultiple) + if padding < 4 { + padding += packetSizeMultiple + } + + length := uint32(len(packet) + int(padding) + 1) + binary.BigEndian.PutUint32(c.prefix[:], length) + if _, err := w.Write(c.prefix[:]); err != nil { + return err + } + + if cap(c.buf) < int(length) { + c.buf = make([]byte, length) + } else { + c.buf = c.buf[:length] + } + + c.buf[0] = padding + copy(c.buf[1:], packet) + if _, err := io.ReadFull(rand, c.buf[1+len(packet):]); err != nil { + return err + } + c.buf = c.aead.Seal(c.buf[:0], c.iv, c.buf, c.prefix[:]) + if _, err := w.Write(c.buf); err != nil { + return err + } + c.incIV() + + return nil +} + +func (c *gcmCipher) incIV() { + for i := 4 + 7; i >= 4; i-- { + c.iv[i]++ + if c.iv[i] != 0 { + break + } + } +} + +func (c *gcmCipher) readCipherPacket(seqNum uint32, r io.Reader) ([]byte, error) { + if _, err := io.ReadFull(r, c.prefix[:]); err != nil { + return nil, err + } + length := binary.BigEndian.Uint32(c.prefix[:]) + if length > maxPacket { + return nil, errors.New("ssh: max packet length exceeded") + } + + if cap(c.buf) < int(length+gcmTagSize) { + c.buf = make([]byte, length+gcmTagSize) + } else { + c.buf = c.buf[:length+gcmTagSize] + } + + if _, err := io.ReadFull(r, c.buf); err != nil { + return nil, err + } + + plain, err := c.aead.Open(c.buf[:0], c.iv, c.buf, c.prefix[:]) + if err != nil { + return nil, err + } + c.incIV() + + if len(plain) == 0 { + return nil, errors.New("ssh: empty packet") + } + + padding := plain[0] + if padding < 4 { + // padding is a byte, so it automatically satisfies + // the maximum size, which is 255. + return nil, fmt.Errorf("ssh: illegal padding %d", padding) + } + + if int(padding+1) >= len(plain) { + return nil, fmt.Errorf("ssh: padding %d too large", padding) + } + plain = plain[1 : length-uint32(padding)] + return plain, nil +} + +// cbcCipher implements aes128-cbc cipher defined in RFC 4253 section 6.1 +type cbcCipher struct { + mac hash.Hash + macSize uint32 + decrypter cipher.BlockMode + encrypter cipher.BlockMode + + // The following members are to avoid per-packet allocations. + seqNumBytes [4]byte + packetData []byte + macResult []byte + + // Amount of data we should still read to hide which + // verification error triggered. + oracleCamouflage uint32 +} + +func newCBCCipher(c cipher.Block, key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) { + cbc := &cbcCipher{ + mac: macModes[algs.MAC].new(macKey), + decrypter: cipher.NewCBCDecrypter(c, iv), + encrypter: cipher.NewCBCEncrypter(c, iv), + packetData: make([]byte, 1024), + } + if cbc.mac != nil { + cbc.macSize = uint32(cbc.mac.Size()) + } + + return cbc, nil +} + +func newAESCBCCipher(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) { + c, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + cbc, err := newCBCCipher(c, key, iv, macKey, algs) + if err != nil { + return nil, err + } + + return cbc, nil +} + +func newTripleDESCBCCipher(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) { + c, err := des.NewTripleDESCipher(key) + if err != nil { + return nil, err + } + + cbc, err := newCBCCipher(c, key, iv, macKey, algs) + if err != nil { + return nil, err + } + + return cbc, nil +} + +func maxUInt32(a, b int) uint32 { + if a > b { + return uint32(a) + } + return uint32(b) +} + +const ( + cbcMinPacketSizeMultiple = 8 + cbcMinPacketSize = 16 + cbcMinPaddingSize = 4 +) + +// cbcError represents a verification error that may leak information. +type cbcError string + +func (e cbcError) Error() string { return string(e) } + +func (c *cbcCipher) readCipherPacket(seqNum uint32, r io.Reader) ([]byte, error) { + p, err := c.readCipherPacketLeaky(seqNum, r) + if err != nil { + if _, ok := err.(cbcError); ok { + // Verification error: read a fixed amount of + // data, to make distinguishing between + // failing MAC and failing length check more + // difficult. + io.CopyN(io.Discard, r, int64(c.oracleCamouflage)) + } + } + return p, err +} + +func (c *cbcCipher) readCipherPacketLeaky(seqNum uint32, r io.Reader) ([]byte, error) { + blockSize := c.decrypter.BlockSize() + + // Read the header, which will include some of the subsequent data in the + // case of block ciphers - this is copied back to the payload later. + // How many bytes of payload/padding will be read with this first read. + firstBlockLength := uint32((prefixLen + blockSize - 1) / blockSize * blockSize) + firstBlock := c.packetData[:firstBlockLength] + if _, err := io.ReadFull(r, firstBlock); err != nil { + return nil, err + } + + c.oracleCamouflage = maxPacket + 4 + c.macSize - firstBlockLength + + c.decrypter.CryptBlocks(firstBlock, firstBlock) + length := binary.BigEndian.Uint32(firstBlock[:4]) + if length > maxPacket { + return nil, cbcError("ssh: packet too large") + } + if length+4 < maxUInt32(cbcMinPacketSize, blockSize) { + // The minimum size of a packet is 16 (or the cipher block size, whichever + // is larger) bytes. + return nil, cbcError("ssh: packet too small") + } + // The length of the packet (including the length field but not the MAC) must + // be a multiple of the block size or 8, whichever is larger. + if (length+4)%maxUInt32(cbcMinPacketSizeMultiple, blockSize) != 0 { + return nil, cbcError("ssh: invalid packet length multiple") + } + + paddingLength := uint32(firstBlock[4]) + if paddingLength < cbcMinPaddingSize || length <= paddingLength+1 { + return nil, cbcError("ssh: invalid packet length") + } + + // Positions within the c.packetData buffer: + macStart := 4 + length + paddingStart := macStart - paddingLength + + // Entire packet size, starting before length, ending at end of mac. + entirePacketSize := macStart + c.macSize + + // Ensure c.packetData is large enough for the entire packet data. + if uint32(cap(c.packetData)) < entirePacketSize { + // Still need to upsize and copy, but this should be rare at runtime, only + // on upsizing the packetData buffer. + c.packetData = make([]byte, entirePacketSize) + copy(c.packetData, firstBlock) + } else { + c.packetData = c.packetData[:entirePacketSize] + } + + n, err := io.ReadFull(r, c.packetData[firstBlockLength:]) + if err != nil { + return nil, err + } + c.oracleCamouflage -= uint32(n) + + remainingCrypted := c.packetData[firstBlockLength:macStart] + c.decrypter.CryptBlocks(remainingCrypted, remainingCrypted) + + mac := c.packetData[macStart:] + if c.mac != nil { + c.mac.Reset() + binary.BigEndian.PutUint32(c.seqNumBytes[:], seqNum) + c.mac.Write(c.seqNumBytes[:]) + c.mac.Write(c.packetData[:macStart]) + c.macResult = c.mac.Sum(c.macResult[:0]) + if subtle.ConstantTimeCompare(c.macResult, mac) != 1 { + return nil, cbcError("ssh: MAC failure") + } + } + + return c.packetData[prefixLen:paddingStart], nil +} + +func (c *cbcCipher) writeCipherPacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { + effectiveBlockSize := maxUInt32(cbcMinPacketSizeMultiple, c.encrypter.BlockSize()) + + // Length of encrypted portion of the packet (header, payload, padding). + // Enforce minimum padding and packet size. + encLength := maxUInt32(prefixLen+len(packet)+cbcMinPaddingSize, cbcMinPaddingSize) + // Enforce block size. + encLength = (encLength + effectiveBlockSize - 1) / effectiveBlockSize * effectiveBlockSize + + length := encLength - 4 + paddingLength := int(length) - (1 + len(packet)) + + // Overall buffer contains: header, payload, padding, mac. + // Space for the MAC is reserved in the capacity but not the slice length. + bufferSize := encLength + c.macSize + if uint32(cap(c.packetData)) < bufferSize { + c.packetData = make([]byte, encLength, bufferSize) + } else { + c.packetData = c.packetData[:encLength] + } + + p := c.packetData + + // Packet header. + binary.BigEndian.PutUint32(p, length) + p = p[4:] + p[0] = byte(paddingLength) + + // Payload. + p = p[1:] + copy(p, packet) + + // Padding. + p = p[len(packet):] + if _, err := io.ReadFull(rand, p); err != nil { + return err + } + + if c.mac != nil { + c.mac.Reset() + binary.BigEndian.PutUint32(c.seqNumBytes[:], seqNum) + c.mac.Write(c.seqNumBytes[:]) + c.mac.Write(c.packetData) + // The MAC is now appended into the capacity reserved for it earlier. + c.packetData = c.mac.Sum(c.packetData) + } + + c.encrypter.CryptBlocks(c.packetData[:encLength], c.packetData[:encLength]) + + if _, err := w.Write(c.packetData); err != nil { + return err + } + + return nil +} + +const chacha20Poly1305ID = "chacha20-poly1305@openssh.com" + +// chacha20Poly1305Cipher implements the chacha20-poly1305@openssh.com +// AEAD, which is described here: +// +// https://tools.ietf.org/html/draft-josefsson-ssh-chacha20-poly1305-openssh-00 +// +// the methods here also implement padding, which RFC 4253 Section 6 +// also requires of stream ciphers. +type chacha20Poly1305Cipher struct { + lengthKey [32]byte + contentKey [32]byte + buf []byte +} + +func newChaCha20Cipher(key, unusedIV, unusedMACKey []byte, unusedAlgs directionAlgorithms) (packetCipher, error) { + if len(key) != 64 { + panic(len(key)) + } + + c := &chacha20Poly1305Cipher{ + buf: make([]byte, 256), + } + + copy(c.contentKey[:], key[:32]) + copy(c.lengthKey[:], key[32:]) + return c, nil +} + +func (c *chacha20Poly1305Cipher) readCipherPacket(seqNum uint32, r io.Reader) ([]byte, error) { + nonce := make([]byte, 12) + binary.BigEndian.PutUint32(nonce[8:], seqNum) + s, err := chacha20.NewUnauthenticatedCipher(c.contentKey[:], nonce) + if err != nil { + return nil, err + } + var polyKey, discardBuf [32]byte + s.XORKeyStream(polyKey[:], polyKey[:]) + s.XORKeyStream(discardBuf[:], discardBuf[:]) // skip the next 32 bytes + + encryptedLength := c.buf[:4] + if _, err := io.ReadFull(r, encryptedLength); err != nil { + return nil, err + } + + var lenBytes [4]byte + ls, err := chacha20.NewUnauthenticatedCipher(c.lengthKey[:], nonce) + if err != nil { + return nil, err + } + ls.XORKeyStream(lenBytes[:], encryptedLength) + + length := binary.BigEndian.Uint32(lenBytes[:]) + if length > maxPacket { + return nil, errors.New("ssh: invalid packet length, packet too large") + } + + contentEnd := 4 + length + packetEnd := contentEnd + poly1305.TagSize + if uint32(cap(c.buf)) < packetEnd { + c.buf = make([]byte, packetEnd) + copy(c.buf[:], encryptedLength) + } else { + c.buf = c.buf[:packetEnd] + } + + if _, err := io.ReadFull(r, c.buf[4:packetEnd]); err != nil { + return nil, err + } + + var mac [poly1305.TagSize]byte + copy(mac[:], c.buf[contentEnd:packetEnd]) + if !poly1305.Verify(&mac, c.buf[:contentEnd], &polyKey) { + return nil, errors.New("ssh: MAC failure") + } + + plain := c.buf[4:contentEnd] + s.XORKeyStream(plain, plain) + + if len(plain) == 0 { + return nil, errors.New("ssh: empty packet") + } + + padding := plain[0] + if padding < 4 { + // padding is a byte, so it automatically satisfies + // the maximum size, which is 255. + return nil, fmt.Errorf("ssh: illegal padding %d", padding) + } + + if int(padding)+1 >= len(plain) { + return nil, fmt.Errorf("ssh: padding %d too large", padding) + } + + plain = plain[1 : len(plain)-int(padding)] + + return plain, nil +} + +func (c *chacha20Poly1305Cipher) writeCipherPacket(seqNum uint32, w io.Writer, rand io.Reader, payload []byte) error { + nonce := make([]byte, 12) + binary.BigEndian.PutUint32(nonce[8:], seqNum) + s, err := chacha20.NewUnauthenticatedCipher(c.contentKey[:], nonce) + if err != nil { + return err + } + var polyKey, discardBuf [32]byte + s.XORKeyStream(polyKey[:], polyKey[:]) + s.XORKeyStream(discardBuf[:], discardBuf[:]) // skip the next 32 bytes + + // There is no blocksize, so fall back to multiple of 8 byte + // padding, as described in RFC 4253, Sec 6. + const packetSizeMultiple = 8 + + padding := packetSizeMultiple - (1+len(payload))%packetSizeMultiple + if padding < 4 { + padding += packetSizeMultiple + } + + // size (4 bytes), padding (1), payload, padding, tag. + totalLength := 4 + 1 + len(payload) + padding + poly1305.TagSize + if cap(c.buf) < totalLength { + c.buf = make([]byte, totalLength) + } else { + c.buf = c.buf[:totalLength] + } + + binary.BigEndian.PutUint32(c.buf, uint32(1+len(payload)+padding)) + ls, err := chacha20.NewUnauthenticatedCipher(c.lengthKey[:], nonce) + if err != nil { + return err + } + ls.XORKeyStream(c.buf, c.buf[:4]) + c.buf[4] = byte(padding) + copy(c.buf[5:], payload) + packetEnd := 5 + len(payload) + padding + if _, err := io.ReadFull(rand, c.buf[5+len(payload):packetEnd]); err != nil { + return err + } + + s.XORKeyStream(c.buf[4:], c.buf[4:packetEnd]) + + var mac [poly1305.TagSize]byte + poly1305.Sum(&mac, c.buf[:packetEnd], &polyKey) + + copy(c.buf[packetEnd:], mac[:]) + + if _, err := w.Write(c.buf); err != nil { + return err + } + return nil +} diff --git a/tempfork/sshtest/ssh/cipher_test.go b/tempfork/sshtest/ssh/cipher_test.go new file mode 100644 index 000000000..fe339862c --- /dev/null +++ b/tempfork/sshtest/ssh/cipher_test.go @@ -0,0 +1,231 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto" + "crypto/rand" + "encoding/binary" + "io" + "testing" + + "golang.org/x/crypto/chacha20" + "golang.org/x/crypto/poly1305" +) + +func TestDefaultCiphersExist(t *testing.T) { + for _, cipherAlgo := range supportedCiphers { + if _, ok := cipherModes[cipherAlgo]; !ok { + t.Errorf("supported cipher %q is unknown", cipherAlgo) + } + } + for _, cipherAlgo := range preferredCiphers { + if _, ok := cipherModes[cipherAlgo]; !ok { + t.Errorf("preferred cipher %q is unknown", cipherAlgo) + } + } +} + +func TestPacketCiphers(t *testing.T) { + defaultMac := "hmac-sha2-256" + defaultCipher := "aes128-ctr" + for cipher := range cipherModes { + t.Run("cipher="+cipher, + func(t *testing.T) { testPacketCipher(t, cipher, defaultMac) }) + } + for mac := range macModes { + t.Run("mac="+mac, + func(t *testing.T) { testPacketCipher(t, defaultCipher, mac) }) + } +} + +func testPacketCipher(t *testing.T, cipher, mac string) { + kr := &kexResult{Hash: crypto.SHA1} + algs := directionAlgorithms{ + Cipher: cipher, + MAC: mac, + Compression: "none", + } + client, err := newPacketCipher(clientKeys, algs, kr) + if err != nil { + t.Fatalf("newPacketCipher(client, %q, %q): %v", cipher, mac, err) + } + server, err := newPacketCipher(clientKeys, algs, kr) + if err != nil { + t.Fatalf("newPacketCipher(client, %q, %q): %v", cipher, mac, err) + } + + want := "bla bla" + input := []byte(want) + buf := &bytes.Buffer{} + if err := client.writeCipherPacket(0, buf, rand.Reader, input); err != nil { + t.Fatalf("writeCipherPacket(%q, %q): %v", cipher, mac, err) + } + + packet, err := server.readCipherPacket(0, buf) + if err != nil { + t.Fatalf("readCipherPacket(%q, %q): %v", cipher, mac, err) + } + + if string(packet) != want { + t.Errorf("roundtrip(%q, %q): got %q, want %q", cipher, mac, packet, want) + } +} + +func TestCBCOracleCounterMeasure(t *testing.T) { + kr := &kexResult{Hash: crypto.SHA1} + algs := directionAlgorithms{ + Cipher: aes128cbcID, + MAC: "hmac-sha1", + Compression: "none", + } + client, err := newPacketCipher(clientKeys, algs, kr) + if err != nil { + t.Fatalf("newPacketCipher(client): %v", err) + } + + want := "bla bla" + input := []byte(want) + buf := &bytes.Buffer{} + if err := client.writeCipherPacket(0, buf, rand.Reader, input); err != nil { + t.Errorf("writeCipherPacket: %v", err) + } + + packetSize := buf.Len() + buf.Write(make([]byte, 2*maxPacket)) + + // We corrupt each byte, but this usually will only test the + // 'packet too large' or 'MAC failure' cases. + lastRead := -1 + for i := 0; i < packetSize; i++ { + server, err := newPacketCipher(clientKeys, algs, kr) + if err != nil { + t.Fatalf("newPacketCipher(client): %v", err) + } + + fresh := &bytes.Buffer{} + fresh.Write(buf.Bytes()) + fresh.Bytes()[i] ^= 0x01 + + before := fresh.Len() + _, err = server.readCipherPacket(0, fresh) + if err == nil { + t.Errorf("corrupt byte %d: readCipherPacket succeeded ", i) + continue + } + if _, ok := err.(cbcError); !ok { + t.Errorf("corrupt byte %d: got %v (%T), want cbcError", i, err, err) + continue + } + + after := fresh.Len() + bytesRead := before - after + if bytesRead < maxPacket { + t.Errorf("corrupt byte %d: read %d bytes, want more than %d", i, bytesRead, maxPacket) + continue + } + + if i > 0 && bytesRead != lastRead { + t.Errorf("corrupt byte %d: read %d bytes, want %d bytes read", i, bytesRead, lastRead) + } + lastRead = bytesRead + } +} + +func TestCVE202143565(t *testing.T) { + tests := []struct { + cipher string + constructPacket func(packetCipher) io.Reader + }{ + { + cipher: gcm128CipherID, + constructPacket: func(client packetCipher) io.Reader { + internalCipher := client.(*gcmCipher) + b := &bytes.Buffer{} + prefix := [4]byte{} + if _, err := b.Write(prefix[:]); err != nil { + t.Fatal(err) + } + internalCipher.buf = internalCipher.aead.Seal(internalCipher.buf[:0], internalCipher.iv, []byte{}, prefix[:]) + if _, err := b.Write(internalCipher.buf); err != nil { + t.Fatal(err) + } + internalCipher.incIV() + + return b + }, + }, + { + cipher: chacha20Poly1305ID, + constructPacket: func(client packetCipher) io.Reader { + internalCipher := client.(*chacha20Poly1305Cipher) + b := &bytes.Buffer{} + + nonce := make([]byte, 12) + s, err := chacha20.NewUnauthenticatedCipher(internalCipher.contentKey[:], nonce) + if err != nil { + t.Fatal(err) + } + var polyKey, discardBuf [32]byte + s.XORKeyStream(polyKey[:], polyKey[:]) + s.XORKeyStream(discardBuf[:], discardBuf[:]) // skip the next 32 bytes + + internalCipher.buf = make([]byte, 4+poly1305.TagSize) + binary.BigEndian.PutUint32(internalCipher.buf, 0) + ls, err := chacha20.NewUnauthenticatedCipher(internalCipher.lengthKey[:], nonce) + if err != nil { + t.Fatal(err) + } + ls.XORKeyStream(internalCipher.buf, internalCipher.buf[:4]) + if _, err := io.ReadFull(rand.Reader, internalCipher.buf[4:4]); err != nil { + t.Fatal(err) + } + + s.XORKeyStream(internalCipher.buf[4:], internalCipher.buf[4:4]) + + var tag [poly1305.TagSize]byte + poly1305.Sum(&tag, internalCipher.buf[:4], &polyKey) + + copy(internalCipher.buf[4:], tag[:]) + + if _, err := b.Write(internalCipher.buf); err != nil { + t.Fatal(err) + } + + return b + }, + }, + } + + for _, tc := range tests { + mac := "hmac-sha2-256" + + kr := &kexResult{Hash: crypto.SHA1} + algs := directionAlgorithms{ + Cipher: tc.cipher, + MAC: mac, + Compression: "none", + } + client, err := newPacketCipher(clientKeys, algs, kr) + if err != nil { + t.Fatalf("newPacketCipher(client, %q, %q): %v", tc.cipher, mac, err) + } + server, err := newPacketCipher(clientKeys, algs, kr) + if err != nil { + t.Fatalf("newPacketCipher(client, %q, %q): %v", tc.cipher, mac, err) + } + + b := tc.constructPacket(client) + + wantErr := "ssh: empty packet" + _, err = server.readCipherPacket(0, b) + if err == nil { + t.Fatalf("readCipherPacket(%q, %q): didn't fail with empty packet", tc.cipher, mac) + } else if err.Error() != wantErr { + t.Fatalf("readCipherPacket(%q, %q): unexpected error, got %q, want %q", tc.cipher, mac, err, wantErr) + } + } +} diff --git a/tempfork/sshtest/ssh/client.go b/tempfork/sshtest/ssh/client.go new file mode 100644 index 000000000..fd8c49749 --- /dev/null +++ b/tempfork/sshtest/ssh/client.go @@ -0,0 +1,282 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "errors" + "fmt" + "net" + "os" + "sync" + "time" +) + +// Client implements a traditional SSH client that supports shells, +// subprocesses, TCP port/streamlocal forwarding and tunneled dialing. +type Client struct { + Conn + + handleForwardsOnce sync.Once // guards calling (*Client).handleForwards + + forwards forwardList // forwarded tcpip connections from the remote side + mu sync.Mutex + channelHandlers map[string]chan NewChannel +} + +// HandleChannelOpen returns a channel on which NewChannel requests +// for the given type are sent. If the type already is being handled, +// nil is returned. The channel is closed when the connection is closed. +func (c *Client) HandleChannelOpen(channelType string) <-chan NewChannel { + c.mu.Lock() + defer c.mu.Unlock() + if c.channelHandlers == nil { + // The SSH channel has been closed. + c := make(chan NewChannel) + close(c) + return c + } + + ch := c.channelHandlers[channelType] + if ch != nil { + return nil + } + + ch = make(chan NewChannel, chanSize) + c.channelHandlers[channelType] = ch + return ch +} + +// NewClient creates a Client on top of the given connection. +func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client { + conn := &Client{ + Conn: c, + channelHandlers: make(map[string]chan NewChannel, 1), + } + + go conn.handleGlobalRequests(reqs) + go conn.handleChannelOpens(chans) + go func() { + conn.Wait() + conn.forwards.closeAll() + }() + return conn +} + +// NewClientConn establishes an authenticated SSH connection using c +// as the underlying transport. The Request and NewChannel channels +// must be serviced or the connection will hang. +func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan NewChannel, <-chan *Request, error) { + fullConf := *config + fullConf.SetDefaults() + if fullConf.HostKeyCallback == nil { + c.Close() + return nil, nil, nil, errors.New("ssh: must specify HostKeyCallback") + } + + conn := &connection{ + sshConn: sshConn{conn: c, user: fullConf.User}, + } + + if err := conn.clientHandshake(addr, &fullConf); err != nil { + c.Close() + return nil, nil, nil, fmt.Errorf("ssh: handshake failed: %w", err) + } + conn.mux = newMux(conn.transport) + return conn, conn.mux.incomingChannels, conn.mux.incomingRequests, nil +} + +// clientHandshake performs the client side key exchange. See RFC 4253 Section +// 7. +func (c *connection) clientHandshake(dialAddress string, config *ClientConfig) error { + if config.ClientVersion != "" { + c.clientVersion = []byte(config.ClientVersion) + } else { + c.clientVersion = []byte(packageVersion) + } + var err error + c.serverVersion, err = exchangeVersions(c.sshConn.conn, c.clientVersion) + if err != nil { + return err + } + + c.transport = newClientTransport( + newTransport(c.sshConn.conn, config.Rand, true /* is client */), + c.clientVersion, c.serverVersion, config, dialAddress, c.sshConn.RemoteAddr()) + if err := c.transport.waitSession(); err != nil { + return err + } + + c.sessionID = c.transport.getSessionID() + return c.clientAuthenticate(config) +} + +// verifyHostKeySignature verifies the host key obtained in the key exchange. +// algo is the negotiated algorithm, and may be a certificate type. +func verifyHostKeySignature(hostKey PublicKey, algo string, result *kexResult) error { + sig, rest, ok := parseSignatureBody(result.Signature) + if len(rest) > 0 || !ok { + return errors.New("ssh: signature parse error") + } + + if a := underlyingAlgo(algo); sig.Format != a { + return fmt.Errorf("ssh: invalid signature algorithm %q, expected %q", sig.Format, a) + } + + return hostKey.Verify(result.H, sig) +} + +// NewSession opens a new Session for this client. (A session is a remote +// execution of a program.) +func (c *Client) NewSession() (*Session, error) { + ch, in, err := c.OpenChannel("session", nil) + if err != nil { + return nil, err + } + return newSession(ch, in) +} + +func (c *Client) handleGlobalRequests(incoming <-chan *Request) { + for r := range incoming { + // This handles keepalive messages and matches + // the behaviour of OpenSSH. + r.Reply(false, nil) + } +} + +// handleChannelOpens channel open messages from the remote side. +func (c *Client) handleChannelOpens(in <-chan NewChannel) { + for ch := range in { + c.mu.Lock() + handler := c.channelHandlers[ch.ChannelType()] + c.mu.Unlock() + + if handler != nil { + handler <- ch + } else { + ch.Reject(UnknownChannelType, fmt.Sprintf("unknown channel type: %v", ch.ChannelType())) + } + } + + c.mu.Lock() + for _, ch := range c.channelHandlers { + close(ch) + } + c.channelHandlers = nil + c.mu.Unlock() +} + +// Dial starts a client connection to the given SSH server. It is a +// convenience function that connects to the given network address, +// initiates the SSH handshake, and then sets up a Client. For access +// to incoming channels and requests, use net.Dial with NewClientConn +// instead. +func Dial(network, addr string, config *ClientConfig) (*Client, error) { + conn, err := net.DialTimeout(network, addr, config.Timeout) + if err != nil { + return nil, err + } + c, chans, reqs, err := NewClientConn(conn, addr, config) + if err != nil { + return nil, err + } + return NewClient(c, chans, reqs), nil +} + +// HostKeyCallback is the function type used for verifying server +// keys. A HostKeyCallback must return nil if the host key is OK, or +// an error to reject it. It receives the hostname as passed to Dial +// or NewClientConn. The remote address is the RemoteAddr of the +// net.Conn underlying the SSH connection. +type HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error + +// BannerCallback is the function type used for treat the banner sent by +// the server. A BannerCallback receives the message sent by the remote server. +type BannerCallback func(message string) error + +// A ClientConfig structure is used to configure a Client. It must not be +// modified after having been passed to an SSH function. +type ClientConfig struct { + // Config contains configuration that is shared between clients and + // servers. + Config + + // User contains the username to authenticate as. + User string + + // Auth contains possible authentication methods to use with the + // server. Only the first instance of a particular RFC 4252 method will + // be used during authentication. + Auth []AuthMethod + + // HostKeyCallback is called during the cryptographic + // handshake to validate the server's host key. The client + // configuration must supply this callback for the connection + // to succeed. The functions InsecureIgnoreHostKey or + // FixedHostKey can be used for simplistic host key checks. + HostKeyCallback HostKeyCallback + + // BannerCallback is called during the SSH dance to display a custom + // server's message. The client configuration can supply this callback to + // handle it as wished. The function BannerDisplayStderr can be used for + // simplistic display on Stderr. + BannerCallback BannerCallback + + // ClientVersion contains the version identification string that will + // be used for the connection. If empty, a reasonable default is used. + ClientVersion string + + // HostKeyAlgorithms lists the public key algorithms that the client will + // accept from the server for host key authentication, in order of + // preference. If empty, a reasonable default is used. Any + // string returned from a PublicKey.Type method may be used, or + // any of the CertAlgo and KeyAlgo constants. + HostKeyAlgorithms []string + + // Timeout is the maximum amount of time for the TCP connection to establish. + // + // A Timeout of zero means no timeout. + Timeout time.Duration +} + +// InsecureIgnoreHostKey returns a function that can be used for +// ClientConfig.HostKeyCallback to accept any host key. It should +// not be used for production code. +func InsecureIgnoreHostKey() HostKeyCallback { + return func(hostname string, remote net.Addr, key PublicKey) error { + return nil + } +} + +type fixedHostKey struct { + key PublicKey +} + +func (f *fixedHostKey) check(hostname string, remote net.Addr, key PublicKey) error { + if f.key == nil { + return fmt.Errorf("ssh: required host key was nil") + } + if !bytes.Equal(key.Marshal(), f.key.Marshal()) { + return fmt.Errorf("ssh: host key mismatch") + } + return nil +} + +// FixedHostKey returns a function for use in +// ClientConfig.HostKeyCallback to accept only a specific host key. +func FixedHostKey(key PublicKey) HostKeyCallback { + hk := &fixedHostKey{key} + return hk.check +} + +// BannerDisplayStderr returns a function that can be used for +// ClientConfig.BannerCallback to display banners on os.Stderr. +func BannerDisplayStderr() BannerCallback { + return func(banner string) error { + _, err := os.Stderr.WriteString(banner) + + return err + } +} diff --git a/tempfork/sshtest/ssh/client_auth.go b/tempfork/sshtest/ssh/client_auth.go new file mode 100644 index 000000000..b86dde151 --- /dev/null +++ b/tempfork/sshtest/ssh/client_auth.go @@ -0,0 +1,796 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "errors" + "fmt" + "io" + "strings" +) + +type authResult int + +const ( + authFailure authResult = iota + authPartialSuccess + authSuccess +) + +// clientAuthenticate authenticates with the remote server. See RFC 4252. +func (c *connection) clientAuthenticate(config *ClientConfig) error { + // initiate user auth session + if err := c.transport.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})); err != nil { + return err + } + packet, err := c.transport.readPacket() + if err != nil { + return err + } + // The server may choose to send a SSH_MSG_EXT_INFO at this point (if we + // advertised willingness to receive one, which we always do) or not. See + // RFC 8308, Section 2.4. + extensions := make(map[string][]byte) + if len(packet) > 0 && packet[0] == msgExtInfo { + var extInfo extInfoMsg + if err := Unmarshal(packet, &extInfo); err != nil { + return err + } + payload := extInfo.Payload + for i := uint32(0); i < extInfo.NumExtensions; i++ { + name, rest, ok := parseString(payload) + if !ok { + return parseError(msgExtInfo) + } + value, rest, ok := parseString(rest) + if !ok { + return parseError(msgExtInfo) + } + extensions[string(name)] = value + payload = rest + } + packet, err = c.transport.readPacket() + if err != nil { + return err + } + } + var serviceAccept serviceAcceptMsg + if err := Unmarshal(packet, &serviceAccept); err != nil { + return err + } + + // during the authentication phase the client first attempts the "none" method + // then any untried methods suggested by the server. + var tried []string + var lastMethods []string + + sessionID := c.transport.getSessionID() + for auth := AuthMethod(new(noneAuth)); auth != nil; { + ok, methods, err := auth.auth(sessionID, config.User, c.transport, config.Rand, extensions) + if err != nil { + // On disconnect, return error immediately + if _, ok := err.(*disconnectMsg); ok { + return err + } + // We return the error later if there is no other method left to + // try. + ok = authFailure + } + if ok == authSuccess { + // success + return nil + } else if ok == authFailure { + if m := auth.method(); !contains(tried, m) { + tried = append(tried, m) + } + } + if methods == nil { + methods = lastMethods + } + lastMethods = methods + + auth = nil + + findNext: + for _, a := range config.Auth { + candidateMethod := a.method() + if contains(tried, candidateMethod) { + continue + } + for _, meth := range methods { + if meth == candidateMethod { + auth = a + break findNext + } + } + } + + if auth == nil && err != nil { + // We have an error and there are no other authentication methods to + // try, so we return it. + return err + } + } + return fmt.Errorf("ssh: unable to authenticate, attempted methods %v, no supported methods remain", tried) +} + +func contains(list []string, e string) bool { + for _, s := range list { + if s == e { + return true + } + } + return false +} + +// An AuthMethod represents an instance of an RFC 4252 authentication method. +type AuthMethod interface { + // auth authenticates user over transport t. + // Returns true if authentication is successful. + // If authentication is not successful, a []string of alternative + // method names is returned. If the slice is nil, it will be ignored + // and the previous set of possible methods will be reused. + auth(session []byte, user string, p packetConn, rand io.Reader, extensions map[string][]byte) (authResult, []string, error) + + // method returns the RFC 4252 method name. + method() string +} + +// "none" authentication, RFC 4252 section 5.2. +type noneAuth int + +func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) { + if err := c.writePacket(Marshal(&userAuthRequestMsg{ + User: user, + Service: serviceSSH, + Method: "none", + })); err != nil { + return authFailure, nil, err + } + + return handleAuthResponse(c) +} + +func (n *noneAuth) method() string { + return "none" +} + +// passwordCallback is an AuthMethod that fetches the password through +// a function call, e.g. by prompting the user. +type passwordCallback func() (password string, err error) + +func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) { + type passwordAuthMsg struct { + User string `sshtype:"50"` + Service string + Method string + Reply bool + Password string + } + + pw, err := cb() + // REVIEW NOTE: is there a need to support skipping a password attempt? + // The program may only find out that the user doesn't have a password + // when prompting. + if err != nil { + return authFailure, nil, err + } + + if err := c.writePacket(Marshal(&passwordAuthMsg{ + User: user, + Service: serviceSSH, + Method: cb.method(), + Reply: false, + Password: pw, + })); err != nil { + return authFailure, nil, err + } + + return handleAuthResponse(c) +} + +func (cb passwordCallback) method() string { + return "password" +} + +// Password returns an AuthMethod using the given password. +func Password(secret string) AuthMethod { + return passwordCallback(func() (string, error) { return secret, nil }) +} + +// PasswordCallback returns an AuthMethod that uses a callback for +// fetching a password. +func PasswordCallback(prompt func() (secret string, err error)) AuthMethod { + return passwordCallback(prompt) +} + +type publickeyAuthMsg struct { + User string `sshtype:"50"` + Service string + Method string + // HasSig indicates to the receiver packet that the auth request is signed and + // should be used for authentication of the request. + HasSig bool + Algoname string + PubKey []byte + // Sig is tagged with "rest" so Marshal will exclude it during + // validateKey + Sig []byte `ssh:"rest"` +} + +// publicKeyCallback is an AuthMethod that uses a set of key +// pairs for authentication. +type publicKeyCallback func() ([]Signer, error) + +func (cb publicKeyCallback) method() string { + return "publickey" +} + +func pickSignatureAlgorithm(signer Signer, extensions map[string][]byte) (MultiAlgorithmSigner, string, error) { + var as MultiAlgorithmSigner + keyFormat := signer.PublicKey().Type() + + // If the signer implements MultiAlgorithmSigner we use the algorithms it + // support, if it implements AlgorithmSigner we assume it supports all + // algorithms, otherwise only the key format one. + switch s := signer.(type) { + case MultiAlgorithmSigner: + as = s + case AlgorithmSigner: + as = &multiAlgorithmSigner{ + AlgorithmSigner: s, + supportedAlgorithms: algorithmsForKeyFormat(underlyingAlgo(keyFormat)), + } + default: + as = &multiAlgorithmSigner{ + AlgorithmSigner: algorithmSignerWrapper{signer}, + supportedAlgorithms: []string{underlyingAlgo(keyFormat)}, + } + } + + getFallbackAlgo := func() (string, error) { + // Fallback to use if there is no "server-sig-algs" extension or a + // common algorithm cannot be found. We use the public key format if the + // MultiAlgorithmSigner supports it, otherwise we return an error. + if !contains(as.Algorithms(), underlyingAlgo(keyFormat)) { + return "", fmt.Errorf("ssh: no common public key signature algorithm, server only supports %q for key type %q, signer only supports %v", + underlyingAlgo(keyFormat), keyFormat, as.Algorithms()) + } + return keyFormat, nil + } + + extPayload, ok := extensions["server-sig-algs"] + if !ok { + // If there is no "server-sig-algs" extension use the fallback + // algorithm. + algo, err := getFallbackAlgo() + return as, algo, err + } + + // The server-sig-algs extension only carries underlying signature + // algorithm, but we are trying to select a protocol-level public key + // algorithm, which might be a certificate type. Extend the list of server + // supported algorithms to include the corresponding certificate algorithms. + serverAlgos := strings.Split(string(extPayload), ",") + for _, algo := range serverAlgos { + if certAlgo, ok := certificateAlgo(algo); ok { + serverAlgos = append(serverAlgos, certAlgo) + } + } + + // Filter algorithms based on those supported by MultiAlgorithmSigner. + var keyAlgos []string + for _, algo := range algorithmsForKeyFormat(keyFormat) { + if contains(as.Algorithms(), underlyingAlgo(algo)) { + keyAlgos = append(keyAlgos, algo) + } + } + + algo, err := findCommon("public key signature algorithm", keyAlgos, serverAlgos) + if err != nil { + // If there is no overlap, return the fallback algorithm to support + // servers that fail to list all supported algorithms. + algo, err := getFallbackAlgo() + return as, algo, err + } + return as, algo, nil +} + +func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader, extensions map[string][]byte) (authResult, []string, error) { + // Authentication is performed by sending an enquiry to test if a key is + // acceptable to the remote. If the key is acceptable, the client will + // attempt to authenticate with the valid key. If not the client will repeat + // the process with the remaining keys. + + signers, err := cb() + if err != nil { + return authFailure, nil, err + } + var methods []string + var errSigAlgo error + + origSignersLen := len(signers) + for idx := 0; idx < len(signers); idx++ { + signer := signers[idx] + pub := signer.PublicKey() + as, algo, err := pickSignatureAlgorithm(signer, extensions) + if err != nil && errSigAlgo == nil { + // If we cannot negotiate a signature algorithm store the first + // error so we can return it to provide a more meaningful message if + // no other signers work. + errSigAlgo = err + continue + } + ok, err := validateKey(pub, algo, user, c) + if err != nil { + return authFailure, nil, err + } + // OpenSSH 7.2-7.7 advertises support for rsa-sha2-256 and rsa-sha2-512 + // in the "server-sig-algs" extension but doesn't support these + // algorithms for certificate authentication, so if the server rejects + // the key try to use the obtained algorithm as if "server-sig-algs" had + // not been implemented if supported from the algorithm signer. + if !ok && idx < origSignersLen && isRSACert(algo) && algo != CertAlgoRSAv01 { + if contains(as.Algorithms(), KeyAlgoRSA) { + // We retry using the compat algorithm after all signers have + // been tried normally. + signers = append(signers, &multiAlgorithmSigner{ + AlgorithmSigner: as, + supportedAlgorithms: []string{KeyAlgoRSA}, + }) + } + } + if !ok { + continue + } + + pubKey := pub.Marshal() + data := buildDataSignedForAuth(session, userAuthRequestMsg{ + User: user, + Service: serviceSSH, + Method: cb.method(), + }, algo, pubKey) + sign, err := as.SignWithAlgorithm(rand, data, underlyingAlgo(algo)) + if err != nil { + return authFailure, nil, err + } + + // manually wrap the serialized signature in a string + s := Marshal(sign) + sig := make([]byte, stringLength(len(s))) + marshalString(sig, s) + msg := publickeyAuthMsg{ + User: user, + Service: serviceSSH, + Method: cb.method(), + HasSig: true, + Algoname: algo, + PubKey: pubKey, + Sig: sig, + } + p := Marshal(&msg) + if err := c.writePacket(p); err != nil { + return authFailure, nil, err + } + var success authResult + success, methods, err = handleAuthResponse(c) + if err != nil { + return authFailure, nil, err + } + + // If authentication succeeds or the list of available methods does not + // contain the "publickey" method, do not attempt to authenticate with any + // other keys. According to RFC 4252 Section 7, the latter can occur when + // additional authentication methods are required. + if success == authSuccess || !contains(methods, cb.method()) { + return success, methods, err + } + } + + return authFailure, methods, errSigAlgo +} + +// validateKey validates the key provided is acceptable to the server. +func validateKey(key PublicKey, algo string, user string, c packetConn) (bool, error) { + pubKey := key.Marshal() + msg := publickeyAuthMsg{ + User: user, + Service: serviceSSH, + Method: "publickey", + HasSig: false, + Algoname: algo, + PubKey: pubKey, + } + if err := c.writePacket(Marshal(&msg)); err != nil { + return false, err + } + + return confirmKeyAck(key, c) +} + +func confirmKeyAck(key PublicKey, c packetConn) (bool, error) { + pubKey := key.Marshal() + + for { + packet, err := c.readPacket() + if err != nil { + return false, err + } + switch packet[0] { + case msgUserAuthBanner: + if err := handleBannerResponse(c, packet); err != nil { + return false, err + } + case msgUserAuthPubKeyOk: + var msg userAuthPubKeyOkMsg + if err := Unmarshal(packet, &msg); err != nil { + return false, err + } + // According to RFC 4252 Section 7 the algorithm in + // SSH_MSG_USERAUTH_PK_OK should match that of the request but some + // servers send the key type instead. OpenSSH allows any algorithm + // that matches the public key, so we do the same. + // https://github.com/openssh/openssh-portable/blob/86bdd385/sshconnect2.c#L709 + if !contains(algorithmsForKeyFormat(key.Type()), msg.Algo) { + return false, nil + } + if !bytes.Equal(msg.PubKey, pubKey) { + return false, nil + } + return true, nil + case msgUserAuthFailure: + return false, nil + default: + return false, unexpectedMessageError(msgUserAuthPubKeyOk, packet[0]) + } + } +} + +// PublicKeys returns an AuthMethod that uses the given key +// pairs. +func PublicKeys(signers ...Signer) AuthMethod { + return publicKeyCallback(func() ([]Signer, error) { return signers, nil }) +} + +// PublicKeysCallback returns an AuthMethod that runs the given +// function to obtain a list of key pairs. +func PublicKeysCallback(getSigners func() (signers []Signer, err error)) AuthMethod { + return publicKeyCallback(getSigners) +} + +// handleAuthResponse returns whether the preceding authentication request succeeded +// along with a list of remaining authentication methods to try next and +// an error if an unexpected response was received. +func handleAuthResponse(c packetConn) (authResult, []string, error) { + gotMsgExtInfo := false + for { + packet, err := c.readPacket() + if err != nil { + return authFailure, nil, err + } + + switch packet[0] { + case msgUserAuthBanner: + if err := handleBannerResponse(c, packet); err != nil { + return authFailure, nil, err + } + case msgExtInfo: + // Ignore post-authentication RFC 8308 extensions, once. + if gotMsgExtInfo { + return authFailure, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0]) + } + gotMsgExtInfo = true + case msgUserAuthFailure: + var msg userAuthFailureMsg + if err := Unmarshal(packet, &msg); err != nil { + return authFailure, nil, err + } + if msg.PartialSuccess { + return authPartialSuccess, msg.Methods, nil + } + return authFailure, msg.Methods, nil + case msgUserAuthSuccess: + return authSuccess, nil, nil + default: + return authFailure, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0]) + } + } +} + +func handleBannerResponse(c packetConn, packet []byte) error { + var msg userAuthBannerMsg + if err := Unmarshal(packet, &msg); err != nil { + return err + } + + transport, ok := c.(*handshakeTransport) + if !ok { + return nil + } + + if transport.bannerCallback != nil { + return transport.bannerCallback(msg.Message) + } + + return nil +} + +// KeyboardInteractiveChallenge should print questions, optionally +// disabling echoing (e.g. for passwords), and return all the answers. +// Challenge may be called multiple times in a single session. After +// successful authentication, the server may send a challenge with no +// questions, for which the name and instruction messages should be +// printed. RFC 4256 section 3.3 details how the UI should behave for +// both CLI and GUI environments. +type KeyboardInteractiveChallenge func(name, instruction string, questions []string, echos []bool) (answers []string, err error) + +// KeyboardInteractive returns an AuthMethod using a prompt/response +// sequence controlled by the server. +func KeyboardInteractive(challenge KeyboardInteractiveChallenge) AuthMethod { + return challenge +} + +func (cb KeyboardInteractiveChallenge) method() string { + return "keyboard-interactive" +} + +func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) { + type initiateMsg struct { + User string `sshtype:"50"` + Service string + Method string + Language string + Submethods string + } + + if err := c.writePacket(Marshal(&initiateMsg{ + User: user, + Service: serviceSSH, + Method: "keyboard-interactive", + })); err != nil { + return authFailure, nil, err + } + + gotMsgExtInfo := false + gotUserAuthInfoRequest := false + for { + packet, err := c.readPacket() + if err != nil { + return authFailure, nil, err + } + + // like handleAuthResponse, but with less options. + switch packet[0] { + case msgUserAuthBanner: + if err := handleBannerResponse(c, packet); err != nil { + return authFailure, nil, err + } + continue + case msgExtInfo: + // Ignore post-authentication RFC 8308 extensions, once. + if gotMsgExtInfo { + return authFailure, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0]) + } + gotMsgExtInfo = true + continue + case msgUserAuthInfoRequest: + // OK + case msgUserAuthFailure: + var msg userAuthFailureMsg + if err := Unmarshal(packet, &msg); err != nil { + return authFailure, nil, err + } + if msg.PartialSuccess { + return authPartialSuccess, msg.Methods, nil + } + if !gotUserAuthInfoRequest { + return authFailure, msg.Methods, unexpectedMessageError(msgUserAuthInfoRequest, packet[0]) + } + return authFailure, msg.Methods, nil + case msgUserAuthSuccess: + return authSuccess, nil, nil + default: + return authFailure, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0]) + } + + var msg userAuthInfoRequestMsg + if err := Unmarshal(packet, &msg); err != nil { + return authFailure, nil, err + } + gotUserAuthInfoRequest = true + + // Manually unpack the prompt/echo pairs. + rest := msg.Prompts + var prompts []string + var echos []bool + for i := 0; i < int(msg.NumPrompts); i++ { + prompt, r, ok := parseString(rest) + if !ok || len(r) == 0 { + return authFailure, nil, errors.New("ssh: prompt format error") + } + prompts = append(prompts, string(prompt)) + echos = append(echos, r[0] != 0) + rest = r[1:] + } + + if len(rest) != 0 { + return authFailure, nil, errors.New("ssh: extra data following keyboard-interactive pairs") + } + + answers, err := cb(msg.Name, msg.Instruction, prompts, echos) + if err != nil { + return authFailure, nil, err + } + + if len(answers) != len(prompts) { + return authFailure, nil, fmt.Errorf("ssh: incorrect number of answers from keyboard-interactive callback %d (expected %d)", len(answers), len(prompts)) + } + responseLength := 1 + 4 + for _, a := range answers { + responseLength += stringLength(len(a)) + } + serialized := make([]byte, responseLength) + p := serialized + p[0] = msgUserAuthInfoResponse + p = p[1:] + p = marshalUint32(p, uint32(len(answers))) + for _, a := range answers { + p = marshalString(p, []byte(a)) + } + + if err := c.writePacket(serialized); err != nil { + return authFailure, nil, err + } + } +} + +type retryableAuthMethod struct { + authMethod AuthMethod + maxTries int +} + +func (r *retryableAuthMethod) auth(session []byte, user string, c packetConn, rand io.Reader, extensions map[string][]byte) (ok authResult, methods []string, err error) { + for i := 0; r.maxTries <= 0 || i < r.maxTries; i++ { + ok, methods, err = r.authMethod.auth(session, user, c, rand, extensions) + if ok != authFailure || err != nil { // either success, partial success or error terminate + return ok, methods, err + } + } + return ok, methods, err +} + +func (r *retryableAuthMethod) method() string { + return r.authMethod.method() +} + +// RetryableAuthMethod is a decorator for other auth methods enabling them to +// be retried up to maxTries before considering that AuthMethod itself failed. +// If maxTries is <= 0, will retry indefinitely +// +// This is useful for interactive clients using challenge/response type +// authentication (e.g. Keyboard-Interactive, Password, etc) where the user +// could mistype their response resulting in the server issuing a +// SSH_MSG_USERAUTH_FAILURE (rfc4252 #8 [password] and rfc4256 #3.4 +// [keyboard-interactive]); Without this decorator, the non-retryable +// AuthMethod would be removed from future consideration, and never tried again +// (and so the user would never be able to retry their entry). +func RetryableAuthMethod(auth AuthMethod, maxTries int) AuthMethod { + return &retryableAuthMethod{authMethod: auth, maxTries: maxTries} +} + +// GSSAPIWithMICAuthMethod is an AuthMethod with "gssapi-with-mic" authentication. +// See RFC 4462 section 3 +// gssAPIClient is implementation of the GSSAPIClient interface, see the definition of the interface for details. +// target is the server host you want to log in to. +func GSSAPIWithMICAuthMethod(gssAPIClient GSSAPIClient, target string) AuthMethod { + if gssAPIClient == nil { + panic("gss-api client must be not nil with enable gssapi-with-mic") + } + return &gssAPIWithMICCallback{gssAPIClient: gssAPIClient, target: target} +} + +type gssAPIWithMICCallback struct { + gssAPIClient GSSAPIClient + target string +} + +func (g *gssAPIWithMICCallback) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) { + m := &userAuthRequestMsg{ + User: user, + Service: serviceSSH, + Method: g.method(), + } + // The GSS-API authentication method is initiated when the client sends an SSH_MSG_USERAUTH_REQUEST. + // See RFC 4462 section 3.2. + m.Payload = appendU32(m.Payload, 1) + m.Payload = appendString(m.Payload, string(krb5OID)) + if err := c.writePacket(Marshal(m)); err != nil { + return authFailure, nil, err + } + // The server responds to the SSH_MSG_USERAUTH_REQUEST with either an + // SSH_MSG_USERAUTH_FAILURE if none of the mechanisms are supported or + // with an SSH_MSG_USERAUTH_GSSAPI_RESPONSE. + // See RFC 4462 section 3.3. + // OpenSSH supports Kerberos V5 mechanism only for GSS-API authentication,so I don't want to check + // selected mech if it is valid. + packet, err := c.readPacket() + if err != nil { + return authFailure, nil, err + } + userAuthGSSAPIResp := &userAuthGSSAPIResponse{} + if err := Unmarshal(packet, userAuthGSSAPIResp); err != nil { + return authFailure, nil, err + } + // Start the loop into the exchange token. + // See RFC 4462 section 3.4. + var token []byte + defer g.gssAPIClient.DeleteSecContext() + for { + // Initiates the establishment of a security context between the application and a remote peer. + nextToken, needContinue, err := g.gssAPIClient.InitSecContext("host@"+g.target, token, false) + if err != nil { + return authFailure, nil, err + } + if len(nextToken) > 0 { + if err := c.writePacket(Marshal(&userAuthGSSAPIToken{ + Token: nextToken, + })); err != nil { + return authFailure, nil, err + } + } + if !needContinue { + break + } + packet, err = c.readPacket() + if err != nil { + return authFailure, nil, err + } + switch packet[0] { + case msgUserAuthFailure: + var msg userAuthFailureMsg + if err := Unmarshal(packet, &msg); err != nil { + return authFailure, nil, err + } + if msg.PartialSuccess { + return authPartialSuccess, msg.Methods, nil + } + return authFailure, msg.Methods, nil + case msgUserAuthGSSAPIError: + userAuthGSSAPIErrorResp := &userAuthGSSAPIError{} + if err := Unmarshal(packet, userAuthGSSAPIErrorResp); err != nil { + return authFailure, nil, err + } + return authFailure, nil, fmt.Errorf("GSS-API Error:\n"+ + "Major Status: %d\n"+ + "Minor Status: %d\n"+ + "Error Message: %s\n", userAuthGSSAPIErrorResp.MajorStatus, userAuthGSSAPIErrorResp.MinorStatus, + userAuthGSSAPIErrorResp.Message) + case msgUserAuthGSSAPIToken: + userAuthGSSAPITokenReq := &userAuthGSSAPIToken{} + if err := Unmarshal(packet, userAuthGSSAPITokenReq); err != nil { + return authFailure, nil, err + } + token = userAuthGSSAPITokenReq.Token + } + } + // Binding Encryption Keys. + // See RFC 4462 section 3.5. + micField := buildMIC(string(session), user, "ssh-connection", "gssapi-with-mic") + micToken, err := g.gssAPIClient.GetMIC(micField) + if err != nil { + return authFailure, nil, err + } + if err := c.writePacket(Marshal(&userAuthGSSAPIMIC{ + MIC: micToken, + })); err != nil { + return authFailure, nil, err + } + return handleAuthResponse(c) +} + +func (g *gssAPIWithMICCallback) method() string { + return "gssapi-with-mic" +} diff --git a/tempfork/sshtest/ssh/client_auth_test.go b/tempfork/sshtest/ssh/client_auth_test.go new file mode 100644 index 000000000..ec27133a3 --- /dev/null +++ b/tempfork/sshtest/ssh/client_auth_test.go @@ -0,0 +1,1384 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto/rand" + "errors" + "fmt" + "io" + "log" + "net" + "os" + "runtime" + "strings" + "testing" +) + +type keyboardInteractive map[string]string + +func (cr keyboardInteractive) Challenge(user string, instruction string, questions []string, echos []bool) ([]string, error) { + var answers []string + for _, q := range questions { + answers = append(answers, cr[q]) + } + return answers, nil +} + +// reused internally by tests +var clientPassword = "tiger" + +// tryAuth runs a handshake with a given config against an SSH server +// with config serverConfig. Returns both client and server side errors. +func tryAuth(t *testing.T, config *ClientConfig) error { + err, _ := tryAuthBothSides(t, config, nil) + return err +} + +// tryAuthWithGSSAPIWithMICConfig runs a handshake with a given config against an SSH server +// with a given GSSAPIWithMICConfig and config serverConfig. Returns both client and server side errors. +func tryAuthWithGSSAPIWithMICConfig(t *testing.T, clientConfig *ClientConfig, gssAPIWithMICConfig *GSSAPIWithMICConfig) error { + err, _ := tryAuthBothSides(t, clientConfig, gssAPIWithMICConfig) + return err +} + +// tryAuthBothSides runs the handshake and returns the resulting errors from both sides of the connection. +func tryAuthBothSides(t *testing.T, config *ClientConfig, gssAPIWithMICConfig *GSSAPIWithMICConfig) (clientError error, serverAuthErrors []error) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + certChecker := CertChecker{ + IsUserAuthority: func(k PublicKey) bool { + return bytes.Equal(k.Marshal(), testPublicKeys["ecdsa"].Marshal()) + }, + UserKeyFallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { + if conn.User() == "testuser" && bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) { + return nil, nil + } + + return nil, fmt.Errorf("pubkey for %q not acceptable", conn.User()) + }, + IsRevoked: func(c *Certificate) bool { + return c.Serial == 666 + }, + } + serverConfig := &ServerConfig{ + PasswordCallback: func(conn ConnMetadata, pass []byte) (*Permissions, error) { + if conn.User() == "testuser" && string(pass) == clientPassword { + return nil, nil + } + return nil, errors.New("password auth failed") + }, + PublicKeyCallback: certChecker.Authenticate, + KeyboardInteractiveCallback: func(conn ConnMetadata, challenge KeyboardInteractiveChallenge) (*Permissions, error) { + ans, err := challenge("user", + "instruction", + []string{"question1", "question2"}, + []bool{true, true}) + if err != nil { + return nil, err + } + ok := conn.User() == "testuser" && ans[0] == "answer1" && ans[1] == "answer2" + if ok { + challenge("user", "motd", nil, nil) + return nil, nil + } + return nil, errors.New("keyboard-interactive failed") + }, + GSSAPIWithMICConfig: gssAPIWithMICConfig, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + serverConfig.AuthLogCallback = func(conn ConnMetadata, method string, err error) { + serverAuthErrors = append(serverAuthErrors, err) + } + + go newServer(c1, serverConfig) + _, _, _, err = NewClientConn(c2, "", config) + return err, serverAuthErrors +} + +type loggingAlgorithmSigner struct { + used []string + AlgorithmSigner +} + +func (l *loggingAlgorithmSigner) Sign(rand io.Reader, data []byte) (*Signature, error) { + l.used = append(l.used, "[Sign]") + return l.AlgorithmSigner.Sign(rand, data) +} + +func (l *loggingAlgorithmSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) { + l.used = append(l.used, algorithm) + return l.AlgorithmSigner.SignWithAlgorithm(rand, data, algorithm) +} + +func TestClientAuthPublicKey(t *testing.T) { + signer := &loggingAlgorithmSigner{AlgorithmSigner: testSigners["rsa"].(AlgorithmSigner)} + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(signer), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } + if len(signer.used) != 1 || signer.used[0] != KeyAlgoRSASHA256 { + t.Errorf("unexpected Sign/SignWithAlgorithm calls: %q", signer.used) + } +} + +// TestClientAuthNoSHA2 tests a ssh-rsa Signer that doesn't implement AlgorithmSigner. +func TestClientAuthNoSHA2(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(&legacyRSASigner{testSigners["rsa"]}), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } +} + +// TestClientAuthThirdKey checks that the third configured can succeed. If we +// were to do three attempts for each key (rsa-sha2-256, rsa-sha2-512, ssh-rsa), +// we'd hit the six maximum attempts before reaching it. +func TestClientAuthThirdKey(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa-openssh-format"], + testSigners["rsa-openssh-format"], testSigners["rsa"]), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } +} + +func TestAuthMethodPassword(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + Password(clientPassword), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } +} + +func TestAuthMethodFallback(t *testing.T) { + var passwordCalled bool + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + PasswordCallback( + func() (string, error) { + passwordCalled = true + return "WRONG", nil + }), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } + + if passwordCalled { + t.Errorf("password auth tried before public-key auth.") + } +} + +func TestAuthMethodWrongPassword(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + Password("wrong"), + PublicKeys(testSigners["rsa"]), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } +} + +func TestAuthMethodKeyboardInteractive(t *testing.T) { + answers := keyboardInteractive(map[string]string{ + "question1": "answer1", + "question2": "answer2", + }) + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + KeyboardInteractive(answers.Challenge), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } +} + +func TestAuthMethodWrongKeyboardInteractive(t *testing.T) { + answers := keyboardInteractive(map[string]string{ + "question1": "answer1", + "question2": "WRONG", + }) + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + KeyboardInteractive(answers.Challenge), + }, + } + + if err := tryAuth(t, config); err == nil { + t.Fatalf("wrong answers should not have authenticated with KeyboardInteractive") + } +} + +// the mock server will only authenticate ssh-rsa keys +func TestAuthMethodInvalidPublicKey(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(testSigners["dsa"]), + }, + } + + if err := tryAuth(t, config); err == nil { + t.Fatalf("dsa private key should not have authenticated with rsa public key") + } +} + +// the client should authenticate with the second key +func TestAuthMethodRSAandDSA(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(testSigners["dsa"], testSigners["rsa"]), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + if err := tryAuth(t, config); err != nil { + t.Fatalf("client could not authenticate with rsa key: %v", err) + } +} + +type invalidAlgSigner struct { + Signer +} + +func (s *invalidAlgSigner) Sign(rand io.Reader, data []byte) (*Signature, error) { + sig, err := s.Signer.Sign(rand, data) + if sig != nil { + sig.Format = "invalid" + } + return sig, err +} + +func TestMethodInvalidAlgorithm(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(&invalidAlgSigner{testSigners["rsa"]}), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + err, serverErrors := tryAuthBothSides(t, config, nil) + if err == nil { + t.Fatalf("login succeeded") + } + + found := false + want := "algorithm \"invalid\"" + + var errStrings []string + for _, err := range serverErrors { + found = found || (err != nil && strings.Contains(err.Error(), want)) + errStrings = append(errStrings, err.Error()) + } + if !found { + t.Errorf("server got error %q, want substring %q", errStrings, want) + } +} + +func TestClientHMAC(t *testing.T) { + for _, mac := range supportedMACs { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + }, + Config: Config{ + MACs: []string{mac}, + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + if err := tryAuth(t, config); err != nil { + t.Fatalf("client could not authenticate with mac algo %s: %v", mac, err) + } + } +} + +// issue 4285. +func TestClientUnsupportedCipher(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(), + }, + Config: Config{ + Ciphers: []string{"aes128-cbc"}, // not currently supported + }, + } + if err := tryAuth(t, config); err == nil { + t.Errorf("expected no ciphers in common") + } +} + +func TestClientUnsupportedKex(t *testing.T) { + if os.Getenv("GO_BUILDER_NAME") != "" { + t.Skip("skipping known-flaky test on the Go build dashboard; see golang.org/issue/15198") + } + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(), + }, + Config: Config{ + KeyExchanges: []string{"non-existent-kex"}, + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + if err := tryAuth(t, config); err == nil || !strings.Contains(err.Error(), "common algorithm") { + t.Errorf("got %v, expected 'common algorithm'", err) + } +} + +func TestClientLoginCert(t *testing.T) { + cert := &Certificate{ + Key: testPublicKeys["rsa"], + ValidBefore: CertTimeInfinity, + CertType: UserCert, + } + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + certSigner, err := NewCertSigner(cert, testSigners["rsa"]) + if err != nil { + t.Fatalf("NewCertSigner: %v", err) + } + + clientConfig := &ClientConfig{ + User: "user", + HostKeyCallback: InsecureIgnoreHostKey(), + } + clientConfig.Auth = append(clientConfig.Auth, PublicKeys(certSigner)) + + // should succeed + if err := tryAuth(t, clientConfig); err != nil { + t.Errorf("cert login failed: %v", err) + } + + // corrupted signature + cert.Signature.Blob[0]++ + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login passed with corrupted sig") + } + + // revoked + cert.Serial = 666 + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("revoked cert login succeeded") + } + cert.Serial = 1 + + // sign with wrong key + cert.SignCert(rand.Reader, testSigners["dsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login passed with non-authoritative key") + } + + // host cert + cert.CertType = HostCert + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login passed with wrong type") + } + cert.CertType = UserCert + + // principal specified + cert.ValidPrincipals = []string{"user"} + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err != nil { + t.Errorf("cert login failed: %v", err) + } + + // wrong principal specified + cert.ValidPrincipals = []string{"fred"} + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login passed with wrong principal") + } + cert.ValidPrincipals = nil + + // added critical option + cert.CriticalOptions = map[string]string{"root-access": "yes"} + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login passed with unrecognized critical option") + } + + // allowed source address + cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42/24,::42/120"} + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err != nil { + t.Errorf("cert login with source-address failed: %v", err) + } + + // disallowed source address + cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42,::42"} + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login with source-address succeeded") + } +} + +func testPermissionsPassing(withPermissions bool, t *testing.T) { + serverConfig := &ServerConfig{ + PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { + if conn.User() == "nopermissions" { + return nil, nil + } + return &Permissions{}, nil + }, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + clientConfig := &ClientConfig{ + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + if withPermissions { + clientConfig.User = "permissions" + } else { + clientConfig.User = "nopermissions" + } + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + go NewClientConn(c2, "", clientConfig) + serverConn, err := newServer(c1, serverConfig) + if err != nil { + t.Fatal(err) + } + if p := serverConn.Permissions; (p != nil) != withPermissions { + t.Fatalf("withPermissions is %t, but Permissions object is %#v", withPermissions, p) + } +} + +func TestPermissionsPassing(t *testing.T) { + testPermissionsPassing(true, t) +} + +func TestNoPermissionsPassing(t *testing.T) { + testPermissionsPassing(false, t) +} + +func TestRetryableAuth(t *testing.T) { + n := 0 + passwords := []string{"WRONG1", "WRONG2"} + + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + RetryableAuthMethod(PasswordCallback(func() (string, error) { + p := passwords[n] + n++ + return p, nil + }), 2), + PublicKeys(testSigners["rsa"]), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } + if n != 2 { + t.Fatalf("Did not try all passwords") + } +} + +func ExampleRetryableAuthMethod() { + user := "testuser" + NumberOfPrompts := 3 + + // Normally this would be a callback that prompts the user to answer the + // provided questions + Cb := func(user, instruction string, questions []string, echos []bool) (answers []string, err error) { + return []string{"answer1", "answer2"}, nil + } + + config := &ClientConfig{ + HostKeyCallback: InsecureIgnoreHostKey(), + User: user, + Auth: []AuthMethod{ + RetryableAuthMethod(KeyboardInteractiveChallenge(Cb), NumberOfPrompts), + }, + } + + host := "mysshserver" + netConn, err := net.Dial("tcp", host) + if err != nil { + log.Fatal(err) + } + + sshConn, _, _, err := NewClientConn(netConn, host, config) + if err != nil { + log.Fatal(err) + } + _ = sshConn +} + +// Test if username is received on server side when NoClientAuth is used +func TestClientAuthNone(t *testing.T) { + user := "testuser" + serverConfig := &ServerConfig{ + NoClientAuth: true, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + clientConfig := &ClientConfig{ + User: user, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + go NewClientConn(c2, "", clientConfig) + serverConn, err := newServer(c1, serverConfig) + if err != nil { + t.Fatalf("newServer: %v", err) + } + if serverConn.User() != user { + t.Fatalf("server: got %q, want %q", serverConn.User(), user) + } +} + +// Test if authentication attempts are limited on server when MaxAuthTries is set +func TestClientAuthMaxAuthTries(t *testing.T) { + user := "testuser" + + serverConfig := &ServerConfig{ + MaxAuthTries: 2, + PasswordCallback: func(conn ConnMetadata, pass []byte) (*Permissions, error) { + if conn.User() == "testuser" && string(pass) == "right" { + return nil, nil + } + return nil, errors.New("password auth failed") + }, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + expectedErr := fmt.Errorf("ssh: handshake failed: %v", &disconnectMsg{ + Reason: 2, + Message: "too many authentication failures", + }) + + for tries := 2; tries < 4; tries++ { + n := tries + clientConfig := &ClientConfig{ + User: user, + Auth: []AuthMethod{ + RetryableAuthMethod(PasswordCallback(func() (string, error) { + n-- + if n == 0 { + return "right", nil + } + return "wrong", nil + }), tries), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + errCh := make(chan error, 1) + + go func() { + _, err := newServer(c1, serverConfig) + errCh <- err + }() + _, _, _, cliErr := NewClientConn(c2, "", clientConfig) + srvErr := <-errCh + + if tries > serverConfig.MaxAuthTries { + if cliErr == nil { + t.Fatalf("client: got no error, want %s", expectedErr) + } else if cliErr.Error() != expectedErr.Error() { + t.Fatalf("client: got %s, want %s", err, expectedErr) + } + var authErr *ServerAuthError + if !errors.As(srvErr, &authErr) { + t.Errorf("expected ServerAuthError, got: %v", srvErr) + } + } else { + if cliErr != nil { + t.Fatalf("client: got %s, want no error", cliErr) + } + } + } +} + +// Test if authentication attempts are correctly limited on server +// when more public keys are provided then MaxAuthTries +func TestClientAuthMaxAuthTriesPublicKey(t *testing.T) { + signers := []Signer{} + for i := 0; i < 6; i++ { + signers = append(signers, testSigners["dsa"]) + } + + validConfig := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(append([]Signer{testSigners["rsa"]}, signers...)...), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + if err := tryAuth(t, validConfig); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } + + expectedErr := fmt.Errorf("ssh: handshake failed: %v", &disconnectMsg{ + Reason: 2, + Message: "too many authentication failures", + }) + invalidConfig := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(append(signers, testSigners["rsa"])...), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + if err := tryAuth(t, invalidConfig); err == nil { + t.Fatalf("client: got no error, want %s", expectedErr) + } else if err.Error() != expectedErr.Error() { + // On Windows we can see a WSAECONNABORTED error + // if the client writes another authentication request + // before the client goroutine reads the disconnection + // message. See issue 50805. + if runtime.GOOS == "windows" && strings.Contains(err.Error(), "wsarecv: An established connection was aborted") { + // OK. + } else { + t.Fatalf("client: got %s, want %s", err, expectedErr) + } + } +} + +// Test whether authentication errors are being properly logged if all +// authentication methods have been exhausted +func TestClientAuthErrorList(t *testing.T) { + publicKeyErr := errors.New("This is an error from PublicKeyCallback") + + clientConfig := &ClientConfig{ + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + serverConfig := &ServerConfig{ + PublicKeyCallback: func(_ ConnMetadata, _ PublicKey) (*Permissions, error) { + return nil, publicKeyErr + }, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + go NewClientConn(c2, "", clientConfig) + _, err = newServer(c1, serverConfig) + if err == nil { + t.Fatal("newServer: got nil, expected errors") + } + + authErrs, ok := err.(*ServerAuthError) + if !ok { + t.Fatalf("errors: got %T, want *ssh.ServerAuthError", err) + } + for i, e := range authErrs.Errors { + switch i { + case 0: + if e != ErrNoAuth { + t.Fatalf("errors: got error %v, want ErrNoAuth", e) + } + case 1: + if e != publicKeyErr { + t.Fatalf("errors: got %v, want %v", e, publicKeyErr) + } + default: + t.Fatalf("errors: got %v, expected 2 errors", authErrs.Errors) + } + } +} + +func TestAuthMethodGSSAPIWithMIC(t *testing.T) { + type testcase struct { + config *ClientConfig + gssConfig *GSSAPIWithMICConfig + clientWantErr string + serverWantErr string + } + testcases := []*testcase{ + { + config: &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + GSSAPIWithMICAuthMethod( + &FakeClient{ + exchanges: []*exchange{ + { + outToken: "client-valid-token-1", + }, + { + expectedToken: "server-valid-token-1", + }, + }, + mic: []byte("valid-mic"), + maxRound: 2, + }, "testtarget", + ), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + }, + gssConfig: &GSSAPIWithMICConfig{ + AllowLogin: func(conn ConnMetadata, srcName string) (*Permissions, error) { + if srcName != conn.User()+"@DOMAIN" { + return nil, fmt.Errorf("srcName is %s, conn user is %s", srcName, conn.User()) + } + return nil, nil + }, + Server: &FakeServer{ + exchanges: []*exchange{ + { + outToken: "server-valid-token-1", + expectedToken: "client-valid-token-1", + }, + }, + maxRound: 1, + expectedMIC: []byte("valid-mic"), + srcName: "testuser@DOMAIN", + }, + }, + }, + { + config: &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + GSSAPIWithMICAuthMethod( + &FakeClient{ + exchanges: []*exchange{ + { + outToken: "client-valid-token-1", + }, + { + expectedToken: "server-valid-token-1", + }, + }, + mic: []byte("valid-mic"), + maxRound: 2, + }, "testtarget", + ), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + }, + gssConfig: &GSSAPIWithMICConfig{ + AllowLogin: func(conn ConnMetadata, srcName string) (*Permissions, error) { + return nil, fmt.Errorf("user is not allowed to login") + }, + Server: &FakeServer{ + exchanges: []*exchange{ + { + outToken: "server-valid-token-1", + expectedToken: "client-valid-token-1", + }, + }, + maxRound: 1, + expectedMIC: []byte("valid-mic"), + srcName: "testuser@DOMAIN", + }, + }, + serverWantErr: "user is not allowed to login", + clientWantErr: "ssh: handshake failed: ssh: unable to authenticate", + }, + { + config: &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + GSSAPIWithMICAuthMethod( + &FakeClient{ + exchanges: []*exchange{ + { + outToken: "client-valid-token-1", + }, + { + expectedToken: "server-valid-token-1", + }, + }, + mic: []byte("valid-mic"), + maxRound: 2, + }, "testtarget", + ), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + }, + gssConfig: &GSSAPIWithMICConfig{ + AllowLogin: func(conn ConnMetadata, srcName string) (*Permissions, error) { + if srcName != conn.User() { + return nil, fmt.Errorf("srcName is %s, conn user is %s", srcName, conn.User()) + } + return nil, nil + }, + Server: &FakeServer{ + exchanges: []*exchange{ + { + outToken: "server-invalid-token-1", + expectedToken: "client-valid-token-1", + }, + }, + maxRound: 1, + expectedMIC: []byte("valid-mic"), + srcName: "testuser@DOMAIN", + }, + }, + clientWantErr: "ssh: handshake failed: got \"server-invalid-token-1\", want token \"server-valid-token-1\"", + }, + { + config: &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + GSSAPIWithMICAuthMethod( + &FakeClient{ + exchanges: []*exchange{ + { + outToken: "client-valid-token-1", + }, + { + expectedToken: "server-valid-token-1", + }, + }, + mic: []byte("invalid-mic"), + maxRound: 2, + }, "testtarget", + ), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + }, + gssConfig: &GSSAPIWithMICConfig{ + AllowLogin: func(conn ConnMetadata, srcName string) (*Permissions, error) { + if srcName != conn.User() { + return nil, fmt.Errorf("srcName is %s, conn user is %s", srcName, conn.User()) + } + return nil, nil + }, + Server: &FakeServer{ + exchanges: []*exchange{ + { + outToken: "server-valid-token-1", + expectedToken: "client-valid-token-1", + }, + }, + maxRound: 1, + expectedMIC: []byte("valid-mic"), + srcName: "testuser@DOMAIN", + }, + }, + serverWantErr: "got MICToken \"invalid-mic\", want \"valid-mic\"", + clientWantErr: "ssh: handshake failed: ssh: unable to authenticate", + }, + } + + for i, c := range testcases { + clientErr, serverErrs := tryAuthBothSides(t, c.config, c.gssConfig) + if (c.clientWantErr == "") != (clientErr == nil) { + t.Fatalf("client got %v, want %s, case %d", clientErr, c.clientWantErr, i) + } + if (c.serverWantErr == "") != (len(serverErrs) == 2 && serverErrs[1] == nil || len(serverErrs) == 1) { + t.Fatalf("server got err %v, want %s", serverErrs, c.serverWantErr) + } + if c.clientWantErr != "" { + if clientErr != nil && !strings.Contains(clientErr.Error(), c.clientWantErr) { + t.Fatalf("client got %v, want %s, case %d", clientErr, c.clientWantErr, i) + } + } + found := false + var errStrings []string + if c.serverWantErr != "" { + for _, err := range serverErrs { + found = found || (err != nil && strings.Contains(err.Error(), c.serverWantErr)) + errStrings = append(errStrings, err.Error()) + } + if !found { + t.Errorf("server got error %q, want substring %q, case %d", errStrings, c.serverWantErr, i) + } + } + } +} + +func TestCompatibleAlgoAndSignatures(t *testing.T) { + type testcase struct { + algo string + sigFormat string + compatible bool + } + testcases := []*testcase{ + { + KeyAlgoRSA, + KeyAlgoRSA, + true, + }, + { + KeyAlgoRSA, + KeyAlgoRSASHA256, + true, + }, + { + KeyAlgoRSA, + KeyAlgoRSASHA512, + true, + }, + { + KeyAlgoRSASHA256, + KeyAlgoRSA, + true, + }, + { + KeyAlgoRSASHA512, + KeyAlgoRSA, + true, + }, + { + KeyAlgoRSASHA512, + KeyAlgoRSASHA256, + true, + }, + { + KeyAlgoRSASHA256, + KeyAlgoRSASHA512, + true, + }, + { + KeyAlgoRSASHA512, + KeyAlgoRSASHA512, + true, + }, + { + CertAlgoRSAv01, + KeyAlgoRSA, + true, + }, + { + CertAlgoRSAv01, + KeyAlgoRSASHA256, + true, + }, + { + CertAlgoRSAv01, + KeyAlgoRSASHA512, + true, + }, + { + CertAlgoRSASHA256v01, + KeyAlgoRSASHA512, + true, + }, + { + CertAlgoRSASHA512v01, + KeyAlgoRSASHA512, + true, + }, + { + CertAlgoRSASHA512v01, + KeyAlgoRSASHA256, + true, + }, + { + CertAlgoRSASHA256v01, + CertAlgoRSAv01, + true, + }, + { + CertAlgoRSAv01, + CertAlgoRSASHA512v01, + true, + }, + { + KeyAlgoECDSA256, + KeyAlgoRSA, + false, + }, + { + KeyAlgoECDSA256, + KeyAlgoECDSA521, + false, + }, + { + KeyAlgoECDSA256, + KeyAlgoECDSA256, + true, + }, + { + KeyAlgoECDSA256, + KeyAlgoED25519, + false, + }, + { + KeyAlgoED25519, + KeyAlgoED25519, + true, + }, + } + + for _, c := range testcases { + if isAlgoCompatible(c.algo, c.sigFormat) != c.compatible { + t.Errorf("algorithm %q, signature format %q, expected compatible to be %t", c.algo, c.sigFormat, c.compatible) + } + } +} + +func TestPickSignatureAlgorithm(t *testing.T) { + type testcase struct { + name string + extensions map[string][]byte + } + cases := []testcase{ + { + name: "server with empty server-sig-algs", + extensions: map[string][]byte{ + "server-sig-algs": []byte(``), + }, + }, + { + name: "server with no server-sig-algs", + extensions: nil, + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + signer, ok := testSigners["rsa"].(MultiAlgorithmSigner) + if !ok { + t.Fatalf("rsa test signer does not implement the MultiAlgorithmSigner interface") + } + // The signer supports the public key algorithm which is then returned. + _, algo, err := pickSignatureAlgorithm(signer, c.extensions) + if err != nil { + t.Fatalf("got %v, want no error", err) + } + if algo != signer.PublicKey().Type() { + t.Fatalf("got algo %q, want %q", algo, signer.PublicKey().Type()) + } + // Test a signer that uses a certificate algorithm as the public key + // type. + cert := &Certificate{ + CertType: UserCert, + Key: signer.PublicKey(), + } + cert.SignCert(rand.Reader, signer) + + certSigner, err := NewCertSigner(cert, signer) + if err != nil { + t.Fatalf("error generating cert signer: %v", err) + } + // The signer supports the public key algorithm and the + // public key format is a certificate type so the cerificate + // algorithm matching the key format must be returned + _, algo, err = pickSignatureAlgorithm(certSigner, c.extensions) + if err != nil { + t.Fatalf("got %v, want no error", err) + } + if algo != certSigner.PublicKey().Type() { + t.Fatalf("got algo %q, want %q", algo, certSigner.PublicKey().Type()) + } + signer, err = NewSignerWithAlgorithms(signer.(AlgorithmSigner), []string{KeyAlgoRSASHA512, KeyAlgoRSASHA256}) + if err != nil { + t.Fatalf("unable to create signer with algorithms: %v", err) + } + // The signer does not support the public key algorithm so an error + // is returned. + _, _, err = pickSignatureAlgorithm(signer, c.extensions) + if err == nil { + t.Fatal("got no error, no common public key signature algorithm error expected") + } + }) + } +} + +// configurablePublicKeyCallback is a public key callback that allows to +// configure the signature algorithm and format. This way we can emulate the +// behavior of buggy clients. +type configurablePublicKeyCallback struct { + signer AlgorithmSigner + signatureAlgo string + signatureFormat string +} + +func (cb configurablePublicKeyCallback) method() string { + return "publickey" +} + +func (cb configurablePublicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader, extensions map[string][]byte) (authResult, []string, error) { + pub := cb.signer.PublicKey() + + ok, err := validateKey(pub, cb.signatureAlgo, user, c) + if err != nil { + return authFailure, nil, err + } + if !ok { + return authFailure, nil, fmt.Errorf("invalid public key") + } + + pubKey := pub.Marshal() + data := buildDataSignedForAuth(session, userAuthRequestMsg{ + User: user, + Service: serviceSSH, + Method: cb.method(), + }, cb.signatureAlgo, pubKey) + sign, err := cb.signer.SignWithAlgorithm(rand, data, underlyingAlgo(cb.signatureFormat)) + if err != nil { + return authFailure, nil, err + } + + s := Marshal(sign) + sig := make([]byte, stringLength(len(s))) + marshalString(sig, s) + msg := publickeyAuthMsg{ + User: user, + Service: serviceSSH, + Method: cb.method(), + HasSig: true, + Algoname: cb.signatureAlgo, + PubKey: pubKey, + Sig: sig, + } + p := Marshal(&msg) + if err := c.writePacket(p); err != nil { + return authFailure, nil, err + } + var success authResult + success, methods, err := handleAuthResponse(c) + if err != nil { + return authFailure, nil, err + } + if success == authSuccess || !contains(methods, cb.method()) { + return success, methods, err + } + + return authFailure, methods, nil +} + +func TestPublicKeyAndAlgoCompatibility(t *testing.T) { + cert := &Certificate{ + Key: testPublicKeys["rsa"], + ValidBefore: CertTimeInfinity, + CertType: UserCert, + } + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + certSigner, err := NewCertSigner(cert, testSigners["rsa"]) + if err != nil { + t.Fatalf("NewCertSigner: %v", err) + } + + clientConfig := &ClientConfig{ + User: "user", + HostKeyCallback: InsecureIgnoreHostKey(), + Auth: []AuthMethod{ + configurablePublicKeyCallback{ + signer: certSigner.(AlgorithmSigner), + signatureAlgo: KeyAlgoRSASHA256, + signatureFormat: KeyAlgoRSASHA256, + }, + }, + } + if err := tryAuth(t, clientConfig); err == nil { + t.Error("cert login passed with incompatible public key type and algorithm") + } +} + +func TestClientAuthGPGAgentCompat(t *testing.T) { + clientConfig := &ClientConfig{ + User: "testuser", + HostKeyCallback: InsecureIgnoreHostKey(), + Auth: []AuthMethod{ + // algorithm rsa-sha2-512 and signature format ssh-rsa. + configurablePublicKeyCallback{ + signer: testSigners["rsa"].(AlgorithmSigner), + signatureAlgo: KeyAlgoRSASHA512, + signatureFormat: KeyAlgoRSA, + }, + }, + } + if err := tryAuth(t, clientConfig); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } +} + +func TestCertAuthOpenSSHCompat(t *testing.T) { + cert := &Certificate{ + Key: testPublicKeys["rsa"], + ValidBefore: CertTimeInfinity, + CertType: UserCert, + } + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + certSigner, err := NewCertSigner(cert, testSigners["rsa"]) + if err != nil { + t.Fatalf("NewCertSigner: %v", err) + } + + clientConfig := &ClientConfig{ + User: "user", + HostKeyCallback: InsecureIgnoreHostKey(), + Auth: []AuthMethod{ + // algorithm ssh-rsa-cert-v01@openssh.com and signature format + // rsa-sha2-256. + configurablePublicKeyCallback{ + signer: certSigner.(AlgorithmSigner), + signatureAlgo: CertAlgoRSAv01, + signatureFormat: KeyAlgoRSASHA256, + }, + }, + } + if err := tryAuth(t, clientConfig); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } +} + +func TestKeyboardInteractiveAuthEarlyFail(t *testing.T) { + const maxAuthTries = 2 + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + // Start testserver + serverConfig := &ServerConfig{ + MaxAuthTries: maxAuthTries, + KeyboardInteractiveCallback: func(c ConnMetadata, + client KeyboardInteractiveChallenge) (*Permissions, error) { + // Fail keyboard-interactive authentication early before + // any prompt is sent to client. + return nil, errors.New("keyboard-interactive auth failed") + }, + PasswordCallback: func(c ConnMetadata, + pass []byte) (*Permissions, error) { + if string(pass) == clientPassword { + return nil, nil + } + return nil, errors.New("password auth failed") + }, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + serverDone := make(chan struct{}) + go func() { + defer func() { serverDone <- struct{}{} }() + conn, chans, reqs, err := NewServerConn(c2, serverConfig) + if err != nil { + return + } + _ = conn.Close() + + discarderDone := make(chan struct{}) + go func() { + defer func() { discarderDone <- struct{}{} }() + DiscardRequests(reqs) + }() + for newChannel := range chans { + newChannel.Reject(Prohibited, + "testserver not accepting requests") + } + + <-discarderDone + }() + + // Connect to testserver, expect KeyboardInteractive() to be not called, + // PasswordCallback() to be called and connection to succeed. + passwordCallbackCalled := false + clientConfig := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + RetryableAuthMethod(KeyboardInteractive(func(name, + instruction string, questions []string, + echos []bool) ([]string, error) { + t.Errorf("unexpected call to KeyboardInteractive()") + return []string{clientPassword}, nil + }), maxAuthTries), + RetryableAuthMethod(PasswordCallback(func() (secret string, + err error) { + t.Logf("PasswordCallback()") + passwordCallbackCalled = true + return clientPassword, nil + }), maxAuthTries), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + conn, _, _, err := NewClientConn(c1, "", clientConfig) + if err != nil { + t.Errorf("unexpected NewClientConn() error: %v", err) + } + if conn != nil { + conn.Close() + } + + // Wait for server to finish. + <-serverDone + + if !passwordCallbackCalled { + t.Errorf("expected PasswordCallback() to be called") + } +} diff --git a/tempfork/sshtest/ssh/client_test.go b/tempfork/sshtest/ssh/client_test.go new file mode 100644 index 000000000..2621f0ea5 --- /dev/null +++ b/tempfork/sshtest/ssh/client_test.go @@ -0,0 +1,367 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto/rand" + "errors" + "fmt" + "net" + "strings" + "testing" +) + +func TestClientVersion(t *testing.T) { + for _, tt := range []struct { + name string + version string + multiLine string + wantErr bool + }{ + { + name: "default version", + version: packageVersion, + }, + { + name: "custom version", + version: "SSH-2.0-CustomClientVersionString", + }, + { + name: "good multi line version", + version: packageVersion, + multiLine: strings.Repeat("ignored\r\n", 20), + }, + { + name: "bad multi line version", + version: packageVersion, + multiLine: "bad multi line version", + wantErr: true, + }, + { + name: "long multi line version", + version: packageVersion, + multiLine: strings.Repeat("long multi line version\r\n", 50)[:256], + wantErr: true, + }, + } { + t.Run(tt.name, func(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + go func() { + if tt.multiLine != "" { + c1.Write([]byte(tt.multiLine)) + } + NewClientConn(c1, "", &ClientConfig{ + ClientVersion: tt.version, + HostKeyCallback: InsecureIgnoreHostKey(), + }) + c1.Close() + }() + conf := &ServerConfig{NoClientAuth: true} + conf.AddHostKey(testSigners["rsa"]) + conn, _, _, err := NewServerConn(c2, conf) + if err == nil == tt.wantErr { + t.Fatalf("got err %v; wantErr %t", err, tt.wantErr) + } + if tt.wantErr { + // Don't verify the version on an expected error. + return + } + if got := string(conn.ClientVersion()); got != tt.version { + t.Fatalf("got %q; want %q", got, tt.version) + } + }) + } +} + +func TestHostKeyCheck(t *testing.T) { + for _, tt := range []struct { + name string + wantError string + key PublicKey + }{ + {"no callback", "must specify HostKeyCallback", nil}, + {"correct key", "", testSigners["rsa"].PublicKey()}, + {"mismatch", "mismatch", testSigners["ecdsa"].PublicKey()}, + } { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + serverConf := &ServerConfig{ + NoClientAuth: true, + } + serverConf.AddHostKey(testSigners["rsa"]) + + go NewServerConn(c1, serverConf) + clientConf := ClientConfig{ + User: "user", + } + if tt.key != nil { + clientConf.HostKeyCallback = FixedHostKey(tt.key) + } + + _, _, _, err = NewClientConn(c2, "", &clientConf) + if err != nil { + if tt.wantError == "" || !strings.Contains(err.Error(), tt.wantError) { + t.Errorf("%s: got error %q, missing %q", tt.name, err.Error(), tt.wantError) + } + } else if tt.wantError != "" { + t.Errorf("%s: succeeded, but want error string %q", tt.name, tt.wantError) + } + } +} + +func TestVerifyHostKeySignature(t *testing.T) { + for _, tt := range []struct { + key string + signAlgo string + verifyAlgo string + wantError string + }{ + {"rsa", KeyAlgoRSA, KeyAlgoRSA, ""}, + {"rsa", KeyAlgoRSASHA256, KeyAlgoRSASHA256, ""}, + {"rsa", KeyAlgoRSA, KeyAlgoRSASHA512, `ssh: invalid signature algorithm "ssh-rsa", expected "rsa-sha2-512"`}, + {"ed25519", KeyAlgoED25519, KeyAlgoED25519, ""}, + } { + key := testSigners[tt.key].PublicKey() + s, ok := testSigners[tt.key].(AlgorithmSigner) + if !ok { + t.Fatalf("needed an AlgorithmSigner") + } + sig, err := s.SignWithAlgorithm(rand.Reader, []byte("test"), tt.signAlgo) + if err != nil { + t.Fatalf("couldn't sign: %q", err) + } + + b := bytes.Buffer{} + writeString(&b, []byte(sig.Format)) + writeString(&b, sig.Blob) + + result := kexResult{Signature: b.Bytes(), H: []byte("test")} + + err = verifyHostKeySignature(key, tt.verifyAlgo, &result) + if err != nil { + if tt.wantError == "" || !strings.Contains(err.Error(), tt.wantError) { + t.Errorf("got error %q, expecting %q", err.Error(), tt.wantError) + } + } else if tt.wantError != "" { + t.Errorf("succeeded, but want error string %q", tt.wantError) + } + } +} + +func TestBannerCallback(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + serverConf := &ServerConfig{ + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + return &Permissions{}, nil + }, + BannerCallback: func(conn ConnMetadata) string { + return "Hello World" + }, + } + serverConf.AddHostKey(testSigners["rsa"]) + go NewServerConn(c1, serverConf) + + var receivedBanner string + var bannerCount int + clientConf := ClientConfig{ + Auth: []AuthMethod{ + Password("123"), + }, + User: "user", + HostKeyCallback: InsecureIgnoreHostKey(), + BannerCallback: func(message string) error { + bannerCount++ + receivedBanner = message + return nil + }, + } + + _, _, _, err = NewClientConn(c2, "", &clientConf) + if err != nil { + t.Fatal(err) + } + + if bannerCount != 1 { + t.Errorf("got %d banners; want 1", bannerCount) + } + + expected := "Hello World" + if receivedBanner != expected { + t.Fatalf("got %s; want %s", receivedBanner, expected) + } +} + +func TestNewClientConn(t *testing.T) { + errHostKeyMismatch := errors.New("host key mismatch") + + for _, tt := range []struct { + name string + user string + simulateHostKeyMismatch HostKeyCallback + }{ + { + name: "good user field for ConnMetadata", + user: "testuser", + }, + { + name: "empty user field for ConnMetadata", + user: "", + }, + { + name: "host key mismatch", + user: "testuser", + simulateHostKeyMismatch: func(hostname string, remote net.Addr, key PublicKey) error { + return fmt.Errorf("%w: %s", errHostKeyMismatch, bytes.TrimSpace(MarshalAuthorizedKey(key))) + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + serverConf := &ServerConfig{ + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + return &Permissions{}, nil + }, + } + serverConf.AddHostKey(testSigners["rsa"]) + go NewServerConn(c1, serverConf) + + clientConf := &ClientConfig{ + User: tt.user, + Auth: []AuthMethod{ + Password("testpw"), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + if tt.simulateHostKeyMismatch != nil { + clientConf.HostKeyCallback = tt.simulateHostKeyMismatch + } + + clientConn, _, _, err := NewClientConn(c2, "", clientConf) + if err != nil { + if tt.simulateHostKeyMismatch != nil && errors.Is(err, errHostKeyMismatch) { + return + } + t.Fatal(err) + } + + if userGot := clientConn.User(); userGot != tt.user { + t.Errorf("got user %q; want user %q", userGot, tt.user) + } + }) + } +} + +func TestUnsupportedAlgorithm(t *testing.T) { + for _, tt := range []struct { + name string + config Config + wantError string + }{ + { + "unsupported KEX", + Config{ + KeyExchanges: []string{"unsupported"}, + }, + "no common algorithm", + }, + { + "unsupported and supported KEXs", + Config{ + KeyExchanges: []string{"unsupported", kexAlgoCurve25519SHA256}, + }, + "", + }, + { + "unsupported cipher", + Config{ + Ciphers: []string{"unsupported"}, + }, + "no common algorithm", + }, + { + "unsupported and supported ciphers", + Config{ + Ciphers: []string{"unsupported", chacha20Poly1305ID}, + }, + "", + }, + { + "unsupported MAC", + Config{ + MACs: []string{"unsupported"}, + // MAC is used for non AAED ciphers. + Ciphers: []string{"aes256-ctr"}, + }, + "no common algorithm", + }, + { + "unsupported and supported MACs", + Config{ + MACs: []string{"unsupported", "hmac-sha2-256-etm@openssh.com"}, + // MAC is used for non AAED ciphers. + Ciphers: []string{"aes256-ctr"}, + }, + "", + }, + } { + t.Run(tt.name, func(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + serverConf := &ServerConfig{ + Config: tt.config, + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + return &Permissions{}, nil + }, + } + serverConf.AddHostKey(testSigners["rsa"]) + go NewServerConn(c1, serverConf) + + clientConf := &ClientConfig{ + User: "testuser", + Config: tt.config, + Auth: []AuthMethod{ + Password("testpw"), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + _, _, _, err = NewClientConn(c2, "", clientConf) + if err != nil { + if tt.wantError == "" || !strings.Contains(err.Error(), tt.wantError) { + t.Errorf("%s: got error %q, missing %q", tt.name, err.Error(), tt.wantError) + } + } else if tt.wantError != "" { + t.Errorf("%s: succeeded, but want error string %q", tt.name, tt.wantError) + } + }) + } +} diff --git a/tempfork/sshtest/ssh/common.go b/tempfork/sshtest/ssh/common.go new file mode 100644 index 000000000..7e9c2cbc6 --- /dev/null +++ b/tempfork/sshtest/ssh/common.go @@ -0,0 +1,476 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "crypto" + "crypto/rand" + "fmt" + "io" + "math" + "sync" + + _ "crypto/sha1" + _ "crypto/sha256" + _ "crypto/sha512" +) + +// These are string constants in the SSH protocol. +const ( + compressionNone = "none" + serviceUserAuth = "ssh-userauth" + serviceSSH = "ssh-connection" +) + +// supportedCiphers lists ciphers we support but might not recommend. +var supportedCiphers = []string{ + "aes128-ctr", "aes192-ctr", "aes256-ctr", + "aes128-gcm@openssh.com", gcm256CipherID, + chacha20Poly1305ID, + "arcfour256", "arcfour128", "arcfour", + aes128cbcID, + tripledescbcID, +} + +// preferredCiphers specifies the default preference for ciphers. +var preferredCiphers = []string{ + "aes128-gcm@openssh.com", gcm256CipherID, + chacha20Poly1305ID, + "aes128-ctr", "aes192-ctr", "aes256-ctr", +} + +// supportedKexAlgos specifies the supported key-exchange algorithms in +// preference order. +var supportedKexAlgos = []string{ + kexAlgoCurve25519SHA256, kexAlgoCurve25519SHA256LibSSH, + // P384 and P521 are not constant-time yet, but since we don't + // reuse ephemeral keys, using them for ECDH should be OK. + kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521, + kexAlgoDH14SHA256, kexAlgoDH16SHA512, kexAlgoDH14SHA1, + kexAlgoDH1SHA1, +} + +// serverForbiddenKexAlgos contains key exchange algorithms, that are forbidden +// for the server half. +var serverForbiddenKexAlgos = map[string]struct{}{ + kexAlgoDHGEXSHA1: {}, // server half implementation is only minimal to satisfy the automated tests + kexAlgoDHGEXSHA256: {}, // server half implementation is only minimal to satisfy the automated tests +} + +// preferredKexAlgos specifies the default preference for key-exchange +// algorithms in preference order. The diffie-hellman-group16-sha512 algorithm +// is disabled by default because it is a bit slower than the others. +var preferredKexAlgos = []string{ + kexAlgoCurve25519SHA256, kexAlgoCurve25519SHA256LibSSH, + kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521, + kexAlgoDH14SHA256, kexAlgoDH14SHA1, +} + +// supportedHostKeyAlgos specifies the supported host-key algorithms (i.e. methods +// of authenticating servers) in preference order. +var supportedHostKeyAlgos = []string{ + CertAlgoRSASHA256v01, CertAlgoRSASHA512v01, + CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, + CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoED25519v01, + + KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, + KeyAlgoRSASHA256, KeyAlgoRSASHA512, + KeyAlgoRSA, KeyAlgoDSA, + + KeyAlgoED25519, +} + +// supportedMACs specifies a default set of MAC algorithms in preference order. +// This is based on RFC 4253, section 6.4, but with hmac-md5 variants removed +// because they have reached the end of their useful life. +var supportedMACs = []string{ + "hmac-sha2-256-etm@openssh.com", "hmac-sha2-512-etm@openssh.com", "hmac-sha2-256", "hmac-sha2-512", "hmac-sha1", "hmac-sha1-96", +} + +var supportedCompressions = []string{compressionNone} + +// hashFuncs keeps the mapping of supported signature algorithms to their +// respective hashes needed for signing and verification. +var hashFuncs = map[string]crypto.Hash{ + KeyAlgoRSA: crypto.SHA1, + KeyAlgoRSASHA256: crypto.SHA256, + KeyAlgoRSASHA512: crypto.SHA512, + KeyAlgoDSA: crypto.SHA1, + KeyAlgoECDSA256: crypto.SHA256, + KeyAlgoECDSA384: crypto.SHA384, + KeyAlgoECDSA521: crypto.SHA512, + // KeyAlgoED25519 doesn't pre-hash. + KeyAlgoSKECDSA256: crypto.SHA256, + KeyAlgoSKED25519: crypto.SHA256, +} + +// algorithmsForKeyFormat returns the supported signature algorithms for a given +// public key format (PublicKey.Type), in order of preference. See RFC 8332, +// Section 2. See also the note in sendKexInit on backwards compatibility. +func algorithmsForKeyFormat(keyFormat string) []string { + switch keyFormat { + case KeyAlgoRSA: + return []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512, KeyAlgoRSA} + case CertAlgoRSAv01: + return []string{CertAlgoRSASHA256v01, CertAlgoRSASHA512v01, CertAlgoRSAv01} + default: + return []string{keyFormat} + } +} + +// isRSA returns whether algo is a supported RSA algorithm, including certificate +// algorithms. +func isRSA(algo string) bool { + algos := algorithmsForKeyFormat(KeyAlgoRSA) + return contains(algos, underlyingAlgo(algo)) +} + +func isRSACert(algo string) bool { + _, ok := certKeyAlgoNames[algo] + if !ok { + return false + } + return isRSA(algo) +} + +// supportedPubKeyAuthAlgos specifies the supported client public key +// authentication algorithms. Note that this doesn't include certificate types +// since those use the underlying algorithm. This list is sent to the client if +// it supports the server-sig-algs extension. Order is irrelevant. +var supportedPubKeyAuthAlgos = []string{ + KeyAlgoED25519, + KeyAlgoSKED25519, KeyAlgoSKECDSA256, + KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, + KeyAlgoRSASHA256, KeyAlgoRSASHA512, KeyAlgoRSA, + KeyAlgoDSA, +} + +// unexpectedMessageError results when the SSH message that we received didn't +// match what we wanted. +func unexpectedMessageError(expected, got uint8) error { + return fmt.Errorf("ssh: unexpected message type %d (expected %d)", got, expected) +} + +// parseError results from a malformed SSH message. +func parseError(tag uint8) error { + return fmt.Errorf("ssh: parse error in message type %d", tag) +} + +func findCommon(what string, client []string, server []string) (common string, err error) { + for _, c := range client { + for _, s := range server { + if c == s { + return c, nil + } + } + } + return "", fmt.Errorf("ssh: no common algorithm for %s; client offered: %v, server offered: %v", what, client, server) +} + +// directionAlgorithms records algorithm choices in one direction (either read or write) +type directionAlgorithms struct { + Cipher string + MAC string + Compression string +} + +// rekeyBytes returns a rekeying intervals in bytes. +func (a *directionAlgorithms) rekeyBytes() int64 { + // According to RFC 4344 block ciphers should rekey after + // 2^(BLOCKSIZE/4) blocks. For all AES flavors BLOCKSIZE is + // 128. + switch a.Cipher { + case "aes128-ctr", "aes192-ctr", "aes256-ctr", gcm128CipherID, gcm256CipherID, aes128cbcID: + return 16 * (1 << 32) + + } + + // For others, stick with RFC 4253 recommendation to rekey after 1 Gb of data. + return 1 << 30 +} + +var aeadCiphers = map[string]bool{ + gcm128CipherID: true, + gcm256CipherID: true, + chacha20Poly1305ID: true, +} + +type algorithms struct { + kex string + hostKey string + w directionAlgorithms + r directionAlgorithms +} + +func findAgreedAlgorithms(isClient bool, clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms, err error) { + result := &algorithms{} + + result.kex, err = findCommon("key exchange", clientKexInit.KexAlgos, serverKexInit.KexAlgos) + if err != nil { + return + } + + result.hostKey, err = findCommon("host key", clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos) + if err != nil { + return + } + + stoc, ctos := &result.w, &result.r + if isClient { + ctos, stoc = stoc, ctos + } + + ctos.Cipher, err = findCommon("client to server cipher", clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer) + if err != nil { + return + } + + stoc.Cipher, err = findCommon("server to client cipher", clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient) + if err != nil { + return + } + + if !aeadCiphers[ctos.Cipher] { + ctos.MAC, err = findCommon("client to server MAC", clientKexInit.MACsClientServer, serverKexInit.MACsClientServer) + if err != nil { + return + } + } + + if !aeadCiphers[stoc.Cipher] { + stoc.MAC, err = findCommon("server to client MAC", clientKexInit.MACsServerClient, serverKexInit.MACsServerClient) + if err != nil { + return + } + } + + ctos.Compression, err = findCommon("client to server compression", clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer) + if err != nil { + return + } + + stoc.Compression, err = findCommon("server to client compression", clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient) + if err != nil { + return + } + + return result, nil +} + +// If rekeythreshold is too small, we can't make any progress sending +// stuff. +const minRekeyThreshold uint64 = 256 + +// Config contains configuration data common to both ServerConfig and +// ClientConfig. +type Config struct { + // Rand provides the source of entropy for cryptographic + // primitives. If Rand is nil, the cryptographic random reader + // in package crypto/rand will be used. + Rand io.Reader + + // The maximum number of bytes sent or received after which a + // new key is negotiated. It must be at least 256. If + // unspecified, a size suitable for the chosen cipher is used. + RekeyThreshold uint64 + + // The allowed key exchanges algorithms. If unspecified then a default set + // of algorithms is used. Unsupported values are silently ignored. + KeyExchanges []string + + // The allowed cipher algorithms. If unspecified then a sensible default is + // used. Unsupported values are silently ignored. + Ciphers []string + + // The allowed MAC algorithms. If unspecified then a sensible default is + // used. Unsupported values are silently ignored. + MACs []string +} + +// SetDefaults sets sensible values for unset fields in config. This is +// exported for testing: Configs passed to SSH functions are copied and have +// default values set automatically. +func (c *Config) SetDefaults() { + if c.Rand == nil { + c.Rand = rand.Reader + } + if c.Ciphers == nil { + c.Ciphers = preferredCiphers + } + var ciphers []string + for _, c := range c.Ciphers { + if cipherModes[c] != nil { + // Ignore the cipher if we have no cipherModes definition. + ciphers = append(ciphers, c) + } + } + c.Ciphers = ciphers + + if c.KeyExchanges == nil { + c.KeyExchanges = preferredKexAlgos + } + var kexs []string + for _, k := range c.KeyExchanges { + if kexAlgoMap[k] != nil { + // Ignore the KEX if we have no kexAlgoMap definition. + kexs = append(kexs, k) + } + } + c.KeyExchanges = kexs + + if c.MACs == nil { + c.MACs = supportedMACs + } + var macs []string + for _, m := range c.MACs { + if macModes[m] != nil { + // Ignore the MAC if we have no macModes definition. + macs = append(macs, m) + } + } + c.MACs = macs + + if c.RekeyThreshold == 0 { + // cipher specific default + } else if c.RekeyThreshold < minRekeyThreshold { + c.RekeyThreshold = minRekeyThreshold + } else if c.RekeyThreshold >= math.MaxInt64 { + // Avoid weirdness if somebody uses -1 as a threshold. + c.RekeyThreshold = math.MaxInt64 + } +} + +// buildDataSignedForAuth returns the data that is signed in order to prove +// possession of a private key. See RFC 4252, section 7. algo is the advertised +// algorithm, and may be a certificate type. +func buildDataSignedForAuth(sessionID []byte, req userAuthRequestMsg, algo string, pubKey []byte) []byte { + data := struct { + Session []byte + Type byte + User string + Service string + Method string + Sign bool + Algo string + PubKey []byte + }{ + sessionID, + msgUserAuthRequest, + req.User, + req.Service, + req.Method, + true, + algo, + pubKey, + } + return Marshal(data) +} + +func appendU16(buf []byte, n uint16) []byte { + return append(buf, byte(n>>8), byte(n)) +} + +func appendU32(buf []byte, n uint32) []byte { + return append(buf, byte(n>>24), byte(n>>16), byte(n>>8), byte(n)) +} + +func appendU64(buf []byte, n uint64) []byte { + return append(buf, + byte(n>>56), byte(n>>48), byte(n>>40), byte(n>>32), + byte(n>>24), byte(n>>16), byte(n>>8), byte(n)) +} + +func appendInt(buf []byte, n int) []byte { + return appendU32(buf, uint32(n)) +} + +func appendString(buf []byte, s string) []byte { + buf = appendU32(buf, uint32(len(s))) + buf = append(buf, s...) + return buf +} + +func appendBool(buf []byte, b bool) []byte { + if b { + return append(buf, 1) + } + return append(buf, 0) +} + +// newCond is a helper to hide the fact that there is no usable zero +// value for sync.Cond. +func newCond() *sync.Cond { return sync.NewCond(new(sync.Mutex)) } + +// window represents the buffer available to clients +// wishing to write to a channel. +type window struct { + *sync.Cond + win uint32 // RFC 4254 5.2 says the window size can grow to 2^32-1 + writeWaiters int + closed bool +} + +// add adds win to the amount of window available +// for consumers. +func (w *window) add(win uint32) bool { + // a zero sized window adjust is a noop. + if win == 0 { + return true + } + w.L.Lock() + if w.win+win < win { + w.L.Unlock() + return false + } + w.win += win + // It is unusual that multiple goroutines would be attempting to reserve + // window space, but not guaranteed. Use broadcast to notify all waiters + // that additional window is available. + w.Broadcast() + w.L.Unlock() + return true +} + +// close sets the window to closed, so all reservations fail +// immediately. +func (w *window) close() { + w.L.Lock() + w.closed = true + w.Broadcast() + w.L.Unlock() +} + +// reserve reserves win from the available window capacity. +// If no capacity remains, reserve will block. reserve may +// return less than requested. +func (w *window) reserve(win uint32) (uint32, error) { + var err error + w.L.Lock() + w.writeWaiters++ + w.Broadcast() + for w.win == 0 && !w.closed { + w.Wait() + } + w.writeWaiters-- + if w.win < win { + win = w.win + } + w.win -= win + if w.closed { + err = io.EOF + } + w.L.Unlock() + return win, err +} + +// waitWriterBlocked waits until some goroutine is blocked for further +// writes. It is used in tests only. +func (w *window) waitWriterBlocked() { + w.Cond.L.Lock() + for w.writeWaiters == 0 { + w.Cond.Wait() + } + w.Cond.L.Unlock() +} diff --git a/tempfork/sshtest/ssh/common_test.go b/tempfork/sshtest/ssh/common_test.go new file mode 100644 index 000000000..a7beee8e8 --- /dev/null +++ b/tempfork/sshtest/ssh/common_test.go @@ -0,0 +1,176 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "reflect" + "testing" +) + +func TestFindAgreedAlgorithms(t *testing.T) { + initKex := func(k *kexInitMsg) { + if k.KexAlgos == nil { + k.KexAlgos = []string{"kex1"} + } + if k.ServerHostKeyAlgos == nil { + k.ServerHostKeyAlgos = []string{"hostkey1"} + } + if k.CiphersClientServer == nil { + k.CiphersClientServer = []string{"cipher1"} + + } + if k.CiphersServerClient == nil { + k.CiphersServerClient = []string{"cipher1"} + + } + if k.MACsClientServer == nil { + k.MACsClientServer = []string{"mac1"} + + } + if k.MACsServerClient == nil { + k.MACsServerClient = []string{"mac1"} + + } + if k.CompressionClientServer == nil { + k.CompressionClientServer = []string{"compression1"} + + } + if k.CompressionServerClient == nil { + k.CompressionServerClient = []string{"compression1"} + + } + if k.LanguagesClientServer == nil { + k.LanguagesClientServer = []string{"language1"} + + } + if k.LanguagesServerClient == nil { + k.LanguagesServerClient = []string{"language1"} + + } + } + + initDirAlgs := func(a *directionAlgorithms) { + if a.Cipher == "" { + a.Cipher = "cipher1" + } + if a.MAC == "" { + a.MAC = "mac1" + } + if a.Compression == "" { + a.Compression = "compression1" + } + } + + initAlgs := func(a *algorithms) { + if a.kex == "" { + a.kex = "kex1" + } + if a.hostKey == "" { + a.hostKey = "hostkey1" + } + initDirAlgs(&a.r) + initDirAlgs(&a.w) + } + + type testcase struct { + name string + clientIn, serverIn kexInitMsg + wantClient, wantServer algorithms + wantErr bool + } + + cases := []testcase{ + { + name: "standard", + }, + + { + name: "no common hostkey", + serverIn: kexInitMsg{ + ServerHostKeyAlgos: []string{"hostkey2"}, + }, + wantErr: true, + }, + + { + name: "no common kex", + serverIn: kexInitMsg{ + KexAlgos: []string{"kex2"}, + }, + wantErr: true, + }, + + { + name: "no common cipher", + serverIn: kexInitMsg{ + CiphersClientServer: []string{"cipher2"}, + }, + wantErr: true, + }, + + { + name: "client decides cipher", + serverIn: kexInitMsg{ + CiphersClientServer: []string{"cipher1", "cipher2"}, + CiphersServerClient: []string{"cipher2", "cipher3"}, + }, + clientIn: kexInitMsg{ + CiphersClientServer: []string{"cipher2", "cipher1"}, + CiphersServerClient: []string{"cipher3", "cipher2"}, + }, + wantClient: algorithms{ + r: directionAlgorithms{ + Cipher: "cipher3", + }, + w: directionAlgorithms{ + Cipher: "cipher2", + }, + }, + wantServer: algorithms{ + w: directionAlgorithms{ + Cipher: "cipher3", + }, + r: directionAlgorithms{ + Cipher: "cipher2", + }, + }, + }, + + // TODO(hanwen): fix and add tests for AEAD ignoring + // the MACs field + } + + for i := range cases { + initKex(&cases[i].clientIn) + initKex(&cases[i].serverIn) + initAlgs(&cases[i].wantClient) + initAlgs(&cases[i].wantServer) + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + serverAlgs, serverErr := findAgreedAlgorithms(false, &c.clientIn, &c.serverIn) + clientAlgs, clientErr := findAgreedAlgorithms(true, &c.clientIn, &c.serverIn) + + serverHasErr := serverErr != nil + clientHasErr := clientErr != nil + if c.wantErr != serverHasErr || c.wantErr != clientHasErr { + t.Fatalf("got client/server error (%v, %v), want hasError %v", + clientErr, serverErr, c.wantErr) + + } + if c.wantErr { + return + } + + if !reflect.DeepEqual(serverAlgs, &c.wantServer) { + t.Errorf("server: got algs %#v, want %#v", serverAlgs, &c.wantServer) + } + if !reflect.DeepEqual(clientAlgs, &c.wantClient) { + t.Errorf("server: got algs %#v, want %#v", clientAlgs, &c.wantClient) + } + }) + } +} diff --git a/tempfork/sshtest/ssh/connection.go b/tempfork/sshtest/ssh/connection.go new file mode 100644 index 000000000..8f345ee92 --- /dev/null +++ b/tempfork/sshtest/ssh/connection.go @@ -0,0 +1,143 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "fmt" + "net" +) + +// OpenChannelError is returned if the other side rejects an +// OpenChannel request. +type OpenChannelError struct { + Reason RejectionReason + Message string +} + +func (e *OpenChannelError) Error() string { + return fmt.Sprintf("ssh: rejected: %s (%s)", e.Reason, e.Message) +} + +// ConnMetadata holds metadata for the connection. +type ConnMetadata interface { + // User returns the user ID for this connection. + User() string + + // SessionID returns the session hash, also denoted by H. + SessionID() []byte + + // ClientVersion returns the client's version string as hashed + // into the session ID. + ClientVersion() []byte + + // ServerVersion returns the server's version string as hashed + // into the session ID. + ServerVersion() []byte + + // RemoteAddr returns the remote address for this connection. + RemoteAddr() net.Addr + + // LocalAddr returns the local address for this connection. + LocalAddr() net.Addr +} + +// Conn represents an SSH connection for both server and client roles. +// Conn is the basis for implementing an application layer, such +// as ClientConn, which implements the traditional shell access for +// clients. +type Conn interface { + ConnMetadata + + // SendRequest sends a global request, and returns the + // reply. If wantReply is true, it returns the response status + // and payload. See also RFC 4254, section 4. + SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) + + // OpenChannel tries to open an channel. If the request is + // rejected, it returns *OpenChannelError. On success it returns + // the SSH Channel and a Go channel for incoming, out-of-band + // requests. The Go channel must be serviced, or the + // connection will hang. + OpenChannel(name string, data []byte) (Channel, <-chan *Request, error) + + // Close closes the underlying network connection + Close() error + + // Wait blocks until the connection has shut down, and returns the + // error causing the shutdown. + Wait() error + + // TODO(hanwen): consider exposing: + // RequestKeyChange + // Disconnect +} + +// DiscardRequests consumes and rejects all requests from the +// passed-in channel. +func DiscardRequests(in <-chan *Request) { + for req := range in { + if req.WantReply { + req.Reply(false, nil) + } + } +} + +// A connection represents an incoming connection. +type connection struct { + transport *handshakeTransport + sshConn + + // The connection protocol. + *mux +} + +func (c *connection) Close() error { + return c.sshConn.conn.Close() +} + +// sshConn provides net.Conn metadata, but disallows direct reads and +// writes. +type sshConn struct { + conn net.Conn + + user string + sessionID []byte + clientVersion []byte + serverVersion []byte +} + +func dup(src []byte) []byte { + dst := make([]byte, len(src)) + copy(dst, src) + return dst +} + +func (c *sshConn) User() string { + return c.user +} + +func (c *sshConn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +func (c *sshConn) Close() error { + return c.conn.Close() +} + +func (c *sshConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *sshConn) SessionID() []byte { + return dup(c.sessionID) +} + +func (c *sshConn) ClientVersion() []byte { + return dup(c.clientVersion) +} + +func (c *sshConn) ServerVersion() []byte { + return dup(c.serverVersion) +} diff --git a/tempfork/sshtest/ssh/doc.go b/tempfork/sshtest/ssh/doc.go new file mode 100644 index 000000000..f5d352fe3 --- /dev/null +++ b/tempfork/sshtest/ssh/doc.go @@ -0,0 +1,23 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +Package ssh implements an SSH client and server. + +SSH is a transport security protocol, an authentication protocol and a +family of application protocols. The most typical application level +protocol is a remote shell and this is specifically implemented. However, +the multiplexed nature of SSH is exposed to users that wish to support +others. + +References: + + [PROTOCOL]: https://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL?rev=HEAD + [PROTOCOL.certkeys]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys?rev=HEAD + [SSH-PARAMETERS]: http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1 + +This package does not fall under the stability promise of the Go language itself, +so its API may be changed when pressing needs arise. +*/ +package ssh diff --git a/tempfork/sshtest/ssh/example_test.go b/tempfork/sshtest/ssh/example_test.go new file mode 100644 index 000000000..97b3b6aba --- /dev/null +++ b/tempfork/sshtest/ssh/example_test.go @@ -0,0 +1,400 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh_test + +import ( + "bufio" + "bytes" + "crypto/rand" + "crypto/rsa" + "fmt" + "log" + "net" + "net/http" + "os" + "path/filepath" + "strings" + "sync" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/terminal" +) + +func ExampleNewServerConn() { + // Public key authentication is done by comparing + // the public key of a received connection + // with the entries in the authorized_keys file. + authorizedKeysBytes, err := os.ReadFile("authorized_keys") + if err != nil { + log.Fatalf("Failed to load authorized_keys, err: %v", err) + } + + authorizedKeysMap := map[string]bool{} + for len(authorizedKeysBytes) > 0 { + pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes) + if err != nil { + log.Fatal(err) + } + + authorizedKeysMap[string(pubKey.Marshal())] = true + authorizedKeysBytes = rest + } + + // An SSH server is represented by a ServerConfig, which holds + // certificate details and handles authentication of ServerConns. + config := &ssh.ServerConfig{ + // Remove to disable password auth. + PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) { + // Should use constant-time compare (or better, salt+hash) in + // a production setting. + if c.User() == "testuser" && string(pass) == "tiger" { + return nil, nil + } + return nil, fmt.Errorf("password rejected for %q", c.User()) + }, + + // Remove to disable public key auth. + PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { + if authorizedKeysMap[string(pubKey.Marshal())] { + return &ssh.Permissions{ + // Record the public key used for authentication. + Extensions: map[string]string{ + "pubkey-fp": ssh.FingerprintSHA256(pubKey), + }, + }, nil + } + return nil, fmt.Errorf("unknown public key for %q", c.User()) + }, + } + + privateBytes, err := os.ReadFile("id_rsa") + if err != nil { + log.Fatal("Failed to load private key: ", err) + } + + private, err := ssh.ParsePrivateKey(privateBytes) + if err != nil { + log.Fatal("Failed to parse private key: ", err) + } + config.AddHostKey(private) + + // Once a ServerConfig has been configured, connections can be + // accepted. + listener, err := net.Listen("tcp", "0.0.0.0:2022") + if err != nil { + log.Fatal("failed to listen for connection: ", err) + } + nConn, err := listener.Accept() + if err != nil { + log.Fatal("failed to accept incoming connection: ", err) + } + + // Before use, a handshake must be performed on the incoming + // net.Conn. + conn, chans, reqs, err := ssh.NewServerConn(nConn, config) + if err != nil { + log.Fatal("failed to handshake: ", err) + } + log.Printf("logged in with key %s", conn.Permissions.Extensions["pubkey-fp"]) + + var wg sync.WaitGroup + defer wg.Wait() + + // The incoming Request channel must be serviced. + wg.Add(1) + go func() { + ssh.DiscardRequests(reqs) + wg.Done() + }() + + // Service the incoming Channel channel. + for newChannel := range chans { + // Channels have a type, depending on the application level + // protocol intended. In the case of a shell, the type is + // "session" and ServerShell may be used to present a simple + // terminal interface. + if newChannel.ChannelType() != "session" { + newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") + continue + } + channel, requests, err := newChannel.Accept() + if err != nil { + log.Fatalf("Could not accept channel: %v", err) + } + + // Sessions have out-of-band requests such as "shell", + // "pty-req" and "env". Here we handle only the + // "shell" request. + wg.Add(1) + go func(in <-chan *ssh.Request) { + for req := range in { + req.Reply(req.Type == "shell", nil) + } + wg.Done() + }(requests) + + term := terminal.NewTerminal(channel, "> ") + + wg.Add(1) + go func() { + defer func() { + channel.Close() + wg.Done() + }() + for { + line, err := term.ReadLine() + if err != nil { + break + } + fmt.Println(line) + } + }() + } +} + +func ExampleServerConfig_AddHostKey() { + // Minimal ServerConfig supporting only password authentication. + config := &ssh.ServerConfig{ + PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) { + // Should use constant-time compare (or better, salt+hash) in + // a production setting. + if c.User() == "testuser" && string(pass) == "tiger" { + return nil, nil + } + return nil, fmt.Errorf("password rejected for %q", c.User()) + }, + } + + privateBytes, err := os.ReadFile("id_rsa") + if err != nil { + log.Fatal("Failed to load private key: ", err) + } + + private, err := ssh.ParsePrivateKey(privateBytes) + if err != nil { + log.Fatal("Failed to parse private key: ", err) + } + // Restrict host key algorithms to disable ssh-rsa. + signer, err := ssh.NewSignerWithAlgorithms(private.(ssh.AlgorithmSigner), []string{ssh.KeyAlgoRSASHA256, ssh.KeyAlgoRSASHA512}) + if err != nil { + log.Fatal("Failed to create private key with restricted algorithms: ", err) + } + config.AddHostKey(signer) +} + +func ExampleClientConfig_HostKeyCallback() { + // Every client must provide a host key check. Here is a + // simple-minded parse of OpenSSH's known_hosts file + host := "hostname" + file, err := os.Open(filepath.Join(os.Getenv("HOME"), ".ssh", "known_hosts")) + if err != nil { + log.Fatal(err) + } + defer file.Close() + + scanner := bufio.NewScanner(file) + var hostKey ssh.PublicKey + for scanner.Scan() { + fields := strings.Split(scanner.Text(), " ") + if len(fields) != 3 { + continue + } + if strings.Contains(fields[0], host) { + var err error + hostKey, _, _, _, err = ssh.ParseAuthorizedKey(scanner.Bytes()) + if err != nil { + log.Fatalf("error parsing %q: %v", fields[2], err) + } + break + } + } + + if hostKey == nil { + log.Fatalf("no hostkey for %s", host) + } + + config := ssh.ClientConfig{ + User: os.Getenv("USER"), + HostKeyCallback: ssh.FixedHostKey(hostKey), + } + + _, err = ssh.Dial("tcp", host+":22", &config) + log.Println(err) +} + +func ExampleDial() { + var hostKey ssh.PublicKey + // An SSH client is represented with a ClientConn. + // + // To authenticate with the remote server you must pass at least one + // implementation of AuthMethod via the Auth field in ClientConfig, + // and provide a HostKeyCallback. + config := &ssh.ClientConfig{ + User: "username", + Auth: []ssh.AuthMethod{ + ssh.Password("yourpassword"), + }, + HostKeyCallback: ssh.FixedHostKey(hostKey), + } + client, err := ssh.Dial("tcp", "yourserver.com:22", config) + if err != nil { + log.Fatal("Failed to dial: ", err) + } + defer client.Close() + + // Each ClientConn can support multiple interactive sessions, + // represented by a Session. + session, err := client.NewSession() + if err != nil { + log.Fatal("Failed to create session: ", err) + } + defer session.Close() + + // Once a Session is created, you can execute a single command on + // the remote side using the Run method. + var b bytes.Buffer + session.Stdout = &b + if err := session.Run("/usr/bin/whoami"); err != nil { + log.Fatal("Failed to run: " + err.Error()) + } + fmt.Println(b.String()) +} + +func ExamplePublicKeys() { + var hostKey ssh.PublicKey + // A public key may be used to authenticate against the remote + // server by using an unencrypted PEM-encoded private key file. + // + // If you have an encrypted private key, the crypto/x509 package + // can be used to decrypt it. + key, err := os.ReadFile("/home/user/.ssh/id_rsa") + if err != nil { + log.Fatalf("unable to read private key: %v", err) + } + + // Create the Signer for this private key. + signer, err := ssh.ParsePrivateKey(key) + if err != nil { + log.Fatalf("unable to parse private key: %v", err) + } + + config := &ssh.ClientConfig{ + User: "user", + Auth: []ssh.AuthMethod{ + // Use the PublicKeys method for remote authentication. + ssh.PublicKeys(signer), + }, + HostKeyCallback: ssh.FixedHostKey(hostKey), + } + + // Connect to the remote server and perform the SSH handshake. + client, err := ssh.Dial("tcp", "host.com:22", config) + if err != nil { + log.Fatalf("unable to connect: %v", err) + } + defer client.Close() +} + +func ExampleClient_Listen() { + var hostKey ssh.PublicKey + config := &ssh.ClientConfig{ + User: "username", + Auth: []ssh.AuthMethod{ + ssh.Password("password"), + }, + HostKeyCallback: ssh.FixedHostKey(hostKey), + } + // Dial your ssh server. + conn, err := ssh.Dial("tcp", "localhost:22", config) + if err != nil { + log.Fatal("unable to connect: ", err) + } + defer conn.Close() + + // Request the remote side to open port 8080 on all interfaces. + l, err := conn.Listen("tcp", "0.0.0.0:8080") + if err != nil { + log.Fatal("unable to register tcp forward: ", err) + } + defer l.Close() + + // Serve HTTP with your SSH server acting as a reverse proxy. + http.Serve(l, http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + fmt.Fprintf(resp, "Hello world!\n") + })) +} + +func ExampleSession_RequestPty() { + var hostKey ssh.PublicKey + // Create client config + config := &ssh.ClientConfig{ + User: "username", + Auth: []ssh.AuthMethod{ + ssh.Password("password"), + }, + HostKeyCallback: ssh.FixedHostKey(hostKey), + } + // Connect to ssh server + conn, err := ssh.Dial("tcp", "localhost:22", config) + if err != nil { + log.Fatal("unable to connect: ", err) + } + defer conn.Close() + // Create a session + session, err := conn.NewSession() + if err != nil { + log.Fatal("unable to create session: ", err) + } + defer session.Close() + // Set up terminal modes + modes := ssh.TerminalModes{ + ssh.ECHO: 0, // disable echoing + ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud + ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud + } + // Request pseudo terminal + if err := session.RequestPty("xterm", 40, 80, modes); err != nil { + log.Fatal("request for pseudo terminal failed: ", err) + } + // Start remote shell + if err := session.Shell(); err != nil { + log.Fatal("failed to start shell: ", err) + } +} + +func ExampleCertificate_SignCert() { + // Sign a certificate with a specific algorithm. + privateKey, err := rsa.GenerateKey(rand.Reader, 3072) + if err != nil { + log.Fatal("unable to generate RSA key: ", err) + } + publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey) + if err != nil { + log.Fatal("unable to get RSA public key: ", err) + } + caKey, err := rsa.GenerateKey(rand.Reader, 3072) + if err != nil { + log.Fatal("unable to generate CA key: ", err) + } + signer, err := ssh.NewSignerFromKey(caKey) + if err != nil { + log.Fatal("unable to generate signer from key: ", err) + } + mas, err := ssh.NewSignerWithAlgorithms(signer.(ssh.AlgorithmSigner), []string{ssh.KeyAlgoRSASHA256}) + if err != nil { + log.Fatal("unable to create signer with algorithms: ", err) + } + certificate := ssh.Certificate{ + Key: publicKey, + CertType: ssh.UserCert, + } + if err := certificate.SignCert(rand.Reader, mas); err != nil { + log.Fatal("unable to sign certificate: ", err) + } + // Save the public key to a file and check that rsa-sha-256 is used for + // signing: + // ssh-keygen -L -f + fmt.Println(string(ssh.MarshalAuthorizedKey(&certificate))) +} diff --git a/tempfork/sshtest/ssh/handshake.go b/tempfork/sshtest/ssh/handshake.go new file mode 100644 index 000000000..fef687db0 --- /dev/null +++ b/tempfork/sshtest/ssh/handshake.go @@ -0,0 +1,816 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "crypto/rand" + "errors" + "fmt" + "io" + "log" + "net" + "strings" + "sync" +) + +// debugHandshake, if set, prints messages sent and received. Key +// exchange messages are printed as if DH were used, so the debug +// messages are wrong when using ECDH. +const debugHandshake = false + +// chanSize sets the amount of buffering SSH connections. This is +// primarily for testing: setting chanSize=0 uncovers deadlocks more +// quickly. +const chanSize = 16 + +// keyingTransport is a packet based transport that supports key +// changes. It need not be thread-safe. It should pass through +// msgNewKeys in both directions. +type keyingTransport interface { + packetConn + + // prepareKeyChange sets up a key change. The key change for a + // direction will be effected if a msgNewKeys message is sent + // or received. + prepareKeyChange(*algorithms, *kexResult) error + + // setStrictMode sets the strict KEX mode, notably triggering + // sequence number resets on sending or receiving msgNewKeys. + // If the sequence number is already > 1 when setStrictMode + // is called, an error is returned. + setStrictMode() error + + // setInitialKEXDone indicates to the transport that the initial key exchange + // was completed + setInitialKEXDone() +} + +// handshakeTransport implements rekeying on top of a keyingTransport +// and offers a thread-safe writePacket() interface. +type handshakeTransport struct { + conn keyingTransport + config *Config + + serverVersion []byte + clientVersion []byte + + // hostKeys is non-empty if we are the server. In that case, + // it contains all host keys that can be used to sign the + // connection. + hostKeys []Signer + + // publicKeyAuthAlgorithms is non-empty if we are the server. In that case, + // it contains the supported client public key authentication algorithms. + publicKeyAuthAlgorithms []string + + // hostKeyAlgorithms is non-empty if we are the client. In that case, + // we accept these key types from the server as host key. + hostKeyAlgorithms []string + + // On read error, incoming is closed, and readError is set. + incoming chan []byte + readError error + + mu sync.Mutex + writeError error + sentInitPacket []byte + sentInitMsg *kexInitMsg + pendingPackets [][]byte // Used when a key exchange is in progress. + writePacketsLeft uint32 + writeBytesLeft int64 + userAuthComplete bool // whether the user authentication phase is complete + + // If the read loop wants to schedule a kex, it pings this + // channel, and the write loop will send out a kex + // message. + requestKex chan struct{} + + // If the other side requests or confirms a kex, its kexInit + // packet is sent here for the write loop to find it. + startKex chan *pendingKex + kexLoopDone chan struct{} // closed (with writeError non-nil) when kexLoop exits + + // data for host key checking + hostKeyCallback HostKeyCallback + dialAddress string + remoteAddr net.Addr + + // bannerCallback is non-empty if we are the client and it has been set in + // ClientConfig. In that case it is called during the user authentication + // dance to handle a custom server's message. + bannerCallback BannerCallback + + // Algorithms agreed in the last key exchange. + algorithms *algorithms + + // Counters exclusively owned by readLoop. + readPacketsLeft uint32 + readBytesLeft int64 + + // The session ID or nil if first kex did not complete yet. + sessionID []byte + + // strictMode indicates if the other side of the handshake indicated + // that we should be following the strict KEX protocol restrictions. + strictMode bool +} + +type pendingKex struct { + otherInit []byte + done chan error +} + +func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport { + t := &handshakeTransport{ + conn: conn, + serverVersion: serverVersion, + clientVersion: clientVersion, + incoming: make(chan []byte, chanSize), + requestKex: make(chan struct{}, 1), + startKex: make(chan *pendingKex), + kexLoopDone: make(chan struct{}), + + config: config, + } + t.resetReadThresholds() + t.resetWriteThresholds() + + // We always start with a mandatory key exchange. + t.requestKex <- struct{}{} + return t +} + +func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport { + t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion) + t.dialAddress = dialAddr + t.remoteAddr = addr + t.hostKeyCallback = config.HostKeyCallback + t.bannerCallback = config.BannerCallback + if config.HostKeyAlgorithms != nil { + t.hostKeyAlgorithms = config.HostKeyAlgorithms + } else { + t.hostKeyAlgorithms = supportedHostKeyAlgos + } + go t.readLoop() + go t.kexLoop() + return t +} + +func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport { + t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion) + t.hostKeys = config.hostKeys + t.publicKeyAuthAlgorithms = config.PublicKeyAuthAlgorithms + go t.readLoop() + go t.kexLoop() + return t +} + +func (t *handshakeTransport) getSessionID() []byte { + return t.sessionID +} + +// waitSession waits for the session to be established. This should be +// the first thing to call after instantiating handshakeTransport. +func (t *handshakeTransport) waitSession() error { + p, err := t.readPacket() + if err != nil { + return err + } + if p[0] != msgNewKeys { + return fmt.Errorf("ssh: first packet should be msgNewKeys") + } + + return nil +} + +func (t *handshakeTransport) id() string { + if len(t.hostKeys) > 0 { + return "server" + } + return "client" +} + +func (t *handshakeTransport) printPacket(p []byte, write bool) { + action := "got" + if write { + action = "sent" + } + + if p[0] == msgChannelData || p[0] == msgChannelExtendedData { + log.Printf("%s %s data (packet %d bytes)", t.id(), action, len(p)) + } else { + msg, err := decode(p) + log.Printf("%s %s %T %v (%v)", t.id(), action, msg, msg, err) + } +} + +func (t *handshakeTransport) readPacket() ([]byte, error) { + p, ok := <-t.incoming + if !ok { + return nil, t.readError + } + return p, nil +} + +func (t *handshakeTransport) readLoop() { + first := true + for { + p, err := t.readOnePacket(first) + first = false + if err != nil { + t.readError = err + close(t.incoming) + break + } + // If this is the first kex, and strict KEX mode is enabled, + // we don't ignore any messages, as they may be used to manipulate + // the packet sequence numbers. + if !(t.sessionID == nil && t.strictMode) && (p[0] == msgIgnore || p[0] == msgDebug) { + continue + } + t.incoming <- p + } + + // Stop writers too. + t.recordWriteError(t.readError) + + // Unblock the writer should it wait for this. + close(t.startKex) + + // Don't close t.requestKex; it's also written to from writePacket. +} + +func (t *handshakeTransport) pushPacket(p []byte) error { + if debugHandshake { + t.printPacket(p, true) + } + return t.conn.writePacket(p) +} + +func (t *handshakeTransport) getWriteError() error { + t.mu.Lock() + defer t.mu.Unlock() + return t.writeError +} + +func (t *handshakeTransport) recordWriteError(err error) { + t.mu.Lock() + defer t.mu.Unlock() + if t.writeError == nil && err != nil { + t.writeError = err + } +} + +func (t *handshakeTransport) requestKeyExchange() { + select { + case t.requestKex <- struct{}{}: + default: + // something already requested a kex, so do nothing. + } +} + +func (t *handshakeTransport) resetWriteThresholds() { + t.writePacketsLeft = packetRekeyThreshold + if t.config.RekeyThreshold > 0 { + t.writeBytesLeft = int64(t.config.RekeyThreshold) + } else if t.algorithms != nil { + t.writeBytesLeft = t.algorithms.w.rekeyBytes() + } else { + t.writeBytesLeft = 1 << 30 + } +} + +func (t *handshakeTransport) kexLoop() { + +write: + for t.getWriteError() == nil { + var request *pendingKex + var sent bool + + for request == nil || !sent { + var ok bool + select { + case request, ok = <-t.startKex: + if !ok { + break write + } + case <-t.requestKex: + break + } + + if !sent { + if err := t.sendKexInit(); err != nil { + t.recordWriteError(err) + break + } + sent = true + } + } + + if err := t.getWriteError(); err != nil { + if request != nil { + request.done <- err + } + break + } + + // We're not servicing t.requestKex, but that is OK: + // we never block on sending to t.requestKex. + + // We're not servicing t.startKex, but the remote end + // has just sent us a kexInitMsg, so it can't send + // another key change request, until we close the done + // channel on the pendingKex request. + + err := t.enterKeyExchange(request.otherInit) + + t.mu.Lock() + t.writeError = err + t.sentInitPacket = nil + t.sentInitMsg = nil + + t.resetWriteThresholds() + + // we have completed the key exchange. Since the + // reader is still blocked, it is safe to clear out + // the requestKex channel. This avoids the situation + // where: 1) we consumed our own request for the + // initial kex, and 2) the kex from the remote side + // caused another send on the requestKex channel, + clear: + for { + select { + case <-t.requestKex: + // + default: + break clear + } + } + + request.done <- t.writeError + + // kex finished. Push packets that we received while + // the kex was in progress. Don't look at t.startKex + // and don't increment writtenSinceKex: if we trigger + // another kex while we are still busy with the last + // one, things will become very confusing. + for _, p := range t.pendingPackets { + t.writeError = t.pushPacket(p) + if t.writeError != nil { + break + } + } + t.pendingPackets = t.pendingPackets[:0] + t.mu.Unlock() + } + + // Unblock reader. + t.conn.Close() + + // drain startKex channel. We don't service t.requestKex + // because nobody does blocking sends there. + for request := range t.startKex { + request.done <- t.getWriteError() + } + + // Mark that the loop is done so that Close can return. + close(t.kexLoopDone) +} + +// The protocol uses uint32 for packet counters, so we can't let them +// reach 1<<32. We will actually read and write more packets than +// this, though: the other side may send more packets, and after we +// hit this limit on writing we will send a few more packets for the +// key exchange itself. +const packetRekeyThreshold = (1 << 31) + +func (t *handshakeTransport) resetReadThresholds() { + t.readPacketsLeft = packetRekeyThreshold + if t.config.RekeyThreshold > 0 { + t.readBytesLeft = int64(t.config.RekeyThreshold) + } else if t.algorithms != nil { + t.readBytesLeft = t.algorithms.r.rekeyBytes() + } else { + t.readBytesLeft = 1 << 30 + } +} + +func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) { + p, err := t.conn.readPacket() + if err != nil { + return nil, err + } + + if t.readPacketsLeft > 0 { + t.readPacketsLeft-- + } else { + t.requestKeyExchange() + } + + if t.readBytesLeft > 0 { + t.readBytesLeft -= int64(len(p)) + } else { + t.requestKeyExchange() + } + + if debugHandshake { + t.printPacket(p, false) + } + + if first && p[0] != msgKexInit { + return nil, fmt.Errorf("ssh: first packet should be msgKexInit") + } + + if p[0] != msgKexInit { + return p, nil + } + + firstKex := t.sessionID == nil + + kex := pendingKex{ + done: make(chan error, 1), + otherInit: p, + } + t.startKex <- &kex + err = <-kex.done + + if debugHandshake { + log.Printf("%s exited key exchange (first %v), err %v", t.id(), firstKex, err) + } + + if err != nil { + return nil, err + } + + t.resetReadThresholds() + + // By default, a key exchange is hidden from higher layers by + // translating it into msgIgnore. + successPacket := []byte{msgIgnore} + if firstKex { + // sendKexInit() for the first kex waits for + // msgNewKeys so the authentication process is + // guaranteed to happen over an encrypted transport. + successPacket = []byte{msgNewKeys} + } + + return successPacket, nil +} + +const ( + kexStrictClient = "kex-strict-c-v00@openssh.com" + kexStrictServer = "kex-strict-s-v00@openssh.com" +) + +// sendKexInit sends a key change message. +func (t *handshakeTransport) sendKexInit() error { + t.mu.Lock() + defer t.mu.Unlock() + if t.sentInitMsg != nil { + // kexInits may be sent either in response to the other side, + // or because our side wants to initiate a key change, so we + // may have already sent a kexInit. In that case, don't send a + // second kexInit. + return nil + } + + msg := &kexInitMsg{ + CiphersClientServer: t.config.Ciphers, + CiphersServerClient: t.config.Ciphers, + MACsClientServer: t.config.MACs, + MACsServerClient: t.config.MACs, + CompressionClientServer: supportedCompressions, + CompressionServerClient: supportedCompressions, + } + io.ReadFull(rand.Reader, msg.Cookie[:]) + + // We mutate the KexAlgos slice, in order to add the kex-strict extension algorithm, + // and possibly to add the ext-info extension algorithm. Since the slice may be the + // user owned KeyExchanges, we create our own slice in order to avoid using user + // owned memory by mistake. + msg.KexAlgos = make([]string, 0, len(t.config.KeyExchanges)+2) // room for kex-strict and ext-info + msg.KexAlgos = append(msg.KexAlgos, t.config.KeyExchanges...) + + isServer := len(t.hostKeys) > 0 + if isServer { + for _, k := range t.hostKeys { + // If k is a MultiAlgorithmSigner, we restrict the signature + // algorithms. If k is a AlgorithmSigner, presume it supports all + // signature algorithms associated with the key format. If k is not + // an AlgorithmSigner, we can only assume it only supports the + // algorithms that matches the key format. (This means that Sign + // can't pick a different default). + keyFormat := k.PublicKey().Type() + + switch s := k.(type) { + case MultiAlgorithmSigner: + for _, algo := range algorithmsForKeyFormat(keyFormat) { + if contains(s.Algorithms(), underlyingAlgo(algo)) { + msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, algo) + } + } + case AlgorithmSigner: + msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, algorithmsForKeyFormat(keyFormat)...) + default: + msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, keyFormat) + } + } + + if t.sessionID == nil { + msg.KexAlgos = append(msg.KexAlgos, kexStrictServer) + } + } else { + msg.ServerHostKeyAlgos = t.hostKeyAlgorithms + + // As a client we opt in to receiving SSH_MSG_EXT_INFO so we know what + // algorithms the server supports for public key authentication. See RFC + // 8308, Section 2.1. + // + // We also send the strict KEX mode extension algorithm, in order to opt + // into the strict KEX mode. + if firstKeyExchange := t.sessionID == nil; firstKeyExchange { + msg.KexAlgos = append(msg.KexAlgos, "ext-info-c") + msg.KexAlgos = append(msg.KexAlgos, kexStrictClient) + } + + } + + packet := Marshal(msg) + + // writePacket destroys the contents, so save a copy. + packetCopy := make([]byte, len(packet)) + copy(packetCopy, packet) + + if err := t.pushPacket(packetCopy); err != nil { + return err + } + + t.sentInitMsg = msg + t.sentInitPacket = packet + + return nil +} + +var errSendBannerPhase = errors.New("ssh: SendAuthBanner outside of authentication phase") + +func (t *handshakeTransport) writePacket(p []byte) error { + t.mu.Lock() + defer t.mu.Unlock() + + switch p[0] { + case msgKexInit: + return errors.New("ssh: only handshakeTransport can send kexInit") + case msgNewKeys: + return errors.New("ssh: only handshakeTransport can send newKeys") + case msgUserAuthBanner: + if t.userAuthComplete { + return errSendBannerPhase + } + case msgUserAuthSuccess: + t.userAuthComplete = true + } + + if t.writeError != nil { + return t.writeError + } + + if t.sentInitMsg != nil { + // Copy the packet so the writer can reuse the buffer. + cp := make([]byte, len(p)) + copy(cp, p) + t.pendingPackets = append(t.pendingPackets, cp) + return nil + } + + if t.writeBytesLeft > 0 { + t.writeBytesLeft -= int64(len(p)) + } else { + t.requestKeyExchange() + } + + if t.writePacketsLeft > 0 { + t.writePacketsLeft-- + } else { + t.requestKeyExchange() + } + + if err := t.pushPacket(p); err != nil { + t.writeError = err + } + + return nil +} + +func (t *handshakeTransport) Close() error { + // Close the connection. This should cause the readLoop goroutine to wake up + // and close t.startKex, which will shut down kexLoop if running. + err := t.conn.Close() + + // Wait for the kexLoop goroutine to complete. + // At that point we know that the readLoop goroutine is complete too, + // because kexLoop itself waits for readLoop to close the startKex channel. + <-t.kexLoopDone + + return err +} + +func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error { + if debugHandshake { + log.Printf("%s entered key exchange", t.id()) + } + + otherInit := &kexInitMsg{} + if err := Unmarshal(otherInitPacket, otherInit); err != nil { + return err + } + + magics := handshakeMagics{ + clientVersion: t.clientVersion, + serverVersion: t.serverVersion, + clientKexInit: otherInitPacket, + serverKexInit: t.sentInitPacket, + } + + clientInit := otherInit + serverInit := t.sentInitMsg + isClient := len(t.hostKeys) == 0 + if isClient { + clientInit, serverInit = serverInit, clientInit + + magics.clientKexInit = t.sentInitPacket + magics.serverKexInit = otherInitPacket + } + + var err error + t.algorithms, err = findAgreedAlgorithms(isClient, clientInit, serverInit) + if err != nil { + return err + } + + if t.sessionID == nil && ((isClient && contains(serverInit.KexAlgos, kexStrictServer)) || (!isClient && contains(clientInit.KexAlgos, kexStrictClient))) { + t.strictMode = true + if err := t.conn.setStrictMode(); err != nil { + return err + } + } + + // We don't send FirstKexFollows, but we handle receiving it. + // + // RFC 4253 section 7 defines the kex and the agreement method for + // first_kex_packet_follows. It states that the guessed packet + // should be ignored if the "kex algorithm and/or the host + // key algorithm is guessed wrong (server and client have + // different preferred algorithm), or if any of the other + // algorithms cannot be agreed upon". The other algorithms have + // already been checked above so the kex algorithm and host key + // algorithm are checked here. + if otherInit.FirstKexFollows && (clientInit.KexAlgos[0] != serverInit.KexAlgos[0] || clientInit.ServerHostKeyAlgos[0] != serverInit.ServerHostKeyAlgos[0]) { + // other side sent a kex message for the wrong algorithm, + // which we have to ignore. + if _, err := t.conn.readPacket(); err != nil { + return err + } + } + + kex, ok := kexAlgoMap[t.algorithms.kex] + if !ok { + return fmt.Errorf("ssh: unexpected key exchange algorithm %v", t.algorithms.kex) + } + + var result *kexResult + if len(t.hostKeys) > 0 { + result, err = t.server(kex, &magics) + } else { + result, err = t.client(kex, &magics) + } + + if err != nil { + return err + } + + firstKeyExchange := t.sessionID == nil + if firstKeyExchange { + t.sessionID = result.H + } + result.SessionID = t.sessionID + + if err := t.conn.prepareKeyChange(t.algorithms, result); err != nil { + return err + } + if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil { + return err + } + + // On the server side, after the first SSH_MSG_NEWKEYS, send a SSH_MSG_EXT_INFO + // message with the server-sig-algs extension if the client supports it. See + // RFC 8308, Sections 2.4 and 3.1, and [PROTOCOL], Section 1.9. + if !isClient && firstKeyExchange && contains(clientInit.KexAlgos, "ext-info-c") { + supportedPubKeyAuthAlgosList := strings.Join(t.publicKeyAuthAlgorithms, ",") + extInfo := &extInfoMsg{ + NumExtensions: 2, + Payload: make([]byte, 0, 4+15+4+len(supportedPubKeyAuthAlgosList)+4+16+4+1), + } + extInfo.Payload = appendInt(extInfo.Payload, len("server-sig-algs")) + extInfo.Payload = append(extInfo.Payload, "server-sig-algs"...) + extInfo.Payload = appendInt(extInfo.Payload, len(supportedPubKeyAuthAlgosList)) + extInfo.Payload = append(extInfo.Payload, supportedPubKeyAuthAlgosList...) + extInfo.Payload = appendInt(extInfo.Payload, len("ping@openssh.com")) + extInfo.Payload = append(extInfo.Payload, "ping@openssh.com"...) + extInfo.Payload = appendInt(extInfo.Payload, 1) + extInfo.Payload = append(extInfo.Payload, "0"...) + if err := t.conn.writePacket(Marshal(extInfo)); err != nil { + return err + } + } + + if packet, err := t.conn.readPacket(); err != nil { + return err + } else if packet[0] != msgNewKeys { + return unexpectedMessageError(msgNewKeys, packet[0]) + } + + if firstKeyExchange { + // Indicates to the transport that the first key exchange is completed + // after receiving SSH_MSG_NEWKEYS. + t.conn.setInitialKEXDone() + } + + return nil +} + +// algorithmSignerWrapper is an AlgorithmSigner that only supports the default +// key format algorithm. +// +// This is technically a violation of the AlgorithmSigner interface, but it +// should be unreachable given where we use this. Anyway, at least it returns an +// error instead of panicing or producing an incorrect signature. +type algorithmSignerWrapper struct { + Signer +} + +func (a algorithmSignerWrapper) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) { + if algorithm != underlyingAlgo(a.PublicKey().Type()) { + return nil, errors.New("ssh: internal error: algorithmSignerWrapper invoked with non-default algorithm") + } + return a.Sign(rand, data) +} + +func pickHostKey(hostKeys []Signer, algo string) AlgorithmSigner { + for _, k := range hostKeys { + if s, ok := k.(MultiAlgorithmSigner); ok { + if !contains(s.Algorithms(), underlyingAlgo(algo)) { + continue + } + } + + if algo == k.PublicKey().Type() { + return algorithmSignerWrapper{k} + } + + k, ok := k.(AlgorithmSigner) + if !ok { + continue + } + for _, a := range algorithmsForKeyFormat(k.PublicKey().Type()) { + if algo == a { + return k + } + } + } + return nil +} + +func (t *handshakeTransport) server(kex kexAlgorithm, magics *handshakeMagics) (*kexResult, error) { + hostKey := pickHostKey(t.hostKeys, t.algorithms.hostKey) + if hostKey == nil { + return nil, errors.New("ssh: internal error: negotiated unsupported signature type") + } + + r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey, t.algorithms.hostKey) + return r, err +} + +func (t *handshakeTransport) client(kex kexAlgorithm, magics *handshakeMagics) (*kexResult, error) { + result, err := kex.Client(t.conn, t.config.Rand, magics) + if err != nil { + return nil, err + } + + hostKey, err := ParsePublicKey(result.HostKey) + if err != nil { + return nil, err + } + + if err := verifyHostKeySignature(hostKey, t.algorithms.hostKey, result); err != nil { + return nil, err + } + + err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey) + if err != nil { + return nil, err + } + + return result, nil +} diff --git a/tempfork/sshtest/ssh/handshake_test.go b/tempfork/sshtest/ssh/handshake_test.go new file mode 100644 index 000000000..2bc607b64 --- /dev/null +++ b/tempfork/sshtest/ssh/handshake_test.go @@ -0,0 +1,1021 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto/rand" + "errors" + "fmt" + "io" + "net" + "reflect" + "runtime" + "strings" + "sync" + "testing" +) + +type testChecker struct { + calls []string +} + +func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error { + if dialAddr == "bad" { + return fmt.Errorf("dialAddr is bad") + } + + if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil { + return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr) + } + + t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal())) + + return nil +} + +// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and +// therefore is buffered (net.Pipe deadlocks if both sides start with +// a write.) +func netPipe() (net.Conn, net.Conn, error) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + listener, err = net.Listen("tcp", "[::1]:0") + if err != nil { + return nil, nil, err + } + } + defer listener.Close() + c1, err := net.Dial("tcp", listener.Addr().String()) + if err != nil { + return nil, nil, err + } + + c2, err := listener.Accept() + if err != nil { + c1.Close() + return nil, nil, err + } + + return c1, c2, nil +} + +// noiseTransport inserts ignore messages to check that the read loop +// and the key exchange filters out these messages. +type noiseTransport struct { + keyingTransport +} + +func (t *noiseTransport) writePacket(p []byte) error { + ignore := []byte{msgIgnore} + if err := t.keyingTransport.writePacket(ignore); err != nil { + return err + } + debug := []byte{msgDebug, 1, 2, 3} + if err := t.keyingTransport.writePacket(debug); err != nil { + return err + } + + return t.keyingTransport.writePacket(p) +} + +func addNoiseTransport(t keyingTransport) keyingTransport { + return &noiseTransport{t} +} + +// handshakePair creates two handshakeTransports connected with each +// other. If the noise argument is true, both transports will try to +// confuse the other side by sending ignore and debug messages. +func handshakePair(clientConf *ClientConfig, addr string, noise bool) (client *handshakeTransport, server *handshakeTransport, err error) { + a, b, err := netPipe() + if err != nil { + return nil, nil, err + } + + var trC, trS keyingTransport + + trC = newTransport(a, rand.Reader, true) + trS = newTransport(b, rand.Reader, false) + if noise { + trC = addNoiseTransport(trC) + trS = addNoiseTransport(trS) + } + clientConf.SetDefaults() + + v := []byte("version") + client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr()) + + serverConf := &ServerConfig{} + serverConf.AddHostKey(testSigners["ecdsa"]) + serverConf.AddHostKey(testSigners["rsa"]) + serverConf.SetDefaults() + server = newServerTransport(trS, v, v, serverConf) + + if err := server.waitSession(); err != nil { + return nil, nil, fmt.Errorf("server.waitSession: %v", err) + } + if err := client.waitSession(); err != nil { + return nil, nil, fmt.Errorf("client.waitSession: %v", err) + } + + return client, server, nil +} + +func TestHandshakeBasic(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("see golang.org/issue/7237") + } + + checker := &syncChecker{ + waitCall: make(chan int, 10), + called: make(chan int, 10), + } + + checker.waitCall <- 1 + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false) + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + + defer trC.Close() + defer trS.Close() + + // Let first kex complete normally. + <-checker.called + + clientDone := make(chan int, 0) + gotHalf := make(chan int, 0) + const N = 20 + errorCh := make(chan error, 1) + + go func() { + defer close(clientDone) + // Client writes a bunch of stuff, and does a key + // change in the middle. This should not confuse the + // handshake in progress. We do this twice, so we test + // that the packet buffer is reset correctly. + for i := 0; i < N; i++ { + p := []byte{msgRequestSuccess, byte(i)} + if err := trC.writePacket(p); err != nil { + errorCh <- err + trC.Close() + return + } + if (i % 10) == 5 { + <-gotHalf + // halfway through, we request a key change. + trC.requestKeyExchange() + + // Wait until we can be sure the key + // change has really started before we + // write more. + <-checker.called + } + if (i % 10) == 7 { + // write some packets until the kex + // completes, to test buffering of + // packets. + checker.waitCall <- 1 + } + } + errorCh <- nil + }() + + // Server checks that client messages come in cleanly + i := 0 + for ; i < N; i++ { + p, err := trS.readPacket() + if err != nil && err != io.EOF { + t.Fatalf("server error: %v", err) + } + if (i % 10) == 5 { + gotHalf <- 1 + } + + want := []byte{msgRequestSuccess, byte(i)} + if bytes.Compare(p, want) != 0 { + t.Errorf("message %d: got %v, want %v", i, p, want) + } + } + <-clientDone + if err := <-errorCh; err != nil { + t.Fatalf("sendPacket: %v", err) + } + if i != N { + t.Errorf("received %d messages, want 10.", i) + } + + close(checker.called) + if _, ok := <-checker.called; ok { + // If all went well, we registered exactly 2 key changes: one + // that establishes the session, and one that we requested + // additionally. + t.Fatalf("got another host key checks after 2 handshakes") + } +} + +func TestForceFirstKex(t *testing.T) { + // like handshakePair, but must access the keyingTransport. + checker := &testChecker{} + clientConf := &ClientConfig{HostKeyCallback: checker.Check} + a, b, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + + var trC, trS keyingTransport + + trC = newTransport(a, rand.Reader, true) + + // This is the disallowed packet: + trC.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})) + + // Rest of the setup. + trS = newTransport(b, rand.Reader, false) + clientConf.SetDefaults() + + v := []byte("version") + client := newClientTransport(trC, v, v, clientConf, "addr", a.RemoteAddr()) + + serverConf := &ServerConfig{} + serverConf.AddHostKey(testSigners["ecdsa"]) + serverConf.AddHostKey(testSigners["rsa"]) + serverConf.SetDefaults() + server := newServerTransport(trS, v, v, serverConf) + + defer client.Close() + defer server.Close() + + // We setup the initial key exchange, but the remote side + // tries to send serviceRequestMsg in cleartext, which is + // disallowed. + + if err := server.waitSession(); err == nil { + t.Errorf("server first kex init should reject unexpected packet") + } +} + +func TestHandshakeAutoRekeyWrite(t *testing.T) { + checker := &syncChecker{ + called: make(chan int, 10), + waitCall: nil, + } + clientConf := &ClientConfig{HostKeyCallback: checker.Check} + clientConf.RekeyThreshold = 500 + trC, trS, err := handshakePair(clientConf, "addr", false) + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + defer trC.Close() + defer trS.Close() + + input := make([]byte, 251) + input[0] = msgRequestSuccess + + done := make(chan int, 1) + const numPacket = 5 + go func() { + defer close(done) + j := 0 + for ; j < numPacket; j++ { + if p, err := trS.readPacket(); err != nil { + break + } else if !bytes.Equal(input, p) { + t.Errorf("got packet type %d, want %d", p[0], input[0]) + } + } + + if j != numPacket { + t.Errorf("got %d, want 5 messages", j) + } + }() + + <-checker.called + + for i := 0; i < numPacket; i++ { + p := make([]byte, len(input)) + copy(p, input) + if err := trC.writePacket(p); err != nil { + t.Errorf("writePacket: %v", err) + } + if i == 2 { + // Make sure the kex is in progress. + <-checker.called + } + + } + <-done +} + +type syncChecker struct { + waitCall chan int + called chan int +} + +func (c *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error { + c.called <- 1 + if c.waitCall != nil { + <-c.waitCall + } + return nil +} + +func TestHandshakeAutoRekeyRead(t *testing.T) { + sync := &syncChecker{ + called: make(chan int, 2), + waitCall: nil, + } + clientConf := &ClientConfig{ + HostKeyCallback: sync.Check, + } + clientConf.RekeyThreshold = 500 + + trC, trS, err := handshakePair(clientConf, "addr", false) + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + defer trC.Close() + defer trS.Close() + + packet := make([]byte, 501) + packet[0] = msgRequestSuccess + if err := trS.writePacket(packet); err != nil { + t.Fatalf("writePacket: %v", err) + } + + // While we read out the packet, a key change will be + // initiated. + errorCh := make(chan error, 1) + go func() { + _, err := trC.readPacket() + errorCh <- err + }() + + if err := <-errorCh; err != nil { + t.Fatalf("readPacket(client): %v", err) + } + + <-sync.called +} + +// errorKeyingTransport generates errors after a given number of +// read/write operations. +type errorKeyingTransport struct { + packetConn + readLeft, writeLeft int +} + +func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error { + return nil +} + +func (n *errorKeyingTransport) getSessionID() []byte { + return nil +} + +func (n *errorKeyingTransport) writePacket(packet []byte) error { + if n.writeLeft == 0 { + n.Close() + return errors.New("barf") + } + + n.writeLeft-- + return n.packetConn.writePacket(packet) +} + +func (n *errorKeyingTransport) readPacket() ([]byte, error) { + if n.readLeft == 0 { + n.Close() + return nil, errors.New("barf") + } + + n.readLeft-- + return n.packetConn.readPacket() +} + +func (n *errorKeyingTransport) setStrictMode() error { return nil } + +func (n *errorKeyingTransport) setInitialKEXDone() {} + +func TestHandshakeErrorHandlingRead(t *testing.T) { + for i := 0; i < 20; i++ { + testHandshakeErrorHandlingN(t, i, -1, false) + } +} + +func TestHandshakeErrorHandlingWrite(t *testing.T) { + for i := 0; i < 20; i++ { + testHandshakeErrorHandlingN(t, -1, i, false) + } +} + +func TestHandshakeErrorHandlingReadCoupled(t *testing.T) { + for i := 0; i < 20; i++ { + testHandshakeErrorHandlingN(t, i, -1, true) + } +} + +func TestHandshakeErrorHandlingWriteCoupled(t *testing.T) { + for i := 0; i < 20; i++ { + testHandshakeErrorHandlingN(t, -1, i, true) + } +} + +// testHandshakeErrorHandlingN runs handshakes, injecting errors. If +// handshakeTransport deadlocks, the go runtime will detect it and +// panic. +func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int, coupled bool) { + if (runtime.GOOS == "js" || runtime.GOOS == "wasip1") && runtime.GOARCH == "wasm" { + t.Skipf("skipping on %s/wasm; see golang.org/issue/32840", runtime.GOOS) + } + msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)}) + + a, b := memPipe() + defer a.Close() + defer b.Close() + + key := testSigners["ecdsa"] + serverConf := Config{RekeyThreshold: minRekeyThreshold} + serverConf.SetDefaults() + serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'}) + serverConn.hostKeys = []Signer{key} + go serverConn.readLoop() + go serverConn.kexLoop() + + clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold} + clientConf.SetDefaults() + clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'}) + clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()} + clientConn.hostKeyCallback = InsecureIgnoreHostKey() + go clientConn.readLoop() + go clientConn.kexLoop() + + var wg sync.WaitGroup + + for _, hs := range []packetConn{serverConn, clientConn} { + if !coupled { + wg.Add(2) + go func(c packetConn) { + for i := 0; ; i++ { + str := fmt.Sprintf("%08x", i) + strings.Repeat("x", int(minRekeyThreshold)/4-8) + err := c.writePacket(Marshal(&serviceRequestMsg{str})) + if err != nil { + break + } + } + wg.Done() + c.Close() + }(hs) + go func(c packetConn) { + for { + _, err := c.readPacket() + if err != nil { + break + } + } + wg.Done() + }(hs) + } else { + wg.Add(1) + go func(c packetConn) { + for { + _, err := c.readPacket() + if err != nil { + break + } + if err := c.writePacket(msg); err != nil { + break + } + + } + wg.Done() + }(hs) + } + } + wg.Wait() +} + +func TestDisconnect(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("see golang.org/issue/7237") + } + checker := &testChecker{} + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false) + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + + defer trC.Close() + defer trS.Close() + + trC.writePacket([]byte{msgRequestSuccess, 0, 0}) + errMsg := &disconnectMsg{ + Reason: 42, + Message: "such is life", + } + trC.writePacket(Marshal(errMsg)) + trC.writePacket([]byte{msgRequestSuccess, 0, 0}) + + packet, err := trS.readPacket() + if err != nil { + t.Fatalf("readPacket 1: %v", err) + } + if packet[0] != msgRequestSuccess { + t.Errorf("got packet %v, want packet type %d", packet, msgRequestSuccess) + } + + _, err = trS.readPacket() + if err == nil { + t.Errorf("readPacket 2 succeeded") + } else if !reflect.DeepEqual(err, errMsg) { + t.Errorf("got error %#v, want %#v", err, errMsg) + } + + _, err = trS.readPacket() + if err == nil { + t.Errorf("readPacket 3 succeeded") + } +} + +func TestHandshakeRekeyDefault(t *testing.T) { + clientConf := &ClientConfig{ + Config: Config{ + Ciphers: []string{"aes128-ctr"}, + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + trC, trS, err := handshakePair(clientConf, "addr", false) + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + defer trC.Close() + defer trS.Close() + + trC.writePacket([]byte{msgRequestSuccess, 0, 0}) + trC.Close() + + rgb := (1024 + trC.readBytesLeft) >> 30 + wgb := (1024 + trC.writeBytesLeft) >> 30 + + if rgb != 64 { + t.Errorf("got rekey after %dG read, want 64G", rgb) + } + if wgb != 64 { + t.Errorf("got rekey after %dG write, want 64G", wgb) + } +} + +func TestHandshakeAEADCipherNoMAC(t *testing.T) { + for _, cipher := range []string{chacha20Poly1305ID, gcm128CipherID} { + checker := &syncChecker{ + called: make(chan int, 1), + } + clientConf := &ClientConfig{ + Config: Config{ + Ciphers: []string{cipher}, + MACs: []string{}, + }, + HostKeyCallback: checker.Check, + } + trC, trS, err := handshakePair(clientConf, "addr", false) + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + defer trC.Close() + defer trS.Close() + + <-checker.called + } +} + +// TestNoSHA2Support tests a host key Signer that is not an AlgorithmSigner and +// therefore can't do SHA-2 signatures. Ensures the server does not advertise +// support for them in this case. +func TestNoSHA2Support(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + serverConf := &ServerConfig{ + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + return &Permissions{}, nil + }, + } + serverConf.AddHostKey(&legacyRSASigner{testSigners["rsa"]}) + go func() { + _, _, _, err := NewServerConn(c1, serverConf) + if err != nil { + t.Error(err) + } + }() + + clientConf := &ClientConfig{ + User: "test", + Auth: []AuthMethod{Password("testpw")}, + HostKeyCallback: FixedHostKey(testSigners["rsa"].PublicKey()), + } + + if _, _, _, err := NewClientConn(c2, "", clientConf); err != nil { + t.Fatal(err) + } +} + +func TestMultiAlgoSignerHandshake(t *testing.T) { + algorithmSigner, ok := testSigners["rsa"].(AlgorithmSigner) + if !ok { + t.Fatal("rsa test signer does not implement the AlgorithmSigner interface") + } + multiAlgoSigner, err := NewSignerWithAlgorithms(algorithmSigner, []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512}) + if err != nil { + t.Fatalf("unable to create multi algorithm signer: %v", err) + } + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + serverConf := &ServerConfig{ + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + return &Permissions{}, nil + }, + } + serverConf.AddHostKey(multiAlgoSigner) + go NewServerConn(c1, serverConf) + + clientConf := &ClientConfig{ + User: "test", + Auth: []AuthMethod{Password("testpw")}, + HostKeyCallback: FixedHostKey(testSigners["rsa"].PublicKey()), + HostKeyAlgorithms: []string{KeyAlgoRSASHA512}, + } + + if _, _, _, err := NewClientConn(c2, "", clientConf); err != nil { + t.Fatal(err) + } +} + +func TestMultiAlgoSignerNoCommonHostKeyAlgo(t *testing.T) { + algorithmSigner, ok := testSigners["rsa"].(AlgorithmSigner) + if !ok { + t.Fatal("rsa test signer does not implement the AlgorithmSigner interface") + } + multiAlgoSigner, err := NewSignerWithAlgorithms(algorithmSigner, []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512}) + if err != nil { + t.Fatalf("unable to create multi algorithm signer: %v", err) + } + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + // ssh-rsa is disabled server side + serverConf := &ServerConfig{ + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + return &Permissions{}, nil + }, + } + serverConf.AddHostKey(multiAlgoSigner) + go NewServerConn(c1, serverConf) + + // the client only supports ssh-rsa + clientConf := &ClientConfig{ + User: "test", + Auth: []AuthMethod{Password("testpw")}, + HostKeyCallback: FixedHostKey(testSigners["rsa"].PublicKey()), + HostKeyAlgorithms: []string{KeyAlgoRSA}, + } + + _, _, _, err = NewClientConn(c2, "", clientConf) + if err == nil { + t.Fatal("succeeded connecting with no common hostkey algorithm") + } +} + +func TestPickIncompatibleHostKeyAlgo(t *testing.T) { + algorithmSigner, ok := testSigners["rsa"].(AlgorithmSigner) + if !ok { + t.Fatal("rsa test signer does not implement the AlgorithmSigner interface") + } + multiAlgoSigner, err := NewSignerWithAlgorithms(algorithmSigner, []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512}) + if err != nil { + t.Fatalf("unable to create multi algorithm signer: %v", err) + } + signer := pickHostKey([]Signer{multiAlgoSigner}, KeyAlgoRSA) + if signer != nil { + t.Fatal("incompatible signer returned") + } +} + +func TestStrictKEXResetSeqFirstKEX(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("see golang.org/issue/7237") + } + + checker := &syncChecker{ + waitCall: make(chan int, 10), + called: make(chan int, 10), + } + + checker.waitCall <- 1 + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false) + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + <-checker.called + + t.Cleanup(func() { + trC.Close() + trS.Close() + }) + + // Throw away the msgExtInfo packet sent during the handshake by the server + _, err = trC.readPacket() + if err != nil { + t.Fatalf("readPacket failed: %s", err) + } + + // close the handshake transports before checking the sequence number to + // avoid races. + trC.Close() + trS.Close() + + // check that the sequence number counters. We reset after msgNewKeys, but + // then the server immediately writes msgExtInfo, and we close the + // transports so we expect read 2, write 0 on the client and read 1, write 1 + // on the server. + if trC.conn.(*transport).reader.seqNum != 2 || trC.conn.(*transport).writer.seqNum != 0 || + trS.conn.(*transport).reader.seqNum != 1 || trS.conn.(*transport).writer.seqNum != 1 { + t.Errorf( + "unexpected sequence counters:\nclient: reader %d (expected 2), writer %d (expected 0)\nserver: reader %d (expected 1), writer %d (expected 1)", + trC.conn.(*transport).reader.seqNum, + trC.conn.(*transport).writer.seqNum, + trS.conn.(*transport).reader.seqNum, + trS.conn.(*transport).writer.seqNum, + ) + } +} + +func TestStrictKEXResetSeqSuccessiveKEX(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("see golang.org/issue/7237") + } + + checker := &syncChecker{ + waitCall: make(chan int, 10), + called: make(chan int, 10), + } + + checker.waitCall <- 1 + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false) + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + <-checker.called + + t.Cleanup(func() { + trC.Close() + trS.Close() + }) + + // Throw away the msgExtInfo packet sent during the handshake by the server + _, err = trC.readPacket() + if err != nil { + t.Fatalf("readPacket failed: %s", err) + } + + // write and read five packets on either side to bump the sequence numbers + for i := 0; i < 5; i++ { + if err := trC.writePacket([]byte{msgRequestSuccess}); err != nil { + t.Fatalf("writePacket failed: %s", err) + } + if _, err := trS.readPacket(); err != nil { + t.Fatalf("readPacket failed: %s", err) + } + if err := trS.writePacket([]byte{msgRequestSuccess}); err != nil { + t.Fatalf("writePacket failed: %s", err) + } + if _, err := trC.readPacket(); err != nil { + t.Fatalf("readPacket failed: %s", err) + } + } + + // Request a key exchange, which should cause the sequence numbers to reset + checker.waitCall <- 1 + trC.requestKeyExchange() + <-checker.called + + // write a packet on the client, and then read it, to verify the key change has actually happened, since + // the HostKeyCallback is called _during_ the handshake, so isn't actually indicative of the handshake + // finishing. + dummyPacket := []byte{99} + if err := trS.writePacket(dummyPacket); err != nil { + t.Fatalf("writePacket failed: %s", err) + } + if p, err := trC.readPacket(); err != nil { + t.Fatalf("readPacket failed: %s", err) + } else if !bytes.Equal(p, dummyPacket) { + t.Fatalf("unexpected packet: got %x, want %x", p, dummyPacket) + } + + // close the handshake transports before checking the sequence number to + // avoid races. + trC.Close() + trS.Close() + + if trC.conn.(*transport).reader.seqNum != 2 || trC.conn.(*transport).writer.seqNum != 0 || + trS.conn.(*transport).reader.seqNum != 1 || trS.conn.(*transport).writer.seqNum != 1 { + t.Errorf( + "unexpected sequence counters:\nclient: reader %d (expected 2), writer %d (expected 0)\nserver: reader %d (expected 1), writer %d (expected 1)", + trC.conn.(*transport).reader.seqNum, + trC.conn.(*transport).writer.seqNum, + trS.conn.(*transport).reader.seqNum, + trS.conn.(*transport).writer.seqNum, + ) + } +} + +func TestSeqNumIncrease(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("see golang.org/issue/7237") + } + + checker := &syncChecker{ + waitCall: make(chan int, 10), + called: make(chan int, 10), + } + + checker.waitCall <- 1 + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false) + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + <-checker.called + + t.Cleanup(func() { + trC.Close() + trS.Close() + }) + + // Throw away the msgExtInfo packet sent during the handshake by the server + _, err = trC.readPacket() + if err != nil { + t.Fatalf("readPacket failed: %s", err) + } + + // write and read five packets on either side to bump the sequence numbers + for i := 0; i < 5; i++ { + if err := trC.writePacket([]byte{msgRequestSuccess}); err != nil { + t.Fatalf("writePacket failed: %s", err) + } + if _, err := trS.readPacket(); err != nil { + t.Fatalf("readPacket failed: %s", err) + } + if err := trS.writePacket([]byte{msgRequestSuccess}); err != nil { + t.Fatalf("writePacket failed: %s", err) + } + if _, err := trC.readPacket(); err != nil { + t.Fatalf("readPacket failed: %s", err) + } + } + + // close the handshake transports before checking the sequence number to + // avoid races. + trC.Close() + trS.Close() + + if trC.conn.(*transport).reader.seqNum != 7 || trC.conn.(*transport).writer.seqNum != 5 || + trS.conn.(*transport).reader.seqNum != 6 || trS.conn.(*transport).writer.seqNum != 6 { + t.Errorf( + "unexpected sequence counters:\nclient: reader %d (expected 7), writer %d (expected 5)\nserver: reader %d (expected 6), writer %d (expected 6)", + trC.conn.(*transport).reader.seqNum, + trC.conn.(*transport).writer.seqNum, + trS.conn.(*transport).reader.seqNum, + trS.conn.(*transport).writer.seqNum, + ) + } +} + +func TestStrictKEXUnexpectedMsg(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("see golang.org/issue/7237") + } + + // Check that unexpected messages during the handshake cause failure + _, _, err := handshakePair(&ClientConfig{HostKeyCallback: func(hostname string, remote net.Addr, key PublicKey) error { return nil }}, "addr", true) + if err == nil { + t.Fatal("handshake should fail when there are unexpected messages during the handshake") + } + + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: func(hostname string, remote net.Addr, key PublicKey) error { return nil }}, "addr", false) + if err != nil { + t.Fatalf("handshake failed: %s", err) + } + + // Check that ignore/debug pacekts are still ignored outside of the handshake + if err := trC.writePacket([]byte{msgIgnore}); err != nil { + t.Fatalf("writePacket failed: %s", err) + } + if err := trC.writePacket([]byte{msgDebug}); err != nil { + t.Fatalf("writePacket failed: %s", err) + } + dummyPacket := []byte{99} + if err := trC.writePacket(dummyPacket); err != nil { + t.Fatalf("writePacket failed: %s", err) + } + + if p, err := trS.readPacket(); err != nil { + t.Fatalf("readPacket failed: %s", err) + } else if !bytes.Equal(p, dummyPacket) { + t.Fatalf("unexpected packet: got %x, want %x", p, dummyPacket) + } +} + +func TestStrictKEXMixed(t *testing.T) { + // Test that we still support a mixed connection, where one side sends kex-strict but the other + // side doesn't. + + a, b, err := netPipe() + if err != nil { + t.Fatalf("netPipe failed: %s", err) + } + + var trC, trS keyingTransport + + trC = newTransport(a, rand.Reader, true) + trS = newTransport(b, rand.Reader, false) + trS = addNoiseTransport(trS) + + clientConf := &ClientConfig{HostKeyCallback: func(hostname string, remote net.Addr, key PublicKey) error { return nil }} + clientConf.SetDefaults() + + v := []byte("version") + client := newClientTransport(trC, v, v, clientConf, "addr", a.RemoteAddr()) + + serverConf := &ServerConfig{} + serverConf.AddHostKey(testSigners["ecdsa"]) + serverConf.AddHostKey(testSigners["rsa"]) + serverConf.SetDefaults() + + transport := newHandshakeTransport(trS, &serverConf.Config, []byte("version"), []byte("version")) + transport.hostKeys = serverConf.hostKeys + transport.publicKeyAuthAlgorithms = serverConf.PublicKeyAuthAlgorithms + + readOneFailure := make(chan error, 1) + go func() { + if _, err := transport.readOnePacket(true); err != nil { + readOneFailure <- err + } + }() + + // Basically sendKexInit, but without the kex-strict extension algorithm + msg := &kexInitMsg{ + KexAlgos: transport.config.KeyExchanges, + CiphersClientServer: transport.config.Ciphers, + CiphersServerClient: transport.config.Ciphers, + MACsClientServer: transport.config.MACs, + MACsServerClient: transport.config.MACs, + CompressionClientServer: supportedCompressions, + CompressionServerClient: supportedCompressions, + ServerHostKeyAlgos: []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512, KeyAlgoRSA}, + } + packet := Marshal(msg) + // writePacket destroys the contents, so save a copy. + packetCopy := make([]byte, len(packet)) + copy(packetCopy, packet) + if err := transport.pushPacket(packetCopy); err != nil { + t.Fatalf("pushPacket: %s", err) + } + transport.sentInitMsg = msg + transport.sentInitPacket = packet + + if err := transport.getWriteError(); err != nil { + t.Fatalf("getWriteError failed: %s", err) + } + var request *pendingKex + select { + case err = <-readOneFailure: + t.Fatalf("server readOnePacket failed: %s", err) + case request = <-transport.startKex: + break + } + + // We expect the following calls to fail if the side which does not support + // kex-strict sends unexpected/ignored packets during the handshake, even if + // the other side does support kex-strict. + + if err := transport.enterKeyExchange(request.otherInit); err != nil { + t.Fatalf("enterKeyExchange failed: %s", err) + } + if err := client.waitSession(); err != nil { + t.Fatalf("client.waitSession: %v", err) + } +} diff --git a/tempfork/sshtest/ssh/kex.go b/tempfork/sshtest/ssh/kex.go new file mode 100644 index 000000000..8a05f7990 --- /dev/null +++ b/tempfork/sshtest/ssh/kex.go @@ -0,0 +1,786 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/subtle" + "encoding/binary" + "errors" + "fmt" + "io" + "math/big" + + "golang.org/x/crypto/curve25519" +) + +const ( + kexAlgoDH1SHA1 = "diffie-hellman-group1-sha1" + kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1" + kexAlgoDH14SHA256 = "diffie-hellman-group14-sha256" + kexAlgoDH16SHA512 = "diffie-hellman-group16-sha512" + kexAlgoECDH256 = "ecdh-sha2-nistp256" + kexAlgoECDH384 = "ecdh-sha2-nistp384" + kexAlgoECDH521 = "ecdh-sha2-nistp521" + kexAlgoCurve25519SHA256LibSSH = "curve25519-sha256@libssh.org" + kexAlgoCurve25519SHA256 = "curve25519-sha256" + + // For the following kex only the client half contains a production + // ready implementation. The server half only consists of a minimal + // implementation to satisfy the automated tests. + kexAlgoDHGEXSHA1 = "diffie-hellman-group-exchange-sha1" + kexAlgoDHGEXSHA256 = "diffie-hellman-group-exchange-sha256" +) + +// kexResult captures the outcome of a key exchange. +type kexResult struct { + // Session hash. See also RFC 4253, section 8. + H []byte + + // Shared secret. See also RFC 4253, section 8. + K []byte + + // Host key as hashed into H. + HostKey []byte + + // Signature of H. + Signature []byte + + // A cryptographic hash function that matches the security + // level of the key exchange algorithm. It is used for + // calculating H, and for deriving keys from H and K. + Hash crypto.Hash + + // The session ID, which is the first H computed. This is used + // to derive key material inside the transport. + SessionID []byte +} + +// handshakeMagics contains data that is always included in the +// session hash. +type handshakeMagics struct { + clientVersion, serverVersion []byte + clientKexInit, serverKexInit []byte +} + +func (m *handshakeMagics) write(w io.Writer) { + writeString(w, m.clientVersion) + writeString(w, m.serverVersion) + writeString(w, m.clientKexInit) + writeString(w, m.serverKexInit) +} + +// kexAlgorithm abstracts different key exchange algorithms. +type kexAlgorithm interface { + // Server runs server-side key agreement, signing the result + // with a hostkey. algo is the negotiated algorithm, and may + // be a certificate type. + Server(p packetConn, rand io.Reader, magics *handshakeMagics, s AlgorithmSigner, algo string) (*kexResult, error) + + // Client runs the client-side key agreement. Caller is + // responsible for verifying the host key signature. + Client(p packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) +} + +// dhGroup is a multiplicative group suitable for implementing Diffie-Hellman key agreement. +type dhGroup struct { + g, p, pMinus1 *big.Int + hashFunc crypto.Hash +} + +func (group *dhGroup) diffieHellman(theirPublic, myPrivate *big.Int) (*big.Int, error) { + if theirPublic.Cmp(bigOne) <= 0 || theirPublic.Cmp(group.pMinus1) >= 0 { + return nil, errors.New("ssh: DH parameter out of bounds") + } + return new(big.Int).Exp(theirPublic, myPrivate, group.p), nil +} + +func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) { + var x *big.Int + for { + var err error + if x, err = rand.Int(randSource, group.pMinus1); err != nil { + return nil, err + } + if x.Sign() > 0 { + break + } + } + + X := new(big.Int).Exp(group.g, x, group.p) + kexDHInit := kexDHInitMsg{ + X: X, + } + if err := c.writePacket(Marshal(&kexDHInit)); err != nil { + return nil, err + } + + packet, err := c.readPacket() + if err != nil { + return nil, err + } + + var kexDHReply kexDHReplyMsg + if err = Unmarshal(packet, &kexDHReply); err != nil { + return nil, err + } + + ki, err := group.diffieHellman(kexDHReply.Y, x) + if err != nil { + return nil, err + } + + h := group.hashFunc.New() + magics.write(h) + writeString(h, kexDHReply.HostKey) + writeInt(h, X) + writeInt(h, kexDHReply.Y) + K := make([]byte, intLength(ki)) + marshalInt(K, ki) + h.Write(K) + + return &kexResult{ + H: h.Sum(nil), + K: K, + HostKey: kexDHReply.HostKey, + Signature: kexDHReply.Signature, + Hash: group.hashFunc, + }, nil +} + +func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv AlgorithmSigner, algo string) (result *kexResult, err error) { + packet, err := c.readPacket() + if err != nil { + return + } + var kexDHInit kexDHInitMsg + if err = Unmarshal(packet, &kexDHInit); err != nil { + return + } + + var y *big.Int + for { + if y, err = rand.Int(randSource, group.pMinus1); err != nil { + return + } + if y.Sign() > 0 { + break + } + } + + Y := new(big.Int).Exp(group.g, y, group.p) + ki, err := group.diffieHellman(kexDHInit.X, y) + if err != nil { + return nil, err + } + + hostKeyBytes := priv.PublicKey().Marshal() + + h := group.hashFunc.New() + magics.write(h) + writeString(h, hostKeyBytes) + writeInt(h, kexDHInit.X) + writeInt(h, Y) + + K := make([]byte, intLength(ki)) + marshalInt(K, ki) + h.Write(K) + + H := h.Sum(nil) + + // H is already a hash, but the hostkey signing will apply its + // own key-specific hash algorithm. + sig, err := signAndMarshal(priv, randSource, H, algo) + if err != nil { + return nil, err + } + + kexDHReply := kexDHReplyMsg{ + HostKey: hostKeyBytes, + Y: Y, + Signature: sig, + } + packet = Marshal(&kexDHReply) + + err = c.writePacket(packet) + return &kexResult{ + H: H, + K: K, + HostKey: hostKeyBytes, + Signature: sig, + Hash: group.hashFunc, + }, err +} + +// ecdh performs Elliptic Curve Diffie-Hellman key exchange as +// described in RFC 5656, section 4. +type ecdh struct { + curve elliptic.Curve +} + +func (kex *ecdh) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) { + ephKey, err := ecdsa.GenerateKey(kex.curve, rand) + if err != nil { + return nil, err + } + + kexInit := kexECDHInitMsg{ + ClientPubKey: elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y), + } + + serialized := Marshal(&kexInit) + if err := c.writePacket(serialized); err != nil { + return nil, err + } + + packet, err := c.readPacket() + if err != nil { + return nil, err + } + + var reply kexECDHReplyMsg + if err = Unmarshal(packet, &reply); err != nil { + return nil, err + } + + x, y, err := unmarshalECKey(kex.curve, reply.EphemeralPubKey) + if err != nil { + return nil, err + } + + // generate shared secret + secret, _ := kex.curve.ScalarMult(x, y, ephKey.D.Bytes()) + + h := ecHash(kex.curve).New() + magics.write(h) + writeString(h, reply.HostKey) + writeString(h, kexInit.ClientPubKey) + writeString(h, reply.EphemeralPubKey) + K := make([]byte, intLength(secret)) + marshalInt(K, secret) + h.Write(K) + + return &kexResult{ + H: h.Sum(nil), + K: K, + HostKey: reply.HostKey, + Signature: reply.Signature, + Hash: ecHash(kex.curve), + }, nil +} + +// unmarshalECKey parses and checks an EC key. +func unmarshalECKey(curve elliptic.Curve, pubkey []byte) (x, y *big.Int, err error) { + x, y = elliptic.Unmarshal(curve, pubkey) + if x == nil { + return nil, nil, errors.New("ssh: elliptic.Unmarshal failure") + } + if !validateECPublicKey(curve, x, y) { + return nil, nil, errors.New("ssh: public key not on curve") + } + return x, y, nil +} + +// validateECPublicKey checks that the point is a valid public key for +// the given curve. See [SEC1], 3.2.2 +func validateECPublicKey(curve elliptic.Curve, x, y *big.Int) bool { + if x.Sign() == 0 && y.Sign() == 0 { + return false + } + + if x.Cmp(curve.Params().P) >= 0 { + return false + } + + if y.Cmp(curve.Params().P) >= 0 { + return false + } + + if !curve.IsOnCurve(x, y) { + return false + } + + // We don't check if N * PubKey == 0, since + // + // - the NIST curves have cofactor = 1, so this is implicit. + // (We don't foresee an implementation that supports non NIST + // curves) + // + // - for ephemeral keys, we don't need to worry about small + // subgroup attacks. + return true +} + +func (kex *ecdh) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv AlgorithmSigner, algo string) (result *kexResult, err error) { + packet, err := c.readPacket() + if err != nil { + return nil, err + } + + var kexECDHInit kexECDHInitMsg + if err = Unmarshal(packet, &kexECDHInit); err != nil { + return nil, err + } + + clientX, clientY, err := unmarshalECKey(kex.curve, kexECDHInit.ClientPubKey) + if err != nil { + return nil, err + } + + // We could cache this key across multiple users/multiple + // connection attempts, but the benefit is small. OpenSSH + // generates a new key for each incoming connection. + ephKey, err := ecdsa.GenerateKey(kex.curve, rand) + if err != nil { + return nil, err + } + + hostKeyBytes := priv.PublicKey().Marshal() + + serializedEphKey := elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y) + + // generate shared secret + secret, _ := kex.curve.ScalarMult(clientX, clientY, ephKey.D.Bytes()) + + h := ecHash(kex.curve).New() + magics.write(h) + writeString(h, hostKeyBytes) + writeString(h, kexECDHInit.ClientPubKey) + writeString(h, serializedEphKey) + + K := make([]byte, intLength(secret)) + marshalInt(K, secret) + h.Write(K) + + H := h.Sum(nil) + + // H is already a hash, but the hostkey signing will apply its + // own key-specific hash algorithm. + sig, err := signAndMarshal(priv, rand, H, algo) + if err != nil { + return nil, err + } + + reply := kexECDHReplyMsg{ + EphemeralPubKey: serializedEphKey, + HostKey: hostKeyBytes, + Signature: sig, + } + + serialized := Marshal(&reply) + if err := c.writePacket(serialized); err != nil { + return nil, err + } + + return &kexResult{ + H: H, + K: K, + HostKey: reply.HostKey, + Signature: sig, + Hash: ecHash(kex.curve), + }, nil +} + +// ecHash returns the hash to match the given elliptic curve, see RFC +// 5656, section 6.2.1 +func ecHash(curve elliptic.Curve) crypto.Hash { + bitSize := curve.Params().BitSize + switch { + case bitSize <= 256: + return crypto.SHA256 + case bitSize <= 384: + return crypto.SHA384 + } + return crypto.SHA512 +} + +var kexAlgoMap = map[string]kexAlgorithm{} + +func init() { + // This is the group called diffie-hellman-group1-sha1 in + // RFC 4253 and Oakley Group 2 in RFC 2409. + p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF", 16) + kexAlgoMap[kexAlgoDH1SHA1] = &dhGroup{ + g: new(big.Int).SetInt64(2), + p: p, + pMinus1: new(big.Int).Sub(p, bigOne), + hashFunc: crypto.SHA1, + } + + // This are the groups called diffie-hellman-group14-sha1 and + // diffie-hellman-group14-sha256 in RFC 4253 and RFC 8268, + // and Oakley Group 14 in RFC 3526. + p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16) + group14 := &dhGroup{ + g: new(big.Int).SetInt64(2), + p: p, + pMinus1: new(big.Int).Sub(p, bigOne), + } + + kexAlgoMap[kexAlgoDH14SHA1] = &dhGroup{ + g: group14.g, p: group14.p, pMinus1: group14.pMinus1, + hashFunc: crypto.SHA1, + } + kexAlgoMap[kexAlgoDH14SHA256] = &dhGroup{ + g: group14.g, p: group14.p, pMinus1: group14.pMinus1, + hashFunc: crypto.SHA256, + } + + // This is the group called diffie-hellman-group16-sha512 in RFC + // 8268 and Oakley Group 16 in RFC 3526. + p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C934063199FFFFFFFFFFFFFFFF", 16) + + kexAlgoMap[kexAlgoDH16SHA512] = &dhGroup{ + g: new(big.Int).SetInt64(2), + p: p, + pMinus1: new(big.Int).Sub(p, bigOne), + hashFunc: crypto.SHA512, + } + + kexAlgoMap[kexAlgoECDH521] = &ecdh{elliptic.P521()} + kexAlgoMap[kexAlgoECDH384] = &ecdh{elliptic.P384()} + kexAlgoMap[kexAlgoECDH256] = &ecdh{elliptic.P256()} + kexAlgoMap[kexAlgoCurve25519SHA256] = &curve25519sha256{} + kexAlgoMap[kexAlgoCurve25519SHA256LibSSH] = &curve25519sha256{} + kexAlgoMap[kexAlgoDHGEXSHA1] = &dhGEXSHA{hashFunc: crypto.SHA1} + kexAlgoMap[kexAlgoDHGEXSHA256] = &dhGEXSHA{hashFunc: crypto.SHA256} +} + +// curve25519sha256 implements the curve25519-sha256 (formerly known as +// curve25519-sha256@libssh.org) key exchange method, as described in RFC 8731. +type curve25519sha256 struct{} + +type curve25519KeyPair struct { + priv [32]byte + pub [32]byte +} + +func (kp *curve25519KeyPair) generate(rand io.Reader) error { + if _, err := io.ReadFull(rand, kp.priv[:]); err != nil { + return err + } + curve25519.ScalarBaseMult(&kp.pub, &kp.priv) + return nil +} + +// curve25519Zeros is just an array of 32 zero bytes so that we have something +// convenient to compare against in order to reject curve25519 points with the +// wrong order. +var curve25519Zeros [32]byte + +func (kex *curve25519sha256) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) { + var kp curve25519KeyPair + if err := kp.generate(rand); err != nil { + return nil, err + } + if err := c.writePacket(Marshal(&kexECDHInitMsg{kp.pub[:]})); err != nil { + return nil, err + } + + packet, err := c.readPacket() + if err != nil { + return nil, err + } + + var reply kexECDHReplyMsg + if err = Unmarshal(packet, &reply); err != nil { + return nil, err + } + if len(reply.EphemeralPubKey) != 32 { + return nil, errors.New("ssh: peer's curve25519 public value has wrong length") + } + + var servPub, secret [32]byte + copy(servPub[:], reply.EphemeralPubKey) + curve25519.ScalarMult(&secret, &kp.priv, &servPub) + if subtle.ConstantTimeCompare(secret[:], curve25519Zeros[:]) == 1 { + return nil, errors.New("ssh: peer's curve25519 public value has wrong order") + } + + h := crypto.SHA256.New() + magics.write(h) + writeString(h, reply.HostKey) + writeString(h, kp.pub[:]) + writeString(h, reply.EphemeralPubKey) + + ki := new(big.Int).SetBytes(secret[:]) + K := make([]byte, intLength(ki)) + marshalInt(K, ki) + h.Write(K) + + return &kexResult{ + H: h.Sum(nil), + K: K, + HostKey: reply.HostKey, + Signature: reply.Signature, + Hash: crypto.SHA256, + }, nil +} + +func (kex *curve25519sha256) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv AlgorithmSigner, algo string) (result *kexResult, err error) { + packet, err := c.readPacket() + if err != nil { + return + } + var kexInit kexECDHInitMsg + if err = Unmarshal(packet, &kexInit); err != nil { + return + } + + if len(kexInit.ClientPubKey) != 32 { + return nil, errors.New("ssh: peer's curve25519 public value has wrong length") + } + + var kp curve25519KeyPair + if err := kp.generate(rand); err != nil { + return nil, err + } + + var clientPub, secret [32]byte + copy(clientPub[:], kexInit.ClientPubKey) + curve25519.ScalarMult(&secret, &kp.priv, &clientPub) + if subtle.ConstantTimeCompare(secret[:], curve25519Zeros[:]) == 1 { + return nil, errors.New("ssh: peer's curve25519 public value has wrong order") + } + + hostKeyBytes := priv.PublicKey().Marshal() + + h := crypto.SHA256.New() + magics.write(h) + writeString(h, hostKeyBytes) + writeString(h, kexInit.ClientPubKey) + writeString(h, kp.pub[:]) + + ki := new(big.Int).SetBytes(secret[:]) + K := make([]byte, intLength(ki)) + marshalInt(K, ki) + h.Write(K) + + H := h.Sum(nil) + + sig, err := signAndMarshal(priv, rand, H, algo) + if err != nil { + return nil, err + } + + reply := kexECDHReplyMsg{ + EphemeralPubKey: kp.pub[:], + HostKey: hostKeyBytes, + Signature: sig, + } + if err := c.writePacket(Marshal(&reply)); err != nil { + return nil, err + } + return &kexResult{ + H: H, + K: K, + HostKey: hostKeyBytes, + Signature: sig, + Hash: crypto.SHA256, + }, nil +} + +// dhGEXSHA implements the diffie-hellman-group-exchange-sha1 and +// diffie-hellman-group-exchange-sha256 key agreement protocols, +// as described in RFC 4419 +type dhGEXSHA struct { + hashFunc crypto.Hash +} + +const ( + dhGroupExchangeMinimumBits = 2048 + dhGroupExchangePreferredBits = 2048 + dhGroupExchangeMaximumBits = 8192 +) + +func (gex *dhGEXSHA) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) { + // Send GexRequest + kexDHGexRequest := kexDHGexRequestMsg{ + MinBits: dhGroupExchangeMinimumBits, + PreferedBits: dhGroupExchangePreferredBits, + MaxBits: dhGroupExchangeMaximumBits, + } + if err := c.writePacket(Marshal(&kexDHGexRequest)); err != nil { + return nil, err + } + + // Receive GexGroup + packet, err := c.readPacket() + if err != nil { + return nil, err + } + + var msg kexDHGexGroupMsg + if err = Unmarshal(packet, &msg); err != nil { + return nil, err + } + + // reject if p's bit length < dhGroupExchangeMinimumBits or > dhGroupExchangeMaximumBits + if msg.P.BitLen() < dhGroupExchangeMinimumBits || msg.P.BitLen() > dhGroupExchangeMaximumBits { + return nil, fmt.Errorf("ssh: server-generated gex p is out of range (%d bits)", msg.P.BitLen()) + } + + // Check if g is safe by verifying that 1 < g < p-1 + pMinusOne := new(big.Int).Sub(msg.P, bigOne) + if msg.G.Cmp(bigOne) <= 0 || msg.G.Cmp(pMinusOne) >= 0 { + return nil, fmt.Errorf("ssh: server provided gex g is not safe") + } + + // Send GexInit + pHalf := new(big.Int).Rsh(msg.P, 1) + x, err := rand.Int(randSource, pHalf) + if err != nil { + return nil, err + } + X := new(big.Int).Exp(msg.G, x, msg.P) + kexDHGexInit := kexDHGexInitMsg{ + X: X, + } + if err := c.writePacket(Marshal(&kexDHGexInit)); err != nil { + return nil, err + } + + // Receive GexReply + packet, err = c.readPacket() + if err != nil { + return nil, err + } + + var kexDHGexReply kexDHGexReplyMsg + if err = Unmarshal(packet, &kexDHGexReply); err != nil { + return nil, err + } + + if kexDHGexReply.Y.Cmp(bigOne) <= 0 || kexDHGexReply.Y.Cmp(pMinusOne) >= 0 { + return nil, errors.New("ssh: DH parameter out of bounds") + } + kInt := new(big.Int).Exp(kexDHGexReply.Y, x, msg.P) + + // Check if k is safe by verifying that k > 1 and k < p - 1 + if kInt.Cmp(bigOne) <= 0 || kInt.Cmp(pMinusOne) >= 0 { + return nil, fmt.Errorf("ssh: derived k is not safe") + } + + h := gex.hashFunc.New() + magics.write(h) + writeString(h, kexDHGexReply.HostKey) + binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMinimumBits)) + binary.Write(h, binary.BigEndian, uint32(dhGroupExchangePreferredBits)) + binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMaximumBits)) + writeInt(h, msg.P) + writeInt(h, msg.G) + writeInt(h, X) + writeInt(h, kexDHGexReply.Y) + K := make([]byte, intLength(kInt)) + marshalInt(K, kInt) + h.Write(K) + + return &kexResult{ + H: h.Sum(nil), + K: K, + HostKey: kexDHGexReply.HostKey, + Signature: kexDHGexReply.Signature, + Hash: gex.hashFunc, + }, nil +} + +// Server half implementation of the Diffie Hellman Key Exchange with SHA1 and SHA256. +// +// This is a minimal implementation to satisfy the automated tests. +func (gex dhGEXSHA) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv AlgorithmSigner, algo string) (result *kexResult, err error) { + // Receive GexRequest + packet, err := c.readPacket() + if err != nil { + return + } + var kexDHGexRequest kexDHGexRequestMsg + if err = Unmarshal(packet, &kexDHGexRequest); err != nil { + return + } + + // Send GexGroup + // This is the group called diffie-hellman-group14-sha1 in RFC + // 4253 and Oakley Group 14 in RFC 3526. + p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16) + g := big.NewInt(2) + + msg := &kexDHGexGroupMsg{ + P: p, + G: g, + } + if err := c.writePacket(Marshal(msg)); err != nil { + return nil, err + } + + // Receive GexInit + packet, err = c.readPacket() + if err != nil { + return + } + var kexDHGexInit kexDHGexInitMsg + if err = Unmarshal(packet, &kexDHGexInit); err != nil { + return + } + + pHalf := new(big.Int).Rsh(p, 1) + + y, err := rand.Int(randSource, pHalf) + if err != nil { + return + } + Y := new(big.Int).Exp(g, y, p) + + pMinusOne := new(big.Int).Sub(p, bigOne) + if kexDHGexInit.X.Cmp(bigOne) <= 0 || kexDHGexInit.X.Cmp(pMinusOne) >= 0 { + return nil, errors.New("ssh: DH parameter out of bounds") + } + kInt := new(big.Int).Exp(kexDHGexInit.X, y, p) + + hostKeyBytes := priv.PublicKey().Marshal() + + h := gex.hashFunc.New() + magics.write(h) + writeString(h, hostKeyBytes) + binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMinimumBits)) + binary.Write(h, binary.BigEndian, uint32(dhGroupExchangePreferredBits)) + binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMaximumBits)) + writeInt(h, p) + writeInt(h, g) + writeInt(h, kexDHGexInit.X) + writeInt(h, Y) + + K := make([]byte, intLength(kInt)) + marshalInt(K, kInt) + h.Write(K) + + H := h.Sum(nil) + + // H is already a hash, but the hostkey signing will apply its + // own key-specific hash algorithm. + sig, err := signAndMarshal(priv, randSource, H, algo) + if err != nil { + return nil, err + } + + kexDHGexReply := kexDHGexReplyMsg{ + HostKey: hostKeyBytes, + Y: Y, + Signature: sig, + } + packet = Marshal(&kexDHGexReply) + + err = c.writePacket(packet) + + return &kexResult{ + H: H, + K: K, + HostKey: hostKeyBytes, + Signature: sig, + Hash: gex.hashFunc, + }, err +} diff --git a/tempfork/sshtest/ssh/kex_test.go b/tempfork/sshtest/ssh/kex_test.go new file mode 100644 index 000000000..cb7f66a50 --- /dev/null +++ b/tempfork/sshtest/ssh/kex_test.go @@ -0,0 +1,106 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +// Key exchange tests. + +import ( + "crypto/rand" + "fmt" + "reflect" + "sync" + "testing" +) + +// Runs multiple key exchanges concurrent to detect potential data races with +// kex obtained from the global kexAlgoMap. +// This test needs to be executed using the race detector in order to detect +// race conditions. +func TestKexes(t *testing.T) { + type kexResultErr struct { + result *kexResult + err error + } + + for name, kex := range kexAlgoMap { + t.Run(name, func(t *testing.T) { + wg := sync.WaitGroup{} + for i := 0; i < 3; i++ { + wg.Add(1) + go func() { + defer wg.Done() + a, b := memPipe() + + s := make(chan kexResultErr, 1) + c := make(chan kexResultErr, 1) + var magics handshakeMagics + go func() { + r, e := kex.Client(a, rand.Reader, &magics) + a.Close() + c <- kexResultErr{r, e} + }() + go func() { + r, e := kex.Server(b, rand.Reader, &magics, testSigners["ecdsa"].(AlgorithmSigner), testSigners["ecdsa"].PublicKey().Type()) + b.Close() + s <- kexResultErr{r, e} + }() + + clientRes := <-c + serverRes := <-s + if clientRes.err != nil { + t.Errorf("client: %v", clientRes.err) + } + if serverRes.err != nil { + t.Errorf("server: %v", serverRes.err) + } + if !reflect.DeepEqual(clientRes.result, serverRes.result) { + t.Errorf("kex %q: mismatch %#v, %#v", name, clientRes.result, serverRes.result) + } + }() + } + wg.Wait() + }) + } +} + +func BenchmarkKexes(b *testing.B) { + type kexResultErr struct { + result *kexResult + err error + } + + for name, kex := range kexAlgoMap { + b.Run(name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + t1, t2 := memPipe() + + s := make(chan kexResultErr, 1) + c := make(chan kexResultErr, 1) + var magics handshakeMagics + + go func() { + r, e := kex.Client(t1, rand.Reader, &magics) + t1.Close() + c <- kexResultErr{r, e} + }() + go func() { + r, e := kex.Server(t2, rand.Reader, &magics, testSigners["ecdsa"].(AlgorithmSigner), testSigners["ecdsa"].PublicKey().Type()) + t2.Close() + s <- kexResultErr{r, e} + }() + + clientRes := <-c + serverRes := <-s + + if clientRes.err != nil { + panic(fmt.Sprintf("client: %v", clientRes.err)) + } + if serverRes.err != nil { + panic(fmt.Sprintf("server: %v", serverRes.err)) + } + } + }) + } +} diff --git a/tempfork/sshtest/ssh/keys.go b/tempfork/sshtest/ssh/keys.go new file mode 100644 index 000000000..4a3d769d9 --- /dev/null +++ b/tempfork/sshtest/ssh/keys.go @@ -0,0 +1,1626 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto" + "crypto/dsa" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/md5" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/asn1" + "encoding/base64" + "encoding/binary" + "encoding/hex" + "encoding/pem" + "errors" + "fmt" + "io" + "math/big" + "strings" +) + +// Public key algorithms names. These values can appear in PublicKey.Type, +// ClientConfig.HostKeyAlgorithms, Signature.Format, or as AlgorithmSigner +// arguments. +const ( + KeyAlgoRSA = "ssh-rsa" + KeyAlgoDSA = "ssh-dss" + KeyAlgoECDSA256 = "ecdsa-sha2-nistp256" + KeyAlgoSKECDSA256 = "sk-ecdsa-sha2-nistp256@openssh.com" + KeyAlgoECDSA384 = "ecdsa-sha2-nistp384" + KeyAlgoECDSA521 = "ecdsa-sha2-nistp521" + KeyAlgoED25519 = "ssh-ed25519" + KeyAlgoSKED25519 = "sk-ssh-ed25519@openssh.com" + + // KeyAlgoRSASHA256 and KeyAlgoRSASHA512 are only public key algorithms, not + // public key formats, so they can't appear as a PublicKey.Type. The + // corresponding PublicKey.Type is KeyAlgoRSA. See RFC 8332, Section 2. + KeyAlgoRSASHA256 = "rsa-sha2-256" + KeyAlgoRSASHA512 = "rsa-sha2-512" +) + +const ( + // Deprecated: use KeyAlgoRSA. + SigAlgoRSA = KeyAlgoRSA + // Deprecated: use KeyAlgoRSASHA256. + SigAlgoRSASHA2256 = KeyAlgoRSASHA256 + // Deprecated: use KeyAlgoRSASHA512. + SigAlgoRSASHA2512 = KeyAlgoRSASHA512 +) + +// parsePubKey parses a public key of the given algorithm. +// Use ParsePublicKey for keys with prepended algorithm. +func parsePubKey(in []byte, algo string) (pubKey PublicKey, rest []byte, err error) { + switch algo { + case KeyAlgoRSA: + return parseRSA(in) + case KeyAlgoDSA: + return parseDSA(in) + case KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521: + return parseECDSA(in) + case KeyAlgoSKECDSA256: + return parseSKECDSA(in) + case KeyAlgoED25519: + return parseED25519(in) + case KeyAlgoSKED25519: + return parseSKEd25519(in) + case CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoSKECDSA256v01, CertAlgoED25519v01, CertAlgoSKED25519v01: + cert, err := parseCert(in, certKeyAlgoNames[algo]) + if err != nil { + return nil, nil, err + } + return cert, nil, nil + } + return nil, nil, fmt.Errorf("ssh: unknown key algorithm: %v", algo) +} + +// parseAuthorizedKey parses a public key in OpenSSH authorized_keys format +// (see sshd(8) manual page) once the options and key type fields have been +// removed. +func parseAuthorizedKey(in []byte) (out PublicKey, comment string, err error) { + in = bytes.TrimSpace(in) + + i := bytes.IndexAny(in, " \t") + if i == -1 { + i = len(in) + } + base64Key := in[:i] + + key := make([]byte, base64.StdEncoding.DecodedLen(len(base64Key))) + n, err := base64.StdEncoding.Decode(key, base64Key) + if err != nil { + return nil, "", err + } + key = key[:n] + out, err = ParsePublicKey(key) + if err != nil { + return nil, "", err + } + comment = string(bytes.TrimSpace(in[i:])) + return out, comment, nil +} + +// ParseKnownHosts parses an entry in the format of the known_hosts file. +// +// The known_hosts format is documented in the sshd(8) manual page. This +// function will parse a single entry from in. On successful return, marker +// will contain the optional marker value (i.e. "cert-authority" or "revoked") +// or else be empty, hosts will contain the hosts that this entry matches, +// pubKey will contain the public key and comment will contain any trailing +// comment at the end of the line. See the sshd(8) manual page for the various +// forms that a host string can take. +// +// The unparsed remainder of the input will be returned in rest. This function +// can be called repeatedly to parse multiple entries. +// +// If no entries were found in the input then err will be io.EOF. Otherwise a +// non-nil err value indicates a parse error. +func ParseKnownHosts(in []byte) (marker string, hosts []string, pubKey PublicKey, comment string, rest []byte, err error) { + for len(in) > 0 { + end := bytes.IndexByte(in, '\n') + if end != -1 { + rest = in[end+1:] + in = in[:end] + } else { + rest = nil + } + + end = bytes.IndexByte(in, '\r') + if end != -1 { + in = in[:end] + } + + in = bytes.TrimSpace(in) + if len(in) == 0 || in[0] == '#' { + in = rest + continue + } + + i := bytes.IndexAny(in, " \t") + if i == -1 { + in = rest + continue + } + + // Strip out the beginning of the known_host key. + // This is either an optional marker or a (set of) hostname(s). + keyFields := bytes.Fields(in) + if len(keyFields) < 3 || len(keyFields) > 5 { + return "", nil, nil, "", nil, errors.New("ssh: invalid entry in known_hosts data") + } + + // keyFields[0] is either "@cert-authority", "@revoked" or a comma separated + // list of hosts + marker := "" + if keyFields[0][0] == '@' { + marker = string(keyFields[0][1:]) + keyFields = keyFields[1:] + } + + hosts := string(keyFields[0]) + // keyFields[1] contains the key type (e.g. “ssh-rsa”). + // However, that information is duplicated inside the + // base64-encoded key and so is ignored here. + + key := bytes.Join(keyFields[2:], []byte(" ")) + if pubKey, comment, err = parseAuthorizedKey(key); err != nil { + return "", nil, nil, "", nil, err + } + + return marker, strings.Split(hosts, ","), pubKey, comment, rest, nil + } + + return "", nil, nil, "", nil, io.EOF +} + +// ParseAuthorizedKey parses a public key from an authorized_keys +// file used in OpenSSH according to the sshd(8) manual page. +func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, err error) { + for len(in) > 0 { + end := bytes.IndexByte(in, '\n') + if end != -1 { + rest = in[end+1:] + in = in[:end] + } else { + rest = nil + } + + end = bytes.IndexByte(in, '\r') + if end != -1 { + in = in[:end] + } + + in = bytes.TrimSpace(in) + if len(in) == 0 || in[0] == '#' { + in = rest + continue + } + + i := bytes.IndexAny(in, " \t") + if i == -1 { + in = rest + continue + } + + if out, comment, err = parseAuthorizedKey(in[i:]); err == nil { + return out, comment, options, rest, nil + } + + // No key type recognised. Maybe there's an options field at + // the beginning. + var b byte + inQuote := false + var candidateOptions []string + optionStart := 0 + for i, b = range in { + isEnd := !inQuote && (b == ' ' || b == '\t') + if (b == ',' && !inQuote) || isEnd { + if i-optionStart > 0 { + candidateOptions = append(candidateOptions, string(in[optionStart:i])) + } + optionStart = i + 1 + } + if isEnd { + break + } + if b == '"' && (i == 0 || (i > 0 && in[i-1] != '\\')) { + inQuote = !inQuote + } + } + for i < len(in) && (in[i] == ' ' || in[i] == '\t') { + i++ + } + if i == len(in) { + // Invalid line: unmatched quote + in = rest + continue + } + + in = in[i:] + i = bytes.IndexAny(in, " \t") + if i == -1 { + in = rest + continue + } + + if out, comment, err = parseAuthorizedKey(in[i:]); err == nil { + options = candidateOptions + return out, comment, options, rest, nil + } + + in = rest + continue + } + + return nil, "", nil, nil, errors.New("ssh: no key found") +} + +// ParsePublicKey parses an SSH public key formatted for use in +// the SSH wire protocol according to RFC 4253, section 6.6. +func ParsePublicKey(in []byte) (out PublicKey, err error) { + algo, in, ok := parseString(in) + if !ok { + return nil, errShortRead + } + var rest []byte + out, rest, err = parsePubKey(in, string(algo)) + if len(rest) > 0 { + return nil, errors.New("ssh: trailing junk in public key") + } + + return out, err +} + +// MarshalAuthorizedKey serializes key for inclusion in an OpenSSH +// authorized_keys file. The return value ends with newline. +func MarshalAuthorizedKey(key PublicKey) []byte { + b := &bytes.Buffer{} + b.WriteString(key.Type()) + b.WriteByte(' ') + e := base64.NewEncoder(base64.StdEncoding, b) + e.Write(key.Marshal()) + e.Close() + b.WriteByte('\n') + return b.Bytes() +} + +// MarshalPrivateKey returns a PEM block with the private key serialized in the +// OpenSSH format. +func MarshalPrivateKey(key crypto.PrivateKey, comment string) (*pem.Block, error) { + return marshalOpenSSHPrivateKey(key, comment, unencryptedOpenSSHMarshaler) +} + +// PublicKey represents a public key using an unspecified algorithm. +// +// Some PublicKeys provided by this package also implement CryptoPublicKey. +type PublicKey interface { + // Type returns the key format name, e.g. "ssh-rsa". + Type() string + + // Marshal returns the serialized key data in SSH wire format, with the name + // prefix. To unmarshal the returned data, use the ParsePublicKey function. + Marshal() []byte + + // Verify that sig is a signature on the given data using this key. This + // method will hash the data appropriately first. sig.Format is allowed to + // be any signature algorithm compatible with the key type, the caller + // should check if it has more stringent requirements. + Verify(data []byte, sig *Signature) error +} + +// CryptoPublicKey, if implemented by a PublicKey, +// returns the underlying crypto.PublicKey form of the key. +type CryptoPublicKey interface { + CryptoPublicKey() crypto.PublicKey +} + +// A Signer can create signatures that verify against a public key. +// +// Some Signers provided by this package also implement MultiAlgorithmSigner. +type Signer interface { + // PublicKey returns the associated PublicKey. + PublicKey() PublicKey + + // Sign returns a signature for the given data. This method will hash the + // data appropriately first. The signature algorithm is expected to match + // the key format returned by the PublicKey.Type method (and not to be any + // alternative algorithm supported by the key format). + Sign(rand io.Reader, data []byte) (*Signature, error) +} + +// An AlgorithmSigner is a Signer that also supports specifying an algorithm to +// use for signing. +// +// An AlgorithmSigner can't advertise the algorithms it supports, unless it also +// implements MultiAlgorithmSigner, so it should be prepared to be invoked with +// every algorithm supported by the public key format. +type AlgorithmSigner interface { + Signer + + // SignWithAlgorithm is like Signer.Sign, but allows specifying a desired + // signing algorithm. Callers may pass an empty string for the algorithm in + // which case the AlgorithmSigner will use a default algorithm. This default + // doesn't currently control any behavior in this package. + SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) +} + +// MultiAlgorithmSigner is an AlgorithmSigner that also reports the algorithms +// supported by that signer. +type MultiAlgorithmSigner interface { + AlgorithmSigner + + // Algorithms returns the available algorithms in preference order. The list + // must not be empty, and it must not include certificate types. + Algorithms() []string +} + +// NewSignerWithAlgorithms returns a signer restricted to the specified +// algorithms. The algorithms must be set in preference order. The list must not +// be empty, and it must not include certificate types. An error is returned if +// the specified algorithms are incompatible with the public key type. +func NewSignerWithAlgorithms(signer AlgorithmSigner, algorithms []string) (MultiAlgorithmSigner, error) { + if len(algorithms) == 0 { + return nil, errors.New("ssh: please specify at least one valid signing algorithm") + } + var signerAlgos []string + supportedAlgos := algorithmsForKeyFormat(underlyingAlgo(signer.PublicKey().Type())) + if s, ok := signer.(*multiAlgorithmSigner); ok { + signerAlgos = s.Algorithms() + } else { + signerAlgos = supportedAlgos + } + + for _, algo := range algorithms { + if !contains(supportedAlgos, algo) { + return nil, fmt.Errorf("ssh: algorithm %q is not supported for key type %q", + algo, signer.PublicKey().Type()) + } + if !contains(signerAlgos, algo) { + return nil, fmt.Errorf("ssh: algorithm %q is restricted for the provided signer", algo) + } + } + return &multiAlgorithmSigner{ + AlgorithmSigner: signer, + supportedAlgorithms: algorithms, + }, nil +} + +type multiAlgorithmSigner struct { + AlgorithmSigner + supportedAlgorithms []string +} + +func (s *multiAlgorithmSigner) Algorithms() []string { + return s.supportedAlgorithms +} + +func (s *multiAlgorithmSigner) isAlgorithmSupported(algorithm string) bool { + if algorithm == "" { + algorithm = underlyingAlgo(s.PublicKey().Type()) + } + for _, algo := range s.supportedAlgorithms { + if algorithm == algo { + return true + } + } + return false +} + +func (s *multiAlgorithmSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) { + if !s.isAlgorithmSupported(algorithm) { + return nil, fmt.Errorf("ssh: algorithm %q is not supported: %v", algorithm, s.supportedAlgorithms) + } + return s.AlgorithmSigner.SignWithAlgorithm(rand, data, algorithm) +} + +type rsaPublicKey rsa.PublicKey + +func (r *rsaPublicKey) Type() string { + return "ssh-rsa" +} + +// parseRSA parses an RSA key according to RFC 4253, section 6.6. +func parseRSA(in []byte) (out PublicKey, rest []byte, err error) { + var w struct { + E *big.Int + N *big.Int + Rest []byte `ssh:"rest"` + } + if err := Unmarshal(in, &w); err != nil { + return nil, nil, err + } + + if w.E.BitLen() > 24 { + return nil, nil, errors.New("ssh: exponent too large") + } + e := w.E.Int64() + if e < 3 || e&1 == 0 { + return nil, nil, errors.New("ssh: incorrect exponent") + } + + var key rsa.PublicKey + key.E = int(e) + key.N = w.N + return (*rsaPublicKey)(&key), w.Rest, nil +} + +func (r *rsaPublicKey) Marshal() []byte { + e := new(big.Int).SetInt64(int64(r.E)) + // RSA publickey struct layout should match the struct used by + // parseRSACert in the x/crypto/ssh/agent package. + wirekey := struct { + Name string + E *big.Int + N *big.Int + }{ + KeyAlgoRSA, + e, + r.N, + } + return Marshal(&wirekey) +} + +func (r *rsaPublicKey) Verify(data []byte, sig *Signature) error { + supportedAlgos := algorithmsForKeyFormat(r.Type()) + if !contains(supportedAlgos, sig.Format) { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, r.Type()) + } + hash := hashFuncs[sig.Format] + h := hash.New() + h.Write(data) + digest := h.Sum(nil) + + // Signatures in PKCS1v15 must match the key's modulus in + // length. However with SSH, some signers provide RSA + // signatures which are missing the MSB 0's of the bignum + // represented. With ssh-rsa signatures, this is encouraged by + // the spec (even though e.g. OpenSSH will give the full + // length unconditionally). With rsa-sha2-* signatures, the + // verifier is allowed to support these, even though they are + // out of spec. See RFC 4253 Section 6.6 for ssh-rsa and RFC + // 8332 Section 3 for rsa-sha2-* details. + // + // In practice: + // * OpenSSH always allows "short" signatures: + // https://github.com/openssh/openssh-portable/blob/V_9_8_P1/ssh-rsa.c#L526 + // but always generates padded signatures: + // https://github.com/openssh/openssh-portable/blob/V_9_8_P1/ssh-rsa.c#L439 + // + // * PuTTY versions 0.81 and earlier will generate short + // signatures for all RSA signature variants. Note that + // PuTTY is embedded in other software, such as WinSCP and + // FileZilla. At the time of writing, a patch has been + // applied to PuTTY to generate padded signatures for + // rsa-sha2-*, but not yet released: + // https://git.tartarus.org/?p=simon/putty.git;a=commitdiff;h=a5bcf3d384e1bf15a51a6923c3724cbbee022d8e + // + // * SSH.NET versions 2024.0.0 and earlier will generate short + // signatures for all RSA signature variants, fixed in 2024.1.0: + // https://github.com/sshnet/SSH.NET/releases/tag/2024.1.0 + // + // As a result, we pad these up to the key size by inserting + // leading 0's. + // + // Note that support for short signatures with rsa-sha2-* may + // be removed in the future due to such signatures not being + // allowed by the spec. + blob := sig.Blob + keySize := (*rsa.PublicKey)(r).Size() + if len(blob) < keySize { + padded := make([]byte, keySize) + copy(padded[keySize-len(blob):], blob) + blob = padded + } + return rsa.VerifyPKCS1v15((*rsa.PublicKey)(r), hash, digest, blob) +} + +func (r *rsaPublicKey) CryptoPublicKey() crypto.PublicKey { + return (*rsa.PublicKey)(r) +} + +type dsaPublicKey dsa.PublicKey + +func (k *dsaPublicKey) Type() string { + return "ssh-dss" +} + +func checkDSAParams(param *dsa.Parameters) error { + // SSH specifies FIPS 186-2, which only provided a single size + // (1024 bits) DSA key. FIPS 186-3 allows for larger key + // sizes, which would confuse SSH. + if l := param.P.BitLen(); l != 1024 { + return fmt.Errorf("ssh: unsupported DSA key size %d", l) + } + + return nil +} + +// parseDSA parses an DSA key according to RFC 4253, section 6.6. +func parseDSA(in []byte) (out PublicKey, rest []byte, err error) { + var w struct { + P, Q, G, Y *big.Int + Rest []byte `ssh:"rest"` + } + if err := Unmarshal(in, &w); err != nil { + return nil, nil, err + } + + param := dsa.Parameters{ + P: w.P, + Q: w.Q, + G: w.G, + } + if err := checkDSAParams(¶m); err != nil { + return nil, nil, err + } + + key := &dsaPublicKey{ + Parameters: param, + Y: w.Y, + } + return key, w.Rest, nil +} + +func (k *dsaPublicKey) Marshal() []byte { + // DSA publickey struct layout should match the struct used by + // parseDSACert in the x/crypto/ssh/agent package. + w := struct { + Name string + P, Q, G, Y *big.Int + }{ + k.Type(), + k.P, + k.Q, + k.G, + k.Y, + } + + return Marshal(&w) +} + +func (k *dsaPublicKey) Verify(data []byte, sig *Signature) error { + if sig.Format != k.Type() { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type()) + } + h := hashFuncs[sig.Format].New() + h.Write(data) + digest := h.Sum(nil) + + // Per RFC 4253, section 6.6, + // The value for 'dss_signature_blob' is encoded as a string containing + // r, followed by s (which are 160-bit integers, without lengths or + // padding, unsigned, and in network byte order). + // For DSS purposes, sig.Blob should be exactly 40 bytes in length. + if len(sig.Blob) != 40 { + return errors.New("ssh: DSA signature parse error") + } + r := new(big.Int).SetBytes(sig.Blob[:20]) + s := new(big.Int).SetBytes(sig.Blob[20:]) + if dsa.Verify((*dsa.PublicKey)(k), digest, r, s) { + return nil + } + return errors.New("ssh: signature did not verify") +} + +func (k *dsaPublicKey) CryptoPublicKey() crypto.PublicKey { + return (*dsa.PublicKey)(k) +} + +type dsaPrivateKey struct { + *dsa.PrivateKey +} + +func (k *dsaPrivateKey) PublicKey() PublicKey { + return (*dsaPublicKey)(&k.PrivateKey.PublicKey) +} + +func (k *dsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) { + return k.SignWithAlgorithm(rand, data, k.PublicKey().Type()) +} + +func (k *dsaPrivateKey) Algorithms() []string { + return []string{k.PublicKey().Type()} +} + +func (k *dsaPrivateKey) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) { + if algorithm != "" && algorithm != k.PublicKey().Type() { + return nil, fmt.Errorf("ssh: unsupported signature algorithm %s", algorithm) + } + + h := hashFuncs[k.PublicKey().Type()].New() + h.Write(data) + digest := h.Sum(nil) + r, s, err := dsa.Sign(rand, k.PrivateKey, digest) + if err != nil { + return nil, err + } + + sig := make([]byte, 40) + rb := r.Bytes() + sb := s.Bytes() + + copy(sig[20-len(rb):20], rb) + copy(sig[40-len(sb):], sb) + + return &Signature{ + Format: k.PublicKey().Type(), + Blob: sig, + }, nil +} + +type ecdsaPublicKey ecdsa.PublicKey + +func (k *ecdsaPublicKey) Type() string { + return "ecdsa-sha2-" + k.nistID() +} + +func (k *ecdsaPublicKey) nistID() string { + switch k.Params().BitSize { + case 256: + return "nistp256" + case 384: + return "nistp384" + case 521: + return "nistp521" + } + panic("ssh: unsupported ecdsa key size") +} + +type ed25519PublicKey ed25519.PublicKey + +func (k ed25519PublicKey) Type() string { + return KeyAlgoED25519 +} + +func parseED25519(in []byte) (out PublicKey, rest []byte, err error) { + var w struct { + KeyBytes []byte + Rest []byte `ssh:"rest"` + } + + if err := Unmarshal(in, &w); err != nil { + return nil, nil, err + } + + if l := len(w.KeyBytes); l != ed25519.PublicKeySize { + return nil, nil, fmt.Errorf("invalid size %d for Ed25519 public key", l) + } + + return ed25519PublicKey(w.KeyBytes), w.Rest, nil +} + +func (k ed25519PublicKey) Marshal() []byte { + w := struct { + Name string + KeyBytes []byte + }{ + KeyAlgoED25519, + []byte(k), + } + return Marshal(&w) +} + +func (k ed25519PublicKey) Verify(b []byte, sig *Signature) error { + if sig.Format != k.Type() { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type()) + } + if l := len(k); l != ed25519.PublicKeySize { + return fmt.Errorf("ssh: invalid size %d for Ed25519 public key", l) + } + + if ok := ed25519.Verify(ed25519.PublicKey(k), b, sig.Blob); !ok { + return errors.New("ssh: signature did not verify") + } + + return nil +} + +func (k ed25519PublicKey) CryptoPublicKey() crypto.PublicKey { + return ed25519.PublicKey(k) +} + +func supportedEllipticCurve(curve elliptic.Curve) bool { + return curve == elliptic.P256() || curve == elliptic.P384() || curve == elliptic.P521() +} + +// parseECDSA parses an ECDSA key according to RFC 5656, section 3.1. +func parseECDSA(in []byte) (out PublicKey, rest []byte, err error) { + var w struct { + Curve string + KeyBytes []byte + Rest []byte `ssh:"rest"` + } + + if err := Unmarshal(in, &w); err != nil { + return nil, nil, err + } + + key := new(ecdsa.PublicKey) + + switch w.Curve { + case "nistp256": + key.Curve = elliptic.P256() + case "nistp384": + key.Curve = elliptic.P384() + case "nistp521": + key.Curve = elliptic.P521() + default: + return nil, nil, errors.New("ssh: unsupported curve") + } + + key.X, key.Y = elliptic.Unmarshal(key.Curve, w.KeyBytes) + if key.X == nil || key.Y == nil { + return nil, nil, errors.New("ssh: invalid curve point") + } + return (*ecdsaPublicKey)(key), w.Rest, nil +} + +func (k *ecdsaPublicKey) Marshal() []byte { + // See RFC 5656, section 3.1. + keyBytes := elliptic.Marshal(k.Curve, k.X, k.Y) + // ECDSA publickey struct layout should match the struct used by + // parseECDSACert in the x/crypto/ssh/agent package. + w := struct { + Name string + ID string + Key []byte + }{ + k.Type(), + k.nistID(), + keyBytes, + } + + return Marshal(&w) +} + +func (k *ecdsaPublicKey) Verify(data []byte, sig *Signature) error { + if sig.Format != k.Type() { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type()) + } + + h := hashFuncs[sig.Format].New() + h.Write(data) + digest := h.Sum(nil) + + // Per RFC 5656, section 3.1.2, + // The ecdsa_signature_blob value has the following specific encoding: + // mpint r + // mpint s + var ecSig struct { + R *big.Int + S *big.Int + } + + if err := Unmarshal(sig.Blob, &ecSig); err != nil { + return err + } + + if ecdsa.Verify((*ecdsa.PublicKey)(k), digest, ecSig.R, ecSig.S) { + return nil + } + return errors.New("ssh: signature did not verify") +} + +func (k *ecdsaPublicKey) CryptoPublicKey() crypto.PublicKey { + return (*ecdsa.PublicKey)(k) +} + +// skFields holds the additional fields present in U2F/FIDO2 signatures. +// See openssh/PROTOCOL.u2f 'SSH U2F Signatures' for details. +type skFields struct { + // Flags contains U2F/FIDO2 flags such as 'user present' + Flags byte + // Counter is a monotonic signature counter which can be + // used to detect concurrent use of a private key, should + // it be extracted from hardware. + Counter uint32 +} + +type skECDSAPublicKey struct { + // application is a URL-like string, typically "ssh:" for SSH. + // see openssh/PROTOCOL.u2f for details. + application string + ecdsa.PublicKey +} + +func (k *skECDSAPublicKey) Type() string { + return KeyAlgoSKECDSA256 +} + +func (k *skECDSAPublicKey) nistID() string { + return "nistp256" +} + +func parseSKECDSA(in []byte) (out PublicKey, rest []byte, err error) { + var w struct { + Curve string + KeyBytes []byte + Application string + Rest []byte `ssh:"rest"` + } + + if err := Unmarshal(in, &w); err != nil { + return nil, nil, err + } + + key := new(skECDSAPublicKey) + key.application = w.Application + + if w.Curve != "nistp256" { + return nil, nil, errors.New("ssh: unsupported curve") + } + key.Curve = elliptic.P256() + + key.X, key.Y = elliptic.Unmarshal(key.Curve, w.KeyBytes) + if key.X == nil || key.Y == nil { + return nil, nil, errors.New("ssh: invalid curve point") + } + + return key, w.Rest, nil +} + +func (k *skECDSAPublicKey) Marshal() []byte { + // See RFC 5656, section 3.1. + keyBytes := elliptic.Marshal(k.Curve, k.X, k.Y) + w := struct { + Name string + ID string + Key []byte + Application string + }{ + k.Type(), + k.nistID(), + keyBytes, + k.application, + } + + return Marshal(&w) +} + +func (k *skECDSAPublicKey) Verify(data []byte, sig *Signature) error { + if sig.Format != k.Type() { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type()) + } + + h := hashFuncs[sig.Format].New() + h.Write([]byte(k.application)) + appDigest := h.Sum(nil) + + h.Reset() + h.Write(data) + dataDigest := h.Sum(nil) + + var ecSig struct { + R *big.Int + S *big.Int + } + if err := Unmarshal(sig.Blob, &ecSig); err != nil { + return err + } + + var skf skFields + if err := Unmarshal(sig.Rest, &skf); err != nil { + return err + } + + blob := struct { + ApplicationDigest []byte `ssh:"rest"` + Flags byte + Counter uint32 + MessageDigest []byte `ssh:"rest"` + }{ + appDigest, + skf.Flags, + skf.Counter, + dataDigest, + } + + original := Marshal(blob) + + h.Reset() + h.Write(original) + digest := h.Sum(nil) + + if ecdsa.Verify((*ecdsa.PublicKey)(&k.PublicKey), digest, ecSig.R, ecSig.S) { + return nil + } + return errors.New("ssh: signature did not verify") +} + +func (k *skECDSAPublicKey) CryptoPublicKey() crypto.PublicKey { + return &k.PublicKey +} + +type skEd25519PublicKey struct { + // application is a URL-like string, typically "ssh:" for SSH. + // see openssh/PROTOCOL.u2f for details. + application string + ed25519.PublicKey +} + +func (k *skEd25519PublicKey) Type() string { + return KeyAlgoSKED25519 +} + +func parseSKEd25519(in []byte) (out PublicKey, rest []byte, err error) { + var w struct { + KeyBytes []byte + Application string + Rest []byte `ssh:"rest"` + } + + if err := Unmarshal(in, &w); err != nil { + return nil, nil, err + } + + if l := len(w.KeyBytes); l != ed25519.PublicKeySize { + return nil, nil, fmt.Errorf("invalid size %d for Ed25519 public key", l) + } + + key := new(skEd25519PublicKey) + key.application = w.Application + key.PublicKey = ed25519.PublicKey(w.KeyBytes) + + return key, w.Rest, nil +} + +func (k *skEd25519PublicKey) Marshal() []byte { + w := struct { + Name string + KeyBytes []byte + Application string + }{ + KeyAlgoSKED25519, + []byte(k.PublicKey), + k.application, + } + return Marshal(&w) +} + +func (k *skEd25519PublicKey) Verify(data []byte, sig *Signature) error { + if sig.Format != k.Type() { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type()) + } + if l := len(k.PublicKey); l != ed25519.PublicKeySize { + return fmt.Errorf("invalid size %d for Ed25519 public key", l) + } + + h := hashFuncs[sig.Format].New() + h.Write([]byte(k.application)) + appDigest := h.Sum(nil) + + h.Reset() + h.Write(data) + dataDigest := h.Sum(nil) + + var edSig struct { + Signature []byte `ssh:"rest"` + } + + if err := Unmarshal(sig.Blob, &edSig); err != nil { + return err + } + + var skf skFields + if err := Unmarshal(sig.Rest, &skf); err != nil { + return err + } + + blob := struct { + ApplicationDigest []byte `ssh:"rest"` + Flags byte + Counter uint32 + MessageDigest []byte `ssh:"rest"` + }{ + appDigest, + skf.Flags, + skf.Counter, + dataDigest, + } + + original := Marshal(blob) + + if ok := ed25519.Verify(k.PublicKey, original, edSig.Signature); !ok { + return errors.New("ssh: signature did not verify") + } + + return nil +} + +func (k *skEd25519PublicKey) CryptoPublicKey() crypto.PublicKey { + return k.PublicKey +} + +// NewSignerFromKey takes an *rsa.PrivateKey, *dsa.PrivateKey, +// *ecdsa.PrivateKey or any other crypto.Signer and returns a +// corresponding Signer instance. ECDSA keys must use P-256, P-384 or +// P-521. DSA keys must use parameter size L1024N160. +func NewSignerFromKey(key interface{}) (Signer, error) { + switch key := key.(type) { + case crypto.Signer: + return NewSignerFromSigner(key) + case *dsa.PrivateKey: + return newDSAPrivateKey(key) + default: + return nil, fmt.Errorf("ssh: unsupported key type %T", key) + } +} + +func newDSAPrivateKey(key *dsa.PrivateKey) (Signer, error) { + if err := checkDSAParams(&key.PublicKey.Parameters); err != nil { + return nil, err + } + + return &dsaPrivateKey{key}, nil +} + +type wrappedSigner struct { + signer crypto.Signer + pubKey PublicKey +} + +// NewSignerFromSigner takes any crypto.Signer implementation and +// returns a corresponding Signer interface. This can be used, for +// example, with keys kept in hardware modules. +func NewSignerFromSigner(signer crypto.Signer) (Signer, error) { + pubKey, err := NewPublicKey(signer.Public()) + if err != nil { + return nil, err + } + + return &wrappedSigner{signer, pubKey}, nil +} + +func (s *wrappedSigner) PublicKey() PublicKey { + return s.pubKey +} + +func (s *wrappedSigner) Sign(rand io.Reader, data []byte) (*Signature, error) { + return s.SignWithAlgorithm(rand, data, s.pubKey.Type()) +} + +func (s *wrappedSigner) Algorithms() []string { + return algorithmsForKeyFormat(s.pubKey.Type()) +} + +func (s *wrappedSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) { + if algorithm == "" { + algorithm = s.pubKey.Type() + } + + if !contains(s.Algorithms(), algorithm) { + return nil, fmt.Errorf("ssh: unsupported signature algorithm %q for key format %q", algorithm, s.pubKey.Type()) + } + + hashFunc := hashFuncs[algorithm] + var digest []byte + if hashFunc != 0 { + h := hashFunc.New() + h.Write(data) + digest = h.Sum(nil) + } else { + digest = data + } + + signature, err := s.signer.Sign(rand, digest, hashFunc) + if err != nil { + return nil, err + } + + // crypto.Signer.Sign is expected to return an ASN.1-encoded signature + // for ECDSA and DSA, but that's not the encoding expected by SSH, so + // re-encode. + switch s.pubKey.(type) { + case *ecdsaPublicKey, *dsaPublicKey: + type asn1Signature struct { + R, S *big.Int + } + asn1Sig := new(asn1Signature) + _, err := asn1.Unmarshal(signature, asn1Sig) + if err != nil { + return nil, err + } + + switch s.pubKey.(type) { + case *ecdsaPublicKey: + signature = Marshal(asn1Sig) + + case *dsaPublicKey: + signature = make([]byte, 40) + r := asn1Sig.R.Bytes() + s := asn1Sig.S.Bytes() + copy(signature[20-len(r):20], r) + copy(signature[40-len(s):40], s) + } + } + + return &Signature{ + Format: algorithm, + Blob: signature, + }, nil +} + +// NewPublicKey takes an *rsa.PublicKey, *dsa.PublicKey, *ecdsa.PublicKey, +// or ed25519.PublicKey returns a corresponding PublicKey instance. +// ECDSA keys must use P-256, P-384 or P-521. +func NewPublicKey(key interface{}) (PublicKey, error) { + switch key := key.(type) { + case *rsa.PublicKey: + return (*rsaPublicKey)(key), nil + case *ecdsa.PublicKey: + if !supportedEllipticCurve(key.Curve) { + return nil, errors.New("ssh: only P-256, P-384 and P-521 EC keys are supported") + } + return (*ecdsaPublicKey)(key), nil + case *dsa.PublicKey: + return (*dsaPublicKey)(key), nil + case ed25519.PublicKey: + if l := len(key); l != ed25519.PublicKeySize { + return nil, fmt.Errorf("ssh: invalid size %d for Ed25519 public key", l) + } + return ed25519PublicKey(key), nil + default: + return nil, fmt.Errorf("ssh: unsupported key type %T", key) + } +} + +// ParsePrivateKey returns a Signer from a PEM encoded private key. It supports +// the same keys as ParseRawPrivateKey. If the private key is encrypted, it +// will return a PassphraseMissingError. +func ParsePrivateKey(pemBytes []byte) (Signer, error) { + key, err := ParseRawPrivateKey(pemBytes) + if err != nil { + return nil, err + } + + return NewSignerFromKey(key) +} + +// encryptedBlock tells whether a private key is +// encrypted by examining its Proc-Type header +// for a mention of ENCRYPTED +// according to RFC 1421 Section 4.6.1.1. +func encryptedBlock(block *pem.Block) bool { + return strings.Contains(block.Headers["Proc-Type"], "ENCRYPTED") +} + +// A PassphraseMissingError indicates that parsing this private key requires a +// passphrase. Use ParsePrivateKeyWithPassphrase. +type PassphraseMissingError struct { + // PublicKey will be set if the private key format includes an unencrypted + // public key along with the encrypted private key. + PublicKey PublicKey +} + +func (*PassphraseMissingError) Error() string { + return "ssh: this private key is passphrase protected" +} + +// ParseRawPrivateKey returns a private key from a PEM encoded private key. It supports +// RSA, DSA, ECDSA, and Ed25519 private keys in PKCS#1, PKCS#8, OpenSSL, and OpenSSH +// formats. If the private key is encrypted, it will return a PassphraseMissingError. +func ParseRawPrivateKey(pemBytes []byte) (interface{}, error) { + block, _ := pem.Decode(pemBytes) + if block == nil { + return nil, errors.New("ssh: no key found") + } + + if encryptedBlock(block) { + return nil, &PassphraseMissingError{} + } + + switch block.Type { + case "RSA PRIVATE KEY": + return x509.ParsePKCS1PrivateKey(block.Bytes) + // RFC5208 - https://tools.ietf.org/html/rfc5208 + case "PRIVATE KEY": + return x509.ParsePKCS8PrivateKey(block.Bytes) + case "EC PRIVATE KEY": + return x509.ParseECPrivateKey(block.Bytes) + case "DSA PRIVATE KEY": + return ParseDSAPrivateKey(block.Bytes) + case "OPENSSH PRIVATE KEY": + return parseOpenSSHPrivateKey(block.Bytes, unencryptedOpenSSHKey) + default: + return nil, fmt.Errorf("ssh: unsupported key type %q", block.Type) + } +} + +// ParseDSAPrivateKey returns a DSA private key from its ASN.1 DER encoding, as +// specified by the OpenSSL DSA man page. +func ParseDSAPrivateKey(der []byte) (*dsa.PrivateKey, error) { + var k struct { + Version int + P *big.Int + Q *big.Int + G *big.Int + Pub *big.Int + Priv *big.Int + } + rest, err := asn1.Unmarshal(der, &k) + if err != nil { + return nil, errors.New("ssh: failed to parse DSA key: " + err.Error()) + } + if len(rest) > 0 { + return nil, errors.New("ssh: garbage after DSA key") + } + + return &dsa.PrivateKey{ + PublicKey: dsa.PublicKey{ + Parameters: dsa.Parameters{ + P: k.P, + Q: k.Q, + G: k.G, + }, + Y: k.Pub, + }, + X: k.Priv, + }, nil +} + +func unencryptedOpenSSHKey(cipherName, kdfName, kdfOpts string, privKeyBlock []byte) ([]byte, error) { + if kdfName != "none" || cipherName != "none" { + return nil, &PassphraseMissingError{} + } + if kdfOpts != "" { + return nil, errors.New("ssh: invalid openssh private key") + } + return privKeyBlock, nil +} + +func unencryptedOpenSSHMarshaler(privKeyBlock []byte) ([]byte, string, string, string, error) { + key := generateOpenSSHPadding(privKeyBlock, 8) + return key, "none", "none", "", nil +} + +const privateKeyAuthMagic = "openssh-key-v1\x00" + +type openSSHDecryptFunc func(CipherName, KdfName, KdfOpts string, PrivKeyBlock []byte) ([]byte, error) +type openSSHEncryptFunc func(PrivKeyBlock []byte) (ProtectedKeyBlock []byte, cipherName, kdfName, kdfOptions string, err error) + +type openSSHEncryptedPrivateKey struct { + CipherName string + KdfName string + KdfOpts string + NumKeys uint32 + PubKey []byte + PrivKeyBlock []byte +} + +type openSSHPrivateKey struct { + Check1 uint32 + Check2 uint32 + Keytype string + Rest []byte `ssh:"rest"` +} + +type openSSHRSAPrivateKey struct { + N *big.Int + E *big.Int + D *big.Int + Iqmp *big.Int + P *big.Int + Q *big.Int + Comment string + Pad []byte `ssh:"rest"` +} + +type openSSHEd25519PrivateKey struct { + Pub []byte + Priv []byte + Comment string + Pad []byte `ssh:"rest"` +} + +type openSSHECDSAPrivateKey struct { + Curve string + Pub []byte + D *big.Int + Comment string + Pad []byte `ssh:"rest"` +} + +// parseOpenSSHPrivateKey parses an OpenSSH private key, using the decrypt +// function to unwrap the encrypted portion. unencryptedOpenSSHKey can be used +// as the decrypt function to parse an unencrypted private key. See +// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.key. +func parseOpenSSHPrivateKey(key []byte, decrypt openSSHDecryptFunc) (crypto.PrivateKey, error) { + if len(key) < len(privateKeyAuthMagic) || string(key[:len(privateKeyAuthMagic)]) != privateKeyAuthMagic { + return nil, errors.New("ssh: invalid openssh private key format") + } + remaining := key[len(privateKeyAuthMagic):] + + var w openSSHEncryptedPrivateKey + if err := Unmarshal(remaining, &w); err != nil { + return nil, err + } + if w.NumKeys != 1 { + // We only support single key files, and so does OpenSSH. + // https://github.com/openssh/openssh-portable/blob/4103a3ec7/sshkey.c#L4171 + return nil, errors.New("ssh: multi-key files are not supported") + } + + privKeyBlock, err := decrypt(w.CipherName, w.KdfName, w.KdfOpts, w.PrivKeyBlock) + if err != nil { + if err, ok := err.(*PassphraseMissingError); ok { + pub, errPub := ParsePublicKey(w.PubKey) + if errPub != nil { + return nil, fmt.Errorf("ssh: failed to parse embedded public key: %v", errPub) + } + err.PublicKey = pub + } + return nil, err + } + + var pk1 openSSHPrivateKey + if err := Unmarshal(privKeyBlock, &pk1); err != nil || pk1.Check1 != pk1.Check2 { + if w.CipherName != "none" { + return nil, x509.IncorrectPasswordError + } + return nil, errors.New("ssh: malformed OpenSSH key") + } + + switch pk1.Keytype { + case KeyAlgoRSA: + var key openSSHRSAPrivateKey + if err := Unmarshal(pk1.Rest, &key); err != nil { + return nil, err + } + + if err := checkOpenSSHKeyPadding(key.Pad); err != nil { + return nil, err + } + + pk := &rsa.PrivateKey{ + PublicKey: rsa.PublicKey{ + N: key.N, + E: int(key.E.Int64()), + }, + D: key.D, + Primes: []*big.Int{key.P, key.Q}, + } + + if err := pk.Validate(); err != nil { + return nil, err + } + + pk.Precompute() + + return pk, nil + case KeyAlgoED25519: + var key openSSHEd25519PrivateKey + if err := Unmarshal(pk1.Rest, &key); err != nil { + return nil, err + } + + if len(key.Priv) != ed25519.PrivateKeySize { + return nil, errors.New("ssh: private key unexpected length") + } + + if err := checkOpenSSHKeyPadding(key.Pad); err != nil { + return nil, err + } + + pk := ed25519.PrivateKey(make([]byte, ed25519.PrivateKeySize)) + copy(pk, key.Priv) + return &pk, nil + case KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521: + var key openSSHECDSAPrivateKey + if err := Unmarshal(pk1.Rest, &key); err != nil { + return nil, err + } + + if err := checkOpenSSHKeyPadding(key.Pad); err != nil { + return nil, err + } + + var curve elliptic.Curve + switch key.Curve { + case "nistp256": + curve = elliptic.P256() + case "nistp384": + curve = elliptic.P384() + case "nistp521": + curve = elliptic.P521() + default: + return nil, errors.New("ssh: unhandled elliptic curve: " + key.Curve) + } + + X, Y := elliptic.Unmarshal(curve, key.Pub) + if X == nil || Y == nil { + return nil, errors.New("ssh: failed to unmarshal public key") + } + + if key.D.Cmp(curve.Params().N) >= 0 { + return nil, errors.New("ssh: scalar is out of range") + } + + x, y := curve.ScalarBaseMult(key.D.Bytes()) + if x.Cmp(X) != 0 || y.Cmp(Y) != 0 { + return nil, errors.New("ssh: public key does not match private key") + } + + return &ecdsa.PrivateKey{ + PublicKey: ecdsa.PublicKey{ + Curve: curve, + X: X, + Y: Y, + }, + D: key.D, + }, nil + default: + return nil, errors.New("ssh: unhandled key type") + } +} + +func marshalOpenSSHPrivateKey(key crypto.PrivateKey, comment string, encrypt openSSHEncryptFunc) (*pem.Block, error) { + var w openSSHEncryptedPrivateKey + var pk1 openSSHPrivateKey + + // Random check bytes. + var check uint32 + if err := binary.Read(rand.Reader, binary.BigEndian, &check); err != nil { + return nil, err + } + + pk1.Check1 = check + pk1.Check2 = check + w.NumKeys = 1 + + // Use a []byte directly on ed25519 keys. + if k, ok := key.(*ed25519.PrivateKey); ok { + key = *k + } + + switch k := key.(type) { + case *rsa.PrivateKey: + E := new(big.Int).SetInt64(int64(k.PublicKey.E)) + // Marshal public key: + // E and N are in reversed order in the public and private key. + pubKey := struct { + KeyType string + E *big.Int + N *big.Int + }{ + KeyAlgoRSA, + E, k.PublicKey.N, + } + w.PubKey = Marshal(pubKey) + + // Marshal private key. + key := openSSHRSAPrivateKey{ + N: k.PublicKey.N, + E: E, + D: k.D, + Iqmp: k.Precomputed.Qinv, + P: k.Primes[0], + Q: k.Primes[1], + Comment: comment, + } + pk1.Keytype = KeyAlgoRSA + pk1.Rest = Marshal(key) + case ed25519.PrivateKey: + pub := make([]byte, ed25519.PublicKeySize) + priv := make([]byte, ed25519.PrivateKeySize) + copy(pub, k[32:]) + copy(priv, k) + + // Marshal public key. + pubKey := struct { + KeyType string + Pub []byte + }{ + KeyAlgoED25519, pub, + } + w.PubKey = Marshal(pubKey) + + // Marshal private key. + key := openSSHEd25519PrivateKey{ + Pub: pub, + Priv: priv, + Comment: comment, + } + pk1.Keytype = KeyAlgoED25519 + pk1.Rest = Marshal(key) + case *ecdsa.PrivateKey: + var curve, keyType string + switch name := k.Curve.Params().Name; name { + case "P-256": + curve = "nistp256" + keyType = KeyAlgoECDSA256 + case "P-384": + curve = "nistp384" + keyType = KeyAlgoECDSA384 + case "P-521": + curve = "nistp521" + keyType = KeyAlgoECDSA521 + default: + return nil, errors.New("ssh: unhandled elliptic curve " + name) + } + + pub := elliptic.Marshal(k.Curve, k.PublicKey.X, k.PublicKey.Y) + + // Marshal public key. + pubKey := struct { + KeyType string + Curve string + Pub []byte + }{ + keyType, curve, pub, + } + w.PubKey = Marshal(pubKey) + + // Marshal private key. + key := openSSHECDSAPrivateKey{ + Curve: curve, + Pub: pub, + D: k.D, + Comment: comment, + } + pk1.Keytype = keyType + pk1.Rest = Marshal(key) + default: + return nil, fmt.Errorf("ssh: unsupported key type %T", k) + } + + var err error + // Add padding and encrypt the key if necessary. + w.PrivKeyBlock, w.CipherName, w.KdfName, w.KdfOpts, err = encrypt(Marshal(pk1)) + if err != nil { + return nil, err + } + + b := Marshal(w) + block := &pem.Block{ + Type: "OPENSSH PRIVATE KEY", + Bytes: append([]byte(privateKeyAuthMagic), b...), + } + return block, nil +} + +func checkOpenSSHKeyPadding(pad []byte) error { + for i, b := range pad { + if int(b) != i+1 { + return errors.New("ssh: padding not as expected") + } + } + return nil +} + +func generateOpenSSHPadding(block []byte, blockSize int) []byte { + for i, l := 0, len(block); (l+i)%blockSize != 0; i++ { + block = append(block, byte(i+1)) + } + return block +} + +// FingerprintLegacyMD5 returns the user presentation of the key's +// fingerprint as described by RFC 4716 section 4. +func FingerprintLegacyMD5(pubKey PublicKey) string { + md5sum := md5.Sum(pubKey.Marshal()) + hexarray := make([]string, len(md5sum)) + for i, c := range md5sum { + hexarray[i] = hex.EncodeToString([]byte{c}) + } + return strings.Join(hexarray, ":") +} + +// FingerprintSHA256 returns the user presentation of the key's +// fingerprint as unpadded base64 encoded sha256 hash. +// This format was introduced from OpenSSH 6.8. +// https://www.openssh.com/txt/release-6.8 +// https://tools.ietf.org/html/rfc4648#section-3.2 (unpadded base64 encoding) +func FingerprintSHA256(pubKey PublicKey) string { + sha256sum := sha256.Sum256(pubKey.Marshal()) + hash := base64.RawStdEncoding.EncodeToString(sha256sum[:]) + return "SHA256:" + hash +} diff --git a/tempfork/sshtest/ssh/keys_test.go b/tempfork/sshtest/ssh/keys_test.go new file mode 100644 index 000000000..bf1f0be1b --- /dev/null +++ b/tempfork/sshtest/ssh/keys_test.go @@ -0,0 +1,724 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto/dsa" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/hex" + "encoding/pem" + "fmt" + "io" + "reflect" + "strings" + "testing" + + "golang.org/x/crypto/ssh/testdata" +) + +func rawKey(pub PublicKey) interface{} { + switch k := pub.(type) { + case *rsaPublicKey: + return (*rsa.PublicKey)(k) + case *dsaPublicKey: + return (*dsa.PublicKey)(k) + case *ecdsaPublicKey: + return (*ecdsa.PublicKey)(k) + case ed25519PublicKey: + return (ed25519.PublicKey)(k) + case *Certificate: + return k + } + panic("unknown key type") +} + +func TestKeyMarshalParse(t *testing.T) { + for _, priv := range testSigners { + pub := priv.PublicKey() + roundtrip, err := ParsePublicKey(pub.Marshal()) + if err != nil { + t.Errorf("ParsePublicKey(%T): %v", pub, err) + } + + k1 := rawKey(pub) + k2 := rawKey(roundtrip) + + if !reflect.DeepEqual(k1, k2) { + t.Errorf("got %#v in roundtrip, want %#v", k2, k1) + } + } +} + +func TestUnsupportedCurves(t *testing.T) { + raw, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader) + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } + + if _, err = NewSignerFromKey(raw); err == nil || !strings.Contains(err.Error(), "only P-256") { + t.Fatalf("NewPrivateKey should not succeed with P-224, got: %v", err) + } + + if _, err = NewPublicKey(&raw.PublicKey); err == nil || !strings.Contains(err.Error(), "only P-256") { + t.Fatalf("NewPublicKey should not succeed with P-224, got: %v", err) + } +} + +func TestNewPublicKey(t *testing.T) { + for _, k := range testSigners { + raw := rawKey(k.PublicKey()) + // Skip certificates, as NewPublicKey does not support them. + if _, ok := raw.(*Certificate); ok { + continue + } + pub, err := NewPublicKey(raw) + if err != nil { + t.Errorf("NewPublicKey(%#v): %v", raw, err) + } + if !reflect.DeepEqual(k.PublicKey(), pub) { + t.Errorf("NewPublicKey(%#v) = %#v, want %#v", raw, pub, k.PublicKey()) + } + } +} + +func TestKeySignVerify(t *testing.T) { + for _, priv := range testSigners { + pub := priv.PublicKey() + + data := []byte("sign me") + sig, err := priv.Sign(rand.Reader, data) + if err != nil { + t.Fatalf("Sign(%T): %v", priv, err) + } + + if err := pub.Verify(data, sig); err != nil { + t.Errorf("publicKey.Verify(%T): %v", priv, err) + } + sig.Blob[5]++ + if err := pub.Verify(data, sig); err == nil { + t.Errorf("publicKey.Verify on broken sig did not fail") + } + } +} + +func TestKeySignWithAlgorithmVerify(t *testing.T) { + for k, priv := range testSigners { + if algorithmSigner, ok := priv.(MultiAlgorithmSigner); !ok { + t.Errorf("Signers %q constructed by ssh package should always implement the MultiAlgorithmSigner interface: %T", k, priv) + } else { + pub := priv.PublicKey() + data := []byte("sign me") + + signWithAlgTestCase := func(algorithm string, expectedAlg string) { + sig, err := algorithmSigner.SignWithAlgorithm(rand.Reader, data, algorithm) + if err != nil { + t.Fatalf("Sign(%T): %v", priv, err) + } + if sig.Format != expectedAlg { + t.Errorf("signature format did not match requested signature algorithm: %s != %s", sig.Format, expectedAlg) + } + + if err := pub.Verify(data, sig); err != nil { + t.Errorf("publicKey.Verify(%T): %v", priv, err) + } + sig.Blob[5]++ + if err := pub.Verify(data, sig); err == nil { + t.Errorf("publicKey.Verify on broken sig did not fail") + } + } + + // Using the empty string as the algorithm name should result in the same signature format as the algorithm-free Sign method. + defaultSig, err := priv.Sign(rand.Reader, data) + if err != nil { + t.Fatalf("Sign(%T): %v", priv, err) + } + signWithAlgTestCase("", defaultSig.Format) + + // RSA keys are the only ones which currently support more than one signing algorithm + if pub.Type() == KeyAlgoRSA { + for _, algorithm := range []string{KeyAlgoRSA, KeyAlgoRSASHA256, KeyAlgoRSASHA512} { + signWithAlgTestCase(algorithm, algorithm) + } + } + } + } +} + +func TestKeySignWithShortSignature(t *testing.T) { + signer := testSigners["rsa"].(AlgorithmSigner) + pub := signer.PublicKey() + // Note: data obtained by empirically trying until a result + // starting with 0 appeared + tests := []struct { + algorithm string + data []byte + }{ + { + algorithm: KeyAlgoRSA, + data: []byte("sign me92"), + }, + { + algorithm: KeyAlgoRSASHA256, + data: []byte("sign me294"), + }, + { + algorithm: KeyAlgoRSASHA512, + data: []byte("sign me60"), + }, + } + + for _, tt := range tests { + sig, err := signer.SignWithAlgorithm(rand.Reader, tt.data, tt.algorithm) + if err != nil { + t.Fatalf("Sign(%T): %v", signer, err) + } + if sig.Blob[0] != 0 { + t.Errorf("%s: Expected signature with a leading 0", tt.algorithm) + } + sig.Blob = sig.Blob[1:] + if err := pub.Verify(tt.data, sig); err != nil { + t.Errorf("publicKey.Verify(%s): %v", tt.algorithm, err) + } + } +} + +func TestParseRSAPrivateKey(t *testing.T) { + key := testPrivateKeys["rsa"] + + rsa, ok := key.(*rsa.PrivateKey) + if !ok { + t.Fatalf("got %T, want *rsa.PrivateKey", rsa) + } + + if err := rsa.Validate(); err != nil { + t.Errorf("Validate: %v", err) + } +} + +func TestParseECPrivateKey(t *testing.T) { + key := testPrivateKeys["ecdsa"] + + ecKey, ok := key.(*ecdsa.PrivateKey) + if !ok { + t.Fatalf("got %T, want *ecdsa.PrivateKey", ecKey) + } + + if !validateECPublicKey(ecKey.Curve, ecKey.X, ecKey.Y) { + t.Fatalf("public key does not validate.") + } +} + +func TestParseDSA(t *testing.T) { + // We actually exercise the ParsePrivateKey codepath here, as opposed to + // using the ParseRawPrivateKey+NewSignerFromKey path that testdata_test.go + // uses. + s, err := ParsePrivateKey(testdata.PEMBytes["dsa"]) + if err != nil { + t.Fatalf("ParsePrivateKey returned error: %s", err) + } + + data := []byte("sign me") + sig, err := s.Sign(rand.Reader, data) + if err != nil { + t.Fatalf("dsa.Sign: %v", err) + } + + if err := s.PublicKey().Verify(data, sig); err != nil { + t.Errorf("Verify failed: %v", err) + } +} + +// Tests for authorized_keys parsing. + +// getTestKey returns a public key, and its base64 encoding. +func getTestKey() (PublicKey, string) { + k := testPublicKeys["rsa"] + + b := &bytes.Buffer{} + e := base64.NewEncoder(base64.StdEncoding, b) + e.Write(k.Marshal()) + e.Close() + + return k, b.String() +} + +func TestMarshalParsePublicKey(t *testing.T) { + pub, pubSerialized := getTestKey() + line := fmt.Sprintf("%s %s user@host", pub.Type(), pubSerialized) + + authKeys := MarshalAuthorizedKey(pub) + actualFields := strings.Fields(string(authKeys)) + if len(actualFields) == 0 { + t.Fatalf("failed authKeys: %v", authKeys) + } + + // drop the comment + expectedFields := strings.Fields(line)[0:2] + + if !reflect.DeepEqual(actualFields, expectedFields) { + t.Errorf("got %v, expected %v", actualFields, expectedFields) + } + + actPub, _, _, _, err := ParseAuthorizedKey([]byte(line)) + if err != nil { + t.Fatalf("cannot parse %v: %v", line, err) + } + if !reflect.DeepEqual(actPub, pub) { + t.Errorf("got %v, expected %v", actPub, pub) + } +} + +func TestMarshalPrivateKey(t *testing.T) { + tests := []struct { + name string + }{ + {"rsa-openssh-format"}, + {"ed25519"}, + {"p256-openssh-format"}, + {"p384-openssh-format"}, + {"p521-openssh-format"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expected, ok := testPrivateKeys[tt.name] + if !ok { + t.Fatalf("cannot find key %s", tt.name) + } + + block, err := MarshalPrivateKey(expected, "test@golang.org") + if err != nil { + t.Fatalf("cannot marshal %s: %v", tt.name, err) + } + + key, err := ParseRawPrivateKey(pem.EncodeToMemory(block)) + if err != nil { + t.Fatalf("cannot parse %s: %v", tt.name, err) + } + + if !reflect.DeepEqual(expected, key) { + t.Errorf("unexpected marshaled key %s", tt.name) + } + }) + } +} + +type testAuthResult struct { + pubKey PublicKey + options []string + comments string + rest string + ok bool +} + +func testAuthorizedKeys(t *testing.T, authKeys []byte, expected []testAuthResult) { + rest := authKeys + var values []testAuthResult + for len(rest) > 0 { + var r testAuthResult + var err error + r.pubKey, r.comments, r.options, rest, err = ParseAuthorizedKey(rest) + r.ok = (err == nil) + t.Log(err) + r.rest = string(rest) + values = append(values, r) + } + + if !reflect.DeepEqual(values, expected) { + t.Errorf("got %#v, expected %#v", values, expected) + } +} + +func TestAuthorizedKeyBasic(t *testing.T) { + pub, pubSerialized := getTestKey() + line := "ssh-rsa " + pubSerialized + " user@host" + testAuthorizedKeys(t, []byte(line), + []testAuthResult{ + {pub, nil, "user@host", "", true}, + }) +} + +func TestAuth(t *testing.T) { + pub, pubSerialized := getTestKey() + authWithOptions := []string{ + `# comments to ignore before any keys...`, + ``, + `env="HOME=/home/root",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`, + `# comments to ignore, along with a blank line`, + ``, + `env="HOME=/home/root2" ssh-rsa ` + pubSerialized + ` user2@host2`, + ``, + `# more comments, plus a invalid entry`, + `ssh-rsa data-that-will-not-parse user@host3`, + } + for _, eol := range []string{"\n", "\r\n"} { + authOptions := strings.Join(authWithOptions, eol) + rest2 := strings.Join(authWithOptions[3:], eol) + rest3 := strings.Join(authWithOptions[6:], eol) + testAuthorizedKeys(t, []byte(authOptions), []testAuthResult{ + {pub, []string{`env="HOME=/home/root"`, "no-port-forwarding"}, "user@host", rest2, true}, + {pub, []string{`env="HOME=/home/root2"`}, "user2@host2", rest3, true}, + {nil, nil, "", "", false}, + }) + } +} + +func TestAuthWithQuotedSpaceInEnv(t *testing.T) { + pub, pubSerialized := getTestKey() + authWithQuotedSpaceInEnv := []byte(`env="HOME=/home/root dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`) + testAuthorizedKeys(t, []byte(authWithQuotedSpaceInEnv), []testAuthResult{ + {pub, []string{`env="HOME=/home/root dir"`, "no-port-forwarding"}, "user@host", "", true}, + }) +} + +func TestAuthWithQuotedCommaInEnv(t *testing.T) { + pub, pubSerialized := getTestKey() + authWithQuotedCommaInEnv := []byte(`env="HOME=/home/root,dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`) + testAuthorizedKeys(t, []byte(authWithQuotedCommaInEnv), []testAuthResult{ + {pub, []string{`env="HOME=/home/root,dir"`, "no-port-forwarding"}, "user@host", "", true}, + }) +} + +func TestAuthWithQuotedQuoteInEnv(t *testing.T) { + pub, pubSerialized := getTestKey() + authWithQuotedQuoteInEnv := []byte(`env="HOME=/home/\"root dir",no-port-forwarding` + "\t" + `ssh-rsa` + "\t" + pubSerialized + ` user@host`) + authWithDoubleQuotedQuote := []byte(`no-port-forwarding,env="HOME=/home/ \"root dir\"" ssh-rsa ` + pubSerialized + "\t" + `user@host`) + testAuthorizedKeys(t, []byte(authWithQuotedQuoteInEnv), []testAuthResult{ + {pub, []string{`env="HOME=/home/\"root dir"`, "no-port-forwarding"}, "user@host", "", true}, + }) + + testAuthorizedKeys(t, []byte(authWithDoubleQuotedQuote), []testAuthResult{ + {pub, []string{"no-port-forwarding", `env="HOME=/home/ \"root dir\""`}, "user@host", "", true}, + }) +} + +func TestAuthWithInvalidSpace(t *testing.T) { + _, pubSerialized := getTestKey() + authWithInvalidSpace := []byte(`env="HOME=/home/root dir", no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host +#more to follow but still no valid keys`) + testAuthorizedKeys(t, []byte(authWithInvalidSpace), []testAuthResult{ + {nil, nil, "", "", false}, + }) +} + +func TestAuthWithMissingQuote(t *testing.T) { + pub, pubSerialized := getTestKey() + authWithMissingQuote := []byte(`env="HOME=/home/root,no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host +env="HOME=/home/root",shared-control ssh-rsa ` + pubSerialized + ` user@host`) + + testAuthorizedKeys(t, []byte(authWithMissingQuote), []testAuthResult{ + {pub, []string{`env="HOME=/home/root"`, `shared-control`}, "user@host", "", true}, + }) +} + +func TestInvalidEntry(t *testing.T) { + authInvalid := []byte(`ssh-rsa`) + _, _, _, _, err := ParseAuthorizedKey(authInvalid) + if err == nil { + t.Errorf("got valid entry for %q", authInvalid) + } +} + +var knownHostsParseTests = []struct { + input string + err string + + marker string + comment string + hosts []string + rest string +}{ + { + "", + "EOF", + + "", "", nil, "", + }, + { + "# Just a comment", + "EOF", + + "", "", nil, "", + }, + { + " \t ", + "EOF", + + "", "", nil, "", + }, + { + "localhost ssh-rsa {RSAPUB}", + "", + + "", "", []string{"localhost"}, "", + }, + { + "localhost\tssh-rsa {RSAPUB}", + "", + + "", "", []string{"localhost"}, "", + }, + { + "localhost\tssh-rsa {RSAPUB}\tcomment comment", + "", + + "", "comment comment", []string{"localhost"}, "", + }, + { + "localhost\tssh-rsa {RSAPUB}\tcomment comment\n", + "", + + "", "comment comment", []string{"localhost"}, "", + }, + { + "localhost\tssh-rsa {RSAPUB}\tcomment comment\r\n", + "", + + "", "comment comment", []string{"localhost"}, "", + }, + { + "localhost\tssh-rsa {RSAPUB}\tcomment comment\r\nnext line", + "", + + "", "comment comment", []string{"localhost"}, "next line", + }, + { + "localhost,[host2:123]\tssh-rsa {RSAPUB}\tcomment comment", + "", + + "", "comment comment", []string{"localhost", "[host2:123]"}, "", + }, + { + "@marker \tlocalhost,[host2:123]\tssh-rsa {RSAPUB}", + "", + + "marker", "", []string{"localhost", "[host2:123]"}, "", + }, + { + "@marker \tlocalhost,[host2:123]\tssh-rsa aabbccdd", + "short read", + + "", "", nil, "", + }, +} + +func TestKnownHostsParsing(t *testing.T) { + rsaPub, rsaPubSerialized := getTestKey() + + for i, test := range knownHostsParseTests { + var expectedKey PublicKey + const rsaKeyToken = "{RSAPUB}" + + input := test.input + if strings.Contains(input, rsaKeyToken) { + expectedKey = rsaPub + input = strings.Replace(test.input, rsaKeyToken, rsaPubSerialized, -1) + } + + marker, hosts, pubKey, comment, rest, err := ParseKnownHosts([]byte(input)) + if err != nil { + if len(test.err) == 0 { + t.Errorf("#%d: unexpectedly failed with %q", i, err) + } else if !strings.Contains(err.Error(), test.err) { + t.Errorf("#%d: expected error containing %q, but got %q", i, test.err, err) + } + continue + } else if len(test.err) != 0 { + t.Errorf("#%d: succeeded but expected error including %q", i, test.err) + continue + } + + if !reflect.DeepEqual(expectedKey, pubKey) { + t.Errorf("#%d: expected key %#v, but got %#v", i, expectedKey, pubKey) + } + + if marker != test.marker { + t.Errorf("#%d: expected marker %q, but got %q", i, test.marker, marker) + } + + if comment != test.comment { + t.Errorf("#%d: expected comment %q, but got %q", i, test.comment, comment) + } + + if !reflect.DeepEqual(test.hosts, hosts) { + t.Errorf("#%d: expected hosts %#v, but got %#v", i, test.hosts, hosts) + } + + if rest := string(rest); rest != test.rest { + t.Errorf("#%d: expected remaining input to be %q, but got %q", i, test.rest, rest) + } + } +} + +func TestFingerprintLegacyMD5(t *testing.T) { + pub, _ := getTestKey() + fingerprint := FingerprintLegacyMD5(pub) + want := "b7:ef:d3:d5:89:29:52:96:9f:df:47:41:4d:15:37:f4" // ssh-keygen -lf -E md5 rsa + if fingerprint != want { + t.Errorf("got fingerprint %q want %q", fingerprint, want) + } +} + +func TestFingerprintSHA256(t *testing.T) { + pub, _ := getTestKey() + fingerprint := FingerprintSHA256(pub) + want := "SHA256:fi5+D7UmDZDE9Q2sAVvvlpcQSIakN4DERdINgXd2AnE" // ssh-keygen -lf rsa + if fingerprint != want { + t.Errorf("got fingerprint %q want %q", fingerprint, want) + } +} + +func TestInvalidKeys(t *testing.T) { + keyTypes := []string{ + "RSA PRIVATE KEY", + "PRIVATE KEY", + "EC PRIVATE KEY", + "DSA PRIVATE KEY", + "OPENSSH PRIVATE KEY", + } + + for _, keyType := range keyTypes { + for _, dataLen := range []int{0, 1, 2, 5, 10, 20} { + data := make([]byte, dataLen) + if _, err := io.ReadFull(rand.Reader, data); err != nil { + t.Fatal(err) + } + + var buf bytes.Buffer + pem.Encode(&buf, &pem.Block{ + Type: keyType, + Bytes: data, + }) + + // This test is just to ensure that the function + // doesn't panic so the return value is ignored. + ParseRawPrivateKey(buf.Bytes()) + } + } +} + +func TestSKKeys(t *testing.T) { + for _, d := range testdata.SKData { + pk, _, _, _, err := ParseAuthorizedKey(d.PubKey) + if err != nil { + t.Fatalf("parseAuthorizedKey returned error: %v", err) + } + + sigBuf := make([]byte, hex.DecodedLen(len(d.HexSignature))) + if _, err := hex.Decode(sigBuf, d.HexSignature); err != nil { + t.Fatalf("hex.Decode() failed: %v", err) + } + + dataBuf := make([]byte, hex.DecodedLen(len(d.HexData))) + if _, err := hex.Decode(dataBuf, d.HexData); err != nil { + t.Fatalf("hex.Decode() failed: %v", err) + } + + sig, _, ok := parseSignature(sigBuf) + if !ok { + t.Fatalf("parseSignature(%v) failed", sigBuf) + } + + // Test that good data and signature pass verification + if err := pk.Verify(dataBuf, sig); err != nil { + t.Errorf("%s: PublicKey.Verify(%v, %v) failed: %v", d.Name, dataBuf, sig, err) + } + + // Invalid data being passed in + invalidData := []byte("INVALID DATA") + if err := pk.Verify(invalidData, sig); err == nil { + t.Errorf("%s with invalid data: PublicKey.Verify(%v, %v) passed unexpectedly", d.Name, invalidData, sig) + } + + // Change byte in blob to corrup signature + sig.Blob[5] = byte('A') + // Corrupted data being passed in + if err := pk.Verify(dataBuf, sig); err == nil { + t.Errorf("%s with corrupted signature: PublicKey.Verify(%v, %v) passed unexpectedly", d.Name, dataBuf, sig) + } + } +} + +func TestNewSignerWithAlgos(t *testing.T) { + algorithSigner, ok := testSigners["rsa"].(AlgorithmSigner) + if !ok { + t.Fatal("rsa test signer does not implement the AlgorithmSigner interface") + } + _, err := NewSignerWithAlgorithms(algorithSigner, nil) + if err == nil { + t.Error("signer with algos created with no algorithms") + } + + _, err = NewSignerWithAlgorithms(algorithSigner, []string{KeyAlgoED25519}) + if err == nil { + t.Error("signer with algos created with invalid algorithms") + } + + _, err = NewSignerWithAlgorithms(algorithSigner, []string{CertAlgoRSASHA256v01}) + if err == nil { + t.Error("signer with algos created with certificate algorithms") + } + + mas, err := NewSignerWithAlgorithms(algorithSigner, []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512}) + if err != nil { + t.Errorf("unable to create signer with valid algorithms: %v", err) + } + + _, err = NewSignerWithAlgorithms(mas, []string{KeyAlgoRSA}) + if err == nil { + t.Error("signer with algos created with restricted algorithms") + } +} + +func TestCryptoPublicKey(t *testing.T) { + for _, priv := range testSigners { + p1 := priv.PublicKey() + key, ok := p1.(CryptoPublicKey) + if !ok { + continue + } + p2, err := NewPublicKey(key.CryptoPublicKey()) + if err != nil { + t.Fatalf("NewPublicKey(CryptoPublicKey) failed for %s, got: %v", p1.Type(), err) + } + if !reflect.DeepEqual(p1, p2) { + t.Errorf("got %#v in NewPublicKey, want %#v", p2, p1) + } + } + for _, d := range testdata.SKData { + p1, _, _, _, err := ParseAuthorizedKey(d.PubKey) + if err != nil { + t.Fatalf("parseAuthorizedKey returned error: %v", err) + } + k1, ok := p1.(CryptoPublicKey) + if !ok { + t.Fatalf("%T does not implement CryptoPublicKey", p1) + } + + var p2 PublicKey + switch pub := k1.CryptoPublicKey().(type) { + case *ecdsa.PublicKey: + p2 = &skECDSAPublicKey{ + application: "ssh:", + PublicKey: *pub, + } + case ed25519.PublicKey: + p2 = &skEd25519PublicKey{ + application: "ssh:", + PublicKey: pub, + } + default: + t.Fatalf("unexpected type %T from CryptoPublicKey()", pub) + } + if !reflect.DeepEqual(p1, p2) { + t.Errorf("got %#v, want %#v", p2, p1) + } + } +} diff --git a/tempfork/sshtest/ssh/mac.go b/tempfork/sshtest/ssh/mac.go new file mode 100644 index 000000000..06a1b2750 --- /dev/null +++ b/tempfork/sshtest/ssh/mac.go @@ -0,0 +1,68 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +// Message authentication support + +import ( + "crypto/hmac" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "hash" +) + +type macMode struct { + keySize int + etm bool + new func(key []byte) hash.Hash +} + +// truncatingMAC wraps around a hash.Hash and truncates the output digest to +// a given size. +type truncatingMAC struct { + length int + hmac hash.Hash +} + +func (t truncatingMAC) Write(data []byte) (int, error) { + return t.hmac.Write(data) +} + +func (t truncatingMAC) Sum(in []byte) []byte { + out := t.hmac.Sum(in) + return out[:len(in)+t.length] +} + +func (t truncatingMAC) Reset() { + t.hmac.Reset() +} + +func (t truncatingMAC) Size() int { + return t.length +} + +func (t truncatingMAC) BlockSize() int { return t.hmac.BlockSize() } + +var macModes = map[string]*macMode{ + "hmac-sha2-512-etm@openssh.com": {64, true, func(key []byte) hash.Hash { + return hmac.New(sha512.New, key) + }}, + "hmac-sha2-256-etm@openssh.com": {32, true, func(key []byte) hash.Hash { + return hmac.New(sha256.New, key) + }}, + "hmac-sha2-512": {64, false, func(key []byte) hash.Hash { + return hmac.New(sha512.New, key) + }}, + "hmac-sha2-256": {32, false, func(key []byte) hash.Hash { + return hmac.New(sha256.New, key) + }}, + "hmac-sha1": {20, false, func(key []byte) hash.Hash { + return hmac.New(sha1.New, key) + }}, + "hmac-sha1-96": {20, false, func(key []byte) hash.Hash { + return truncatingMAC{12, hmac.New(sha1.New, key)} + }}, +} diff --git a/tempfork/sshtest/ssh/mempipe_test.go b/tempfork/sshtest/ssh/mempipe_test.go new file mode 100644 index 000000000..f27339c51 --- /dev/null +++ b/tempfork/sshtest/ssh/mempipe_test.go @@ -0,0 +1,124 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "io" + "sync" + "testing" +) + +// An in-memory packetConn. It is safe to call Close and writePacket +// from different goroutines. +type memTransport struct { + eof bool + pending [][]byte + write *memTransport + writeCount uint64 + sync.Mutex + *sync.Cond +} + +func (t *memTransport) readPacket() ([]byte, error) { + t.Lock() + defer t.Unlock() + for { + if len(t.pending) > 0 { + r := t.pending[0] + t.pending = t.pending[1:] + return r, nil + } + if t.eof { + return nil, io.EOF + } + t.Cond.Wait() + } +} + +func (t *memTransport) closeSelf() error { + t.Lock() + defer t.Unlock() + if t.eof { + return io.EOF + } + t.eof = true + t.Cond.Broadcast() + return nil +} + +func (t *memTransport) Close() error { + err := t.write.closeSelf() + t.closeSelf() + return err +} + +func (t *memTransport) writePacket(p []byte) error { + t.write.Lock() + defer t.write.Unlock() + if t.write.eof { + return io.EOF + } + c := make([]byte, len(p)) + copy(c, p) + t.write.pending = append(t.write.pending, c) + t.write.Cond.Signal() + t.writeCount++ + return nil +} + +func (t *memTransport) getWriteCount() uint64 { + t.write.Lock() + defer t.write.Unlock() + return t.writeCount +} + +func memPipe() (a, b packetConn) { + t1 := memTransport{} + t2 := memTransport{} + t1.write = &t2 + t2.write = &t1 + t1.Cond = sync.NewCond(&t1.Mutex) + t2.Cond = sync.NewCond(&t2.Mutex) + return &t1, &t2 +} + +func TestMemPipe(t *testing.T) { + a, b := memPipe() + if err := a.writePacket([]byte{42}); err != nil { + t.Fatalf("writePacket: %v", err) + } + if wc := a.(*memTransport).getWriteCount(); wc != 1 { + t.Fatalf("got %v, want 1", wc) + } + if err := a.Close(); err != nil { + t.Fatal("Close: ", err) + } + p, err := b.readPacket() + if err != nil { + t.Fatal("readPacket: ", err) + } + if len(p) != 1 || p[0] != 42 { + t.Fatalf("got %v, want {42}", p) + } + p, err = b.readPacket() + if err != io.EOF { + t.Fatalf("got %v, %v, want EOF", p, err) + } + if wc := b.(*memTransport).getWriteCount(); wc != 0 { + t.Fatalf("got %v, want 0", wc) + } +} + +func TestDoubleClose(t *testing.T) { + a, _ := memPipe() + err := a.Close() + if err != nil { + t.Errorf("Close: %v", err) + } + err = a.Close() + if err != io.EOF { + t.Errorf("expect EOF on double close.") + } +} diff --git a/tempfork/sshtest/ssh/messages.go b/tempfork/sshtest/ssh/messages.go new file mode 100644 index 000000000..b55f86056 --- /dev/null +++ b/tempfork/sshtest/ssh/messages.go @@ -0,0 +1,891 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "math/big" + "reflect" + "strconv" + "strings" +) + +// These are SSH message type numbers. They are scattered around several +// documents but many were taken from [SSH-PARAMETERS]. +const ( + msgIgnore = 2 + msgUnimplemented = 3 + msgDebug = 4 + msgNewKeys = 21 +) + +// SSH messages: +// +// These structures mirror the wire format of the corresponding SSH messages. +// They are marshaled using reflection with the marshal and unmarshal functions +// in this file. The only wrinkle is that a final member of type []byte with a +// ssh tag of "rest" receives the remainder of a packet when unmarshaling. + +// See RFC 4253, section 11.1. +const msgDisconnect = 1 + +// disconnectMsg is the message that signals a disconnect. It is also +// the error type returned from mux.Wait() +type disconnectMsg struct { + Reason uint32 `sshtype:"1"` + Message string + Language string +} + +func (d *disconnectMsg) Error() string { + return fmt.Sprintf("ssh: disconnect, reason %d: %s", d.Reason, d.Message) +} + +// See RFC 4253, section 7.1. +const msgKexInit = 20 + +type kexInitMsg struct { + Cookie [16]byte `sshtype:"20"` + KexAlgos []string + ServerHostKeyAlgos []string + CiphersClientServer []string + CiphersServerClient []string + MACsClientServer []string + MACsServerClient []string + CompressionClientServer []string + CompressionServerClient []string + LanguagesClientServer []string + LanguagesServerClient []string + FirstKexFollows bool + Reserved uint32 +} + +// See RFC 4253, section 8. + +// Diffie-Hellman +const msgKexDHInit = 30 + +type kexDHInitMsg struct { + X *big.Int `sshtype:"30"` +} + +const msgKexECDHInit = 30 + +type kexECDHInitMsg struct { + ClientPubKey []byte `sshtype:"30"` +} + +const msgKexECDHReply = 31 + +type kexECDHReplyMsg struct { + HostKey []byte `sshtype:"31"` + EphemeralPubKey []byte + Signature []byte +} + +const msgKexDHReply = 31 + +type kexDHReplyMsg struct { + HostKey []byte `sshtype:"31"` + Y *big.Int + Signature []byte +} + +// See RFC 4419, section 5. +const msgKexDHGexGroup = 31 + +type kexDHGexGroupMsg struct { + P *big.Int `sshtype:"31"` + G *big.Int +} + +const msgKexDHGexInit = 32 + +type kexDHGexInitMsg struct { + X *big.Int `sshtype:"32"` +} + +const msgKexDHGexReply = 33 + +type kexDHGexReplyMsg struct { + HostKey []byte `sshtype:"33"` + Y *big.Int + Signature []byte +} + +const msgKexDHGexRequest = 34 + +type kexDHGexRequestMsg struct { + MinBits uint32 `sshtype:"34"` + PreferedBits uint32 + MaxBits uint32 +} + +// See RFC 4253, section 10. +const msgServiceRequest = 5 + +type serviceRequestMsg struct { + Service string `sshtype:"5"` +} + +// See RFC 4253, section 10. +const msgServiceAccept = 6 + +type serviceAcceptMsg struct { + Service string `sshtype:"6"` +} + +// See RFC 8308, section 2.3 +const msgExtInfo = 7 + +type extInfoMsg struct { + NumExtensions uint32 `sshtype:"7"` + Payload []byte `ssh:"rest"` +} + +// See RFC 4252, section 5. +const msgUserAuthRequest = 50 + +type userAuthRequestMsg struct { + User string `sshtype:"50"` + Service string + Method string + Payload []byte `ssh:"rest"` +} + +// Used for debug printouts of packets. +type userAuthSuccessMsg struct { +} + +// See RFC 4252, section 5.1 +const msgUserAuthFailure = 51 + +type userAuthFailureMsg struct { + Methods []string `sshtype:"51"` + PartialSuccess bool +} + +// See RFC 4252, section 5.1 +const msgUserAuthSuccess = 52 + +// See RFC 4252, section 5.4 +const msgUserAuthBanner = 53 + +type userAuthBannerMsg struct { + Message string `sshtype:"53"` + // unused, but required to allow message parsing + Language string +} + +// See RFC 4256, section 3.2 +const msgUserAuthInfoRequest = 60 +const msgUserAuthInfoResponse = 61 + +type userAuthInfoRequestMsg struct { + Name string `sshtype:"60"` + Instruction string + Language string + NumPrompts uint32 + Prompts []byte `ssh:"rest"` +} + +// See RFC 4254, section 5.1. +const msgChannelOpen = 90 + +type channelOpenMsg struct { + ChanType string `sshtype:"90"` + PeersID uint32 + PeersWindow uint32 + MaxPacketSize uint32 + TypeSpecificData []byte `ssh:"rest"` +} + +const msgChannelExtendedData = 95 +const msgChannelData = 94 + +// Used for debug print outs of packets. +type channelDataMsg struct { + PeersID uint32 `sshtype:"94"` + Length uint32 + Rest []byte `ssh:"rest"` +} + +// See RFC 4254, section 5.1. +const msgChannelOpenConfirm = 91 + +type channelOpenConfirmMsg struct { + PeersID uint32 `sshtype:"91"` + MyID uint32 + MyWindow uint32 + MaxPacketSize uint32 + TypeSpecificData []byte `ssh:"rest"` +} + +// See RFC 4254, section 5.1. +const msgChannelOpenFailure = 92 + +type channelOpenFailureMsg struct { + PeersID uint32 `sshtype:"92"` + Reason RejectionReason + Message string + Language string +} + +const msgChannelRequest = 98 + +type channelRequestMsg struct { + PeersID uint32 `sshtype:"98"` + Request string + WantReply bool + RequestSpecificData []byte `ssh:"rest"` +} + +// See RFC 4254, section 5.4. +const msgChannelSuccess = 99 + +type channelRequestSuccessMsg struct { + PeersID uint32 `sshtype:"99"` +} + +// See RFC 4254, section 5.4. +const msgChannelFailure = 100 + +type channelRequestFailureMsg struct { + PeersID uint32 `sshtype:"100"` +} + +// See RFC 4254, section 5.3 +const msgChannelClose = 97 + +type channelCloseMsg struct { + PeersID uint32 `sshtype:"97"` +} + +// See RFC 4254, section 5.3 +const msgChannelEOF = 96 + +type channelEOFMsg struct { + PeersID uint32 `sshtype:"96"` +} + +// See RFC 4254, section 4 +const msgGlobalRequest = 80 + +type globalRequestMsg struct { + Type string `sshtype:"80"` + WantReply bool + Data []byte `ssh:"rest"` +} + +// See RFC 4254, section 4 +const msgRequestSuccess = 81 + +type globalRequestSuccessMsg struct { + Data []byte `ssh:"rest" sshtype:"81"` +} + +// See RFC 4254, section 4 +const msgRequestFailure = 82 + +type globalRequestFailureMsg struct { + Data []byte `ssh:"rest" sshtype:"82"` +} + +// See RFC 4254, section 5.2 +const msgChannelWindowAdjust = 93 + +type windowAdjustMsg struct { + PeersID uint32 `sshtype:"93"` + AdditionalBytes uint32 +} + +// See RFC 4252, section 7 +const msgUserAuthPubKeyOk = 60 + +type userAuthPubKeyOkMsg struct { + Algo string `sshtype:"60"` + PubKey []byte +} + +// See RFC 4462, section 3 +const msgUserAuthGSSAPIResponse = 60 + +type userAuthGSSAPIResponse struct { + SupportMech []byte `sshtype:"60"` +} + +const msgUserAuthGSSAPIToken = 61 + +type userAuthGSSAPIToken struct { + Token []byte `sshtype:"61"` +} + +const msgUserAuthGSSAPIMIC = 66 + +type userAuthGSSAPIMIC struct { + MIC []byte `sshtype:"66"` +} + +// See RFC 4462, section 3.9 +const msgUserAuthGSSAPIErrTok = 64 + +type userAuthGSSAPIErrTok struct { + ErrorToken []byte `sshtype:"64"` +} + +// See RFC 4462, section 3.8 +const msgUserAuthGSSAPIError = 65 + +type userAuthGSSAPIError struct { + MajorStatus uint32 `sshtype:"65"` + MinorStatus uint32 + Message string + LanguageTag string +} + +// Transport layer OpenSSH extension. See [PROTOCOL], section 1.9 +const msgPing = 192 + +type pingMsg struct { + Data string `sshtype:"192"` +} + +// Transport layer OpenSSH extension. See [PROTOCOL], section 1.9 +const msgPong = 193 + +type pongMsg struct { + Data string `sshtype:"193"` +} + +// typeTags returns the possible type bytes for the given reflect.Type, which +// should be a struct. The possible values are separated by a '|' character. +func typeTags(structType reflect.Type) (tags []byte) { + tagStr := structType.Field(0).Tag.Get("sshtype") + + for _, tag := range strings.Split(tagStr, "|") { + i, err := strconv.Atoi(tag) + if err == nil { + tags = append(tags, byte(i)) + } + } + + return tags +} + +func fieldError(t reflect.Type, field int, problem string) error { + if problem != "" { + problem = ": " + problem + } + return fmt.Errorf("ssh: unmarshal error for field %s of type %s%s", t.Field(field).Name, t.Name(), problem) +} + +var errShortRead = errors.New("ssh: short read") + +// Unmarshal parses data in SSH wire format into a structure. The out +// argument should be a pointer to struct. If the first member of the +// struct has the "sshtype" tag set to a '|'-separated set of numbers +// in decimal, the packet must start with one of those numbers. In +// case of error, Unmarshal returns a ParseError or +// UnexpectedMessageError. +func Unmarshal(data []byte, out interface{}) error { + v := reflect.ValueOf(out).Elem() + structType := v.Type() + expectedTypes := typeTags(structType) + + var expectedType byte + if len(expectedTypes) > 0 { + expectedType = expectedTypes[0] + } + + if len(data) == 0 { + return parseError(expectedType) + } + + if len(expectedTypes) > 0 { + goodType := false + for _, e := range expectedTypes { + if e > 0 && data[0] == e { + goodType = true + break + } + } + if !goodType { + return fmt.Errorf("ssh: unexpected message type %d (expected one of %v)", data[0], expectedTypes) + } + data = data[1:] + } + + var ok bool + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + t := field.Type() + switch t.Kind() { + case reflect.Bool: + if len(data) < 1 { + return errShortRead + } + field.SetBool(data[0] != 0) + data = data[1:] + case reflect.Array: + if t.Elem().Kind() != reflect.Uint8 { + return fieldError(structType, i, "array of unsupported type") + } + if len(data) < t.Len() { + return errShortRead + } + for j, n := 0, t.Len(); j < n; j++ { + field.Index(j).Set(reflect.ValueOf(data[j])) + } + data = data[t.Len():] + case reflect.Uint64: + var u64 uint64 + if u64, data, ok = parseUint64(data); !ok { + return errShortRead + } + field.SetUint(u64) + case reflect.Uint32: + var u32 uint32 + if u32, data, ok = parseUint32(data); !ok { + return errShortRead + } + field.SetUint(uint64(u32)) + case reflect.Uint8: + if len(data) < 1 { + return errShortRead + } + field.SetUint(uint64(data[0])) + data = data[1:] + case reflect.String: + var s []byte + if s, data, ok = parseString(data); !ok { + return fieldError(structType, i, "") + } + field.SetString(string(s)) + case reflect.Slice: + switch t.Elem().Kind() { + case reflect.Uint8: + if structType.Field(i).Tag.Get("ssh") == "rest" { + field.Set(reflect.ValueOf(data)) + data = nil + } else { + var s []byte + if s, data, ok = parseString(data); !ok { + return errShortRead + } + field.Set(reflect.ValueOf(s)) + } + case reflect.String: + var nl []string + if nl, data, ok = parseNameList(data); !ok { + return errShortRead + } + field.Set(reflect.ValueOf(nl)) + default: + return fieldError(structType, i, "slice of unsupported type") + } + case reflect.Ptr: + if t == bigIntType { + var n *big.Int + if n, data, ok = parseInt(data); !ok { + return errShortRead + } + field.Set(reflect.ValueOf(n)) + } else { + return fieldError(structType, i, "pointer to unsupported type") + } + default: + return fieldError(structType, i, fmt.Sprintf("unsupported type: %v", t)) + } + } + + if len(data) != 0 { + return parseError(expectedType) + } + + return nil +} + +// Marshal serializes the message in msg to SSH wire format. The msg +// argument should be a struct or pointer to struct. If the first +// member has the "sshtype" tag set to a number in decimal, that +// number is prepended to the result. If the last of member has the +// "ssh" tag set to "rest", its contents are appended to the output. +func Marshal(msg interface{}) []byte { + out := make([]byte, 0, 64) + return marshalStruct(out, msg) +} + +func marshalStruct(out []byte, msg interface{}) []byte { + v := reflect.Indirect(reflect.ValueOf(msg)) + msgTypes := typeTags(v.Type()) + if len(msgTypes) > 0 { + out = append(out, msgTypes[0]) + } + + for i, n := 0, v.NumField(); i < n; i++ { + field := v.Field(i) + switch t := field.Type(); t.Kind() { + case reflect.Bool: + var v uint8 + if field.Bool() { + v = 1 + } + out = append(out, v) + case reflect.Array: + if t.Elem().Kind() != reflect.Uint8 { + panic(fmt.Sprintf("array of non-uint8 in field %d: %T", i, field.Interface())) + } + for j, l := 0, t.Len(); j < l; j++ { + out = append(out, uint8(field.Index(j).Uint())) + } + case reflect.Uint32: + out = appendU32(out, uint32(field.Uint())) + case reflect.Uint64: + out = appendU64(out, uint64(field.Uint())) + case reflect.Uint8: + out = append(out, uint8(field.Uint())) + case reflect.String: + s := field.String() + out = appendInt(out, len(s)) + out = append(out, s...) + case reflect.Slice: + switch t.Elem().Kind() { + case reflect.Uint8: + if v.Type().Field(i).Tag.Get("ssh") != "rest" { + out = appendInt(out, field.Len()) + } + out = append(out, field.Bytes()...) + case reflect.String: + offset := len(out) + out = appendU32(out, 0) + if n := field.Len(); n > 0 { + for j := 0; j < n; j++ { + f := field.Index(j) + if j != 0 { + out = append(out, ',') + } + out = append(out, f.String()...) + } + // overwrite length value + binary.BigEndian.PutUint32(out[offset:], uint32(len(out)-offset-4)) + } + default: + panic(fmt.Sprintf("slice of unknown type in field %d: %T", i, field.Interface())) + } + case reflect.Ptr: + if t == bigIntType { + var n *big.Int + nValue := reflect.ValueOf(&n) + nValue.Elem().Set(field) + needed := intLength(n) + oldLength := len(out) + + if cap(out)-len(out) < needed { + newOut := make([]byte, len(out), 2*(len(out)+needed)) + copy(newOut, out) + out = newOut + } + out = out[:oldLength+needed] + marshalInt(out[oldLength:], n) + } else { + panic(fmt.Sprintf("pointer to unknown type in field %d: %T", i, field.Interface())) + } + } + } + + return out +} + +var bigOne = big.NewInt(1) + +func parseString(in []byte) (out, rest []byte, ok bool) { + if len(in) < 4 { + return + } + length := binary.BigEndian.Uint32(in) + in = in[4:] + if uint32(len(in)) < length { + return + } + out = in[:length] + rest = in[length:] + ok = true + return +} + +var ( + comma = []byte{','} + emptyNameList = []string{} +) + +func parseNameList(in []byte) (out []string, rest []byte, ok bool) { + contents, rest, ok := parseString(in) + if !ok { + return + } + if len(contents) == 0 { + out = emptyNameList + return + } + parts := bytes.Split(contents, comma) + out = make([]string, len(parts)) + for i, part := range parts { + out[i] = string(part) + } + return +} + +func parseInt(in []byte) (out *big.Int, rest []byte, ok bool) { + contents, rest, ok := parseString(in) + if !ok { + return + } + out = new(big.Int) + + if len(contents) > 0 && contents[0]&0x80 == 0x80 { + // This is a negative number + notBytes := make([]byte, len(contents)) + for i := range notBytes { + notBytes[i] = ^contents[i] + } + out.SetBytes(notBytes) + out.Add(out, bigOne) + out.Neg(out) + } else { + // Positive number + out.SetBytes(contents) + } + ok = true + return +} + +func parseUint32(in []byte) (uint32, []byte, bool) { + if len(in) < 4 { + return 0, nil, false + } + return binary.BigEndian.Uint32(in), in[4:], true +} + +func parseUint64(in []byte) (uint64, []byte, bool) { + if len(in) < 8 { + return 0, nil, false + } + return binary.BigEndian.Uint64(in), in[8:], true +} + +func intLength(n *big.Int) int { + length := 4 /* length bytes */ + if n.Sign() < 0 { + nMinus1 := new(big.Int).Neg(n) + nMinus1.Sub(nMinus1, bigOne) + bitLen := nMinus1.BitLen() + if bitLen%8 == 0 { + // The number will need 0xff padding + length++ + } + length += (bitLen + 7) / 8 + } else if n.Sign() == 0 { + // A zero is the zero length string + } else { + bitLen := n.BitLen() + if bitLen%8 == 0 { + // The number will need 0x00 padding + length++ + } + length += (bitLen + 7) / 8 + } + + return length +} + +func marshalUint32(to []byte, n uint32) []byte { + binary.BigEndian.PutUint32(to, n) + return to[4:] +} + +func marshalUint64(to []byte, n uint64) []byte { + binary.BigEndian.PutUint64(to, n) + return to[8:] +} + +func marshalInt(to []byte, n *big.Int) []byte { + lengthBytes := to + to = to[4:] + length := 0 + + if n.Sign() < 0 { + // A negative number has to be converted to two's-complement + // form. So we'll subtract 1 and invert. If the + // most-significant-bit isn't set then we'll need to pad the + // beginning with 0xff in order to keep the number negative. + nMinus1 := new(big.Int).Neg(n) + nMinus1.Sub(nMinus1, bigOne) + bytes := nMinus1.Bytes() + for i := range bytes { + bytes[i] ^= 0xff + } + if len(bytes) == 0 || bytes[0]&0x80 == 0 { + to[0] = 0xff + to = to[1:] + length++ + } + nBytes := copy(to, bytes) + to = to[nBytes:] + length += nBytes + } else if n.Sign() == 0 { + // A zero is the zero length string + } else { + bytes := n.Bytes() + if len(bytes) > 0 && bytes[0]&0x80 != 0 { + // We'll have to pad this with a 0x00 in order to + // stop it looking like a negative number. + to[0] = 0 + to = to[1:] + length++ + } + nBytes := copy(to, bytes) + to = to[nBytes:] + length += nBytes + } + + lengthBytes[0] = byte(length >> 24) + lengthBytes[1] = byte(length >> 16) + lengthBytes[2] = byte(length >> 8) + lengthBytes[3] = byte(length) + return to +} + +func writeInt(w io.Writer, n *big.Int) { + length := intLength(n) + buf := make([]byte, length) + marshalInt(buf, n) + w.Write(buf) +} + +func writeString(w io.Writer, s []byte) { + var lengthBytes [4]byte + lengthBytes[0] = byte(len(s) >> 24) + lengthBytes[1] = byte(len(s) >> 16) + lengthBytes[2] = byte(len(s) >> 8) + lengthBytes[3] = byte(len(s)) + w.Write(lengthBytes[:]) + w.Write(s) +} + +func stringLength(n int) int { + return 4 + n +} + +func marshalString(to []byte, s []byte) []byte { + to[0] = byte(len(s) >> 24) + to[1] = byte(len(s) >> 16) + to[2] = byte(len(s) >> 8) + to[3] = byte(len(s)) + to = to[4:] + copy(to, s) + return to[len(s):] +} + +var bigIntType = reflect.TypeOf((*big.Int)(nil)) + +// Decode a packet into its corresponding message. +func decode(packet []byte) (interface{}, error) { + var msg interface{} + switch packet[0] { + case msgDisconnect: + msg = new(disconnectMsg) + case msgServiceRequest: + msg = new(serviceRequestMsg) + case msgServiceAccept: + msg = new(serviceAcceptMsg) + case msgExtInfo: + msg = new(extInfoMsg) + case msgKexInit: + msg = new(kexInitMsg) + case msgKexDHInit: + msg = new(kexDHInitMsg) + case msgKexDHReply: + msg = new(kexDHReplyMsg) + case msgUserAuthRequest: + msg = new(userAuthRequestMsg) + case msgUserAuthSuccess: + return new(userAuthSuccessMsg), nil + case msgUserAuthFailure: + msg = new(userAuthFailureMsg) + case msgUserAuthPubKeyOk: + msg = new(userAuthPubKeyOkMsg) + case msgGlobalRequest: + msg = new(globalRequestMsg) + case msgRequestSuccess: + msg = new(globalRequestSuccessMsg) + case msgRequestFailure: + msg = new(globalRequestFailureMsg) + case msgChannelOpen: + msg = new(channelOpenMsg) + case msgChannelData: + msg = new(channelDataMsg) + case msgChannelOpenConfirm: + msg = new(channelOpenConfirmMsg) + case msgChannelOpenFailure: + msg = new(channelOpenFailureMsg) + case msgChannelWindowAdjust: + msg = new(windowAdjustMsg) + case msgChannelEOF: + msg = new(channelEOFMsg) + case msgChannelClose: + msg = new(channelCloseMsg) + case msgChannelRequest: + msg = new(channelRequestMsg) + case msgChannelSuccess: + msg = new(channelRequestSuccessMsg) + case msgChannelFailure: + msg = new(channelRequestFailureMsg) + case msgUserAuthGSSAPIToken: + msg = new(userAuthGSSAPIToken) + case msgUserAuthGSSAPIMIC: + msg = new(userAuthGSSAPIMIC) + case msgUserAuthGSSAPIErrTok: + msg = new(userAuthGSSAPIErrTok) + case msgUserAuthGSSAPIError: + msg = new(userAuthGSSAPIError) + default: + return nil, unexpectedMessageError(0, packet[0]) + } + if err := Unmarshal(packet, msg); err != nil { + return nil, err + } + return msg, nil +} + +var packetTypeNames = map[byte]string{ + msgDisconnect: "disconnectMsg", + msgServiceRequest: "serviceRequestMsg", + msgServiceAccept: "serviceAcceptMsg", + msgExtInfo: "extInfoMsg", + msgKexInit: "kexInitMsg", + msgKexDHInit: "kexDHInitMsg", + msgKexDHReply: "kexDHReplyMsg", + msgUserAuthRequest: "userAuthRequestMsg", + msgUserAuthSuccess: "userAuthSuccessMsg", + msgUserAuthFailure: "userAuthFailureMsg", + msgUserAuthPubKeyOk: "userAuthPubKeyOkMsg", + msgGlobalRequest: "globalRequestMsg", + msgRequestSuccess: "globalRequestSuccessMsg", + msgRequestFailure: "globalRequestFailureMsg", + msgChannelOpen: "channelOpenMsg", + msgChannelData: "channelDataMsg", + msgChannelOpenConfirm: "channelOpenConfirmMsg", + msgChannelOpenFailure: "channelOpenFailureMsg", + msgChannelWindowAdjust: "windowAdjustMsg", + msgChannelEOF: "channelEOFMsg", + msgChannelClose: "channelCloseMsg", + msgChannelRequest: "channelRequestMsg", + msgChannelSuccess: "channelRequestSuccessMsg", + msgChannelFailure: "channelRequestFailureMsg", +} diff --git a/tempfork/sshtest/ssh/messages_test.go b/tempfork/sshtest/ssh/messages_test.go new file mode 100644 index 000000000..e79076412 --- /dev/null +++ b/tempfork/sshtest/ssh/messages_test.go @@ -0,0 +1,288 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "math/big" + "math/rand" + "reflect" + "testing" + "testing/quick" +) + +var intLengthTests = []struct { + val, length int +}{ + {0, 4 + 0}, + {1, 4 + 1}, + {127, 4 + 1}, + {128, 4 + 2}, + {-1, 4 + 1}, +} + +func TestIntLength(t *testing.T) { + for _, test := range intLengthTests { + v := new(big.Int).SetInt64(int64(test.val)) + length := intLength(v) + if length != test.length { + t.Errorf("For %d, got length %d but expected %d", test.val, length, test.length) + } + } +} + +type msgAllTypes struct { + Bool bool `sshtype:"21"` + Array [16]byte + Uint64 uint64 + Uint32 uint32 + Uint8 uint8 + String string + Strings []string + Bytes []byte + Int *big.Int + Rest []byte `ssh:"rest"` +} + +func (t *msgAllTypes) Generate(rand *rand.Rand, size int) reflect.Value { + m := &msgAllTypes{} + m.Bool = rand.Intn(2) == 1 + randomBytes(m.Array[:], rand) + m.Uint64 = uint64(rand.Int63n(1<<63 - 1)) + m.Uint32 = uint32(rand.Intn((1 << 31) - 1)) + m.Uint8 = uint8(rand.Intn(1 << 8)) + m.String = string(m.Array[:]) + m.Strings = randomNameList(rand) + m.Bytes = m.Array[:] + m.Int = randomInt(rand) + m.Rest = m.Array[:] + return reflect.ValueOf(m) +} + +func TestMarshalUnmarshal(t *testing.T) { + rand := rand.New(rand.NewSource(0)) + iface := &msgAllTypes{} + ty := reflect.ValueOf(iface).Type() + + n := 100 + if testing.Short() { + n = 5 + } + for j := 0; j < n; j++ { + v, ok := quick.Value(ty, rand) + if !ok { + t.Errorf("failed to create value") + break + } + + m1 := v.Elem().Interface() + m2 := iface + + marshaled := Marshal(m1) + if err := Unmarshal(marshaled, m2); err != nil { + t.Errorf("Unmarshal %#v: %s", m1, err) + break + } + + if !reflect.DeepEqual(v.Interface(), m2) { + t.Errorf("got: %#v\nwant:%#v\n%x", m2, m1, marshaled) + break + } + } +} + +func TestUnmarshalEmptyPacket(t *testing.T) { + var b []byte + var m channelRequestSuccessMsg + if err := Unmarshal(b, &m); err == nil { + t.Fatalf("unmarshal of empty slice succeeded") + } +} + +func TestUnmarshalUnexpectedPacket(t *testing.T) { + type S struct { + I uint32 `sshtype:"43"` + S string + B bool + } + + s := S{11, "hello", true} + packet := Marshal(s) + packet[0] = 42 + roundtrip := S{} + err := Unmarshal(packet, &roundtrip) + if err == nil { + t.Fatal("expected error, not nil") + } +} + +func TestMarshalPtr(t *testing.T) { + s := struct { + S string + }{"hello"} + + m1 := Marshal(s) + m2 := Marshal(&s) + if !bytes.Equal(m1, m2) { + t.Errorf("got %q, want %q for marshaled pointer", m2, m1) + } +} + +func TestBareMarshalUnmarshal(t *testing.T) { + type S struct { + I uint32 + S string + B bool + } + + s := S{42, "hello", true} + packet := Marshal(s) + roundtrip := S{} + Unmarshal(packet, &roundtrip) + + if !reflect.DeepEqual(s, roundtrip) { + t.Errorf("got %#v, want %#v", roundtrip, s) + } +} + +func TestBareMarshal(t *testing.T) { + type S2 struct { + I uint32 + } + s := S2{42} + packet := Marshal(s) + i, rest, ok := parseUint32(packet) + if len(rest) > 0 || !ok { + t.Errorf("parseInt(%q): parse error", packet) + } + if i != s.I { + t.Errorf("got %d, want %d", i, s.I) + } +} + +func TestUnmarshalShortKexInitPacket(t *testing.T) { + // This used to panic. + // Issue 11348 + packet := []byte{0x14, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0xff, 0xff, 0xff, 0xff} + kim := &kexInitMsg{} + if err := Unmarshal(packet, kim); err == nil { + t.Error("truncated packet unmarshaled without error") + } +} + +func TestMarshalMultiTag(t *testing.T) { + var res struct { + A uint32 `sshtype:"1|2"` + } + + good1 := struct { + A uint32 `sshtype:"1"` + }{ + 1, + } + good2 := struct { + A uint32 `sshtype:"2"` + }{ + 1, + } + + if e := Unmarshal(Marshal(good1), &res); e != nil { + t.Errorf("error unmarshaling multipart tag: %v", e) + } + + if e := Unmarshal(Marshal(good2), &res); e != nil { + t.Errorf("error unmarshaling multipart tag: %v", e) + } + + bad1 := struct { + A uint32 `sshtype:"3"` + }{ + 1, + } + if e := Unmarshal(Marshal(bad1), &res); e == nil { + t.Errorf("bad struct unmarshaled without error") + } +} + +func randomBytes(out []byte, rand *rand.Rand) { + for i := 0; i < len(out); i++ { + out[i] = byte(rand.Int31()) + } +} + +func randomNameList(rand *rand.Rand) []string { + ret := make([]string, rand.Int31()&15) + for i := range ret { + s := make([]byte, 1+(rand.Int31()&15)) + for j := range s { + s[j] = 'a' + uint8(rand.Int31()&15) + } + ret[i] = string(s) + } + return ret +} + +func randomInt(rand *rand.Rand) *big.Int { + return new(big.Int).SetInt64(int64(int32(rand.Uint32()))) +} + +func (*kexInitMsg) Generate(rand *rand.Rand, size int) reflect.Value { + ki := &kexInitMsg{} + randomBytes(ki.Cookie[:], rand) + ki.KexAlgos = randomNameList(rand) + ki.ServerHostKeyAlgos = randomNameList(rand) + ki.CiphersClientServer = randomNameList(rand) + ki.CiphersServerClient = randomNameList(rand) + ki.MACsClientServer = randomNameList(rand) + ki.MACsServerClient = randomNameList(rand) + ki.CompressionClientServer = randomNameList(rand) + ki.CompressionServerClient = randomNameList(rand) + ki.LanguagesClientServer = randomNameList(rand) + ki.LanguagesServerClient = randomNameList(rand) + if rand.Int31()&1 == 1 { + ki.FirstKexFollows = true + } + return reflect.ValueOf(ki) +} + +func (*kexDHInitMsg) Generate(rand *rand.Rand, size int) reflect.Value { + dhi := &kexDHInitMsg{} + dhi.X = randomInt(rand) + return reflect.ValueOf(dhi) +} + +var ( + _kexInitMsg = new(kexInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface() + _kexDHInitMsg = new(kexDHInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface() + + _kexInit = Marshal(_kexInitMsg) + _kexDHInit = Marshal(_kexDHInitMsg) +) + +func BenchmarkMarshalKexInitMsg(b *testing.B) { + for i := 0; i < b.N; i++ { + Marshal(_kexInitMsg) + } +} + +func BenchmarkUnmarshalKexInitMsg(b *testing.B) { + m := new(kexInitMsg) + for i := 0; i < b.N; i++ { + Unmarshal(_kexInit, m) + } +} + +func BenchmarkMarshalKexDHInitMsg(b *testing.B) { + for i := 0; i < b.N; i++ { + Marshal(_kexDHInitMsg) + } +} + +func BenchmarkUnmarshalKexDHInitMsg(b *testing.B) { + m := new(kexDHInitMsg) + for i := 0; i < b.N; i++ { + Unmarshal(_kexDHInit, m) + } +} diff --git a/tempfork/sshtest/ssh/mux.go b/tempfork/sshtest/ssh/mux.go new file mode 100644 index 000000000..d2d24c635 --- /dev/null +++ b/tempfork/sshtest/ssh/mux.go @@ -0,0 +1,357 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "encoding/binary" + "fmt" + "io" + "log" + "sync" + "sync/atomic" +) + +// debugMux, if set, causes messages in the connection protocol to be +// logged. +const debugMux = false + +// chanList is a thread safe channel list. +type chanList struct { + // protects concurrent access to chans + sync.Mutex + + // chans are indexed by the local id of the channel, which the + // other side should send in the PeersId field. + chans []*channel + + // This is a debugging aid: it offsets all IDs by this + // amount. This helps distinguish otherwise identical + // server/client muxes + offset uint32 +} + +// Assigns a channel ID to the given channel. +func (c *chanList) add(ch *channel) uint32 { + c.Lock() + defer c.Unlock() + for i := range c.chans { + if c.chans[i] == nil { + c.chans[i] = ch + return uint32(i) + c.offset + } + } + c.chans = append(c.chans, ch) + return uint32(len(c.chans)-1) + c.offset +} + +// getChan returns the channel for the given ID. +func (c *chanList) getChan(id uint32) *channel { + id -= c.offset + + c.Lock() + defer c.Unlock() + if id < uint32(len(c.chans)) { + return c.chans[id] + } + return nil +} + +func (c *chanList) remove(id uint32) { + id -= c.offset + c.Lock() + if id < uint32(len(c.chans)) { + c.chans[id] = nil + } + c.Unlock() +} + +// dropAll forgets all channels it knows, returning them in a slice. +func (c *chanList) dropAll() []*channel { + c.Lock() + defer c.Unlock() + var r []*channel + + for _, ch := range c.chans { + if ch == nil { + continue + } + r = append(r, ch) + } + c.chans = nil + return r +} + +// mux represents the state for the SSH connection protocol, which +// multiplexes many channels onto a single packet transport. +type mux struct { + conn packetConn + chanList chanList + + incomingChannels chan NewChannel + + globalSentMu sync.Mutex + globalResponses chan interface{} + incomingRequests chan *Request + + errCond *sync.Cond + err error +} + +// When debugging, each new chanList instantiation has a different +// offset. +var globalOff uint32 + +func (m *mux) Wait() error { + m.errCond.L.Lock() + defer m.errCond.L.Unlock() + for m.err == nil { + m.errCond.Wait() + } + return m.err +} + +// newMux returns a mux that runs over the given connection. +func newMux(p packetConn) *mux { + m := &mux{ + conn: p, + incomingChannels: make(chan NewChannel, chanSize), + globalResponses: make(chan interface{}, 1), + incomingRequests: make(chan *Request, chanSize), + errCond: newCond(), + } + if debugMux { + m.chanList.offset = atomic.AddUint32(&globalOff, 1) + } + + go m.loop() + return m +} + +func (m *mux) sendMessage(msg interface{}) error { + p := Marshal(msg) + if debugMux { + log.Printf("send global(%d): %#v", m.chanList.offset, msg) + } + return m.conn.writePacket(p) +} + +func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) { + if wantReply { + m.globalSentMu.Lock() + defer m.globalSentMu.Unlock() + } + + if err := m.sendMessage(globalRequestMsg{ + Type: name, + WantReply: wantReply, + Data: payload, + }); err != nil { + return false, nil, err + } + + if !wantReply { + return false, nil, nil + } + + msg, ok := <-m.globalResponses + if !ok { + return false, nil, io.EOF + } + switch msg := msg.(type) { + case *globalRequestFailureMsg: + return false, msg.Data, nil + case *globalRequestSuccessMsg: + return true, msg.Data, nil + default: + return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg) + } +} + +// ackRequest must be called after processing a global request that +// has WantReply set. +func (m *mux) ackRequest(ok bool, data []byte) error { + if ok { + return m.sendMessage(globalRequestSuccessMsg{Data: data}) + } + return m.sendMessage(globalRequestFailureMsg{Data: data}) +} + +func (m *mux) Close() error { + return m.conn.Close() +} + +// loop runs the connection machine. It will process packets until an +// error is encountered. To synchronize on loop exit, use mux.Wait. +func (m *mux) loop() { + var err error + for err == nil { + err = m.onePacket() + } + + for _, ch := range m.chanList.dropAll() { + ch.close() + } + + close(m.incomingChannels) + close(m.incomingRequests) + close(m.globalResponses) + + m.conn.Close() + + m.errCond.L.Lock() + m.err = err + m.errCond.Broadcast() + m.errCond.L.Unlock() + + if debugMux { + log.Println("loop exit", err) + } +} + +// onePacket reads and processes one packet. +func (m *mux) onePacket() error { + packet, err := m.conn.readPacket() + if err != nil { + return err + } + + if debugMux { + if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData { + log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet)) + } else { + p, _ := decode(packet) + log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet)) + } + } + + switch packet[0] { + case msgChannelOpen: + return m.handleChannelOpen(packet) + case msgGlobalRequest, msgRequestSuccess, msgRequestFailure: + return m.handleGlobalPacket(packet) + case msgPing: + var msg pingMsg + if err := Unmarshal(packet, &msg); err != nil { + return fmt.Errorf("failed to unmarshal ping@openssh.com message: %w", err) + } + return m.sendMessage(pongMsg(msg)) + } + + // assume a channel packet. + if len(packet) < 5 { + return parseError(packet[0]) + } + id := binary.BigEndian.Uint32(packet[1:]) + ch := m.chanList.getChan(id) + if ch == nil { + return m.handleUnknownChannelPacket(id, packet) + } + + return ch.handlePacket(packet) +} + +func (m *mux) handleGlobalPacket(packet []byte) error { + msg, err := decode(packet) + if err != nil { + return err + } + + switch msg := msg.(type) { + case *globalRequestMsg: + m.incomingRequests <- &Request{ + Type: msg.Type, + WantReply: msg.WantReply, + Payload: msg.Data, + mux: m, + } + case *globalRequestSuccessMsg, *globalRequestFailureMsg: + m.globalResponses <- msg + default: + panic(fmt.Sprintf("not a global message %#v", msg)) + } + + return nil +} + +// handleChannelOpen schedules a channel to be Accept()ed. +func (m *mux) handleChannelOpen(packet []byte) error { + var msg channelOpenMsg + if err := Unmarshal(packet, &msg); err != nil { + return err + } + + if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { + failMsg := channelOpenFailureMsg{ + PeersID: msg.PeersID, + Reason: ConnectionFailed, + Message: "invalid request", + Language: "en_US.UTF-8", + } + return m.sendMessage(failMsg) + } + + c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData) + c.remoteId = msg.PeersID + c.maxRemotePayload = msg.MaxPacketSize + c.remoteWin.add(msg.PeersWindow) + m.incomingChannels <- c + return nil +} + +func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) { + ch, err := m.openChannel(chanType, extra) + if err != nil { + return nil, nil, err + } + + return ch, ch.incomingRequests, nil +} + +func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) { + ch := m.newChannel(chanType, channelOutbound, extra) + + ch.maxIncomingPayload = channelMaxPacket + + open := channelOpenMsg{ + ChanType: chanType, + PeersWindow: ch.myWindow, + MaxPacketSize: ch.maxIncomingPayload, + TypeSpecificData: extra, + PeersID: ch.localId, + } + if err := m.sendMessage(open); err != nil { + return nil, err + } + + switch msg := (<-ch.msg).(type) { + case *channelOpenConfirmMsg: + return ch, nil + case *channelOpenFailureMsg: + return nil, &OpenChannelError{msg.Reason, msg.Message} + default: + return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg) + } +} + +func (m *mux) handleUnknownChannelPacket(id uint32, packet []byte) error { + msg, err := decode(packet) + if err != nil { + return err + } + + switch msg := msg.(type) { + // RFC 4254 section 5.4 says unrecognized channel requests should + // receive a failure response. + case *channelRequestMsg: + if msg.WantReply { + return m.sendMessage(channelRequestFailureMsg{ + PeersID: msg.PeersID, + }) + } + return nil + default: + return fmt.Errorf("ssh: invalid channel %d", id) + } +} diff --git a/tempfork/sshtest/ssh/mux_test.go b/tempfork/sshtest/ssh/mux_test.go new file mode 100644 index 000000000..21f0ac3e3 --- /dev/null +++ b/tempfork/sshtest/ssh/mux_test.go @@ -0,0 +1,839 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "errors" + "fmt" + "io" + "sync" + "testing" +) + +func muxPair() (*mux, *mux) { + a, b := memPipe() + + s := newMux(a) + c := newMux(b) + + return s, c +} + +// Returns both ends of a channel, and the mux for the 2nd +// channel. +func channelPair(t *testing.T) (*channel, *channel, *mux) { + c, s := muxPair() + + res := make(chan *channel, 1) + go func() { + newCh, ok := <-s.incomingChannels + if !ok { + t.Error("no incoming channel") + close(res) + return + } + if newCh.ChannelType() != "chan" { + t.Errorf("got type %q want chan", newCh.ChannelType()) + newCh.Reject(Prohibited, fmt.Sprintf("got type %q want chan", newCh.ChannelType())) + close(res) + return + } + ch, _, err := newCh.Accept() + if err != nil { + t.Errorf("accept: %v", err) + close(res) + return + } + res <- ch.(*channel) + }() + + ch, err := c.openChannel("chan", nil) + if err != nil { + t.Fatalf("OpenChannel: %v", err) + } + w := <-res + if w == nil { + t.Fatal("unable to get write channel") + } + + return w, ch, c +} + +// Test that stderr and stdout can be addressed from different +// goroutines. This is intended for use with the race detector. +func TestMuxChannelExtendedThreadSafety(t *testing.T) { + writer, reader, mux := channelPair(t) + defer writer.Close() + defer reader.Close() + defer mux.Close() + + var wr, rd sync.WaitGroup + magic := "hello world" + + wr.Add(2) + go func() { + io.WriteString(writer, magic) + wr.Done() + }() + go func() { + io.WriteString(writer.Stderr(), magic) + wr.Done() + }() + + rd.Add(2) + go func() { + c, err := io.ReadAll(reader) + if string(c) != magic { + t.Errorf("stdout read got %q, want %q (error %s)", c, magic, err) + } + rd.Done() + }() + go func() { + c, err := io.ReadAll(reader.Stderr()) + if string(c) != magic { + t.Errorf("stderr read got %q, want %q (error %s)", c, magic, err) + } + rd.Done() + }() + + wr.Wait() + writer.CloseWrite() + rd.Wait() +} + +func TestMuxReadWrite(t *testing.T) { + s, c, mux := channelPair(t) + defer s.Close() + defer c.Close() + defer mux.Close() + + magic := "hello world" + magicExt := "hello stderr" + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) + go func() { + defer wg.Done() + _, err := s.Write([]byte(magic)) + if err != nil { + t.Errorf("Write: %v", err) + return + } + _, err = s.Extended(1).Write([]byte(magicExt)) + if err != nil { + t.Errorf("Write: %v", err) + return + } + }() + + var buf [1024]byte + n, err := c.Read(buf[:]) + if err != nil { + t.Fatalf("server Read: %v", err) + } + got := string(buf[:n]) + if got != magic { + t.Fatalf("server: got %q want %q", got, magic) + } + + n, err = c.Extended(1).Read(buf[:]) + if err != nil { + t.Fatalf("server Read: %v", err) + } + + got = string(buf[:n]) + if got != magicExt { + t.Fatalf("server: got %q want %q", got, magic) + } +} + +func TestMuxChannelOverflow(t *testing.T) { + reader, writer, mux := channelPair(t) + defer reader.Close() + defer writer.Close() + defer mux.Close() + + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) + go func() { + defer wg.Done() + if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { + t.Errorf("could not fill window: %v", err) + } + writer.Write(make([]byte, 1)) + }() + writer.remoteWin.waitWriterBlocked() + + // Send 1 byte. + packet := make([]byte, 1+4+4+1) + packet[0] = msgChannelData + marshalUint32(packet[1:], writer.remoteId) + marshalUint32(packet[5:], uint32(1)) + packet[9] = 42 + + if err := writer.mux.conn.writePacket(packet); err != nil { + t.Errorf("could not send packet") + } + if _, err := reader.SendRequest("hello", true, nil); err == nil { + t.Errorf("SendRequest succeeded.") + } +} + +func TestMuxChannelReadUnblock(t *testing.T) { + reader, writer, mux := channelPair(t) + defer reader.Close() + defer writer.Close() + defer mux.Close() + + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) + go func() { + defer wg.Done() + if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { + t.Errorf("could not fill window: %v", err) + } + if _, err := writer.Write(make([]byte, 1)); err != nil { + t.Errorf("Write: %v", err) + } + writer.Close() + }() + + writer.remoteWin.waitWriterBlocked() + + buf := make([]byte, 32768) + for { + _, err := reader.Read(buf) + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("Read: %v", err) + } + } +} + +func TestMuxChannelCloseWriteUnblock(t *testing.T) { + reader, writer, mux := channelPair(t) + defer reader.Close() + defer writer.Close() + defer mux.Close() + + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) + go func() { + defer wg.Done() + if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { + t.Errorf("could not fill window: %v", err) + } + if _, err := writer.Write(make([]byte, 1)); err != io.EOF { + t.Errorf("got %v, want EOF for unblock write", err) + } + }() + + writer.remoteWin.waitWriterBlocked() + reader.Close() +} + +func TestMuxConnectionCloseWriteUnblock(t *testing.T) { + reader, writer, mux := channelPair(t) + defer reader.Close() + defer writer.Close() + defer mux.Close() + + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) + go func() { + defer wg.Done() + if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { + t.Errorf("could not fill window: %v", err) + } + if _, err := writer.Write(make([]byte, 1)); err != io.EOF { + t.Errorf("got %v, want EOF for unblock write", err) + } + }() + + writer.remoteWin.waitWriterBlocked() + mux.Close() +} + +func TestMuxReject(t *testing.T) { + client, server := muxPair() + defer server.Close() + defer client.Close() + + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) + go func() { + defer wg.Done() + + ch, ok := <-server.incomingChannels + if !ok { + t.Error("cannot accept channel") + return + } + if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" { + t.Errorf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData()) + ch.Reject(RejectionReason(UnknownChannelType), UnknownChannelType.String()) + return + } + ch.Reject(RejectionReason(42), "message") + }() + + ch, err := client.openChannel("ch", []byte("extra")) + if ch != nil { + t.Fatal("openChannel not rejected") + } + + ocf, ok := err.(*OpenChannelError) + if !ok { + t.Errorf("got %#v want *OpenChannelError", err) + } else if ocf.Reason != 42 || ocf.Message != "message" { + t.Errorf("got %#v, want {Reason: 42, Message: %q}", ocf, "message") + } + + want := "ssh: rejected: unknown reason 42 (message)" + if err.Error() != want { + t.Errorf("got %q, want %q", err.Error(), want) + } +} + +func TestMuxChannelRequest(t *testing.T) { + client, server, mux := channelPair(t) + defer server.Close() + defer client.Close() + defer mux.Close() + + var received int + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) + go func() { + for r := range server.incomingRequests { + received++ + r.Reply(r.Type == "yes", nil) + } + wg.Done() + }() + _, err := client.SendRequest("yes", false, nil) + if err != nil { + t.Fatalf("SendRequest: %v", err) + } + ok, err := client.SendRequest("yes", true, nil) + if err != nil { + t.Fatalf("SendRequest: %v", err) + } + + if !ok { + t.Errorf("SendRequest(yes): %v", ok) + + } + + ok, err = client.SendRequest("no", true, nil) + if err != nil { + t.Fatalf("SendRequest: %v", err) + } + if ok { + t.Errorf("SendRequest(no): %v", ok) + } + + client.Close() + wg.Wait() + + if received != 3 { + t.Errorf("got %d requests, want %d", received, 3) + } +} + +func TestMuxUnknownChannelRequests(t *testing.T) { + clientPipe, serverPipe := memPipe() + client := newMux(clientPipe) + defer serverPipe.Close() + defer client.Close() + + kDone := make(chan error, 1) + go func() { + // Ignore unknown channel messages that don't want a reply. + err := serverPipe.writePacket(Marshal(channelRequestMsg{ + PeersID: 1, + Request: "keepalive@openssh.com", + WantReply: false, + RequestSpecificData: []byte{}, + })) + if err != nil { + kDone <- fmt.Errorf("send: %w", err) + return + } + + // Send a keepalive, which should get a channel failure message + // in response. + err = serverPipe.writePacket(Marshal(channelRequestMsg{ + PeersID: 2, + Request: "keepalive@openssh.com", + WantReply: true, + RequestSpecificData: []byte{}, + })) + if err != nil { + kDone <- fmt.Errorf("send: %w", err) + return + } + + packet, err := serverPipe.readPacket() + if err != nil { + kDone <- fmt.Errorf("read packet: %w", err) + return + } + decoded, err := decode(packet) + if err != nil { + kDone <- fmt.Errorf("decode failed: %w", err) + return + } + + switch msg := decoded.(type) { + case *channelRequestFailureMsg: + if msg.PeersID != 2 { + kDone <- fmt.Errorf("received response to wrong message: %v", msg) + return + + } + default: + kDone <- fmt.Errorf("unexpected channel message: %v", msg) + return + } + + kDone <- nil + + // Receive and respond to the keepalive to confirm the mux is + // still processing requests. + packet, err = serverPipe.readPacket() + if err != nil { + kDone <- fmt.Errorf("read packet: %w", err) + return + } + if packet[0] != msgGlobalRequest { + kDone <- errors.New("expected global request") + return + } + + err = serverPipe.writePacket(Marshal(globalRequestFailureMsg{ + Data: []byte{}, + })) + if err != nil { + kDone <- fmt.Errorf("failed to send failure msg: %w", err) + return + } + + close(kDone) + }() + + // Wait for the server to send the keepalive message and receive back a + // response. + if err := <-kDone; err != nil { + t.Fatal(err) + } + + // Confirm client hasn't closed. + if _, _, err := client.SendRequest("keepalive@golang.org", true, nil); err != nil { + t.Fatalf("failed to send keepalive: %v", err) + } + + // Wait for the server to shut down. + if err := <-kDone; err != nil { + t.Fatal(err) + } +} + +func TestMuxClosedChannel(t *testing.T) { + clientPipe, serverPipe := memPipe() + client := newMux(clientPipe) + defer serverPipe.Close() + defer client.Close() + + kDone := make(chan error, 1) + go func() { + // Open the channel. + packet, err := serverPipe.readPacket() + if err != nil { + kDone <- fmt.Errorf("read packet: %w", err) + return + } + if packet[0] != msgChannelOpen { + kDone <- errors.New("expected chan open") + return + } + + var openMsg channelOpenMsg + if err := Unmarshal(packet, &openMsg); err != nil { + kDone <- fmt.Errorf("unmarshal: %w", err) + return + } + + // Send back the opened channel confirmation. + err = serverPipe.writePacket(Marshal(channelOpenConfirmMsg{ + PeersID: openMsg.PeersID, + MyID: 0, + MyWindow: 0, + MaxPacketSize: channelMaxPacket, + })) + if err != nil { + kDone <- fmt.Errorf("send: %w", err) + return + } + + // Close the channel. + err = serverPipe.writePacket(Marshal(channelCloseMsg{ + PeersID: openMsg.PeersID, + })) + if err != nil { + kDone <- fmt.Errorf("send: %w", err) + return + } + + // Send a keepalive message on the channel we just closed. + err = serverPipe.writePacket(Marshal(channelRequestMsg{ + PeersID: openMsg.PeersID, + Request: "keepalive@openssh.com", + WantReply: true, + RequestSpecificData: []byte{}, + })) + if err != nil { + kDone <- fmt.Errorf("send: %w", err) + return + } + + // Receive the channel closed response. + packet, err = serverPipe.readPacket() + if err != nil { + kDone <- fmt.Errorf("read packet: %w", err) + return + } + if packet[0] != msgChannelClose { + kDone <- errors.New("expected channel close") + return + } + + // Receive the keepalive response failure. + packet, err = serverPipe.readPacket() + if err != nil { + kDone <- fmt.Errorf("read packet: %w", err) + return + } + if packet[0] != msgChannelFailure { + kDone <- errors.New("expected channel failure") + return + } + kDone <- nil + + // Receive and respond to the keepalive to confirm the mux is + // still processing requests. + packet, err = serverPipe.readPacket() + if err != nil { + kDone <- fmt.Errorf("read packet: %w", err) + return + } + if packet[0] != msgGlobalRequest { + kDone <- errors.New("expected global request") + return + } + + err = serverPipe.writePacket(Marshal(globalRequestFailureMsg{ + Data: []byte{}, + })) + if err != nil { + kDone <- fmt.Errorf("failed to send failure msg: %w", err) + return + } + + close(kDone) + }() + + // Open a channel. + ch, err := client.openChannel("chan", nil) + if err != nil { + t.Fatalf("OpenChannel: %v", err) + } + defer ch.Close() + + // Wait for the server to close the channel and send the keepalive. + <-kDone + + // Make sure the channel closed. + if _, ok := <-ch.incomingRequests; ok { + t.Fatalf("channel not closed") + } + + // Confirm client hasn't closed + if _, _, err := client.SendRequest("keepalive@golang.org", true, nil); err != nil { + t.Fatalf("failed to send keepalive: %v", err) + } + + // Wait for the server to shut down. + <-kDone +} + +func TestMuxGlobalRequest(t *testing.T) { + var sawPeek bool + var wg sync.WaitGroup + defer func() { + wg.Wait() + if !sawPeek { + t.Errorf("never saw 'peek' request") + } + }() + + clientMux, serverMux := muxPair() + defer serverMux.Close() + defer clientMux.Close() + + wg.Add(1) + go func() { + defer wg.Done() + for r := range serverMux.incomingRequests { + sawPeek = sawPeek || r.Type == "peek" + if r.WantReply { + err := r.Reply(r.Type == "yes", + append([]byte(r.Type), r.Payload...)) + if err != nil { + t.Errorf("AckRequest: %v", err) + } + } + } + }() + + _, _, err := clientMux.SendRequest("peek", false, nil) + if err != nil { + t.Errorf("SendRequest: %v", err) + } + + ok, data, err := clientMux.SendRequest("yes", true, []byte("a")) + if !ok || string(data) != "yesa" || err != nil { + t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v", + ok, data, err) + } + if ok, data, err := clientMux.SendRequest("yes", true, []byte("a")); !ok || string(data) != "yesa" || err != nil { + t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v", + ok, data, err) + } + + if ok, data, err := clientMux.SendRequest("no", true, []byte("a")); ok || string(data) != "noa" || err != nil { + t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v", + ok, data, err) + } +} + +func TestMuxGlobalRequestUnblock(t *testing.T) { + clientMux, serverMux := muxPair() + defer serverMux.Close() + defer clientMux.Close() + + result := make(chan error, 1) + go func() { + _, _, err := clientMux.SendRequest("hello", true, nil) + result <- err + }() + + <-serverMux.incomingRequests + serverMux.conn.Close() + err := <-result + + if err != io.EOF { + t.Errorf("want EOF, got %v", io.EOF) + } +} + +func TestMuxChannelRequestUnblock(t *testing.T) { + a, b, connB := channelPair(t) + defer a.Close() + defer b.Close() + defer connB.Close() + + result := make(chan error, 1) + go func() { + _, err := a.SendRequest("hello", true, nil) + result <- err + }() + + <-b.incomingRequests + connB.conn.Close() + err := <-result + + if err != io.EOF { + t.Errorf("want EOF, got %v", err) + } +} + +func TestMuxCloseChannel(t *testing.T) { + r, w, mux := channelPair(t) + defer mux.Close() + defer r.Close() + defer w.Close() + + result := make(chan error, 1) + go func() { + var b [1024]byte + _, err := r.Read(b[:]) + result <- err + }() + if err := w.Close(); err != nil { + t.Errorf("w.Close: %v", err) + } + + if _, err := w.Write([]byte("hello")); err != io.EOF { + t.Errorf("got err %v, want io.EOF after Close", err) + } + + if err := <-result; err != io.EOF { + t.Errorf("got %v (%T), want io.EOF", err, err) + } +} + +func TestMuxCloseWriteChannel(t *testing.T) { + r, w, mux := channelPair(t) + defer mux.Close() + + result := make(chan error, 1) + go func() { + var b [1024]byte + _, err := r.Read(b[:]) + result <- err + }() + if err := w.CloseWrite(); err != nil { + t.Errorf("w.CloseWrite: %v", err) + } + + if _, err := w.Write([]byte("hello")); err != io.EOF { + t.Errorf("got err %v, want io.EOF after CloseWrite", err) + } + + if err := <-result; err != io.EOF { + t.Errorf("got %v (%T), want io.EOF", err, err) + } +} + +func TestMuxInvalidRecord(t *testing.T) { + a, b := muxPair() + defer a.Close() + defer b.Close() + + packet := make([]byte, 1+4+4+1) + packet[0] = msgChannelData + marshalUint32(packet[1:], 29348723 /* invalid channel id */) + marshalUint32(packet[5:], 1) + packet[9] = 42 + + a.conn.writePacket(packet) + go a.SendRequest("hello", false, nil) + // 'a' wrote an invalid packet, so 'b' has exited. + req, ok := <-b.incomingRequests + if ok { + t.Errorf("got request %#v after receiving invalid packet", req) + } +} + +func TestZeroWindowAdjust(t *testing.T) { + a, b, mux := channelPair(t) + defer a.Close() + defer b.Close() + defer mux.Close() + + go func() { + io.WriteString(a, "hello") + // bogus adjust. + a.sendMessage(windowAdjustMsg{}) + io.WriteString(a, "world") + a.Close() + }() + + want := "helloworld" + c, _ := io.ReadAll(b) + if string(c) != want { + t.Errorf("got %q want %q", c, want) + } +} + +func TestMuxMaxPacketSize(t *testing.T) { + a, b, mux := channelPair(t) + defer a.Close() + defer b.Close() + defer mux.Close() + + large := make([]byte, a.maxRemotePayload+1) + packet := make([]byte, 1+4+4+1+len(large)) + packet[0] = msgChannelData + marshalUint32(packet[1:], a.remoteId) + marshalUint32(packet[5:], uint32(len(large))) + packet[9] = 42 + + if err := a.mux.conn.writePacket(packet); err != nil { + t.Errorf("could not send packet") + } + + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) + go func() { + a.SendRequest("hello", false, nil) + wg.Done() + }() + + _, ok := <-b.incomingRequests + if ok { + t.Errorf("connection still alive after receiving large packet.") + } +} + +func TestMuxChannelWindowDeferredUpdates(t *testing.T) { + s, c, mux := channelPair(t) + cTransport := mux.conn.(*memTransport) + defer s.Close() + defer c.Close() + defer mux.Close() + + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + + data := make([]byte, 1024) + + wg.Add(1) + go func() { + defer wg.Done() + _, err := s.Write(data) + if err != nil { + t.Errorf("Write: %v", err) + return + } + }() + cWritesInit := cTransport.getWriteCount() + buf := make([]byte, 1) + for i := 0; i < len(data); i++ { + n, err := c.Read(buf) + if n != len(buf) || err != nil { + t.Fatalf("Read: %v, %v", n, err) + } + } + cWrites := cTransport.getWriteCount() - cWritesInit + // reading 1 KiB should not cause any window updates to be sent, but allow + // for some unexpected writes + if cWrites > 30 { + t.Fatalf("reading 1 KiB from channel caused %v writes", cWrites) + } +} + +// Don't ship code with debug=true. +func TestDebug(t *testing.T) { + if debugMux { + t.Error("mux debug switched on") + } + if debugHandshake { + t.Error("handshake debug switched on") + } + if debugTransport { + t.Error("transport debug switched on") + } +} diff --git a/tempfork/sshtest/ssh/server.go b/tempfork/sshtest/ssh/server.go new file mode 100644 index 000000000..1839ddc6a --- /dev/null +++ b/tempfork/sshtest/ssh/server.go @@ -0,0 +1,933 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "errors" + "fmt" + "io" + "net" + "strings" +) + +// The Permissions type holds fine-grained permissions that are +// specific to a user or a specific authentication method for a user. +// The Permissions value for a successful authentication attempt is +// available in ServerConn, so it can be used to pass information from +// the user-authentication phase to the application layer. +type Permissions struct { + // CriticalOptions indicate restrictions to the default + // permissions, and are typically used in conjunction with + // user certificates. The standard for SSH certificates + // defines "force-command" (only allow the given command to + // execute) and "source-address" (only allow connections from + // the given address). The SSH package currently only enforces + // the "source-address" critical option. It is up to server + // implementations to enforce other critical options, such as + // "force-command", by checking them after the SSH handshake + // is successful. In general, SSH servers should reject + // connections that specify critical options that are unknown + // or not supported. + CriticalOptions map[string]string + + // Extensions are extra functionality that the server may + // offer on authenticated connections. Lack of support for an + // extension does not preclude authenticating a user. Common + // extensions are "permit-agent-forwarding", + // "permit-X11-forwarding". The Go SSH library currently does + // not act on any extension, and it is up to server + // implementations to honor them. Extensions can be used to + // pass data from the authentication callbacks to the server + // application layer. + Extensions map[string]string +} + +type GSSAPIWithMICConfig struct { + // AllowLogin, must be set, is called when gssapi-with-mic + // authentication is selected (RFC 4462 section 3). The srcName is from the + // results of the GSS-API authentication. The format is username@DOMAIN. + // GSSAPI just guarantees to the server who the user is, but not if they can log in, and with what permissions. + // This callback is called after the user identity is established with GSSAPI to decide if the user can login with + // which permissions. If the user is allowed to login, it should return a nil error. + AllowLogin func(conn ConnMetadata, srcName string) (*Permissions, error) + + // Server must be set. It's the implementation + // of the GSSAPIServer interface. See GSSAPIServer interface for details. + Server GSSAPIServer +} + +// SendAuthBanner implements [ServerPreAuthConn]. +func (s *connection) SendAuthBanner(msg string) error { + return s.transport.writePacket(Marshal(&userAuthBannerMsg{ + Message: msg, + })) +} + +func (*connection) unexportedMethodForFutureProofing() {} + +// ServerPreAuthConn is the interface available on an incoming server +// connection before authentication has completed. +type ServerPreAuthConn interface { + unexportedMethodForFutureProofing() // permits growing ServerPreAuthConn safely later, ala testing.TB + + ConnMetadata + + // SendAuthBanner sends a banner message to the client. + // It returns an error once the authentication phase has ended. + SendAuthBanner(string) error +} + +// ServerConfig holds server specific configuration data. +type ServerConfig struct { + // Config contains configuration shared between client and server. + Config + + // PublicKeyAuthAlgorithms specifies the supported client public key + // authentication algorithms. Note that this should not include certificate + // types since those use the underlying algorithm. This list is sent to the + // client if it supports the server-sig-algs extension. Order is irrelevant. + // If unspecified then a default set of algorithms is used. + PublicKeyAuthAlgorithms []string + + hostKeys []Signer + + // NoClientAuth is true if clients are allowed to connect without + // authenticating. + // To determine NoClientAuth at runtime, set NoClientAuth to true + // and the optional NoClientAuthCallback to a non-nil value. + NoClientAuth bool + + // NoClientAuthCallback, if non-nil, is called when a user + // attempts to authenticate with auth method "none". + // NoClientAuth must also be set to true for this be used, or + // this func is unused. + NoClientAuthCallback func(ConnMetadata) (*Permissions, error) + + // MaxAuthTries specifies the maximum number of authentication attempts + // permitted per connection. If set to a negative number, the number of + // attempts are unlimited. If set to zero, the number of attempts are limited + // to 6. + MaxAuthTries int + + // PasswordCallback, if non-nil, is called when a user + // attempts to authenticate using a password. + PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error) + + // PublicKeyCallback, if non-nil, is called when a client + // offers a public key for authentication. It must return a nil error + // if the given public key can be used to authenticate the + // given user. For example, see CertChecker.Authenticate. A + // call to this function does not guarantee that the key + // offered is in fact used to authenticate. To record any data + // depending on the public key, store it inside a + // Permissions.Extensions entry. + PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error) + + // KeyboardInteractiveCallback, if non-nil, is called when + // keyboard-interactive authentication is selected (RFC + // 4256). The client object's Challenge function should be + // used to query the user. The callback may offer multiple + // Challenge rounds. To avoid information leaks, the client + // should be presented a challenge even if the user is + // unknown. + KeyboardInteractiveCallback func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error) + + // AuthLogCallback, if non-nil, is called to log all authentication + // attempts. + AuthLogCallback func(conn ConnMetadata, method string, err error) + + // PreAuthConnCallback, if non-nil, is called upon receiving a new connection + // before any authentication has started. The provided ServerPreAuthConn + // can be used at any time before authentication is complete, including + // after this callback has returned. + PreAuthConnCallback func(ServerPreAuthConn) + + // ServerVersion is the version identification string to announce in + // the public handshake. + // If empty, a reasonable default is used. + // Note that RFC 4253 section 4.2 requires that this string start with + // "SSH-2.0-". + ServerVersion string + + // BannerCallback, if present, is called and the return string is sent to + // the client after key exchange completed but before authentication. + BannerCallback func(conn ConnMetadata) string + + // GSSAPIWithMICConfig includes gssapi server and callback, which if both non-nil, is used + // when gssapi-with-mic authentication is selected (RFC 4462 section 3). + GSSAPIWithMICConfig *GSSAPIWithMICConfig +} + +// AddHostKey adds a private key as a host key. If an existing host +// key exists with the same public key format, it is replaced. Each server +// config must have at least one host key. +func (s *ServerConfig) AddHostKey(key Signer) { + for i, k := range s.hostKeys { + if k.PublicKey().Type() == key.PublicKey().Type() { + s.hostKeys[i] = key + return + } + } + + s.hostKeys = append(s.hostKeys, key) +} + +// cachedPubKey contains the results of querying whether a public key is +// acceptable for a user. This is a FIFO cache. +type cachedPubKey struct { + user string + pubKeyData []byte + result error + perms *Permissions +} + +// maxCachedPubKeys is the number of cache entries we store. +// +// Due to consistent misuse of the PublicKeyCallback API, we have reduced this +// to 1, such that the only key in the cache is the most recently seen one. This +// forces the behavior that the last call to PublicKeyCallback will always be +// with the key that is used for authentication. +const maxCachedPubKeys = 1 + +// pubKeyCache caches tests for public keys. Since SSH clients +// will query whether a public key is acceptable before attempting to +// authenticate with it, we end up with duplicate queries for public +// key validity. The cache only applies to a single ServerConn. +type pubKeyCache struct { + keys []cachedPubKey +} + +// get returns the result for a given user/algo/key tuple. +func (c *pubKeyCache) get(user string, pubKeyData []byte) (cachedPubKey, bool) { + for _, k := range c.keys { + if k.user == user && bytes.Equal(k.pubKeyData, pubKeyData) { + return k, true + } + } + return cachedPubKey{}, false +} + +// add adds the given tuple to the cache. +func (c *pubKeyCache) add(candidate cachedPubKey) { + if len(c.keys) >= maxCachedPubKeys { + c.keys = c.keys[1:] + } + c.keys = append(c.keys, candidate) +} + +// ServerConn is an authenticated SSH connection, as seen from the +// server +type ServerConn struct { + Conn + + // If the succeeding authentication callback returned a + // non-nil Permissions pointer, it is stored here. + Permissions *Permissions +} + +// NewServerConn starts a new SSH server with c as the underlying +// transport. It starts with a handshake and, if the handshake is +// unsuccessful, it closes the connection and returns an error. The +// Request and NewChannel channels must be serviced, or the connection +// will hang. +// +// The returned error may be of type *ServerAuthError for +// authentication errors. +func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewChannel, <-chan *Request, error) { + fullConf := *config + fullConf.SetDefaults() + if fullConf.MaxAuthTries == 0 { + fullConf.MaxAuthTries = 6 + } + if len(fullConf.PublicKeyAuthAlgorithms) == 0 { + fullConf.PublicKeyAuthAlgorithms = supportedPubKeyAuthAlgos + } else { + for _, algo := range fullConf.PublicKeyAuthAlgorithms { + if !contains(supportedPubKeyAuthAlgos, algo) { + c.Close() + return nil, nil, nil, fmt.Errorf("ssh: unsupported public key authentication algorithm %s", algo) + } + } + } + // Check if the config contains any unsupported key exchanges + for _, kex := range fullConf.KeyExchanges { + if _, ok := serverForbiddenKexAlgos[kex]; ok { + c.Close() + return nil, nil, nil, fmt.Errorf("ssh: unsupported key exchange %s for server", kex) + } + } + + s := &connection{ + sshConn: sshConn{conn: c}, + } + perms, err := s.serverHandshake(&fullConf) + if err != nil { + c.Close() + return nil, nil, nil, err + } + return &ServerConn{s, perms}, s.mux.incomingChannels, s.mux.incomingRequests, nil +} + +// signAndMarshal signs the data with the appropriate algorithm, +// and serializes the result in SSH wire format. algo is the negotiate +// algorithm and may be a certificate type. +func signAndMarshal(k AlgorithmSigner, rand io.Reader, data []byte, algo string) ([]byte, error) { + sig, err := k.SignWithAlgorithm(rand, data, underlyingAlgo(algo)) + if err != nil { + return nil, err + } + + return Marshal(sig), nil +} + +// handshake performs key exchange and user authentication. +func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error) { + if len(config.hostKeys) == 0 { + return nil, errors.New("ssh: server has no host keys") + } + + if !config.NoClientAuth && config.PasswordCallback == nil && config.PublicKeyCallback == nil && + config.KeyboardInteractiveCallback == nil && (config.GSSAPIWithMICConfig == nil || + config.GSSAPIWithMICConfig.AllowLogin == nil || config.GSSAPIWithMICConfig.Server == nil) { + return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false") + } + + if config.ServerVersion != "" { + s.serverVersion = []byte(config.ServerVersion) + } else { + s.serverVersion = []byte(packageVersion) + } + var err error + s.clientVersion, err = exchangeVersions(s.sshConn.conn, s.serverVersion) + if err != nil { + return nil, err + } + + tr := newTransport(s.sshConn.conn, config.Rand, false /* not client */) + s.transport = newServerTransport(tr, s.clientVersion, s.serverVersion, config) + + if err := s.transport.waitSession(); err != nil { + return nil, err + } + + // We just did the key change, so the session ID is established. + s.sessionID = s.transport.getSessionID() + + var packet []byte + if packet, err = s.transport.readPacket(); err != nil { + return nil, err + } + + var serviceRequest serviceRequestMsg + if err = Unmarshal(packet, &serviceRequest); err != nil { + return nil, err + } + if serviceRequest.Service != serviceUserAuth { + return nil, errors.New("ssh: requested service '" + serviceRequest.Service + "' before authenticating") + } + serviceAccept := serviceAcceptMsg{ + Service: serviceUserAuth, + } + if err := s.transport.writePacket(Marshal(&serviceAccept)); err != nil { + return nil, err + } + + perms, err := s.serverAuthenticate(config) + if err != nil { + return nil, err + } + s.mux = newMux(s.transport) + return perms, err +} + +func checkSourceAddress(addr net.Addr, sourceAddrs string) error { + if addr == nil { + return errors.New("ssh: no address known for client, but source-address match required") + } + + tcpAddr, ok := addr.(*net.TCPAddr) + if !ok { + return fmt.Errorf("ssh: remote address %v is not an TCP address when checking source-address match", addr) + } + + for _, sourceAddr := range strings.Split(sourceAddrs, ",") { + if allowedIP := net.ParseIP(sourceAddr); allowedIP != nil { + if allowedIP.Equal(tcpAddr.IP) { + return nil + } + } else { + _, ipNet, err := net.ParseCIDR(sourceAddr) + if err != nil { + return fmt.Errorf("ssh: error parsing source-address restriction %q: %v", sourceAddr, err) + } + + if ipNet.Contains(tcpAddr.IP) { + return nil + } + } + } + + return fmt.Errorf("ssh: remote address %v is not allowed because of source-address restriction", addr) +} + +func gssExchangeToken(gssapiConfig *GSSAPIWithMICConfig, token []byte, s *connection, + sessionID []byte, userAuthReq userAuthRequestMsg) (authErr error, perms *Permissions, err error) { + gssAPIServer := gssapiConfig.Server + defer gssAPIServer.DeleteSecContext() + var srcName string + for { + var ( + outToken []byte + needContinue bool + ) + outToken, srcName, needContinue, err = gssAPIServer.AcceptSecContext(token) + if err != nil { + return err, nil, nil + } + if len(outToken) != 0 { + if err := s.transport.writePacket(Marshal(&userAuthGSSAPIToken{ + Token: outToken, + })); err != nil { + return nil, nil, err + } + } + if !needContinue { + break + } + packet, err := s.transport.readPacket() + if err != nil { + return nil, nil, err + } + userAuthGSSAPITokenReq := &userAuthGSSAPIToken{} + if err := Unmarshal(packet, userAuthGSSAPITokenReq); err != nil { + return nil, nil, err + } + token = userAuthGSSAPITokenReq.Token + } + packet, err := s.transport.readPacket() + if err != nil { + return nil, nil, err + } + userAuthGSSAPIMICReq := &userAuthGSSAPIMIC{} + if err := Unmarshal(packet, userAuthGSSAPIMICReq); err != nil { + return nil, nil, err + } + mic := buildMIC(string(sessionID), userAuthReq.User, userAuthReq.Service, userAuthReq.Method) + if err := gssAPIServer.VerifyMIC(mic, userAuthGSSAPIMICReq.MIC); err != nil { + return err, nil, nil + } + perms, authErr = gssapiConfig.AllowLogin(s, srcName) + return authErr, perms, nil +} + +// isAlgoCompatible checks if the signature format is compatible with the +// selected algorithm taking into account edge cases that occur with old +// clients. +func isAlgoCompatible(algo, sigFormat string) bool { + // Compatibility for old clients. + // + // For certificate authentication with OpenSSH 7.2-7.7 signature format can + // be rsa-sha2-256 or rsa-sha2-512 for the algorithm + // ssh-rsa-cert-v01@openssh.com. + // + // With gpg-agent < 2.2.6 the algorithm can be rsa-sha2-256 or rsa-sha2-512 + // for signature format ssh-rsa. + if isRSA(algo) && isRSA(sigFormat) { + return true + } + // Standard case: the underlying algorithm must match the signature format. + return underlyingAlgo(algo) == sigFormat +} + +// ServerAuthError represents server authentication errors and is +// sometimes returned by NewServerConn. It appends any authentication +// errors that may occur, and is returned if all of the authentication +// methods provided by the user failed to authenticate. +type ServerAuthError struct { + // Errors contains authentication errors returned by the authentication + // callback methods. The first entry is typically ErrNoAuth. + Errors []error +} + +func (l ServerAuthError) Error() string { + var errs []string + for _, err := range l.Errors { + errs = append(errs, err.Error()) + } + return "[" + strings.Join(errs, ", ") + "]" +} + +// ServerAuthCallbacks defines server-side authentication callbacks. +type ServerAuthCallbacks struct { + // PasswordCallback behaves like [ServerConfig.PasswordCallback]. + PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error) + + // PublicKeyCallback behaves like [ServerConfig.PublicKeyCallback]. + PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error) + + // KeyboardInteractiveCallback behaves like [ServerConfig.KeyboardInteractiveCallback]. + KeyboardInteractiveCallback func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error) + + // GSSAPIWithMICConfig behaves like [ServerConfig.GSSAPIWithMICConfig]. + GSSAPIWithMICConfig *GSSAPIWithMICConfig +} + +// PartialSuccessError can be returned by any of the [ServerConfig] +// authentication callbacks to indicate to the client that authentication has +// partially succeeded, but further steps are required. +type PartialSuccessError struct { + // Next defines the authentication callbacks to apply to further steps. The + // available methods communicated to the client are based on the non-nil + // ServerAuthCallbacks fields. + Next ServerAuthCallbacks +} + +func (p *PartialSuccessError) Error() string { + return "ssh: authenticated with partial success" +} + +// ErrNoAuth is the error value returned if no +// authentication method has been passed yet. This happens as a normal +// part of the authentication loop, since the client first tries +// 'none' authentication to discover available methods. +// It is returned in ServerAuthError.Errors from NewServerConn. +var ErrNoAuth = errors.New("ssh: no auth passed yet") + +// BannerError is an error that can be returned by authentication handlers in +// ServerConfig to send a banner message to the client. +type BannerError struct { + Err error + Message string +} + +func (b *BannerError) Unwrap() error { + return b.Err +} + +func (b *BannerError) Error() string { + if b.Err == nil { + return b.Message + } + return b.Err.Error() +} + +func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) { + if config.PreAuthConnCallback != nil { + config.PreAuthConnCallback(s) + } + + sessionID := s.transport.getSessionID() + var cache pubKeyCache + var perms *Permissions + + authFailures := 0 + noneAuthCount := 0 + var authErrs []error + var calledBannerCallback bool + partialSuccessReturned := false + // Set the initial authentication callbacks from the config. They can be + // changed if a PartialSuccessError is returned. + authConfig := ServerAuthCallbacks{ + PasswordCallback: config.PasswordCallback, + PublicKeyCallback: config.PublicKeyCallback, + KeyboardInteractiveCallback: config.KeyboardInteractiveCallback, + GSSAPIWithMICConfig: config.GSSAPIWithMICConfig, + } + +userAuthLoop: + for { + if authFailures >= config.MaxAuthTries && config.MaxAuthTries > 0 { + discMsg := &disconnectMsg{ + Reason: 2, + Message: "too many authentication failures", + } + + if err := s.transport.writePacket(Marshal(discMsg)); err != nil { + return nil, err + } + authErrs = append(authErrs, discMsg) + return nil, &ServerAuthError{Errors: authErrs} + } + + var userAuthReq userAuthRequestMsg + if packet, err := s.transport.readPacket(); err != nil { + if err == io.EOF { + return nil, &ServerAuthError{Errors: authErrs} + } + return nil, err + } else if err = Unmarshal(packet, &userAuthReq); err != nil { + return nil, err + } + + if userAuthReq.Service != serviceSSH { + return nil, errors.New("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service) + } + + if s.user != userAuthReq.User && partialSuccessReturned { + return nil, fmt.Errorf("ssh: client changed the user after a partial success authentication, previous user %q, current user %q", + s.user, userAuthReq.User) + } + + s.user = userAuthReq.User + + if !calledBannerCallback && config.BannerCallback != nil { + calledBannerCallback = true + if msg := config.BannerCallback(s); msg != "" { + if err := s.SendAuthBanner(msg); err != nil { + return nil, err + } + } + } + + perms = nil + authErr := ErrNoAuth + + switch userAuthReq.Method { + case "none": + noneAuthCount++ + // We don't allow none authentication after a partial success + // response. + if config.NoClientAuth && !partialSuccessReturned { + if config.NoClientAuthCallback != nil { + perms, authErr = config.NoClientAuthCallback(s) + } else { + authErr = nil + } + } + case "password": + if authConfig.PasswordCallback == nil { + authErr = errors.New("ssh: password auth not configured") + break + } + payload := userAuthReq.Payload + if len(payload) < 1 || payload[0] != 0 { + return nil, parseError(msgUserAuthRequest) + } + payload = payload[1:] + password, payload, ok := parseString(payload) + if !ok || len(payload) > 0 { + return nil, parseError(msgUserAuthRequest) + } + + perms, authErr = authConfig.PasswordCallback(s, password) + case "keyboard-interactive": + if authConfig.KeyboardInteractiveCallback == nil { + authErr = errors.New("ssh: keyboard-interactive auth not configured") + break + } + + prompter := &sshClientKeyboardInteractive{s} + perms, authErr = authConfig.KeyboardInteractiveCallback(s, prompter.Challenge) + case "publickey": + if authConfig.PublicKeyCallback == nil { + authErr = errors.New("ssh: publickey auth not configured") + break + } + payload := userAuthReq.Payload + if len(payload) < 1 { + return nil, parseError(msgUserAuthRequest) + } + isQuery := payload[0] == 0 + payload = payload[1:] + algoBytes, payload, ok := parseString(payload) + if !ok { + return nil, parseError(msgUserAuthRequest) + } + algo := string(algoBytes) + if !contains(config.PublicKeyAuthAlgorithms, underlyingAlgo(algo)) { + authErr = fmt.Errorf("ssh: algorithm %q not accepted", algo) + break + } + + pubKeyData, payload, ok := parseString(payload) + if !ok { + return nil, parseError(msgUserAuthRequest) + } + + pubKey, err := ParsePublicKey(pubKeyData) + if err != nil { + return nil, err + } + + candidate, ok := cache.get(s.user, pubKeyData) + if !ok { + candidate.user = s.user + candidate.pubKeyData = pubKeyData + candidate.perms, candidate.result = authConfig.PublicKeyCallback(s, pubKey) + _, isPartialSuccessError := candidate.result.(*PartialSuccessError) + + if (candidate.result == nil || isPartialSuccessError) && + candidate.perms != nil && + candidate.perms.CriticalOptions != nil && + candidate.perms.CriticalOptions[sourceAddressCriticalOption] != "" { + if err := checkSourceAddress( + s.RemoteAddr(), + candidate.perms.CriticalOptions[sourceAddressCriticalOption]); err != nil { + candidate.result = err + } + } + cache.add(candidate) + } + + if isQuery { + // The client can query if the given public key + // would be okay. + + if len(payload) > 0 { + return nil, parseError(msgUserAuthRequest) + } + _, isPartialSuccessError := candidate.result.(*PartialSuccessError) + if candidate.result == nil || isPartialSuccessError { + okMsg := userAuthPubKeyOkMsg{ + Algo: algo, + PubKey: pubKeyData, + } + if err = s.transport.writePacket(Marshal(&okMsg)); err != nil { + return nil, err + } + continue userAuthLoop + } + authErr = candidate.result + } else { + sig, payload, ok := parseSignature(payload) + if !ok || len(payload) > 0 { + return nil, parseError(msgUserAuthRequest) + } + // Ensure the declared public key algo is compatible with the + // decoded one. This check will ensure we don't accept e.g. + // ssh-rsa-cert-v01@openssh.com algorithm with ssh-rsa public + // key type. The algorithm and public key type must be + // consistent: both must be certificate algorithms, or neither. + if !contains(algorithmsForKeyFormat(pubKey.Type()), algo) { + authErr = fmt.Errorf("ssh: public key type %q not compatible with selected algorithm %q", + pubKey.Type(), algo) + break + } + // Ensure the public key algo and signature algo + // are supported. Compare the private key + // algorithm name that corresponds to algo with + // sig.Format. This is usually the same, but + // for certs, the names differ. + if !contains(config.PublicKeyAuthAlgorithms, sig.Format) { + authErr = fmt.Errorf("ssh: algorithm %q not accepted", sig.Format) + break + } + if !isAlgoCompatible(algo, sig.Format) { + authErr = fmt.Errorf("ssh: signature %q not compatible with selected algorithm %q", sig.Format, algo) + break + } + + signedData := buildDataSignedForAuth(sessionID, userAuthReq, algo, pubKeyData) + + if err := pubKey.Verify(signedData, sig); err != nil { + return nil, err + } + + authErr = candidate.result + perms = candidate.perms + } + case "gssapi-with-mic": + if authConfig.GSSAPIWithMICConfig == nil { + authErr = errors.New("ssh: gssapi-with-mic auth not configured") + break + } + gssapiConfig := authConfig.GSSAPIWithMICConfig + userAuthRequestGSSAPI, err := parseGSSAPIPayload(userAuthReq.Payload) + if err != nil { + return nil, parseError(msgUserAuthRequest) + } + // OpenSSH supports Kerberos V5 mechanism only for GSS-API authentication. + if userAuthRequestGSSAPI.N == 0 { + authErr = fmt.Errorf("ssh: Mechanism negotiation is not supported") + break + } + var i uint32 + present := false + for i = 0; i < userAuthRequestGSSAPI.N; i++ { + if userAuthRequestGSSAPI.OIDS[i].Equal(krb5Mesh) { + present = true + break + } + } + if !present { + authErr = fmt.Errorf("ssh: GSSAPI authentication must use the Kerberos V5 mechanism") + break + } + // Initial server response, see RFC 4462 section 3.3. + if err := s.transport.writePacket(Marshal(&userAuthGSSAPIResponse{ + SupportMech: krb5OID, + })); err != nil { + return nil, err + } + // Exchange token, see RFC 4462 section 3.4. + packet, err := s.transport.readPacket() + if err != nil { + return nil, err + } + userAuthGSSAPITokenReq := &userAuthGSSAPIToken{} + if err := Unmarshal(packet, userAuthGSSAPITokenReq); err != nil { + return nil, err + } + authErr, perms, err = gssExchangeToken(gssapiConfig, userAuthGSSAPITokenReq.Token, s, sessionID, + userAuthReq) + if err != nil { + return nil, err + } + default: + authErr = fmt.Errorf("ssh: unknown method %q", userAuthReq.Method) + } + + authErrs = append(authErrs, authErr) + + if config.AuthLogCallback != nil { + config.AuthLogCallback(s, userAuthReq.Method, authErr) + } + + var bannerErr *BannerError + if errors.As(authErr, &bannerErr) { + if bannerErr.Message != "" { + if err := s.SendAuthBanner(bannerErr.Message); err != nil { + return nil, err + } + } + } + + if authErr == nil { + break userAuthLoop + } + + var failureMsg userAuthFailureMsg + + if partialSuccess, ok := authErr.(*PartialSuccessError); ok { + // After a partial success error we don't allow changing the user + // name and execute the NoClientAuthCallback. + partialSuccessReturned = true + + // In case a partial success is returned, the server may send + // a new set of authentication methods. + authConfig = partialSuccess.Next + + // Reset pubkey cache, as the new PublicKeyCallback might + // accept a different set of public keys. + cache = pubKeyCache{} + + // Send back a partial success message to the user. + failureMsg.PartialSuccess = true + } else { + // Allow initial attempt of 'none' without penalty. + if authFailures > 0 || userAuthReq.Method != "none" || noneAuthCount != 1 { + authFailures++ + } + if config.MaxAuthTries > 0 && authFailures >= config.MaxAuthTries { + // If we have hit the max attempts, don't bother sending the + // final SSH_MSG_USERAUTH_FAILURE message, since there are + // no more authentication methods which can be attempted, + // and this message may cause the client to re-attempt + // authentication while we send the disconnect message. + // Continue, and trigger the disconnect at the start of + // the loop. + // + // The SSH specification is somewhat confusing about this, + // RFC 4252 Section 5.1 requires each authentication failure + // be responded to with a respective SSH_MSG_USERAUTH_FAILURE + // message, but Section 4 says the server should disconnect + // after some number of attempts, but it isn't explicit which + // message should take precedence (i.e. should there be a failure + // message than a disconnect message, or if we are going to + // disconnect, should we only send that message.) + // + // Either way, OpenSSH disconnects immediately after the last + // failed authentication attempt, and given they are typically + // considered the golden implementation it seems reasonable + // to match that behavior. + continue + } + } + + if authConfig.PasswordCallback != nil { + failureMsg.Methods = append(failureMsg.Methods, "password") + } + if authConfig.PublicKeyCallback != nil { + failureMsg.Methods = append(failureMsg.Methods, "publickey") + } + if authConfig.KeyboardInteractiveCallback != nil { + failureMsg.Methods = append(failureMsg.Methods, "keyboard-interactive") + } + if authConfig.GSSAPIWithMICConfig != nil && authConfig.GSSAPIWithMICConfig.Server != nil && + authConfig.GSSAPIWithMICConfig.AllowLogin != nil { + failureMsg.Methods = append(failureMsg.Methods, "gssapi-with-mic") + } + + if len(failureMsg.Methods) == 0 { + return nil, errors.New("ssh: no authentication methods available") + } + + if err := s.transport.writePacket(Marshal(&failureMsg)); err != nil { + return nil, err + } + } + + if err := s.transport.writePacket([]byte{msgUserAuthSuccess}); err != nil { + return nil, err + } + return perms, nil +} + +// sshClientKeyboardInteractive implements a ClientKeyboardInteractive by +// asking the client on the other side of a ServerConn. +type sshClientKeyboardInteractive struct { + *connection +} + +func (c *sshClientKeyboardInteractive) Challenge(name, instruction string, questions []string, echos []bool) (answers []string, err error) { + if len(questions) != len(echos) { + return nil, errors.New("ssh: echos and questions must have equal length") + } + + var prompts []byte + for i := range questions { + prompts = appendString(prompts, questions[i]) + prompts = appendBool(prompts, echos[i]) + } + + if err := c.transport.writePacket(Marshal(&userAuthInfoRequestMsg{ + Name: name, + Instruction: instruction, + NumPrompts: uint32(len(questions)), + Prompts: prompts, + })); err != nil { + return nil, err + } + + packet, err := c.transport.readPacket() + if err != nil { + return nil, err + } + if packet[0] != msgUserAuthInfoResponse { + return nil, unexpectedMessageError(msgUserAuthInfoResponse, packet[0]) + } + packet = packet[1:] + + n, packet, ok := parseUint32(packet) + if !ok || int(n) != len(questions) { + return nil, parseError(msgUserAuthInfoResponse) + } + + for i := uint32(0); i < n; i++ { + ans, rest, ok := parseString(packet) + if !ok { + return nil, parseError(msgUserAuthInfoResponse) + } + + answers = append(answers, string(ans)) + packet = rest + } + if len(packet) != 0 { + return nil, errors.New("ssh: junk at end of message") + } + + return answers, nil +} diff --git a/tempfork/sshtest/ssh/server_multi_auth_test.go b/tempfork/sshtest/ssh/server_multi_auth_test.go new file mode 100644 index 000000000..3b3980243 --- /dev/null +++ b/tempfork/sshtest/ssh/server_multi_auth_test.go @@ -0,0 +1,412 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "errors" + "fmt" + "strings" + "testing" +) + +func doClientServerAuth(t *testing.T, serverConfig *ServerConfig, clientConfig *ClientConfig) ([]error, error) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + var serverAuthErrors []error + + serverConfig.AddHostKey(testSigners["rsa"]) + serverConfig.AuthLogCallback = func(conn ConnMetadata, method string, err error) { + serverAuthErrors = append(serverAuthErrors, err) + } + go newServer(c1, serverConfig) + c, _, _, err := NewClientConn(c2, "", clientConfig) + if err == nil { + c.Close() + } + return serverAuthErrors, err +} + +func TestMultiStepAuth(t *testing.T) { + // This user can login with password, public key or public key + password. + username := "testuser" + // This user can login with public key + password only. + usernameSecondFactor := "testuser_second_factor" + errPwdAuthFailed := errors.New("password auth failed") + errWrongSequence := errors.New("wrong sequence") + + serverConfig := &ServerConfig{ + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + if conn.User() == usernameSecondFactor { + return nil, errWrongSequence + } + if conn.User() == username && string(password) == clientPassword { + return nil, nil + } + return nil, errPwdAuthFailed + }, + PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { + if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) { + if conn.User() == usernameSecondFactor { + return nil, &PartialSuccessError{ + Next: ServerAuthCallbacks{ + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + if string(password) == clientPassword { + return nil, nil + } + return nil, errPwdAuthFailed + }, + }, + } + } + return nil, nil + } + return nil, fmt.Errorf("pubkey for %q not acceptable", conn.User()) + }, + } + + clientConfig := &ClientConfig{ + User: usernameSecondFactor, + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + Password(clientPassword), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig) + if err != nil { + t.Fatalf("client login error: %s", err) + } + + // The error sequence is: + // - no auth passed yet + // - partial success + // - nil + if len(serverAuthErrors) != 3 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok { + t.Fatalf("expected partial success error, got: %v", serverAuthErrors[1]) + } + // Now test a wrong sequence. + clientConfig.Auth = []AuthMethod{ + Password(clientPassword), + PublicKeys(testSigners["rsa"]), + } + + serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig) + if err == nil { + t.Fatal("client login with wrong sequence must fail") + } + // The error sequence is: + // - no auth passed yet + // - wrong sequence + // - partial success + if len(serverAuthErrors) != 3 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + if serverAuthErrors[1] != errWrongSequence { + t.Fatal("server not returned wrong sequence") + } + if _, ok := serverAuthErrors[2].(*PartialSuccessError); !ok { + t.Fatalf("expected partial success error, got: %v", serverAuthErrors[2]) + } + // Now test using a correct sequence but a wrong password before the right + // one. + n := 0 + passwords := []string{"WRONG", "WRONG", clientPassword} + clientConfig.Auth = []AuthMethod{ + PublicKeys(testSigners["rsa"]), + RetryableAuthMethod(PasswordCallback(func() (string, error) { + p := passwords[n] + n++ + return p, nil + }), 3), + } + + serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig) + if err != nil { + t.Fatalf("client login error: %s", err) + } + // The error sequence is: + // - no auth passed yet + // - partial success + // - wrong password + // - wrong password + // - nil + if len(serverAuthErrors) != 5 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok { + t.Fatal("server not returned partial success") + } + if serverAuthErrors[2] != errPwdAuthFailed { + t.Fatal("server not returned password authentication failed") + } + if serverAuthErrors[3] != errPwdAuthFailed { + t.Fatal("server not returned password authentication failed") + } + // Only password authentication should fail. + clientConfig.Auth = []AuthMethod{ + Password(clientPassword), + } + + serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig) + if err == nil { + t.Fatal("client login with password only must fail") + } + // The error sequence is: + // - no auth passed yet + // - wrong sequence + if len(serverAuthErrors) != 2 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + if serverAuthErrors[1] != errWrongSequence { + t.Fatal("server not returned wrong sequence") + } + + // Only public key authentication should fail. + clientConfig.Auth = []AuthMethod{ + PublicKeys(testSigners["rsa"]), + } + + serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig) + if err == nil { + t.Fatal("client login with public key only must fail") + } + // The error sequence is: + // - no auth passed yet + // - partial success + if len(serverAuthErrors) != 2 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok { + t.Fatal("server not returned partial success") + } + + // Public key and wrong password. + clientConfig.Auth = []AuthMethod{ + PublicKeys(testSigners["rsa"]), + Password("WRONG"), + } + + serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig) + if err == nil { + t.Fatal("client login with wrong password after public key must fail") + } + // The error sequence is: + // - no auth passed yet + // - partial success + // - password auth failed + if len(serverAuthErrors) != 3 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok { + t.Fatal("server not returned partial success") + } + if serverAuthErrors[2] != errPwdAuthFailed { + t.Fatal("server not returned password authentication failed") + } + + // Public key, public key again and then correct password. Public key + // authentication is attempted only once because the partial success error + // returns only "password" as the allowed authentication method. + clientConfig.Auth = []AuthMethod{ + PublicKeys(testSigners["rsa"]), + PublicKeys(testSigners["rsa"]), + Password(clientPassword), + } + + serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig) + if err != nil { + t.Fatalf("client login error: %s", err) + } + // The error sequence is: + // - no auth passed yet + // - partial success + // - nil + if len(serverAuthErrors) != 3 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok { + t.Fatal("server not returned partial success") + } + + // The unrestricted username can do anything + clientConfig = &ClientConfig{ + User: username, + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + Password(clientPassword), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + _, err = doClientServerAuth(t, serverConfig, clientConfig) + if err != nil { + t.Fatalf("unrestricted client login error: %s", err) + } + + clientConfig = &ClientConfig{ + User: username, + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + _, err = doClientServerAuth(t, serverConfig, clientConfig) + if err != nil { + t.Fatalf("unrestricted client login error: %s", err) + } + + clientConfig = &ClientConfig{ + User: username, + Auth: []AuthMethod{ + Password(clientPassword), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + _, err = doClientServerAuth(t, serverConfig, clientConfig) + if err != nil { + t.Fatalf("unrestricted client login error: %s", err) + } +} + +func TestDynamicAuthCallbacks(t *testing.T) { + user1 := "user1" + user2 := "user2" + errInvalidCredentials := errors.New("invalid credentials") + + serverConfig := &ServerConfig{ + NoClientAuth: true, + NoClientAuthCallback: func(conn ConnMetadata) (*Permissions, error) { + switch conn.User() { + case user1: + return nil, &PartialSuccessError{ + Next: ServerAuthCallbacks{ + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + if conn.User() == user1 && string(password) == clientPassword { + return nil, nil + } + return nil, errInvalidCredentials + }, + }, + } + case user2: + return nil, &PartialSuccessError{ + Next: ServerAuthCallbacks{ + PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { + if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) { + if conn.User() == user2 { + return nil, nil + } + } + return nil, errInvalidCredentials + }, + }, + } + default: + return nil, errInvalidCredentials + } + }, + } + + clientConfig := &ClientConfig{ + User: user1, + Auth: []AuthMethod{ + Password(clientPassword), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig) + if err != nil { + t.Fatalf("client login error: %s", err) + } + // The error sequence is: + // - partial success + // - nil + if len(serverAuthErrors) != 2 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok { + t.Fatal("server not returned partial success") + } + + clientConfig = &ClientConfig{ + User: user2, + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig) + if err != nil { + t.Fatalf("client login error: %s", err) + } + // The error sequence is: + // - partial success + // - nil + if len(serverAuthErrors) != 2 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok { + t.Fatal("server not returned partial success") + } + + // user1 cannot login with public key + clientConfig = &ClientConfig{ + User: user1, + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig) + if err == nil { + t.Fatal("user1 login with public key must fail") + } + if !strings.Contains(err.Error(), "no supported methods remain") { + t.Errorf("got %v, expected 'no supported methods remain'", err) + } + if len(serverAuthErrors) != 1 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok { + t.Fatal("server not returned partial success") + } + // user2 cannot login with password + clientConfig = &ClientConfig{ + User: user2, + Auth: []AuthMethod{ + Password(clientPassword), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig) + if err == nil { + t.Fatal("user2 login with password must fail") + } + if !strings.Contains(err.Error(), "no supported methods remain") { + t.Errorf("got %v, expected 'no supported methods remain'", err) + } + if len(serverAuthErrors) != 1 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok { + t.Fatal("server not returned partial success") + } +} diff --git a/tempfork/sshtest/ssh/server_test.go b/tempfork/sshtest/ssh/server_test.go new file mode 100644 index 000000000..c2b24f47c --- /dev/null +++ b/tempfork/sshtest/ssh/server_test.go @@ -0,0 +1,478 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "errors" + "fmt" + "io" + "net" + "reflect" + "strings" + "sync/atomic" + "testing" + "time" +) + +func TestClientAuthRestrictedPublicKeyAlgos(t *testing.T) { + for _, tt := range []struct { + name string + key Signer + wantError bool + }{ + {"rsa", testSigners["rsa"], false}, + {"dsa", testSigners["dsa"], true}, + {"ed25519", testSigners["ed25519"], true}, + } { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + serverConf := &ServerConfig{ + PublicKeyAuthAlgorithms: []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512}, + PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { + return nil, nil + }, + } + serverConf.AddHostKey(testSigners["ecdsap256"]) + + done := make(chan struct{}) + go func() { + defer close(done) + NewServerConn(c1, serverConf) + }() + + clientConf := ClientConfig{ + User: "user", + Auth: []AuthMethod{ + PublicKeys(tt.key), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + _, _, _, err = NewClientConn(c2, "", &clientConf) + if err != nil { + if !tt.wantError { + t.Errorf("%s: got unexpected error %q", tt.name, err.Error()) + } + } else if tt.wantError { + t.Errorf("%s: succeeded, but want error", tt.name) + } + <-done + } +} + +func TestMaxAuthTriesNoneMethod(t *testing.T) { + username := "testuser" + serverConfig := &ServerConfig{ + MaxAuthTries: 2, + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + if conn.User() == username && string(password) == clientPassword { + return nil, nil + } + return nil, errors.New("invalid credentials") + }, + } + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + var serverAuthErrors []error + + serverConfig.AddHostKey(testSigners["rsa"]) + serverConfig.AuthLogCallback = func(conn ConnMetadata, method string, err error) { + serverAuthErrors = append(serverAuthErrors, err) + } + go newServer(c1, serverConfig) + + clientConfig := ClientConfig{ + User: username, + HostKeyCallback: InsecureIgnoreHostKey(), + } + clientConfig.SetDefaults() + // Our client will send 'none' auth only once, so we need to send the + // requests manually. + c := &connection{ + sshConn: sshConn{ + conn: c2, + user: username, + clientVersion: []byte(packageVersion), + }, + } + c.serverVersion, err = exchangeVersions(c.sshConn.conn, c.clientVersion) + if err != nil { + t.Fatalf("unable to exchange version: %v", err) + } + c.transport = newClientTransport( + newTransport(c.sshConn.conn, clientConfig.Rand, true /* is client */), + c.clientVersion, c.serverVersion, &clientConfig, "", c.sshConn.RemoteAddr()) + if err := c.transport.waitSession(); err != nil { + t.Fatalf("unable to wait session: %v", err) + } + c.sessionID = c.transport.getSessionID() + if err := c.transport.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})); err != nil { + t.Fatalf("unable to send ssh-userauth message: %v", err) + } + packet, err := c.transport.readPacket() + if err != nil { + t.Fatal(err) + } + if len(packet) > 0 && packet[0] == msgExtInfo { + packet, err = c.transport.readPacket() + if err != nil { + t.Fatal(err) + } + } + var serviceAccept serviceAcceptMsg + if err := Unmarshal(packet, &serviceAccept); err != nil { + t.Fatal(err) + } + for i := 0; i <= serverConfig.MaxAuthTries; i++ { + auth := new(noneAuth) + _, _, err := auth.auth(c.sessionID, clientConfig.User, c.transport, clientConfig.Rand, nil) + if i < serverConfig.MaxAuthTries { + if err != nil { + t.Fatal(err) + } + continue + } + if err == nil { + t.Fatal("client: got no error") + } else if !strings.Contains(err.Error(), "too many authentication failures") { + t.Fatalf("client: got unexpected error: %v", err) + } + } + if len(serverAuthErrors) != 3 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + for _, err := range serverAuthErrors { + if !errors.Is(err, ErrNoAuth) { + t.Errorf("go error: %v; want: %v", err, ErrNoAuth) + } + } +} + +func TestMaxAuthTriesFirstNoneAuthErrorIgnored(t *testing.T) { + username := "testuser" + serverConfig := &ServerConfig{ + MaxAuthTries: 1, + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + if conn.User() == username && string(password) == clientPassword { + return nil, nil + } + return nil, errors.New("invalid credentials") + }, + } + clientConfig := &ClientConfig{ + User: username, + Auth: []AuthMethod{ + Password(clientPassword), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig) + if err != nil { + t.Fatalf("client login error: %s", err) + } + if len(serverAuthErrors) != 2 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + if !errors.Is(serverAuthErrors[0], ErrNoAuth) { + t.Errorf("go error: %v; want: %v", serverAuthErrors[0], ErrNoAuth) + } + if serverAuthErrors[1] != nil { + t.Errorf("unexpected error: %v", serverAuthErrors[1]) + } +} + +func TestNewServerConnValidationErrors(t *testing.T) { + serverConf := &ServerConfig{ + PublicKeyAuthAlgorithms: []string{CertAlgoRSAv01}, + } + c := &markerConn{} + _, _, _, err := NewServerConn(c, serverConf) + if err == nil { + t.Fatal("NewServerConn with invalid public key auth algorithms succeeded") + } + if !c.isClosed() { + t.Fatal("NewServerConn with invalid public key auth algorithms left connection open") + } + if c.isUsed() { + t.Fatal("NewServerConn with invalid public key auth algorithms used connection") + } + + serverConf = &ServerConfig{ + Config: Config{ + KeyExchanges: []string{kexAlgoDHGEXSHA256}, + }, + } + c = &markerConn{} + _, _, _, err = NewServerConn(c, serverConf) + if err == nil { + t.Fatal("NewServerConn with unsupported key exchange succeeded") + } + if !c.isClosed() { + t.Fatal("NewServerConn with unsupported key exchange left connection open") + } + if c.isUsed() { + t.Fatal("NewServerConn with unsupported key exchange used connection") + } +} + +func TestBannerError(t *testing.T) { + serverConfig := &ServerConfig{ + BannerCallback: func(ConnMetadata) string { + return "banner from BannerCallback" + }, + NoClientAuth: true, + NoClientAuthCallback: func(ConnMetadata) (*Permissions, error) { + err := &BannerError{ + Err: errors.New("error from NoClientAuthCallback"), + Message: "banner from NoClientAuthCallback", + } + return nil, fmt.Errorf("wrapped: %w", err) + }, + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + return &Permissions{}, nil + }, + PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { + return nil, &BannerError{ + Err: errors.New("error from PublicKeyCallback"), + Message: "banner from PublicKeyCallback", + } + }, + KeyboardInteractiveCallback: func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error) { + return nil, &BannerError{ + Err: nil, // make sure that a nil inner error is allowed + Message: "banner from KeyboardInteractiveCallback", + } + }, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + var banners []string + clientConfig := &ClientConfig{ + User: "test", + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + KeyboardInteractive(func(name, instruction string, questions []string, echos []bool) ([]string, error) { + return []string{"letmein"}, nil + }), + Password(clientPassword), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + BannerCallback: func(msg string) error { + banners = append(banners, msg) + return nil + }, + } + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + go newServer(c1, serverConfig) + c, _, _, err := NewClientConn(c2, "", clientConfig) + if err != nil { + t.Fatalf("client connection failed: %v", err) + } + defer c.Close() + + wantBanners := []string{ + "banner from BannerCallback", + "banner from NoClientAuthCallback", + "banner from PublicKeyCallback", + "banner from KeyboardInteractiveCallback", + } + if !reflect.DeepEqual(banners, wantBanners) { + t.Errorf("got banners:\n%q\nwant banners:\n%q", banners, wantBanners) + } +} + +func TestPublicKeyCallbackLastSeen(t *testing.T) { + var lastSeenKey PublicKey + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + serverConf := &ServerConfig{ + PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { + lastSeenKey = key + fmt.Printf("seen %#v\n", key) + if _, ok := key.(*dsaPublicKey); !ok { + return nil, errors.New("nope") + } + return nil, nil + }, + } + serverConf.AddHostKey(testSigners["ecdsap256"]) + + done := make(chan struct{}) + go func() { + defer close(done) + NewServerConn(c1, serverConf) + }() + + clientConf := ClientConfig{ + User: "user", + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"], testSigners["dsa"], testSigners["ed25519"]), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + _, _, _, err = NewClientConn(c2, "", &clientConf) + if err != nil { + t.Fatal(err) + } + <-done + + expectedPublicKey := testSigners["dsa"].PublicKey().Marshal() + lastSeenMarshalled := lastSeenKey.Marshal() + if !bytes.Equal(lastSeenMarshalled, expectedPublicKey) { + t.Errorf("unexpected key: got %#v, want %#v", lastSeenKey, testSigners["dsa"].PublicKey()) + } +} + +func TestPreAuthConnAndBanners(t *testing.T) { + testDone := make(chan struct{}) + defer close(testDone) + + authConnc := make(chan ServerPreAuthConn, 1) + serverConfig := &ServerConfig{ + PreAuthConnCallback: func(c ServerPreAuthConn) { + t.Logf("got ServerPreAuthConn: %v", c) + authConnc <- c // for use later in the test + for _, s := range []string{"hello1", "hello2"} { + if err := c.SendAuthBanner(s); err != nil { + t.Errorf("failed to send banner %q: %v", s, err) + } + } + // Now start a goroutine to spam SendAuthBanner in hopes + // of hitting a race. + go func() { + for { + select { + case <-testDone: + return + default: + if err := c.SendAuthBanner("attempted-race"); err != nil && err != errSendBannerPhase { + t.Errorf("unexpected error from SendAuthBanner: %v", err) + } + time.Sleep(5 * time.Millisecond) + } + } + }() + }, + NoClientAuth: true, + NoClientAuthCallback: func(ConnMetadata) (*Permissions, error) { + t.Logf("got NoClientAuthCallback") + return &Permissions{}, nil + }, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + var banners []string + clientConfig := &ClientConfig{ + User: "test", + HostKeyCallback: InsecureIgnoreHostKey(), + BannerCallback: func(msg string) error { + if msg != "attempted-race" { + banners = append(banners, msg) + } + return nil + }, + } + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + go newServer(c1, serverConfig) + c, _, _, err := NewClientConn(c2, "", clientConfig) + if err != nil { + t.Fatalf("client connection failed: %v", err) + } + defer c.Close() + + wantBanners := []string{ + "hello1", + "hello2", + } + if !reflect.DeepEqual(banners, wantBanners) { + t.Errorf("got banners:\n%q\nwant banners:\n%q", banners, wantBanners) + } + + // Now that we're authenticated, verify that use of SendBanner + // is an error. + var bc ServerPreAuthConn + select { + case bc = <-authConnc: + default: + t.Fatal("expected ServerPreAuthConn") + } + if err := bc.SendAuthBanner("wrong-phase"); err == nil { + t.Error("unexpected success of SendAuthBanner after authentication") + } else if err != errSendBannerPhase { + t.Errorf("unexpected error: %v; want %v", err, errSendBannerPhase) + } +} + +type markerConn struct { + closed uint32 + used uint32 +} + +func (c *markerConn) isClosed() bool { + return atomic.LoadUint32(&c.closed) != 0 +} + +func (c *markerConn) isUsed() bool { + return atomic.LoadUint32(&c.used) != 0 +} + +func (c *markerConn) Close() error { + atomic.StoreUint32(&c.closed, 1) + return nil +} + +func (c *markerConn) Read(b []byte) (n int, err error) { + atomic.StoreUint32(&c.used, 1) + if atomic.LoadUint32(&c.closed) != 0 { + return 0, net.ErrClosed + } else { + return 0, io.EOF + } +} + +func (c *markerConn) Write(b []byte) (n int, err error) { + atomic.StoreUint32(&c.used, 1) + if atomic.LoadUint32(&c.closed) != 0 { + return 0, net.ErrClosed + } else { + return 0, io.ErrClosedPipe + } +} + +func (*markerConn) LocalAddr() net.Addr { return nil } +func (*markerConn) RemoteAddr() net.Addr { return nil } + +func (*markerConn) SetDeadline(t time.Time) error { return nil } +func (*markerConn) SetReadDeadline(t time.Time) error { return nil } +func (*markerConn) SetWriteDeadline(t time.Time) error { return nil } diff --git a/tempfork/sshtest/ssh/session.go b/tempfork/sshtest/ssh/session.go new file mode 100644 index 000000000..acef62259 --- /dev/null +++ b/tempfork/sshtest/ssh/session.go @@ -0,0 +1,647 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +// Session implements an interactive session described in +// "RFC 4254, section 6". + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "sync" +) + +type Signal string + +// POSIX signals as listed in RFC 4254 Section 6.10. +const ( + SIGABRT Signal = "ABRT" + SIGALRM Signal = "ALRM" + SIGFPE Signal = "FPE" + SIGHUP Signal = "HUP" + SIGILL Signal = "ILL" + SIGINT Signal = "INT" + SIGKILL Signal = "KILL" + SIGPIPE Signal = "PIPE" + SIGQUIT Signal = "QUIT" + SIGSEGV Signal = "SEGV" + SIGTERM Signal = "TERM" + SIGUSR1 Signal = "USR1" + SIGUSR2 Signal = "USR2" +) + +var signals = map[Signal]int{ + SIGABRT: 6, + SIGALRM: 14, + SIGFPE: 8, + SIGHUP: 1, + SIGILL: 4, + SIGINT: 2, + SIGKILL: 9, + SIGPIPE: 13, + SIGQUIT: 3, + SIGSEGV: 11, + SIGTERM: 15, +} + +type TerminalModes map[uint8]uint32 + +// POSIX terminal mode flags as listed in RFC 4254 Section 8. +const ( + tty_OP_END = 0 + VINTR = 1 + VQUIT = 2 + VERASE = 3 + VKILL = 4 + VEOF = 5 + VEOL = 6 + VEOL2 = 7 + VSTART = 8 + VSTOP = 9 + VSUSP = 10 + VDSUSP = 11 + VREPRINT = 12 + VWERASE = 13 + VLNEXT = 14 + VFLUSH = 15 + VSWTCH = 16 + VSTATUS = 17 + VDISCARD = 18 + IGNPAR = 30 + PARMRK = 31 + INPCK = 32 + ISTRIP = 33 + INLCR = 34 + IGNCR = 35 + ICRNL = 36 + IUCLC = 37 + IXON = 38 + IXANY = 39 + IXOFF = 40 + IMAXBEL = 41 + IUTF8 = 42 // RFC 8160 + ISIG = 50 + ICANON = 51 + XCASE = 52 + ECHO = 53 + ECHOE = 54 + ECHOK = 55 + ECHONL = 56 + NOFLSH = 57 + TOSTOP = 58 + IEXTEN = 59 + ECHOCTL = 60 + ECHOKE = 61 + PENDIN = 62 + OPOST = 70 + OLCUC = 71 + ONLCR = 72 + OCRNL = 73 + ONOCR = 74 + ONLRET = 75 + CS7 = 90 + CS8 = 91 + PARENB = 92 + PARODD = 93 + TTY_OP_ISPEED = 128 + TTY_OP_OSPEED = 129 +) + +// A Session represents a connection to a remote command or shell. +type Session struct { + // Stdin specifies the remote process's standard input. + // If Stdin is nil, the remote process reads from an empty + // bytes.Buffer. + Stdin io.Reader + + // Stdout and Stderr specify the remote process's standard + // output and error. + // + // If either is nil, Run connects the corresponding file + // descriptor to an instance of io.Discard. There is a + // fixed amount of buffering that is shared for the two streams. + // If either blocks it may eventually cause the remote + // command to block. + Stdout io.Writer + Stderr io.Writer + + ch Channel // the channel backing this session + started bool // true once Start, Run or Shell is invoked. + copyFuncs []func() error + errors chan error // one send per copyFunc + + // true if pipe method is active + stdinpipe, stdoutpipe, stderrpipe bool + + // stdinPipeWriter is non-nil if StdinPipe has not been called + // and Stdin was specified by the user; it is the write end of + // a pipe connecting Session.Stdin to the stdin channel. + stdinPipeWriter io.WriteCloser + + exitStatus chan error +} + +// SendRequest sends an out-of-band channel request on the SSH channel +// underlying the session. +func (s *Session) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { + return s.ch.SendRequest(name, wantReply, payload) +} + +func (s *Session) Close() error { + return s.ch.Close() +} + +// RFC 4254 Section 6.4. +type setenvRequest struct { + Name string + Value string +} + +// Setenv sets an environment variable that will be applied to any +// command executed by Shell or Run. +func (s *Session) Setenv(name, value string) error { + msg := setenvRequest{ + Name: name, + Value: value, + } + ok, err := s.ch.SendRequest("env", true, Marshal(&msg)) + if err == nil && !ok { + err = errors.New("ssh: setenv failed") + } + return err +} + +// RFC 4254 Section 6.2. +type ptyRequestMsg struct { + Term string + Columns uint32 + Rows uint32 + Width uint32 + Height uint32 + Modelist string +} + +// RequestPty requests the association of a pty with the session on the remote host. +func (s *Session) RequestPty(term string, h, w int, termmodes TerminalModes) error { + var tm []byte + for k, v := range termmodes { + kv := struct { + Key byte + Val uint32 + }{k, v} + + tm = append(tm, Marshal(&kv)...) + } + tm = append(tm, tty_OP_END) + req := ptyRequestMsg{ + Term: term, + Columns: uint32(w), + Rows: uint32(h), + Width: uint32(w * 8), + Height: uint32(h * 8), + Modelist: string(tm), + } + ok, err := s.ch.SendRequest("pty-req", true, Marshal(&req)) + if err == nil && !ok { + err = errors.New("ssh: pty-req failed") + } + return err +} + +// RFC 4254 Section 6.5. +type subsystemRequestMsg struct { + Subsystem string +} + +// RequestSubsystem requests the association of a subsystem with the session on the remote host. +// A subsystem is a predefined command that runs in the background when the ssh session is initiated +func (s *Session) RequestSubsystem(subsystem string) error { + msg := subsystemRequestMsg{ + Subsystem: subsystem, + } + ok, err := s.ch.SendRequest("subsystem", true, Marshal(&msg)) + if err == nil && !ok { + err = errors.New("ssh: subsystem request failed") + } + return err +} + +// RFC 4254 Section 6.7. +type ptyWindowChangeMsg struct { + Columns uint32 + Rows uint32 + Width uint32 + Height uint32 +} + +// WindowChange informs the remote host about a terminal window dimension change to h rows and w columns. +func (s *Session) WindowChange(h, w int) error { + req := ptyWindowChangeMsg{ + Columns: uint32(w), + Rows: uint32(h), + Width: uint32(w * 8), + Height: uint32(h * 8), + } + _, err := s.ch.SendRequest("window-change", false, Marshal(&req)) + return err +} + +// RFC 4254 Section 6.9. +type signalMsg struct { + Signal string +} + +// Signal sends the given signal to the remote process. +// sig is one of the SIG* constants. +func (s *Session) Signal(sig Signal) error { + msg := signalMsg{ + Signal: string(sig), + } + + _, err := s.ch.SendRequest("signal", false, Marshal(&msg)) + return err +} + +// RFC 4254 Section 6.5. +type execMsg struct { + Command string +} + +// Start runs cmd on the remote host. Typically, the remote +// server passes cmd to the shell for interpretation. +// A Session only accepts one call to Run, Start or Shell. +func (s *Session) Start(cmd string) error { + if s.started { + return errors.New("ssh: session already started") + } + req := execMsg{ + Command: cmd, + } + + ok, err := s.ch.SendRequest("exec", true, Marshal(&req)) + if err == nil && !ok { + err = fmt.Errorf("ssh: command %v failed", cmd) + } + if err != nil { + return err + } + return s.start() +} + +// Run runs cmd on the remote host. Typically, the remote +// server passes cmd to the shell for interpretation. +// A Session only accepts one call to Run, Start, Shell, Output, +// or CombinedOutput. +// +// The returned error is nil if the command runs, has no problems +// copying stdin, stdout, and stderr, and exits with a zero exit +// status. +// +// If the remote server does not send an exit status, an error of type +// *ExitMissingError is returned. If the command completes +// unsuccessfully or is interrupted by a signal, the error is of type +// *ExitError. Other error types may be returned for I/O problems. +func (s *Session) Run(cmd string) error { + err := s.Start(cmd) + if err != nil { + return err + } + return s.Wait() +} + +// Output runs cmd on the remote host and returns its standard output. +func (s *Session) Output(cmd string) ([]byte, error) { + if s.Stdout != nil { + return nil, errors.New("ssh: Stdout already set") + } + var b bytes.Buffer + s.Stdout = &b + err := s.Run(cmd) + return b.Bytes(), err +} + +type singleWriter struct { + b bytes.Buffer + mu sync.Mutex +} + +func (w *singleWriter) Write(p []byte) (int, error) { + w.mu.Lock() + defer w.mu.Unlock() + return w.b.Write(p) +} + +// CombinedOutput runs cmd on the remote host and returns its combined +// standard output and standard error. +func (s *Session) CombinedOutput(cmd string) ([]byte, error) { + if s.Stdout != nil { + return nil, errors.New("ssh: Stdout already set") + } + if s.Stderr != nil { + return nil, errors.New("ssh: Stderr already set") + } + var b singleWriter + s.Stdout = &b + s.Stderr = &b + err := s.Run(cmd) + return b.b.Bytes(), err +} + +// Shell starts a login shell on the remote host. A Session only +// accepts one call to Run, Start, Shell, Output, or CombinedOutput. +func (s *Session) Shell() error { + if s.started { + return errors.New("ssh: session already started") + } + + ok, err := s.ch.SendRequest("shell", true, nil) + if err == nil && !ok { + return errors.New("ssh: could not start shell") + } + if err != nil { + return err + } + return s.start() +} + +func (s *Session) start() error { + s.started = true + + type F func(*Session) + for _, setupFd := range []F{(*Session).stdin, (*Session).stdout, (*Session).stderr} { + setupFd(s) + } + + s.errors = make(chan error, len(s.copyFuncs)) + for _, fn := range s.copyFuncs { + go func(fn func() error) { + s.errors <- fn() + }(fn) + } + return nil +} + +// Wait waits for the remote command to exit. +// +// The returned error is nil if the command runs, has no problems +// copying stdin, stdout, and stderr, and exits with a zero exit +// status. +// +// If the remote server does not send an exit status, an error of type +// *ExitMissingError is returned. If the command completes +// unsuccessfully or is interrupted by a signal, the error is of type +// *ExitError. Other error types may be returned for I/O problems. +func (s *Session) Wait() error { + if !s.started { + return errors.New("ssh: session not started") + } + waitErr := <-s.exitStatus + + if s.stdinPipeWriter != nil { + s.stdinPipeWriter.Close() + } + var copyError error + for range s.copyFuncs { + if err := <-s.errors; err != nil && copyError == nil { + copyError = err + } + } + if waitErr != nil { + return waitErr + } + return copyError +} + +func (s *Session) wait(reqs <-chan *Request) error { + wm := Waitmsg{status: -1} + // Wait for msg channel to be closed before returning. + for msg := range reqs { + switch msg.Type { + case "exit-status": + wm.status = int(binary.BigEndian.Uint32(msg.Payload)) + case "exit-signal": + var sigval struct { + Signal string + CoreDumped bool + Error string + Lang string + } + if err := Unmarshal(msg.Payload, &sigval); err != nil { + return err + } + + // Must sanitize strings? + wm.signal = sigval.Signal + wm.msg = sigval.Error + wm.lang = sigval.Lang + default: + // This handles keepalives and matches + // OpenSSH's behaviour. + if msg.WantReply { + msg.Reply(false, nil) + } + } + } + if wm.status == 0 { + return nil + } + if wm.status == -1 { + // exit-status was never sent from server + if wm.signal == "" { + // signal was not sent either. RFC 4254 + // section 6.10 recommends against this + // behavior, but it is allowed, so we let + // clients handle it. + return &ExitMissingError{} + } + wm.status = 128 + if _, ok := signals[Signal(wm.signal)]; ok { + wm.status += signals[Signal(wm.signal)] + } + } + + return &ExitError{wm} +} + +// ExitMissingError is returned if a session is torn down cleanly, but +// the server sends no confirmation of the exit status. +type ExitMissingError struct{} + +func (e *ExitMissingError) Error() string { + return "wait: remote command exited without exit status or exit signal" +} + +func (s *Session) stdin() { + if s.stdinpipe { + return + } + var stdin io.Reader + if s.Stdin == nil { + stdin = new(bytes.Buffer) + } else { + r, w := io.Pipe() + go func() { + _, err := io.Copy(w, s.Stdin) + w.CloseWithError(err) + }() + stdin, s.stdinPipeWriter = r, w + } + s.copyFuncs = append(s.copyFuncs, func() error { + _, err := io.Copy(s.ch, stdin) + if err1 := s.ch.CloseWrite(); err == nil && err1 != io.EOF { + err = err1 + } + return err + }) +} + +func (s *Session) stdout() { + if s.stdoutpipe { + return + } + if s.Stdout == nil { + s.Stdout = io.Discard + } + s.copyFuncs = append(s.copyFuncs, func() error { + _, err := io.Copy(s.Stdout, s.ch) + return err + }) +} + +func (s *Session) stderr() { + if s.stderrpipe { + return + } + if s.Stderr == nil { + s.Stderr = io.Discard + } + s.copyFuncs = append(s.copyFuncs, func() error { + _, err := io.Copy(s.Stderr, s.ch.Stderr()) + return err + }) +} + +// sessionStdin reroutes Close to CloseWrite. +type sessionStdin struct { + io.Writer + ch Channel +} + +func (s *sessionStdin) Close() error { + return s.ch.CloseWrite() +} + +// StdinPipe returns a pipe that will be connected to the +// remote command's standard input when the command starts. +func (s *Session) StdinPipe() (io.WriteCloser, error) { + if s.Stdin != nil { + return nil, errors.New("ssh: Stdin already set") + } + if s.started { + return nil, errors.New("ssh: StdinPipe after process started") + } + s.stdinpipe = true + return &sessionStdin{s.ch, s.ch}, nil +} + +// StdoutPipe returns a pipe that will be connected to the +// remote command's standard output when the command starts. +// There is a fixed amount of buffering that is shared between +// stdout and stderr streams. If the StdoutPipe reader is +// not serviced fast enough it may eventually cause the +// remote command to block. +func (s *Session) StdoutPipe() (io.Reader, error) { + if s.Stdout != nil { + return nil, errors.New("ssh: Stdout already set") + } + if s.started { + return nil, errors.New("ssh: StdoutPipe after process started") + } + s.stdoutpipe = true + return s.ch, nil +} + +// StderrPipe returns a pipe that will be connected to the +// remote command's standard error when the command starts. +// There is a fixed amount of buffering that is shared between +// stdout and stderr streams. If the StderrPipe reader is +// not serviced fast enough it may eventually cause the +// remote command to block. +func (s *Session) StderrPipe() (io.Reader, error) { + if s.Stderr != nil { + return nil, errors.New("ssh: Stderr already set") + } + if s.started { + return nil, errors.New("ssh: StderrPipe after process started") + } + s.stderrpipe = true + return s.ch.Stderr(), nil +} + +// newSession returns a new interactive session on the remote host. +func newSession(ch Channel, reqs <-chan *Request) (*Session, error) { + s := &Session{ + ch: ch, + } + s.exitStatus = make(chan error, 1) + go func() { + s.exitStatus <- s.wait(reqs) + }() + + return s, nil +} + +// An ExitError reports unsuccessful completion of a remote command. +type ExitError struct { + Waitmsg +} + +func (e *ExitError) Error() string { + return e.Waitmsg.String() +} + +// Waitmsg stores the information about an exited remote command +// as reported by Wait. +type Waitmsg struct { + status int + signal string + msg string + lang string +} + +// ExitStatus returns the exit status of the remote command. +func (w Waitmsg) ExitStatus() int { + return w.status +} + +// Signal returns the exit signal of the remote command if +// it was terminated violently. +func (w Waitmsg) Signal() string { + return w.signal +} + +// Msg returns the exit message given by the remote command +func (w Waitmsg) Msg() string { + return w.msg +} + +// Lang returns the language tag. See RFC 3066 +func (w Waitmsg) Lang() string { + return w.lang +} + +func (w Waitmsg) String() string { + str := fmt.Sprintf("Process exited with status %v", w.status) + if w.signal != "" { + str += fmt.Sprintf(" from signal %v", w.signal) + } + if w.msg != "" { + str += fmt.Sprintf(". Reason was: %v", w.msg) + } + return str +} diff --git a/tempfork/sshtest/ssh/session_test.go b/tempfork/sshtest/ssh/session_test.go new file mode 100644 index 000000000..807a913e5 --- /dev/null +++ b/tempfork/sshtest/ssh/session_test.go @@ -0,0 +1,892 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +// Session tests. + +import ( + "bytes" + crypto_rand "crypto/rand" + "errors" + "io" + "math/rand" + "net" + "sync" + "testing" + + "golang.org/x/crypto/ssh/terminal" +) + +type serverType func(Channel, <-chan *Request, *testing.T) + +// dial constructs a new test server and returns a *ClientConn. +func dial(handler serverType, t *testing.T) *Client { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) + go func() { + defer func() { + c1.Close() + wg.Done() + }() + conf := ServerConfig{ + NoClientAuth: true, + } + conf.AddHostKey(testSigners["rsa"]) + + conn, chans, reqs, err := NewServerConn(c1, &conf) + if err != nil { + t.Errorf("Unable to handshake: %v", err) + return + } + wg.Add(1) + go func() { + DiscardRequests(reqs) + wg.Done() + }() + + for newCh := range chans { + if newCh.ChannelType() != "session" { + newCh.Reject(UnknownChannelType, "unknown channel type") + continue + } + + ch, inReqs, err := newCh.Accept() + if err != nil { + t.Errorf("Accept: %v", err) + continue + } + wg.Add(1) + go func() { + handler(ch, inReqs, t) + wg.Done() + }() + } + if err := conn.Wait(); err != io.EOF { + t.Logf("server exit reason: %v", err) + } + }() + + config := &ClientConfig{ + User: "testuser", + HostKeyCallback: InsecureIgnoreHostKey(), + } + + conn, chans, reqs, err := NewClientConn(c2, "", config) + if err != nil { + t.Fatalf("unable to dial remote side: %v", err) + } + + return NewClient(conn, chans, reqs) +} + +// Test a simple string is returned to session.Stdout. +func TestSessionShell(t *testing.T) { + conn := dial(shellHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + stdout := new(bytes.Buffer) + session.Stdout = stdout + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %s", err) + } + if err := session.Wait(); err != nil { + t.Fatalf("Remote command did not exit cleanly: %v", err) + } + actual := stdout.String() + if actual != "golang" { + t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual) + } +} + +// TODO(dfc) add support for Std{in,err}Pipe when the Server supports it. + +// Test a simple string is returned via StdoutPipe. +func TestSessionStdoutPipe(t *testing.T) { + conn := dial(shellHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + stdout, err := session.StdoutPipe() + if err != nil { + t.Fatalf("Unable to request StdoutPipe(): %v", err) + } + var buf bytes.Buffer + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + done := make(chan bool, 1) + go func() { + if _, err := io.Copy(&buf, stdout); err != nil { + t.Errorf("Copy of stdout failed: %v", err) + } + done <- true + }() + if err := session.Wait(); err != nil { + t.Fatalf("Remote command did not exit cleanly: %v", err) + } + <-done + actual := buf.String() + if actual != "golang" { + t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual) + } +} + +// Test that a simple string is returned via the Output helper, +// and that stderr is discarded. +func TestSessionOutput(t *testing.T) { + conn := dial(fixedOutputHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + + buf, err := session.Output("") // cmd is ignored by fixedOutputHandler + if err != nil { + t.Error("Remote command did not exit cleanly:", err) + } + w := "this-is-stdout." + g := string(buf) + if g != w { + t.Error("Remote command did not return expected string:") + t.Logf("want %q", w) + t.Logf("got %q", g) + } +} + +// Test that both stdout and stderr are returned +// via the CombinedOutput helper. +func TestSessionCombinedOutput(t *testing.T) { + conn := dial(fixedOutputHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + + buf, err := session.CombinedOutput("") // cmd is ignored by fixedOutputHandler + if err != nil { + t.Error("Remote command did not exit cleanly:", err) + } + const stdout = "this-is-stdout." + const stderr = "this-is-stderr." + g := string(buf) + if g != stdout+stderr && g != stderr+stdout { + t.Error("Remote command did not return expected string:") + t.Logf("want %q, or %q", stdout+stderr, stderr+stdout) + t.Logf("got %q", g) + } +} + +// Test non-0 exit status is returned correctly. +func TestExitStatusNonZero(t *testing.T) { + conn := dial(exitStatusNonZeroHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err == nil { + t.Fatalf("expected command to fail but it didn't") + } + e, ok := err.(*ExitError) + if !ok { + t.Fatalf("expected *ExitError but got %T", err) + } + if e.ExitStatus() != 15 { + t.Fatalf("expected command to exit with 15 but got %v", e.ExitStatus()) + } +} + +// Test 0 exit status is returned correctly. +func TestExitStatusZero(t *testing.T) { + conn := dial(exitStatusZeroHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err != nil { + t.Fatalf("expected nil but got %v", err) + } +} + +// Test exit signal and status are both returned correctly. +func TestExitSignalAndStatus(t *testing.T) { + conn := dial(exitSignalAndStatusHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err == nil { + t.Fatalf("expected command to fail but it didn't") + } + e, ok := err.(*ExitError) + if !ok { + t.Fatalf("expected *ExitError but got %T", err) + } + if e.Signal() != "TERM" || e.ExitStatus() != 15 { + t.Fatalf("expected command to exit with signal TERM and status 15 but got signal %s and status %v", e.Signal(), e.ExitStatus()) + } +} + +// Test exit signal and status are both returned correctly. +func TestKnownExitSignalOnly(t *testing.T) { + conn := dial(exitSignalHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err == nil { + t.Fatalf("expected command to fail but it didn't") + } + e, ok := err.(*ExitError) + if !ok { + t.Fatalf("expected *ExitError but got %T", err) + } + if e.Signal() != "TERM" || e.ExitStatus() != 143 { + t.Fatalf("expected command to exit with signal TERM and status 143 but got signal %s and status %v", e.Signal(), e.ExitStatus()) + } +} + +// Test exit signal and status are both returned correctly. +func TestUnknownExitSignal(t *testing.T) { + conn := dial(exitSignalUnknownHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err == nil { + t.Fatalf("expected command to fail but it didn't") + } + e, ok := err.(*ExitError) + if !ok { + t.Fatalf("expected *ExitError but got %T", err) + } + if e.Signal() != "SYS" || e.ExitStatus() != 128 { + t.Fatalf("expected command to exit with signal SYS and status 128 but got signal %s and status %v", e.Signal(), e.ExitStatus()) + } +} + +func TestExitWithoutStatusOrSignal(t *testing.T) { + conn := dial(exitWithoutSignalOrStatus, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err == nil { + t.Fatalf("expected command to fail but it didn't") + } + if _, ok := err.(*ExitMissingError); !ok { + t.Fatalf("got %T want *ExitMissingError", err) + } +} + +// windowTestBytes is the number of bytes that we'll send to the SSH server. +const windowTestBytes = 16000 * 200 + +// TestServerWindow writes random data to the server. The server is expected to echo +// the same data back, which is compared against the original. +func TestServerWindow(t *testing.T) { + origBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes)) + io.CopyN(origBuf, crypto_rand.Reader, windowTestBytes) + origBytes := origBuf.Bytes() + + conn := dial(echoHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatal(err) + } + defer session.Close() + + serverStdin, err := session.StdinPipe() + if err != nil { + t.Fatalf("StdinPipe failed: %v", err) + } + + result := make(chan []byte) + go func() { + defer close(result) + echoedBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes)) + serverStdout, err := session.StdoutPipe() + if err != nil { + t.Errorf("StdoutPipe failed: %v", err) + return + } + n, err := copyNRandomly("stdout", echoedBuf, serverStdout, windowTestBytes) + if err != nil && err != io.EOF { + t.Errorf("Read only %d bytes from server, expected %d: %v", n, windowTestBytes, err) + } + result <- echoedBuf.Bytes() + }() + + written, err := copyNRandomly("stdin", serverStdin, origBuf, windowTestBytes) + if err != nil { + t.Errorf("failed to copy origBuf to serverStdin: %v", err) + } else if written != windowTestBytes { + t.Errorf("Wrote only %d of %d bytes to server", written, windowTestBytes) + } + + echoedBytes := <-result + + if !bytes.Equal(origBytes, echoedBytes) { + t.Fatalf("Echoed buffer differed from original, orig %d, echoed %d", len(origBytes), len(echoedBytes)) + } +} + +// Verify the client can handle a keepalive packet from the server. +func TestClientHandlesKeepalives(t *testing.T) { + conn := dial(channelKeepaliveSender, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatal(err) + } + defer session.Close() + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err != nil { + t.Fatalf("expected nil but got: %v", err) + } +} + +type exitStatusMsg struct { + Status uint32 +} + +type exitSignalMsg struct { + Signal string + CoreDumped bool + Errmsg string + Lang string +} + +func handleTerminalRequests(in <-chan *Request) { + for req := range in { + ok := false + switch req.Type { + case "shell": + ok = true + if len(req.Payload) > 0 { + // We don't accept any commands, only the default shell. + ok = false + } + case "env": + ok = true + } + req.Reply(ok, nil) + } +} + +func newServerShell(ch Channel, in <-chan *Request, prompt string) *terminal.Terminal { + term := terminal.NewTerminal(ch, prompt) + go handleTerminalRequests(in) + return term +} + +func exitStatusZeroHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + // this string is returned to stdout + shell := newServerShell(ch, in, "> ") + readLine(shell, t) + sendStatus(0, ch, t) +} + +func exitStatusNonZeroHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + shell := newServerShell(ch, in, "> ") + readLine(shell, t) + sendStatus(15, ch, t) +} + +func exitSignalAndStatusHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + shell := newServerShell(ch, in, "> ") + readLine(shell, t) + sendStatus(15, ch, t) + sendSignal("TERM", ch, t) +} + +func exitSignalHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + shell := newServerShell(ch, in, "> ") + readLine(shell, t) + sendSignal("TERM", ch, t) +} + +func exitSignalUnknownHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + shell := newServerShell(ch, in, "> ") + readLine(shell, t) + sendSignal("SYS", ch, t) +} + +func exitWithoutSignalOrStatus(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + shell := newServerShell(ch, in, "> ") + readLine(shell, t) +} + +func shellHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + // this string is returned to stdout + shell := newServerShell(ch, in, "golang") + readLine(shell, t) + sendStatus(0, ch, t) +} + +// Ignores the command, writes fixed strings to stderr and stdout. +// Strings are "this-is-stdout." and "this-is-stderr.". +func fixedOutputHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + _, err := ch.Read(nil) + + req, ok := <-in + if !ok { + t.Fatalf("error: expected channel request, got: %#v", err) + return + } + + // ignore request, always send some text + req.Reply(true, nil) + + _, err = io.WriteString(ch, "this-is-stdout.") + if err != nil { + t.Fatalf("error writing on server: %v", err) + } + _, err = io.WriteString(ch.Stderr(), "this-is-stderr.") + if err != nil { + t.Fatalf("error writing on server: %v", err) + } + sendStatus(0, ch, t) +} + +func readLine(shell *terminal.Terminal, t *testing.T) { + if _, err := shell.ReadLine(); err != nil && err != io.EOF { + t.Errorf("unable to read line: %v", err) + } +} + +func sendStatus(status uint32, ch Channel, t *testing.T) { + msg := exitStatusMsg{ + Status: status, + } + if _, err := ch.SendRequest("exit-status", false, Marshal(&msg)); err != nil { + t.Errorf("unable to send status: %v", err) + } +} + +func sendSignal(signal string, ch Channel, t *testing.T) { + sig := exitSignalMsg{ + Signal: signal, + CoreDumped: false, + Errmsg: "Process terminated", + Lang: "en-GB-oed", + } + if _, err := ch.SendRequest("exit-signal", false, Marshal(&sig)); err != nil { + t.Errorf("unable to send signal: %v", err) + } +} + +func discardHandler(ch Channel, t *testing.T) { + defer ch.Close() + io.Copy(io.Discard, ch) +} + +func echoHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + if n, err := copyNRandomly("echohandler", ch, ch, windowTestBytes); err != nil { + t.Errorf("short write, wrote %d, expected %d: %v ", n, windowTestBytes, err) + } +} + +// copyNRandomly copies n bytes from src to dst. It uses a variable, and random, +// buffer size to exercise more code paths. +func copyNRandomly(title string, dst io.Writer, src io.Reader, n int) (int, error) { + var ( + buf = make([]byte, 32*1024) + written int + remaining = n + ) + for remaining > 0 { + l := rand.Intn(1 << 15) + if remaining < l { + l = remaining + } + nr, er := src.Read(buf[:l]) + nw, ew := dst.Write(buf[:nr]) + remaining -= nw + written += nw + if ew != nil { + return written, ew + } + if nr != nw { + return written, io.ErrShortWrite + } + if er != nil && er != io.EOF { + return written, er + } + } + return written, nil +} + +func channelKeepaliveSender(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + shell := newServerShell(ch, in, "> ") + readLine(shell, t) + if _, err := ch.SendRequest("keepalive@openssh.com", true, nil); err != nil { + t.Errorf("unable to send channel keepalive request: %v", err) + } + sendStatus(0, ch, t) +} + +func TestClientWriteEOF(t *testing.T) { + conn := dial(simpleEchoHandler, t) + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + t.Fatal(err) + } + defer session.Close() + stdin, err := session.StdinPipe() + if err != nil { + t.Fatalf("StdinPipe failed: %v", err) + } + stdout, err := session.StdoutPipe() + if err != nil { + t.Fatalf("StdoutPipe failed: %v", err) + } + + data := []byte(`0000`) + _, err = stdin.Write(data) + if err != nil { + t.Fatalf("Write failed: %v", err) + } + stdin.Close() + + res, err := io.ReadAll(stdout) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + + if !bytes.Equal(data, res) { + t.Fatalf("Read differed from write, wrote: %v, read: %v", data, res) + } +} + +func simpleEchoHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + data, err := io.ReadAll(ch) + if err != nil { + t.Errorf("handler read error: %v", err) + } + _, err = ch.Write(data) + if err != nil { + t.Errorf("handler write error: %v", err) + } +} + +func TestSessionID(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + serverID := make(chan []byte, 1) + clientID := make(chan []byte, 1) + + serverConf := &ServerConfig{ + NoClientAuth: true, + } + serverConf.AddHostKey(testSigners["ecdsa"]) + clientConf := &ClientConfig{ + HostKeyCallback: InsecureIgnoreHostKey(), + User: "user", + } + + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + + srvErrCh := make(chan error, 1) + wg.Add(1) + go func() { + defer wg.Done() + conn, chans, reqs, err := NewServerConn(c1, serverConf) + srvErrCh <- err + if err != nil { + return + } + serverID <- conn.SessionID() + wg.Add(1) + go func() { + DiscardRequests(reqs) + wg.Done() + }() + for ch := range chans { + ch.Reject(Prohibited, "") + } + }() + + cliErrCh := make(chan error, 1) + wg.Add(1) + go func() { + defer wg.Done() + conn, chans, reqs, err := NewClientConn(c2, "", clientConf) + cliErrCh <- err + if err != nil { + return + } + clientID <- conn.SessionID() + wg.Add(1) + go func() { + DiscardRequests(reqs) + wg.Done() + }() + for ch := range chans { + ch.Reject(Prohibited, "") + } + }() + + if err := <-srvErrCh; err != nil { + t.Fatalf("server handshake: %v", err) + } + + if err := <-cliErrCh; err != nil { + t.Fatalf("client handshake: %v", err) + } + + s := <-serverID + c := <-clientID + if bytes.Compare(s, c) != 0 { + t.Errorf("server session ID (%x) != client session ID (%x)", s, c) + } else if len(s) == 0 { + t.Errorf("client and server SessionID were empty.") + } +} + +type noReadConn struct { + readSeen bool + net.Conn +} + +func (c *noReadConn) Close() error { + return nil +} + +func (c *noReadConn) Read(b []byte) (int, error) { + c.readSeen = true + return 0, errors.New("noReadConn error") +} + +func TestInvalidServerConfiguration(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + serveConn := noReadConn{Conn: c1} + serverConf := &ServerConfig{} + + NewServerConn(&serveConn, serverConf) + if serveConn.readSeen { + t.Fatalf("NewServerConn attempted to Read() from Conn while configuration is missing host key") + } + + serverConf.AddHostKey(testSigners["ecdsa"]) + + NewServerConn(&serveConn, serverConf) + if serveConn.readSeen { + t.Fatalf("NewServerConn attempted to Read() from Conn while configuration is missing authentication method") + } +} + +func TestHostKeyAlgorithms(t *testing.T) { + serverConf := &ServerConfig{ + NoClientAuth: true, + } + serverConf.AddHostKey(testSigners["rsa"]) + serverConf.AddHostKey(testSigners["ecdsa"]) + + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + connect := func(clientConf *ClientConfig, want string) { + var alg string + clientConf.HostKeyCallback = func(h string, a net.Addr, key PublicKey) error { + alg = key.Type() + return nil + } + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + wg.Add(1) + go func() { + NewServerConn(c1, serverConf) + wg.Done() + }() + _, _, _, err = NewClientConn(c2, "", clientConf) + if err != nil { + t.Fatalf("NewClientConn: %v", err) + } + if alg != want { + t.Errorf("selected key algorithm %s, want %s", alg, want) + } + } + + // By default, we get the preferred algorithm, which is ECDSA 256. + + clientConf := &ClientConfig{ + HostKeyCallback: InsecureIgnoreHostKey(), + } + connect(clientConf, KeyAlgoECDSA256) + + // Client asks for RSA explicitly. + clientConf.HostKeyAlgorithms = []string{KeyAlgoRSA} + connect(clientConf, KeyAlgoRSA) + + // Client asks for RSA-SHA2-512 explicitly. + clientConf.HostKeyAlgorithms = []string{KeyAlgoRSASHA512} + // We get back an "ssh-rsa" key but the verification happened + // with an RSA-SHA2-512 signature. + connect(clientConf, KeyAlgoRSA) + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + wg.Add(1) + go func() { + NewServerConn(c1, serverConf) + wg.Done() + }() + clientConf.HostKeyAlgorithms = []string{"nonexistent-hostkey-algo"} + _, _, _, err = NewClientConn(c2, "", clientConf) + if err == nil { + t.Fatal("succeeded connecting with unknown hostkey algorithm") + } +} + +func TestServerClientAuthCallback(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + userCh := make(chan string, 1) + + serverConf := &ServerConfig{ + NoClientAuth: true, + NoClientAuthCallback: func(conn ConnMetadata) (*Permissions, error) { + userCh <- conn.User() + return nil, nil + }, + } + const someUsername = "some-username" + + serverConf.AddHostKey(testSigners["ecdsa"]) + clientConf := &ClientConfig{ + HostKeyCallback: InsecureIgnoreHostKey(), + User: someUsername, + } + + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) + go func() { + defer wg.Done() + _, chans, reqs, err := NewServerConn(c1, serverConf) + if err != nil { + t.Errorf("server handshake: %v", err) + userCh <- "error" + return + } + wg.Add(1) + go func() { + DiscardRequests(reqs) + wg.Done() + }() + for ch := range chans { + ch.Reject(Prohibited, "") + } + }() + + conn, _, _, err := NewClientConn(c2, "", clientConf) + if err != nil { + t.Fatalf("client handshake: %v", err) + return + } + conn.Close() + + got := <-userCh + if got != someUsername { + t.Errorf("username = %q; want %q", got, someUsername) + } +} diff --git a/tempfork/sshtest/ssh/ssh_gss.go b/tempfork/sshtest/ssh/ssh_gss.go new file mode 100644 index 000000000..24bd7c8e8 --- /dev/null +++ b/tempfork/sshtest/ssh/ssh_gss.go @@ -0,0 +1,139 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "encoding/asn1" + "errors" +) + +var krb5OID []byte + +func init() { + krb5OID, _ = asn1.Marshal(krb5Mesh) +} + +// GSSAPIClient provides the API to plug-in GSSAPI authentication for client logins. +type GSSAPIClient interface { + // InitSecContext initiates the establishment of a security context for GSS-API between the + // ssh client and ssh server. Initially the token parameter should be specified as nil. + // The routine may return a outputToken which should be transferred to + // the ssh server, where the ssh server will present it to + // AcceptSecContext. If no token need be sent, InitSecContext will indicate this by setting + // needContinue to false. To complete the context + // establishment, one or more reply tokens may be required from the ssh + // server;if so, InitSecContext will return a needContinue which is true. + // In this case, InitSecContext should be called again when the + // reply token is received from the ssh server, passing the reply + // token to InitSecContext via the token parameters. + // See RFC 2743 section 2.2.1 and RFC 4462 section 3.4. + InitSecContext(target string, token []byte, isGSSDelegCreds bool) (outputToken []byte, needContinue bool, err error) + // GetMIC generates a cryptographic MIC for the SSH2 message, and places + // the MIC in a token for transfer to the ssh server. + // The contents of the MIC field are obtained by calling GSS_GetMIC() + // over the following, using the GSS-API context that was just + // established: + // string session identifier + // byte SSH_MSG_USERAUTH_REQUEST + // string user name + // string service + // string "gssapi-with-mic" + // See RFC 2743 section 2.3.1 and RFC 4462 3.5. + GetMIC(micFiled []byte) ([]byte, error) + // Whenever possible, it should be possible for + // DeleteSecContext() calls to be successfully processed even + // if other calls cannot succeed, thereby enabling context-related + // resources to be released. + // In addition to deleting established security contexts, + // gss_delete_sec_context must also be able to delete "half-built" + // security contexts resulting from an incomplete sequence of + // InitSecContext()/AcceptSecContext() calls. + // See RFC 2743 section 2.2.3. + DeleteSecContext() error +} + +// GSSAPIServer provides the API to plug in GSSAPI authentication for server logins. +type GSSAPIServer interface { + // AcceptSecContext allows a remotely initiated security context between the application + // and a remote peer to be established by the ssh client. The routine may return a + // outputToken which should be transferred to the ssh client, + // where the ssh client will present it to InitSecContext. + // If no token need be sent, AcceptSecContext will indicate this + // by setting the needContinue to false. To + // complete the context establishment, one or more reply tokens may be + // required from the ssh client. if so, AcceptSecContext + // will return a needContinue which is true, in which case it + // should be called again when the reply token is received from the ssh + // client, passing the token to AcceptSecContext via the + // token parameters. + // The srcName return value is the authenticated username. + // See RFC 2743 section 2.2.2 and RFC 4462 section 3.4. + AcceptSecContext(token []byte) (outputToken []byte, srcName string, needContinue bool, err error) + // VerifyMIC verifies that a cryptographic MIC, contained in the token parameter, + // fits the supplied message is received from the ssh client. + // See RFC 2743 section 2.3.2. + VerifyMIC(micField []byte, micToken []byte) error + // Whenever possible, it should be possible for + // DeleteSecContext() calls to be successfully processed even + // if other calls cannot succeed, thereby enabling context-related + // resources to be released. + // In addition to deleting established security contexts, + // gss_delete_sec_context must also be able to delete "half-built" + // security contexts resulting from an incomplete sequence of + // InitSecContext()/AcceptSecContext() calls. + // See RFC 2743 section 2.2.3. + DeleteSecContext() error +} + +var ( + // OpenSSH supports Kerberos V5 mechanism only for GSS-API authentication, + // so we also support the krb5 mechanism only. + // See RFC 1964 section 1. + krb5Mesh = asn1.ObjectIdentifier{1, 2, 840, 113554, 1, 2, 2} +) + +// The GSS-API authentication method is initiated when the client sends an SSH_MSG_USERAUTH_REQUEST +// See RFC 4462 section 3.2. +type userAuthRequestGSSAPI struct { + N uint32 + OIDS []asn1.ObjectIdentifier +} + +func parseGSSAPIPayload(payload []byte) (*userAuthRequestGSSAPI, error) { + n, rest, ok := parseUint32(payload) + if !ok { + return nil, errors.New("parse uint32 failed") + } + s := &userAuthRequestGSSAPI{ + N: n, + OIDS: make([]asn1.ObjectIdentifier, n), + } + for i := 0; i < int(n); i++ { + var ( + desiredMech []byte + err error + ) + desiredMech, rest, ok = parseString(rest) + if !ok { + return nil, errors.New("parse string failed") + } + if rest, err = asn1.Unmarshal(desiredMech, &s.OIDS[i]); err != nil { + return nil, err + } + + } + return s, nil +} + +// See RFC 4462 section 3.6. +func buildMIC(sessionID string, username string, service string, authMethod string) []byte { + out := make([]byte, 0, 0) + out = appendString(out, sessionID) + out = append(out, msgUserAuthRequest) + out = appendString(out, username) + out = appendString(out, service) + out = appendString(out, authMethod) + return out +} diff --git a/tempfork/sshtest/ssh/ssh_gss_test.go b/tempfork/sshtest/ssh/ssh_gss_test.go new file mode 100644 index 000000000..39a111288 --- /dev/null +++ b/tempfork/sshtest/ssh/ssh_gss_test.go @@ -0,0 +1,109 @@ +package ssh + +import ( + "fmt" + "testing" +) + +func TestParseGSSAPIPayload(t *testing.T) { + payload := []byte{0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0b, 0x06, 0x09, + 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x12, 0x01, 0x02, 0x02} + res, err := parseGSSAPIPayload(payload) + if err != nil { + t.Fatal(err) + } + if ok := res.OIDS[0].Equal(krb5Mesh); !ok { + t.Fatalf("got %v, want %v", res, krb5Mesh) + } +} + +func TestBuildMIC(t *testing.T) { + sessionID := []byte{134, 180, 134, 194, 62, 145, 171, 82, 119, 149, 254, 196, 125, 173, 177, 145, 187, 85, 53, + 183, 44, 150, 219, 129, 166, 195, 19, 33, 209, 246, 175, 121} + username := "testuser" + service := "ssh-connection" + authMethod := "gssapi-with-mic" + expected := []byte{0, 0, 0, 32, 134, 180, 134, 194, 62, 145, 171, 82, 119, 149, 254, 196, 125, 173, 177, 145, 187, 85, 53, 183, 44, 150, 219, 129, 166, 195, 19, 33, 209, 246, 175, 121, 50, 0, 0, 0, 8, 116, 101, 115, 116, 117, 115, 101, 114, 0, 0, 0, 14, 115, 115, 104, 45, 99, 111, 110, 110, 101, 99, 116, 105, 111, 110, 0, 0, 0, 15, 103, 115, 115, 97, 112, 105, 45, 119, 105, 116, 104, 45, 109, 105, 99} + result := buildMIC(string(sessionID), username, service, authMethod) + if string(result) != string(expected) { + t.Fatalf("buildMic: got %v, want %v", result, expected) + } +} + +type exchange struct { + outToken string + expectedToken string +} + +type FakeClient struct { + exchanges []*exchange + round int + mic []byte + maxRound int +} + +func (f *FakeClient) InitSecContext(target string, token []byte, isGSSDelegCreds bool) (outputToken []byte, needContinue bool, err error) { + if token == nil { + if f.exchanges[f.round].expectedToken != "" { + err = fmt.Errorf("got empty token, want %q", f.exchanges[f.round].expectedToken) + } else { + outputToken = []byte(f.exchanges[f.round].outToken) + } + } else { + if string(token) != string(f.exchanges[f.round].expectedToken) { + err = fmt.Errorf("got %q, want token %q", token, f.exchanges[f.round].expectedToken) + } else { + outputToken = []byte(f.exchanges[f.round].outToken) + } + } + f.round++ + needContinue = f.round < f.maxRound + return +} + +func (f *FakeClient) GetMIC(micField []byte) ([]byte, error) { + return f.mic, nil +} + +func (f *FakeClient) DeleteSecContext() error { + return nil +} + +type FakeServer struct { + exchanges []*exchange + round int + expectedMIC []byte + srcName string + maxRound int +} + +func (f *FakeServer) AcceptSecContext(token []byte) (outputToken []byte, srcName string, needContinue bool, err error) { + if token == nil { + if f.exchanges[f.round].expectedToken != "" { + err = fmt.Errorf("got empty token, want %q", f.exchanges[f.round].expectedToken) + } else { + outputToken = []byte(f.exchanges[f.round].outToken) + } + } else { + if string(token) != string(f.exchanges[f.round].expectedToken) { + err = fmt.Errorf("got %q, want token %q", token, f.exchanges[f.round].expectedToken) + } else { + outputToken = []byte(f.exchanges[f.round].outToken) + } + } + f.round++ + needContinue = f.round < f.maxRound + srcName = f.srcName + return +} + +func (f *FakeServer) VerifyMIC(micField []byte, micToken []byte) error { + if string(micToken) != string(f.expectedMIC) { + return fmt.Errorf("got MICToken %q, want %q", micToken, f.expectedMIC) + } + return nil +} + +func (f *FakeServer) DeleteSecContext() error { + return nil +} diff --git a/tempfork/sshtest/ssh/streamlocal.go b/tempfork/sshtest/ssh/streamlocal.go new file mode 100644 index 000000000..b171b330b --- /dev/null +++ b/tempfork/sshtest/ssh/streamlocal.go @@ -0,0 +1,116 @@ +package ssh + +import ( + "errors" + "io" + "net" +) + +// streamLocalChannelOpenDirectMsg is a struct used for SSH_MSG_CHANNEL_OPEN message +// with "direct-streamlocal@openssh.com" string. +// +// See openssh-portable/PROTOCOL, section 2.4. connection: Unix domain socket forwarding +// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL#L235 +type streamLocalChannelOpenDirectMsg struct { + socketPath string + reserved0 string + reserved1 uint32 +} + +// forwardedStreamLocalPayload is a struct used for SSH_MSG_CHANNEL_OPEN message +// with "forwarded-streamlocal@openssh.com" string. +type forwardedStreamLocalPayload struct { + SocketPath string + Reserved0 string +} + +// streamLocalChannelForwardMsg is a struct used for SSH2_MSG_GLOBAL_REQUEST message +// with "streamlocal-forward@openssh.com"/"cancel-streamlocal-forward@openssh.com" string. +type streamLocalChannelForwardMsg struct { + socketPath string +} + +// ListenUnix is similar to ListenTCP but uses a Unix domain socket. +func (c *Client) ListenUnix(socketPath string) (net.Listener, error) { + c.handleForwardsOnce.Do(c.handleForwards) + m := streamLocalChannelForwardMsg{ + socketPath, + } + // send message + ok, _, err := c.SendRequest("streamlocal-forward@openssh.com", true, Marshal(&m)) + if err != nil { + return nil, err + } + if !ok { + return nil, errors.New("ssh: streamlocal-forward@openssh.com request denied by peer") + } + ch := c.forwards.add(&net.UnixAddr{Name: socketPath, Net: "unix"}) + + return &unixListener{socketPath, c, ch}, nil +} + +func (c *Client) dialStreamLocal(socketPath string) (Channel, error) { + msg := streamLocalChannelOpenDirectMsg{ + socketPath: socketPath, + } + ch, in, err := c.OpenChannel("direct-streamlocal@openssh.com", Marshal(&msg)) + if err != nil { + return nil, err + } + go DiscardRequests(in) + return ch, err +} + +type unixListener struct { + socketPath string + + conn *Client + in <-chan forward +} + +// Accept waits for and returns the next connection to the listener. +func (l *unixListener) Accept() (net.Conn, error) { + s, ok := <-l.in + if !ok { + return nil, io.EOF + } + ch, incoming, err := s.newCh.Accept() + if err != nil { + return nil, err + } + go DiscardRequests(incoming) + + return &chanConn{ + Channel: ch, + laddr: &net.UnixAddr{ + Name: l.socketPath, + Net: "unix", + }, + raddr: &net.UnixAddr{ + Name: "@", + Net: "unix", + }, + }, nil +} + +// Close closes the listener. +func (l *unixListener) Close() error { + // this also closes the listener. + l.conn.forwards.remove(&net.UnixAddr{Name: l.socketPath, Net: "unix"}) + m := streamLocalChannelForwardMsg{ + l.socketPath, + } + ok, _, err := l.conn.SendRequest("cancel-streamlocal-forward@openssh.com", true, Marshal(&m)) + if err == nil && !ok { + err = errors.New("ssh: cancel-streamlocal-forward@openssh.com failed") + } + return err +} + +// Addr returns the listener's network address. +func (l *unixListener) Addr() net.Addr { + return &net.UnixAddr{ + Name: l.socketPath, + Net: "unix", + } +} diff --git a/tempfork/sshtest/ssh/tcpip.go b/tempfork/sshtest/ssh/tcpip.go new file mode 100644 index 000000000..ef5059a11 --- /dev/null +++ b/tempfork/sshtest/ssh/tcpip.go @@ -0,0 +1,509 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "context" + "errors" + "fmt" + "io" + "math/rand" + "net" + "strconv" + "strings" + "sync" + "time" +) + +// Listen requests the remote peer open a listening socket on +// addr. Incoming connections will be available by calling Accept on +// the returned net.Listener. The listener must be serviced, or the +// SSH connection may hang. +// N must be "tcp", "tcp4", "tcp6", or "unix". +func (c *Client) Listen(n, addr string) (net.Listener, error) { + switch n { + case "tcp", "tcp4", "tcp6": + laddr, err := net.ResolveTCPAddr(n, addr) + if err != nil { + return nil, err + } + return c.ListenTCP(laddr) + case "unix": + return c.ListenUnix(addr) + default: + return nil, fmt.Errorf("ssh: unsupported protocol: %s", n) + } +} + +// Automatic port allocation is broken with OpenSSH before 6.0. See +// also https://bugzilla.mindrot.org/show_bug.cgi?id=2017. In +// particular, OpenSSH 5.9 sends a channelOpenMsg with port number 0, +// rather than the actual port number. This means you can never open +// two different listeners with auto allocated ports. We work around +// this by trying explicit ports until we succeed. + +const openSSHPrefix = "OpenSSH_" + +var portRandomizer = rand.New(rand.NewSource(time.Now().UnixNano())) + +// isBrokenOpenSSHVersion returns true if the given version string +// specifies a version of OpenSSH that is known to have a bug in port +// forwarding. +func isBrokenOpenSSHVersion(versionStr string) bool { + i := strings.Index(versionStr, openSSHPrefix) + if i < 0 { + return false + } + i += len(openSSHPrefix) + j := i + for ; j < len(versionStr); j++ { + if versionStr[j] < '0' || versionStr[j] > '9' { + break + } + } + version, _ := strconv.Atoi(versionStr[i:j]) + return version < 6 +} + +// autoPortListenWorkaround simulates automatic port allocation by +// trying random ports repeatedly. +func (c *Client) autoPortListenWorkaround(laddr *net.TCPAddr) (net.Listener, error) { + var sshListener net.Listener + var err error + const tries = 10 + for i := 0; i < tries; i++ { + addr := *laddr + addr.Port = 1024 + portRandomizer.Intn(60000) + sshListener, err = c.ListenTCP(&addr) + if err == nil { + laddr.Port = addr.Port + return sshListener, err + } + } + return nil, fmt.Errorf("ssh: listen on random port failed after %d tries: %v", tries, err) +} + +// RFC 4254 7.1 +type channelForwardMsg struct { + addr string + rport uint32 +} + +// handleForwards starts goroutines handling forwarded connections. +// It's called on first use by (*Client).ListenTCP to not launch +// goroutines until needed. +func (c *Client) handleForwards() { + go c.forwards.handleChannels(c.HandleChannelOpen("forwarded-tcpip")) + go c.forwards.handleChannels(c.HandleChannelOpen("forwarded-streamlocal@openssh.com")) +} + +// ListenTCP requests the remote peer open a listening socket +// on laddr. Incoming connections will be available by calling +// Accept on the returned net.Listener. +func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) { + c.handleForwardsOnce.Do(c.handleForwards) + if laddr.Port == 0 && isBrokenOpenSSHVersion(string(c.ServerVersion())) { + return c.autoPortListenWorkaround(laddr) + } + + m := channelForwardMsg{ + laddr.IP.String(), + uint32(laddr.Port), + } + // send message + ok, resp, err := c.SendRequest("tcpip-forward", true, Marshal(&m)) + if err != nil { + return nil, err + } + if !ok { + return nil, errors.New("ssh: tcpip-forward request denied by peer") + } + + // If the original port was 0, then the remote side will + // supply a real port number in the response. + if laddr.Port == 0 { + var p struct { + Port uint32 + } + if err := Unmarshal(resp, &p); err != nil { + return nil, err + } + laddr.Port = int(p.Port) + } + + // Register this forward, using the port number we obtained. + ch := c.forwards.add(laddr) + + return &tcpListener{laddr, c, ch}, nil +} + +// forwardList stores a mapping between remote +// forward requests and the tcpListeners. +type forwardList struct { + sync.Mutex + entries []forwardEntry +} + +// forwardEntry represents an established mapping of a laddr on a +// remote ssh server to a channel connected to a tcpListener. +type forwardEntry struct { + laddr net.Addr + c chan forward +} + +// forward represents an incoming forwarded tcpip connection. The +// arguments to add/remove/lookup should be address as specified in +// the original forward-request. +type forward struct { + newCh NewChannel // the ssh client channel underlying this forward + raddr net.Addr // the raddr of the incoming connection +} + +func (l *forwardList) add(addr net.Addr) chan forward { + l.Lock() + defer l.Unlock() + f := forwardEntry{ + laddr: addr, + c: make(chan forward, 1), + } + l.entries = append(l.entries, f) + return f.c +} + +// See RFC 4254, section 7.2 +type forwardedTCPPayload struct { + Addr string + Port uint32 + OriginAddr string + OriginPort uint32 +} + +// parseTCPAddr parses the originating address from the remote into a *net.TCPAddr. +func parseTCPAddr(addr string, port uint32) (*net.TCPAddr, error) { + if port == 0 || port > 65535 { + return nil, fmt.Errorf("ssh: port number out of range: %d", port) + } + ip := net.ParseIP(string(addr)) + if ip == nil { + return nil, fmt.Errorf("ssh: cannot parse IP address %q", addr) + } + return &net.TCPAddr{IP: ip, Port: int(port)}, nil +} + +func (l *forwardList) handleChannels(in <-chan NewChannel) { + for ch := range in { + var ( + laddr net.Addr + raddr net.Addr + err error + ) + switch channelType := ch.ChannelType(); channelType { + case "forwarded-tcpip": + var payload forwardedTCPPayload + if err = Unmarshal(ch.ExtraData(), &payload); err != nil { + ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error()) + continue + } + + // RFC 4254 section 7.2 specifies that incoming + // addresses should list the address, in string + // format. It is implied that this should be an IP + // address, as it would be impossible to connect to it + // otherwise. + laddr, err = parseTCPAddr(payload.Addr, payload.Port) + if err != nil { + ch.Reject(ConnectionFailed, err.Error()) + continue + } + raddr, err = parseTCPAddr(payload.OriginAddr, payload.OriginPort) + if err != nil { + ch.Reject(ConnectionFailed, err.Error()) + continue + } + + case "forwarded-streamlocal@openssh.com": + var payload forwardedStreamLocalPayload + if err = Unmarshal(ch.ExtraData(), &payload); err != nil { + ch.Reject(ConnectionFailed, "could not parse forwarded-streamlocal@openssh.com payload: "+err.Error()) + continue + } + laddr = &net.UnixAddr{ + Name: payload.SocketPath, + Net: "unix", + } + raddr = &net.UnixAddr{ + Name: "@", + Net: "unix", + } + default: + panic(fmt.Errorf("ssh: unknown channel type %s", channelType)) + } + if ok := l.forward(laddr, raddr, ch); !ok { + // Section 7.2, implementations MUST reject spurious incoming + // connections. + ch.Reject(Prohibited, "no forward for address") + continue + } + + } +} + +// remove removes the forward entry, and the channel feeding its +// listener. +func (l *forwardList) remove(addr net.Addr) { + l.Lock() + defer l.Unlock() + for i, f := range l.entries { + if addr.Network() == f.laddr.Network() && addr.String() == f.laddr.String() { + l.entries = append(l.entries[:i], l.entries[i+1:]...) + close(f.c) + return + } + } +} + +// closeAll closes and clears all forwards. +func (l *forwardList) closeAll() { + l.Lock() + defer l.Unlock() + for _, f := range l.entries { + close(f.c) + } + l.entries = nil +} + +func (l *forwardList) forward(laddr, raddr net.Addr, ch NewChannel) bool { + l.Lock() + defer l.Unlock() + for _, f := range l.entries { + if laddr.Network() == f.laddr.Network() && laddr.String() == f.laddr.String() { + f.c <- forward{newCh: ch, raddr: raddr} + return true + } + } + return false +} + +type tcpListener struct { + laddr *net.TCPAddr + + conn *Client + in <-chan forward +} + +// Accept waits for and returns the next connection to the listener. +func (l *tcpListener) Accept() (net.Conn, error) { + s, ok := <-l.in + if !ok { + return nil, io.EOF + } + ch, incoming, err := s.newCh.Accept() + if err != nil { + return nil, err + } + go DiscardRequests(incoming) + + return &chanConn{ + Channel: ch, + laddr: l.laddr, + raddr: s.raddr, + }, nil +} + +// Close closes the listener. +func (l *tcpListener) Close() error { + m := channelForwardMsg{ + l.laddr.IP.String(), + uint32(l.laddr.Port), + } + + // this also closes the listener. + l.conn.forwards.remove(l.laddr) + ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m)) + if err == nil && !ok { + err = errors.New("ssh: cancel-tcpip-forward failed") + } + return err +} + +// Addr returns the listener's network address. +func (l *tcpListener) Addr() net.Addr { + return l.laddr +} + +// DialContext initiates a connection to the addr from the remote host. +// +// The provided Context must be non-nil. If the context expires before the +// connection is complete, an error is returned. Once successfully connected, +// any expiration of the context will not affect the connection. +// +// See func Dial for additional information. +func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + type connErr struct { + conn net.Conn + err error + } + ch := make(chan connErr) + go func() { + conn, err := c.Dial(n, addr) + select { + case ch <- connErr{conn, err}: + case <-ctx.Done(): + if conn != nil { + conn.Close() + } + } + }() + select { + case res := <-ch: + return res.conn, res.err + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +// Dial initiates a connection to the addr from the remote host. +// The resulting connection has a zero LocalAddr() and RemoteAddr(). +func (c *Client) Dial(n, addr string) (net.Conn, error) { + var ch Channel + switch n { + case "tcp", "tcp4", "tcp6": + // Parse the address into host and numeric port. + host, portString, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + port, err := strconv.ParseUint(portString, 10, 16) + if err != nil { + return nil, err + } + ch, err = c.dial(net.IPv4zero.String(), 0, host, int(port)) + if err != nil { + return nil, err + } + // Use a zero address for local and remote address. + zeroAddr := &net.TCPAddr{ + IP: net.IPv4zero, + Port: 0, + } + return &chanConn{ + Channel: ch, + laddr: zeroAddr, + raddr: zeroAddr, + }, nil + case "unix": + var err error + ch, err = c.dialStreamLocal(addr) + if err != nil { + return nil, err + } + return &chanConn{ + Channel: ch, + laddr: &net.UnixAddr{ + Name: "@", + Net: "unix", + }, + raddr: &net.UnixAddr{ + Name: addr, + Net: "unix", + }, + }, nil + default: + return nil, fmt.Errorf("ssh: unsupported protocol: %s", n) + } +} + +// DialTCP connects to the remote address raddr on the network net, +// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is used +// as the local address for the connection. +func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) { + if laddr == nil { + laddr = &net.TCPAddr{ + IP: net.IPv4zero, + Port: 0, + } + } + ch, err := c.dial(laddr.IP.String(), laddr.Port, raddr.IP.String(), raddr.Port) + if err != nil { + return nil, err + } + return &chanConn{ + Channel: ch, + laddr: laddr, + raddr: raddr, + }, nil +} + +// RFC 4254 7.2 +type channelOpenDirectMsg struct { + raddr string + rport uint32 + laddr string + lport uint32 +} + +func (c *Client) dial(laddr string, lport int, raddr string, rport int) (Channel, error) { + msg := channelOpenDirectMsg{ + raddr: raddr, + rport: uint32(rport), + laddr: laddr, + lport: uint32(lport), + } + ch, in, err := c.OpenChannel("direct-tcpip", Marshal(&msg)) + if err != nil { + return nil, err + } + go DiscardRequests(in) + return ch, err +} + +type tcpChan struct { + Channel // the backing channel +} + +// chanConn fulfills the net.Conn interface without +// the tcpChan having to hold laddr or raddr directly. +type chanConn struct { + Channel + laddr, raddr net.Addr +} + +// LocalAddr returns the local network address. +func (t *chanConn) LocalAddr() net.Addr { + return t.laddr +} + +// RemoteAddr returns the remote network address. +func (t *chanConn) RemoteAddr() net.Addr { + return t.raddr +} + +// SetDeadline sets the read and write deadlines associated +// with the connection. +func (t *chanConn) SetDeadline(deadline time.Time) error { + if err := t.SetReadDeadline(deadline); err != nil { + return err + } + return t.SetWriteDeadline(deadline) +} + +// SetReadDeadline sets the read deadline. +// A zero value for t means Read will not time out. +// After the deadline, the error from Read will implement net.Error +// with Timeout() == true. +func (t *chanConn) SetReadDeadline(deadline time.Time) error { + // for compatibility with previous version, + // the error message contains "tcpChan" + return errors.New("ssh: tcpChan: deadline not supported") +} + +// SetWriteDeadline exists to satisfy the net.Conn interface +// but is not implemented by this type. It always returns an error. +func (t *chanConn) SetWriteDeadline(deadline time.Time) error { + return errors.New("ssh: tcpChan: deadline not supported") +} diff --git a/tempfork/sshtest/ssh/tcpip_test.go b/tempfork/sshtest/ssh/tcpip_test.go new file mode 100644 index 000000000..4d8511472 --- /dev/null +++ b/tempfork/sshtest/ssh/tcpip_test.go @@ -0,0 +1,53 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "context" + "net" + "testing" + "time" +) + +func TestAutoPortListenBroken(t *testing.T) { + broken := "SSH-2.0-OpenSSH_5.9hh11" + works := "SSH-2.0-OpenSSH_6.1" + if !isBrokenOpenSSHVersion(broken) { + t.Errorf("version %q not marked as broken", broken) + } + if isBrokenOpenSSHVersion(works) { + t.Errorf("version %q marked as broken", works) + } +} + +func TestClientImplementsDialContext(t *testing.T) { + type ContextDialer interface { + DialContext(context.Context, string, string) (net.Conn, error) + } + // Belt and suspenders assertion, since package net does not + // declare a ContextDialer type. + var _ ContextDialer = &net.Dialer{} + var _ ContextDialer = &Client{} +} + +func TestClientDialContextWithCancel(t *testing.T) { + c := &Client{} + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := c.DialContext(ctx, "tcp", "localhost:1000") + if err != context.Canceled { + t.Errorf("DialContext: got nil error, expected %v", context.Canceled) + } +} + +func TestClientDialContextWithDeadline(t *testing.T) { + c := &Client{} + ctx, cancel := context.WithDeadline(context.Background(), time.Now()) + defer cancel() + _, err := c.DialContext(ctx, "tcp", "localhost:1000") + if err != context.DeadlineExceeded { + t.Errorf("DialContext: got nil error, expected %v", context.DeadlineExceeded) + } +} diff --git a/tempfork/sshtest/ssh/testdata_test.go b/tempfork/sshtest/ssh/testdata_test.go new file mode 100644 index 000000000..2da8c79dc --- /dev/null +++ b/tempfork/sshtest/ssh/testdata_test.go @@ -0,0 +1,63 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// IMPLEMENTATION NOTE: To avoid a package loop, this file is in three places: +// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three +// instances. + +package ssh + +import ( + "crypto/rand" + "fmt" + + "golang.org/x/crypto/ssh/testdata" +) + +var ( + testPrivateKeys map[string]interface{} + testSigners map[string]Signer + testPublicKeys map[string]PublicKey +) + +func init() { + var err error + + n := len(testdata.PEMBytes) + testPrivateKeys = make(map[string]interface{}, n) + testSigners = make(map[string]Signer, n) + testPublicKeys = make(map[string]PublicKey, n) + for t, k := range testdata.PEMBytes { + testPrivateKeys[t], err = ParseRawPrivateKey(k) + if err != nil { + panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err)) + } + testSigners[t], err = NewSignerFromKey(testPrivateKeys[t]) + if err != nil { + panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err)) + } + testPublicKeys[t] = testSigners[t].PublicKey() + } + + // Create a cert and sign it for use in tests. + testCert := &Certificate{ + Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil + ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage + ValidAfter: 0, // unix epoch + ValidBefore: CertTimeInfinity, // The end of currently representable time. + Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil + Key: testPublicKeys["ecdsa"], + SignatureKey: testPublicKeys["rsa"], + Permissions: Permissions{ + CriticalOptions: map[string]string{}, + Extensions: map[string]string{}, + }, + } + testCert.SignCert(rand.Reader, testSigners["rsa"]) + testPrivateKeys["cert"] = testPrivateKeys["ecdsa"] + testSigners["cert"], err = NewCertSigner(testCert, testSigners["ecdsa"]) + if err != nil { + panic(fmt.Sprintf("Unable to create certificate signer: %v", err)) + } +} diff --git a/tempfork/sshtest/ssh/transport.go b/tempfork/sshtest/ssh/transport.go new file mode 100644 index 000000000..0424d2d37 --- /dev/null +++ b/tempfork/sshtest/ssh/transport.go @@ -0,0 +1,380 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bufio" + "bytes" + "errors" + "io" + "log" +) + +// debugTransport if set, will print packet types as they go over the +// wire. No message decoding is done, to minimize the impact on timing. +const debugTransport = false + +const ( + gcm128CipherID = "aes128-gcm@openssh.com" + gcm256CipherID = "aes256-gcm@openssh.com" + aes128cbcID = "aes128-cbc" + tripledescbcID = "3des-cbc" +) + +// packetConn represents a transport that implements packet based +// operations. +type packetConn interface { + // Encrypt and send a packet of data to the remote peer. + writePacket(packet []byte) error + + // Read a packet from the connection. The read is blocking, + // i.e. if error is nil, then the returned byte slice is + // always non-empty. + readPacket() ([]byte, error) + + // Close closes the write-side of the connection. + Close() error +} + +// transport is the keyingTransport that implements the SSH packet +// protocol. +type transport struct { + reader connectionState + writer connectionState + + bufReader *bufio.Reader + bufWriter *bufio.Writer + rand io.Reader + isClient bool + io.Closer + + strictMode bool + initialKEXDone bool +} + +// packetCipher represents a combination of SSH encryption/MAC +// protocol. A single instance should be used for one direction only. +type packetCipher interface { + // writeCipherPacket encrypts the packet and writes it to w. The + // contents of the packet are generally scrambled. + writeCipherPacket(seqnum uint32, w io.Writer, rand io.Reader, packet []byte) error + + // readCipherPacket reads and decrypts a packet of data. The + // returned packet may be overwritten by future calls of + // readPacket. + readCipherPacket(seqnum uint32, r io.Reader) ([]byte, error) +} + +// connectionState represents one side (read or write) of the +// connection. This is necessary because each direction has its own +// keys, and can even have its own algorithms +type connectionState struct { + packetCipher + seqNum uint32 + dir direction + pendingKeyChange chan packetCipher +} + +func (t *transport) setStrictMode() error { + if t.reader.seqNum != 1 { + return errors.New("ssh: sequence number != 1 when strict KEX mode requested") + } + t.strictMode = true + return nil +} + +func (t *transport) setInitialKEXDone() { + t.initialKEXDone = true +} + +// prepareKeyChange sets up key material for a keychange. The key changes in +// both directions are triggered by reading and writing a msgNewKey packet +// respectively. +func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error { + ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult) + if err != nil { + return err + } + t.reader.pendingKeyChange <- ciph + + ciph, err = newPacketCipher(t.writer.dir, algs.w, kexResult) + if err != nil { + return err + } + t.writer.pendingKeyChange <- ciph + + return nil +} + +func (t *transport) printPacket(p []byte, write bool) { + if len(p) == 0 { + return + } + who := "server" + if t.isClient { + who = "client" + } + what := "read" + if write { + what = "write" + } + + log.Println(what, who, p[0]) +} + +// Read and decrypt next packet. +func (t *transport) readPacket() (p []byte, err error) { + for { + p, err = t.reader.readPacket(t.bufReader, t.strictMode) + if err != nil { + break + } + // in strict mode we pass through DEBUG and IGNORE packets only during the initial KEX + if len(p) == 0 || (t.strictMode && !t.initialKEXDone) || (p[0] != msgIgnore && p[0] != msgDebug) { + break + } + } + if debugTransport { + t.printPacket(p, false) + } + + return p, err +} + +func (s *connectionState) readPacket(r *bufio.Reader, strictMode bool) ([]byte, error) { + packet, err := s.packetCipher.readCipherPacket(s.seqNum, r) + s.seqNum++ + if err == nil && len(packet) == 0 { + err = errors.New("ssh: zero length packet") + } + + if len(packet) > 0 { + switch packet[0] { + case msgNewKeys: + select { + case cipher := <-s.pendingKeyChange: + s.packetCipher = cipher + if strictMode { + s.seqNum = 0 + } + default: + return nil, errors.New("ssh: got bogus newkeys message") + } + + case msgDisconnect: + // Transform a disconnect message into an + // error. Since this is lowest level at which + // we interpret message types, doing it here + // ensures that we don't have to handle it + // elsewhere. + var msg disconnectMsg + if err := Unmarshal(packet, &msg); err != nil { + return nil, err + } + return nil, &msg + } + } + + // The packet may point to an internal buffer, so copy the + // packet out here. + fresh := make([]byte, len(packet)) + copy(fresh, packet) + + return fresh, err +} + +func (t *transport) writePacket(packet []byte) error { + if debugTransport { + t.printPacket(packet, true) + } + return t.writer.writePacket(t.bufWriter, t.rand, packet, t.strictMode) +} + +func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []byte, strictMode bool) error { + changeKeys := len(packet) > 0 && packet[0] == msgNewKeys + + err := s.packetCipher.writeCipherPacket(s.seqNum, w, rand, packet) + if err != nil { + return err + } + if err = w.Flush(); err != nil { + return err + } + s.seqNum++ + if changeKeys { + select { + case cipher := <-s.pendingKeyChange: + s.packetCipher = cipher + if strictMode { + s.seqNum = 0 + } + default: + panic("ssh: no key material for msgNewKeys") + } + } + return err +} + +func newTransport(rwc io.ReadWriteCloser, rand io.Reader, isClient bool) *transport { + t := &transport{ + bufReader: bufio.NewReader(rwc), + bufWriter: bufio.NewWriter(rwc), + rand: rand, + reader: connectionState{ + packetCipher: &streamPacketCipher{cipher: noneCipher{}}, + pendingKeyChange: make(chan packetCipher, 1), + }, + writer: connectionState{ + packetCipher: &streamPacketCipher{cipher: noneCipher{}}, + pendingKeyChange: make(chan packetCipher, 1), + }, + Closer: rwc, + } + t.isClient = isClient + + if isClient { + t.reader.dir = serverKeys + t.writer.dir = clientKeys + } else { + t.reader.dir = clientKeys + t.writer.dir = serverKeys + } + + return t +} + +type direction struct { + ivTag []byte + keyTag []byte + macKeyTag []byte +} + +var ( + serverKeys = direction{[]byte{'B'}, []byte{'D'}, []byte{'F'}} + clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}} +) + +// setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as +// described in RFC 4253, section 6.4. direction should either be serverKeys +// (to setup server->client keys) or clientKeys (for client->server keys). +func newPacketCipher(d direction, algs directionAlgorithms, kex *kexResult) (packetCipher, error) { + cipherMode := cipherModes[algs.Cipher] + + iv := make([]byte, cipherMode.ivSize) + key := make([]byte, cipherMode.keySize) + + generateKeyMaterial(iv, d.ivTag, kex) + generateKeyMaterial(key, d.keyTag, kex) + + var macKey []byte + if !aeadCiphers[algs.Cipher] { + macMode := macModes[algs.MAC] + macKey = make([]byte, macMode.keySize) + generateKeyMaterial(macKey, d.macKeyTag, kex) + } + + return cipherModes[algs.Cipher].create(key, iv, macKey, algs) +} + +// generateKeyMaterial fills out with key material generated from tag, K, H +// and sessionId, as specified in RFC 4253, section 7.2. +func generateKeyMaterial(out, tag []byte, r *kexResult) { + var digestsSoFar []byte + + h := r.Hash.New() + for len(out) > 0 { + h.Reset() + h.Write(r.K) + h.Write(r.H) + + if len(digestsSoFar) == 0 { + h.Write(tag) + h.Write(r.SessionID) + } else { + h.Write(digestsSoFar) + } + + digest := h.Sum(nil) + n := copy(out, digest) + out = out[n:] + if len(out) > 0 { + digestsSoFar = append(digestsSoFar, digest...) + } + } +} + +const packageVersion = "SSH-2.0-Go" + +// Sends and receives a version line. The versionLine string should +// be US ASCII, start with "SSH-2.0-", and should not include a +// newline. exchangeVersions returns the other side's version line. +func exchangeVersions(rw io.ReadWriter, versionLine []byte) (them []byte, err error) { + // Contrary to the RFC, we do not ignore lines that don't + // start with "SSH-2.0-" to make the library usable with + // nonconforming servers. + for _, c := range versionLine { + // The spec disallows non US-ASCII chars, and + // specifically forbids null chars. + if c < 32 { + return nil, errors.New("ssh: junk character in version line") + } + } + if _, err = rw.Write(append(versionLine, '\r', '\n')); err != nil { + return + } + + them, err = readVersion(rw) + return them, err +} + +// maxVersionStringBytes is the maximum number of bytes that we'll +// accept as a version string. RFC 4253 section 4.2 limits this at 255 +// chars +const maxVersionStringBytes = 255 + +// Read version string as specified by RFC 4253, section 4.2. +func readVersion(r io.Reader) ([]byte, error) { + versionString := make([]byte, 0, 64) + var ok bool + var buf [1]byte + + for length := 0; length < maxVersionStringBytes; length++ { + _, err := io.ReadFull(r, buf[:]) + if err != nil { + return nil, err + } + // The RFC says that the version should be terminated with \r\n + // but several SSH servers actually only send a \n. + if buf[0] == '\n' { + if !bytes.HasPrefix(versionString, []byte("SSH-")) { + // RFC 4253 says we need to ignore all version string lines + // except the one containing the SSH version (provided that + // all the lines do not exceed 255 bytes in total). + versionString = versionString[:0] + continue + } + ok = true + break + } + + // non ASCII chars are disallowed, but we are lenient, + // since Go doesn't use null-terminated strings. + + // The RFC allows a comment after a space, however, + // all of it (version and comments) goes into the + // session hash. + versionString = append(versionString, buf[0]) + } + + if !ok { + return nil, errors.New("ssh: overflow reading version string") + } + + // There might be a '\r' on the end which we should remove. + if len(versionString) > 0 && versionString[len(versionString)-1] == '\r' { + versionString = versionString[:len(versionString)-1] + } + return versionString, nil +} diff --git a/tempfork/sshtest/ssh/transport_test.go b/tempfork/sshtest/ssh/transport_test.go new file mode 100644 index 000000000..8445e1e56 --- /dev/null +++ b/tempfork/sshtest/ssh/transport_test.go @@ -0,0 +1,113 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto/rand" + "encoding/binary" + "strings" + "testing" +) + +func TestReadVersion(t *testing.T) { + longVersion := strings.Repeat("SSH-2.0-bla", 50)[:253] + multiLineVersion := strings.Repeat("ignored\r\n", 20) + "SSH-2.0-bla\r\n" + cases := map[string]string{ + "SSH-2.0-bla\r\n": "SSH-2.0-bla", + "SSH-2.0-bla\n": "SSH-2.0-bla", + multiLineVersion: "SSH-2.0-bla", + longVersion + "\r\n": longVersion, + } + + for in, want := range cases { + result, err := readVersion(bytes.NewBufferString(in)) + if err != nil { + t.Errorf("readVersion(%q): %s", in, err) + } + got := string(result) + if got != want { + t.Errorf("got %q, want %q", got, want) + } + } +} + +func TestReadVersionError(t *testing.T) { + longVersion := strings.Repeat("SSH-2.0-bla", 50)[:253] + multiLineVersion := strings.Repeat("ignored\r\n", 50) + "SSH-2.0-bla\r\n" + cases := []string{ + longVersion + "too-long\r\n", + multiLineVersion, + } + for _, in := range cases { + if _, err := readVersion(bytes.NewBufferString(in)); err == nil { + t.Errorf("readVersion(%q) should have failed", in) + } + } +} + +func TestExchangeVersionsBasic(t *testing.T) { + v := "SSH-2.0-bla" + buf := bytes.NewBufferString(v + "\r\n") + them, err := exchangeVersions(buf, []byte("xyz")) + if err != nil { + t.Errorf("exchangeVersions: %v", err) + } + + if want := "SSH-2.0-bla"; string(them) != want { + t.Errorf("got %q want %q for our version", them, want) + } +} + +func TestExchangeVersions(t *testing.T) { + cases := []string{ + "not\x000allowed", + "not allowed\x01\r\n", + } + for _, c := range cases { + buf := bytes.NewBufferString("SSH-2.0-bla\r\n") + if _, err := exchangeVersions(buf, []byte(c)); err == nil { + t.Errorf("exchangeVersions(%q): should have failed", c) + } + } +} + +type closerBuffer struct { + bytes.Buffer +} + +func (b *closerBuffer) Close() error { + return nil +} + +func TestTransportMaxPacketWrite(t *testing.T) { + buf := &closerBuffer{} + tr := newTransport(buf, rand.Reader, true) + huge := make([]byte, maxPacket+1) + err := tr.writePacket(huge) + if err == nil { + t.Errorf("transport accepted write for a huge packet.") + } +} + +func TestTransportMaxPacketReader(t *testing.T) { + var header [5]byte + huge := make([]byte, maxPacket+128) + binary.BigEndian.PutUint32(header[0:], uint32(len(huge))) + // padding. + header[4] = 0 + + buf := &closerBuffer{} + buf.Write(header[:]) + buf.Write(huge) + + tr := newTransport(buf, rand.Reader, true) + _, err := tr.readPacket() + if err == nil { + t.Errorf("transport succeeded reading huge packet.") + } else if !strings.Contains(err.Error(), "large") { + t.Errorf("got %q, should mention %q", err.Error(), "large") + } +}