This updates all source files to use a new standard header for copyright and license declaration. Notably, copyright no longer includes a date, and we now use the standard SPDX-License-Identifier header. This commit was done almost entirely mechanically with perl, and then some minimal manual fixes. Updates #6865 Signed-off-by: Will Norris <will@tailscale.com>
		
			
				
	
	
		
			131 lines
		
	
	
		
			2.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			131 lines
		
	
	
		
			2.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
// Copyright (c) Tailscale Inc & AUTHORS
 | 
						|
// SPDX-License-Identifier: BSD-3-Clause
 | 
						|
 | 
						|
package tlsdial
 | 
						|
 | 
						|
import (
 | 
						|
	"crypto/x509"
 | 
						|
	"io"
 | 
						|
	"net"
 | 
						|
	"net/http"
 | 
						|
	"os"
 | 
						|
	"os/exec"
 | 
						|
	"path/filepath"
 | 
						|
	"reflect"
 | 
						|
	"runtime"
 | 
						|
	"sync/atomic"
 | 
						|
	"testing"
 | 
						|
)
 | 
						|
 | 
						|
func resetOnce() {
 | 
						|
	rv := reflect.ValueOf(&bakedInRootsOnce).Elem()
 | 
						|
	rv.Set(reflect.Zero(rv.Type()))
 | 
						|
}
 | 
						|
 | 
						|
func TestBakedInRoots(t *testing.T) {
 | 
						|
	resetOnce()
 | 
						|
	p := bakedInRoots()
 | 
						|
	got := p.Subjects()
 | 
						|
	if len(got) != 1 {
 | 
						|
		t.Errorf("subjects = %v; want 1", len(got))
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestFallbackRootWorks(t *testing.T) {
 | 
						|
	defer resetOnce()
 | 
						|
 | 
						|
	const debug = false
 | 
						|
	if runtime.GOOS != "linux" {
 | 
						|
		t.Skip("test assumes Linux")
 | 
						|
	}
 | 
						|
	d := t.TempDir()
 | 
						|
	crtFile := filepath.Join(d, "tlsdial.test.crt")
 | 
						|
	keyFile := filepath.Join(d, "tlsdial.test.key")
 | 
						|
	caFile := filepath.Join(d, "rootCA.pem")
 | 
						|
	cmd := exec.Command("go",
 | 
						|
		"run", "filippo.io/mkcert",
 | 
						|
		"--cert-file="+crtFile,
 | 
						|
		"--key-file="+keyFile,
 | 
						|
		"tlsdial.test")
 | 
						|
	cmd.Env = append(os.Environ(), "CAROOT="+d)
 | 
						|
	out, err := cmd.CombinedOutput()
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("mkcert: %v, %s", err, out)
 | 
						|
	}
 | 
						|
	if debug {
 | 
						|
		t.Logf("Ran: %s", out)
 | 
						|
		dents, err := os.ReadDir(d)
 | 
						|
		if err != nil {
 | 
						|
			t.Fatal(err)
 | 
						|
		}
 | 
						|
		for _, de := range dents {
 | 
						|
			t.Logf(" - %v", de)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	caPEM, err := os.ReadFile(caFile)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
	resetOnce()
 | 
						|
	bakedInRootsOnce.Do(func() {
 | 
						|
		p := x509.NewCertPool()
 | 
						|
		if !p.AppendCertsFromPEM(caPEM) {
 | 
						|
			t.Fatal("failed to add")
 | 
						|
		}
 | 
						|
		bakedInRootsOnce.p = p
 | 
						|
	})
 | 
						|
 | 
						|
	ln, err := net.Listen("tcp", "127.0.0.1:0")
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
	defer ln.Close()
 | 
						|
	if debug {
 | 
						|
		t.Logf("listener running at %v", ln.Addr())
 | 
						|
	}
 | 
						|
	done := make(chan struct{})
 | 
						|
	defer close(done)
 | 
						|
 | 
						|
	errc := make(chan error, 1)
 | 
						|
	go func() {
 | 
						|
		err := http.ServeTLS(ln, http.HandlerFunc(sayHi), crtFile, keyFile)
 | 
						|
		select {
 | 
						|
		case <-done:
 | 
						|
			return
 | 
						|
		default:
 | 
						|
			t.Logf("ServeTLS: %v", err)
 | 
						|
			errc <- err
 | 
						|
		}
 | 
						|
	}()
 | 
						|
 | 
						|
	tr := &http.Transport{
 | 
						|
		Dial: func(network, addr string) (net.Conn, error) {
 | 
						|
			return net.Dial("tcp", ln.Addr().String())
 | 
						|
		},
 | 
						|
		DisableKeepAlives: true, // for test cleanup ease
 | 
						|
	}
 | 
						|
	tr.TLSClientConfig = Config("tlsdial.test", tr.TLSClientConfig)
 | 
						|
	c := &http.Client{Transport: tr}
 | 
						|
 | 
						|
	ctr0 := atomic.LoadInt32(&counterFallbackOK)
 | 
						|
 | 
						|
	res, err := c.Get("https://tlsdial.test/")
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
	defer res.Body.Close()
 | 
						|
	if res.StatusCode != 200 {
 | 
						|
		t.Fatal(res.Status)
 | 
						|
	}
 | 
						|
 | 
						|
	ctrDelta := atomic.LoadInt32(&counterFallbackOK) - ctr0
 | 
						|
	if ctrDelta != 1 {
 | 
						|
		t.Errorf("fallback root success count = %d; want 1", ctrDelta)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func sayHi(w http.ResponseWriter, r *http.Request) {
 | 
						|
	io.WriteString(w, "hi")
 | 
						|
}
 |