diff --git a/client/tailscale/localclient.go b/client/tailscale/localclient.go index 4e452f894..baa211d1f 100644 --- a/client/tailscale/localclient.go +++ b/client/tailscale/localclient.go @@ -62,6 +62,12 @@ type LocalClient struct { // machine's tailscaled or equivalent. If nil, a default is used. Dial func(ctx context.Context, network, addr string) (net.Conn, error) + // Transport optionally specified an alternate [http.RoundTripper] + // used to execute HTTP requests. If nil, a default [http.Transport] is used, + // potentially with custom dialing logic from [Dial]. + // It is primarily used for testing. + Transport http.RoundTripper + // Socket specifies an alternate path to the local Tailscale socket. // If empty, a platform-specific default is used. Socket string @@ -129,9 +135,9 @@ func (lc *LocalClient) DoLocalRequest(req *http.Request) (*http.Response, error) req.Header.Set("Tailscale-Cap", strconv.Itoa(int(tailcfg.CurrentCapabilityVersion))) lc.tsClientOnce.Do(func() { lc.tsClient = &http.Client{ - Transport: &http.Transport{ - DialContext: lc.dialer(), - }, + Transport: cmp.Or(lc.Transport, http.RoundTripper( + &http.Transport{DialContext: lc.dialer()}), + ), } }) if !lc.OmitAuth { diff --git a/ipn/ipnauth/actor.go b/ipn/ipnauth/actor.go index 107017268..040d9b522 100644 --- a/ipn/ipnauth/actor.go +++ b/ipn/ipnauth/actor.go @@ -4,6 +4,7 @@ package ipnauth import ( + "encoding/json" "fmt" "tailscale.com/ipn" @@ -76,3 +77,15 @@ func (id ClientID) String() string { } return fmt.Sprint(id.v) } + +// MarshalJSON implements [json.Marshaler]. +// It is primarily used for testing. +func (id ClientID) MarshalJSON() ([]byte, error) { + return json.Marshal(id.v) +} + +// UnmarshalJSON implements [json.Unmarshaler]. +// It is primarily used for testing. +func (id *ClientID) UnmarshalJSON(b []byte) error { + return json.Unmarshal(b, &id.v) +} diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 81a62045a..576f01b6b 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -3659,6 +3659,15 @@ func (b *LocalBackend) SetCurrentUser(actor ipnauth.Actor) (ipn.WindowsUserID, e return uid, nil } +// CurrentUserForTest returns the current user and the associated WindowsUserID. +// It is used for testing only, and will be removed along with the rest of the +// "current user" functionality as we progress on the multi-user improvements (tailscale/corp#18342). +func (b *LocalBackend) CurrentUserForTest() (ipn.WindowsUserID, ipnauth.Actor) { + b.mu.Lock() + defer b.mu.Unlock() + return b.pm.CurrentUserID(), b.currentUser +} + func (b *LocalBackend) CheckPrefs(p *ipn.Prefs) error { b.mu.Lock() defer b.mu.Unlock() diff --git a/ipn/ipnserver/server_test.go b/ipn/ipnserver/server_test.go index b7d5ea144..8a9324fab 100644 --- a/ipn/ipnserver/server_test.go +++ b/ipn/ipnserver/server_test.go @@ -5,8 +5,32 @@ package ipnserver import ( "context" + "encoding/json" + "errors" + "fmt" + "net" + "net/http" + "net/http/httptest" + "runtime" "sync" + "sync/atomic" "testing" + + "tailscale.com/client/tailscale" + "tailscale.com/client/tailscale/apitype" + "tailscale.com/control/controlclient" + "tailscale.com/envknob" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnauth" + "tailscale.com/ipn/ipnlocal" + "tailscale.com/ipn/store/mem" + "tailscale.com/tsd" + "tailscale.com/tstest" + "tailscale.com/types/logger" + "tailscale.com/types/logid" + "tailscale.com/types/ptr" + "tailscale.com/util/mak" + "tailscale.com/wgengine" ) func TestWaiterSet(t *testing.T) { @@ -44,3 +68,337 @@ func TestWaiterSet(t *testing.T) { cleanup() wantLen(0, "at end") } + +func TestUserConnectDisconnectNonWindows(t *testing.T) { + enableLogging := false + if runtime.GOOS == "windows" { + setGOOSForTest(t, "linux") + } + + ctx := context.Background() + server := startDefaultTestIPNServer(t, ctx, enableLogging) + + // UserA connects and starts watching the IPN bus. + clientA := server.getClientAs("UserA") + watcherA, _ := clientA.WatchIPNBus(ctx, 0) + + // The concept of "current user" is only relevant on Windows + // and it should not be set on non-Windows platforms. + server.checkCurrentUser(nil) + + // Additionally, a different user should be able to connect and use the LocalAPI. + clientB := server.getClientAs("UserB") + if _, gotErr := clientB.Status(ctx); gotErr != nil { + t.Fatalf("Status(%q): want nil; got %v", clientB.User.Name, gotErr) + } + + // Watching the IPN bus should also work for UserB. + watcherB, _ := clientB.WatchIPNBus(ctx, 0) + + // And if we send a notification, both users should receive it. + wantErrMessage := "test error" + testNotify := ipn.Notify{ErrMessage: ptr.To(wantErrMessage)} + server.mustBackend().DebugNotify(testNotify) + + if n, err := watcherA.Next(); err != nil { + t.Fatalf("IPNBusWatcher.Next(%q): %v", clientA.User.Name, err) + } else if gotErrMessage := n.ErrMessage; gotErrMessage == nil || *gotErrMessage != wantErrMessage { + t.Fatalf("IPNBusWatcher.Next(%q): want %v; got %v", clientA.User.Name, wantErrMessage, gotErrMessage) + } + + if n, err := watcherB.Next(); err != nil { + t.Fatalf("IPNBusWatcher.Next(%q): %v", clientB.User.Name, err) + } else if gotErrMessage := n.ErrMessage; gotErrMessage == nil || *gotErrMessage != wantErrMessage { + t.Fatalf("IPNBusWatcher.Next(%q): want %v; got %v", clientB.User.Name, wantErrMessage, gotErrMessage) + } +} + +func TestUserConnectDisconnectOnWindows(t *testing.T) { + enableLogging := false + setGOOSForTest(t, "windows") + + ctx := context.Background() + server := startDefaultTestIPNServer(t, ctx, enableLogging) + + client := server.getClientAs("User") + _, cancelWatcher := client.WatchIPNBus(ctx, 0) + + // On Windows, however, the current user should be set to the user that connected. + server.checkCurrentUser(client.User) + + // Cancel the IPN bus watcher request and wait for the server to unblock. + cancelWatcher() + server.blockWhileInUse(ctx) + + // The current user should not be set after a disconnect, as no one is + // currently using the server. + server.checkCurrentUser(nil) +} + +func TestIPNAlreadyInUseOnWindows(t *testing.T) { + enableLogging := false + setGOOSForTest(t, "windows") + + ctx := context.Background() + server := startDefaultTestIPNServer(t, ctx, enableLogging) + + // UserA connects and starts watching the IPN bus. + clientA := server.getClientAs("UserA") + clientA.WatchIPNBus(ctx, 0) + + // While UserA is connected, UserB should not be able to connect. + clientB := server.getClientAs("UserB") + if _, gotErr := clientB.Status(ctx); gotErr == nil { + t.Fatalf("Status(%q): want error; got nil", clientB.User.Name) + } else if wantError := "401 Unauthorized: Tailscale already in use by UserA"; gotErr.Error() != wantError { + t.Fatalf("Status(%q): want %q; got %q", clientB.User.Name, wantError, gotErr.Error()) + } + + // Current user should still be UserA. + server.checkCurrentUser(clientA.User) +} + +func setGOOSForTest(tb testing.TB, goos string) { + tb.Helper() + envknob.Setenv("TS_DEBUG_FAKE_GOOS", goos) + tb.Cleanup(func() { envknob.Setenv("TS_DEBUG_FAKE_GOOS", "") }) +} + +func testLogger(tb testing.TB, enableLogging bool) logger.Logf { + tb.Helper() + if enableLogging { + return tstest.WhileTestRunningLogger(tb) + } + return logger.Discard +} + +// newTestIPNServer creates a new IPN server for testing, using the specified local backend. +func newTestIPNServer(tb testing.TB, lb *ipnlocal.LocalBackend, enableLogging bool) *Server { + tb.Helper() + server := New(testLogger(tb, enableLogging), logid.PublicID{}, lb.NetMon()) + server.lb.Store(lb) + return server +} + +type testIPNClient struct { + tb testing.TB + *tailscale.LocalClient + User *ipnauth.TestActor +} + +func (c *testIPNClient) WatchIPNBus(ctx context.Context, mask ipn.NotifyWatchOpt) (*tailscale.IPNBusWatcher, context.CancelFunc) { + c.tb.Helper() + ctx, cancelWatcher := context.WithCancel(ctx) + c.tb.Cleanup(cancelWatcher) + watcher, err := c.LocalClient.WatchIPNBus(ctx, mask) + if err != nil { + c.tb.Fatalf("WatchIPNBus(%q): %v", c.User.Name, err) + } + c.tb.Cleanup(func() { watcher.Close() }) + return watcher, cancelWatcher +} + +func pumpIPNBus(watcher *tailscale.IPNBusWatcher) { + for { + _, err := watcher.Next() + if err != nil { + break + } + } +} + +type testIPNServer struct { + tb testing.TB + *Server + clientID atomic.Int64 + getClient func(*ipnauth.TestActor) *tailscale.LocalClient + + actorsMu sync.Mutex + actors map[string]*ipnauth.TestActor +} + +func (s *testIPNServer) getClientAs(name string) *testIPNClient { + clientID := fmt.Sprintf("Client-%d", 1+s.clientID.Add(1)) + user := s.makeTestUser(name, clientID) + return &testIPNClient{ + tb: s.tb, + LocalClient: s.getClient(user), + User: user, + } +} + +func (s *testIPNServer) makeTestUser(name string, clientID string) *ipnauth.TestActor { + s.actorsMu.Lock() + defer s.actorsMu.Unlock() + actor := s.actors[name] + if actor == nil { + actor = &ipnauth.TestActor{Name: name} + if envknob.GOOS() == "windows" { + // Historically, as of 2025-01-13, IPN does not distinguish between + // different users on non-Windows devices. Therefore, the UID, which is + // an [ipn.WindowsUserID], should only be populated when the actual or + // fake GOOS is Windows. + actor.UID = ipn.WindowsUserID(fmt.Sprintf("S-1-5-21-1-0-0-%d", 1001+len(s.actors))) + } + mak.Set(&s.actors, name, actor) + s.tb.Cleanup(func() { delete(s.actors, name) }) + } + actor = ptr.To(*actor) + actor.CID = ipnauth.ClientIDFrom(clientID) + return actor +} + +func (s *testIPNServer) blockWhileInUse(ctx context.Context) error { + ready, cleanup := s.zeroReqWaiter.add(&s.mu, ctx) + <-ready + cleanup() + return ctx.Err() +} + +func (s *testIPNServer) checkCurrentUser(want *ipnauth.TestActor) { + s.tb.Helper() + var wantUID ipn.WindowsUserID + if want != nil { + wantUID = want.UID + } + gotUID, gotActor := s.mustBackend().CurrentUserForTest() + if gotUID != wantUID { + s.tb.Errorf("CurrentUser: got UID %q; want %q", gotUID, wantUID) + } + if gotActor, ok := gotActor.(*ipnauth.TestActor); ok != (want != nil) || (want != nil && *gotActor != *want) { + s.tb.Errorf("CurrentUser: got %v; want %v", gotActor, want) + } +} + +// startTestIPNServer starts a [httptest.Server] that hosts the specified IPN server for the +// duration of the test, using the specified base context for incoming requests. +// It returns a function that creates a [tailscale.LocalClient] as a given [ipnauth.TestActor]. +func startTestIPNServer(tb testing.TB, baseContext context.Context, server *Server) *testIPNServer { + tb.Helper() + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + actor, err := extractActorFromHeader(r.Header) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + tb.Errorf("extractActorFromHeader: %v", err) + return + } + ctx := newTestContextWithActor(r.Context(), actor) + server.serveHTTP(w, r.Clone(ctx)) + })) + ts.Config.Addr = "http://" + apitype.LocalAPIHost + ts.Config.BaseContext = func(_ net.Listener) context.Context { return baseContext } + ts.Config.ErrorLog = logger.StdLogger(logger.WithPrefix(server.logf, "ipnserver: ")) + ts.Start() + tb.Cleanup(ts.Close) + return &testIPNServer{ + tb: tb, + Server: server, + getClient: func(actor *ipnauth.TestActor) *tailscale.LocalClient { + return &tailscale.LocalClient{Transport: newTestRoundTripper(ts, actor)} + }, + } +} + +func startDefaultTestIPNServer(tb testing.TB, ctx context.Context, enableLogging bool) *testIPNServer { + tb.Helper() + lb := newLocalBackendWithTestControl(tb, newUnreachableControlClient, enableLogging) + ctx, stopServer := context.WithCancel(ctx) + tb.Cleanup(stopServer) + return startTestIPNServer(tb, ctx, newTestIPNServer(tb, lb, enableLogging)) +} + +type testRoundTripper struct { + transport http.RoundTripper + actor *ipnauth.TestActor +} + +// newTestRoundTripper creates a new [http.RoundTripper] that sends requests +// to the specified test server as the specified actor. +func newTestRoundTripper(ts *httptest.Server, actor *ipnauth.TestActor) *testRoundTripper { + return &testRoundTripper{ + transport: &http.Transport{DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + var std net.Dialer + return std.DialContext(ctx, network, ts.Listener.Addr().(*net.TCPAddr).String()) + }}, + actor: actor, + } +} + +const testActorHeaderName = "TS-Test-Actor" + +// RoundTrip implements [http.RoundTripper] by forwarding the request to the underlying transport +// and including the test actor's identity in the request headers. +func (rt *testRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { + actorJSON, err := json.Marshal(&rt.actor) + if err != nil { + // An [http.RoundTripper] must always close the request body, including on error. + if r.Body != nil { + r.Body.Close() + } + return nil, err + } + + r = r.Clone(r.Context()) + r.Header.Set(testActorHeaderName, string(actorJSON)) + return rt.transport.RoundTrip(r) +} + +// extractActorFromHeader extracts a test actor from the specified request headers. +func extractActorFromHeader(h http.Header) (*ipnauth.TestActor, error) { + actorJSON := h.Get(testActorHeaderName) + if actorJSON == "" { + return nil, errors.New("missing Test-Actor header") + } + actor := &ipnauth.TestActor{} + if err := json.Unmarshal([]byte(actorJSON), &actor); err != nil { + return nil, fmt.Errorf("invalid Test-Actor header: %v", err) + } + return actor, nil +} + +type newControlClientFn func(tb testing.TB, opts controlclient.Options) controlclient.Client + +func newLocalBackendWithTestControl(tb testing.TB, newControl newControlClientFn, enableLogging bool) *ipnlocal.LocalBackend { + tb.Helper() + + sys := &tsd.System{} + store := &mem.Store{} + sys.Set(store) + + logf := testLogger(tb, enableLogging) + e, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry()) + if err != nil { + tb.Fatalf("NewFakeUserspaceEngine: %v", err) + } + tb.Cleanup(e.Close) + sys.Set(e) + + b, err := ipnlocal.NewLocalBackend(logf, logid.PublicID{}, sys, 0) + if err != nil { + tb.Fatalf("NewLocalBackend: %v", err) + } + tb.Cleanup(b.Shutdown) + b.DisablePortMapperForTest() + + b.SetControlClientGetterForTesting(func(opts controlclient.Options) (controlclient.Client, error) { + return newControl(tb, opts), nil + }) + return b +} + +func newUnreachableControlClient(tb testing.TB, opts controlclient.Options) controlclient.Client { + tb.Helper() + opts.ServerURL = "https://127.0.0.1:1" + cc, err := controlclient.New(opts) + if err != nil { + tb.Fatal(err) + } + return cc +} + +// newTestContextWithActor returns a new context that carries the identity +// of the specified actor and can be used for testing. +// It can be retrieved with [actorFromContext]. +func newTestContextWithActor(ctx context.Context, actor ipnauth.Actor) context.Context { + return actorKey.WithValue(ctx, actorOrError{actor: actor}) +}