diff --git a/cmd/containerboot/kube.go b/cmd/containerboot/kube.go index 4d00687ee..0a2dfa1bf 100644 --- a/cmd/containerboot/kube.go +++ b/cmd/containerboot/kube.go @@ -8,15 +8,22 @@ import ( "context" "encoding/json" + "errors" "fmt" + "log" "net/http" "net/netip" "os" + "strings" + "time" + "tailscale.com/ipn" "tailscale.com/kube/kubeapi" "tailscale.com/kube/kubeclient" "tailscale.com/kube/kubetypes" + "tailscale.com/logtail/backoff" "tailscale.com/tailcfg" + "tailscale.com/types/logger" ) // kubeClient is a wrapper around Tailscale's internal kube client that knows how to talk to the kube API server. We use @@ -126,3 +133,62 @@ func (kc *kubeClient) storeCapVerUID(ctx context.Context, podUID string) error { } return kc.StrategicMergePatchSecret(ctx, kc.stateSecret, s, "tailscale-container") } + +// waitForConsistentState waits for tailscaled to finish writing state if it +// looks like it's started. It is designed to reduce the likelihood that +// tailscaled gets shut down in the window between authenticating to control +// and finishing writing state. However, it's not bullet proof because we can't +// atomically authenticate and write state. +func (kc *kubeClient) waitForConsistentState(ctx context.Context) error { + var logged bool + + bo := backoff.NewBackoff("", logger.Discard, 2*time.Second) + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + secret, err := kc.GetSecret(ctx, kc.stateSecret) + if ctx.Err() != nil || kubeclient.IsNotFoundErr(err) { + return nil + } + if err != nil { + return fmt.Errorf("getting Secret %q: %v", kc.stateSecret, err) + } + + if hasConsistentState(secret.Data) { + return nil + } + + if !logged { + log.Printf("Waiting for tailscaled to finish writing state to Secret %q", kc.stateSecret) + logged = true + } + bo.BackOff(ctx, errors.New("")) // Fake error to trigger actual sleep. + } +} + +// hasConsistentState returns true is there is either no state or the full set +// of expected keys are present. +func hasConsistentState(d map[string][]byte) bool { + var ( + _, hasCurrent = d[string(ipn.CurrentProfileStateKey)] + _, hasKnown = d[string(ipn.KnownProfilesStateKey)] + _, hasMachine = d[string(ipn.MachineKeyStateKey)] + hasProfile bool + ) + + for k := range d { + if strings.HasPrefix(k, "profile-") { + if hasProfile { + return false // We only expect one profile. + } + hasProfile = true + } + } + + // Approximate check, we don't want to reimplement all of profileManager. + return (hasCurrent && hasKnown && hasMachine && hasProfile) || + (!hasCurrent && !hasKnown && !hasMachine && !hasProfile) +} diff --git a/cmd/containerboot/kube_test.go b/cmd/containerboot/kube_test.go index 2ba69af7c..413971bc6 100644 --- a/cmd/containerboot/kube_test.go +++ b/cmd/containerboot/kube_test.go @@ -9,8 +9,10 @@ "context" "errors" "testing" + "time" "github.com/google/go-cmp/cmp" + "tailscale.com/ipn" "tailscale.com/kube/kubeapi" "tailscale.com/kube/kubeclient" ) @@ -205,3 +207,34 @@ func TestSetupKube(t *testing.T) { }) } } + +func TestWaitForConsistentState(t *testing.T) { + data := map[string][]byte{ + // Missing _current-profile. + string(ipn.KnownProfilesStateKey): []byte(""), + string(ipn.MachineKeyStateKey): []byte(""), + "profile-foo": []byte(""), + } + kc := &kubeClient{ + Client: &kubeclient.FakeClient{ + GetSecretImpl: func(context.Context, string) (*kubeapi.Secret, error) { + return &kubeapi.Secret{ + Data: data, + }, nil + }, + }, + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if err := kc.waitForConsistentState(ctx); err != context.DeadlineExceeded { + t.Fatalf("expected DeadlineExceeded, got %v", err) + } + + ctx, cancel = context.WithTimeout(context.Background(), time.Second) + defer cancel() + data[string(ipn.CurrentProfileStateKey)] = []byte("") + if err := kc.waitForConsistentState(ctx); err != nil { + t.Fatalf("expected nil, got %v", err) + } +} diff --git a/cmd/containerboot/main.go b/cmd/containerboot/main.go index 0aca27f5f..cf4bd8620 100644 --- a/cmd/containerboot/main.go +++ b/cmd/containerboot/main.go @@ -137,53 +137,83 @@ func newNetfilterRunner(logf logger.Logf) (linuxfw.NetfilterRunner, error) { } func main() { + if err := run(); err != nil && !errors.Is(err, context.Canceled) { + log.Fatal(err) + } +} + +func run() error { log.SetPrefix("boot: ") tailscale.I_Acknowledge_This_API_Is_Unstable = true cfg, err := configFromEnv() if err != nil { - log.Fatalf("invalid configuration: %v", err) + return fmt.Errorf("invalid configuration: %w", err) } if !cfg.UserspaceMode { if err := ensureTunFile(cfg.Root); err != nil { - log.Fatalf("Unable to create tuntap device file: %v", err) + return fmt.Errorf("unable to create tuntap device file: %w", err) } if cfg.ProxyTargetIP != "" || cfg.ProxyTargetDNSName != "" || cfg.Routes != nil || cfg.TailnetTargetIP != "" || cfg.TailnetTargetFQDN != "" { if err := ensureIPForwarding(cfg.Root, cfg.ProxyTargetIP, cfg.TailnetTargetIP, cfg.TailnetTargetFQDN, cfg.Routes); err != nil { log.Printf("Failed to enable IP forwarding: %v", err) log.Printf("To run tailscale as a proxy or router container, IP forwarding must be enabled.") if cfg.InKubernetes { - log.Fatalf("You can either set the sysctls as a privileged initContainer, or run the tailscale container with privileged=true.") + return fmt.Errorf("you can either set the sysctls as a privileged initContainer, or run the tailscale container with privileged=true.") } else { - log.Fatalf("You can fix this by running the container with privileged=true, or the equivalent in your container runtime that permits access to sysctls.") + return fmt.Errorf("you can fix this by running the container with privileged=true, or the equivalent in your container runtime that permits access to sysctls.") } } } } - // Context is used for all setup stuff until we're in steady + // Root context for the whole containerboot process, used to make sure + // shutdown signals are promptly and cleanly handled. + ctx, cancel := contextWithExitSignalWatch() + defer cancel() + + // bootCtx is used for all setup stuff until we're in steady // state, so that if something is hanging we eventually time out // and crashloop the container. - bootCtx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + bootCtx, cancel := context.WithTimeout(ctx, 60*time.Second) defer cancel() var kc *kubeClient if cfg.InKubernetes { kc, err = newKubeClient(cfg.Root, cfg.KubeSecret) if err != nil { - log.Fatalf("error initializing kube client: %v", err) + return fmt.Errorf("error initializing kube client: %w", err) } if err := cfg.setupKube(bootCtx, kc); err != nil { - log.Fatalf("error setting up for running on Kubernetes: %v", err) + return fmt.Errorf("error setting up for running on Kubernetes: %w", err) } } client, daemonProcess, err := startTailscaled(bootCtx, cfg) if err != nil { - log.Fatalf("failed to bring up tailscale: %v", err) + return fmt.Errorf("failed to bring up tailscale: %w", err) } killTailscaled := func() { + if hasKubeStateStore(cfg) { + // Check we're not shutting tailscaled down while it's still writing + // state. If we authenticate and fail to write all the state, we'll + // never recover automatically. + // + // The default termination grace period for a Pod is 30s. We wait 25s at + // most so that we still reserve some of that budget for tailscaled + // to receive and react to a SIGTERM before the SIGKILL that k8s + // will send at the end of the grace period. + ctx, cancel := context.WithTimeout(context.Background(), 25*time.Second) + defer cancel() + + log.Printf("Checking for consistent state") + err := kc.waitForConsistentState(ctx) + if err != nil { + log.Printf("Error waiting for consistent state on shutdown: %v", err) + } + } + log.Printf("Sending SIGTERM to tailscaled") if err := daemonProcess.Signal(unix.SIGTERM); err != nil { log.Fatalf("error shutting tailscaled down: %v", err) } @@ -231,7 +261,7 @@ func main() { w, err := client.WatchIPNBus(bootCtx, ipn.NotifyInitialNetMap|ipn.NotifyInitialPrefs|ipn.NotifyInitialState) if err != nil { - log.Fatalf("failed to watch tailscaled for updates: %v", err) + return fmt.Errorf("failed to watch tailscaled for updates: %w", err) } // Now that we've started tailscaled, we can symlink the socket to the @@ -267,18 +297,18 @@ func main() { didLogin = true w.Close() if err := tailscaleUp(bootCtx, cfg); err != nil { - return fmt.Errorf("failed to auth tailscale: %v", err) + return fmt.Errorf("failed to auth tailscale: %w", err) } w, err = client.WatchIPNBus(bootCtx, ipn.NotifyInitialNetMap|ipn.NotifyInitialState) if err != nil { - return fmt.Errorf("rewatching tailscaled for updates after auth: %v", err) + return fmt.Errorf("rewatching tailscaled for updates after auth: %w", err) } return nil } if isTwoStepConfigAlwaysAuth(cfg) { if err := authTailscale(); err != nil { - log.Fatalf("failed to auth tailscale: %v", err) + return fmt.Errorf("failed to auth tailscale: %w", err) } } @@ -286,7 +316,7 @@ func main() { for { n, err := w.Next() if err != nil { - log.Fatalf("failed to read from tailscaled: %v", err) + return fmt.Errorf("failed to read from tailscaled: %w", err) } if n.State != nil { @@ -295,10 +325,10 @@ func main() { if isOneStepConfig(cfg) { // This could happen if this is the first time tailscaled was run for this // device and the auth key was not passed via the configfile. - log.Fatalf("invalid state: tailscaled daemon started with a config file, but tailscale is not logged in: ensure you pass a valid auth key in the config file.") + return fmt.Errorf("invalid state: tailscaled daemon started with a config file, but tailscale is not logged in: ensure you pass a valid auth key in the config file.") } if err := authTailscale(); err != nil { - log.Fatalf("failed to auth tailscale: %v", err) + return fmt.Errorf("failed to auth tailscale: %w", err) } case ipn.NeedsMachineAuth: log.Printf("machine authorization required, please visit the admin panel") @@ -318,14 +348,11 @@ func main() { w.Close() - ctx, cancel := contextWithExitSignalWatch() - defer cancel() - if isTwoStepConfigAuthOnce(cfg) { // Now that we are authenticated, we can set/reset any of the // settings that we need to. if err := tailscaleSet(ctx, cfg); err != nil { - log.Fatalf("failed to auth tailscale: %v", err) + return fmt.Errorf("failed to auth tailscale: %w", err) } } @@ -334,11 +361,11 @@ func main() { if cfg.ServeConfigPath != "" { log.Printf("serve proxy: unsetting previous config") if err := client.SetServeConfig(ctx, new(ipn.ServeConfig)); err != nil { - log.Fatalf("failed to unset serve config: %v", err) + return fmt.Errorf("failed to unset serve config: %w", err) } if hasKubeStateStore(cfg) { if err := kc.storeHTTPSEndpoint(ctx, ""); err != nil { - log.Fatalf("failed to update HTTPS endpoint in tailscale state: %v", err) + return fmt.Errorf("failed to update HTTPS endpoint in tailscale state: %w", err) } } } @@ -349,19 +376,19 @@ func main() { // wipe it, but it's good hygiene. log.Printf("Deleting authkey from kube secret") if err := kc.deleteAuthKey(ctx); err != nil { - log.Fatalf("deleting authkey from kube secret: %v", err) + return fmt.Errorf("deleting authkey from kube secret: %w", err) } } if hasKubeStateStore(cfg) { if err := kc.storeCapVerUID(ctx, cfg.PodUID); err != nil { - log.Fatalf("storing capability version and UID: %v", err) + return fmt.Errorf("storing capability version and UID: %w", err) } } w, err = client.WatchIPNBus(ctx, ipn.NotifyInitialNetMap|ipn.NotifyInitialState) if err != nil { - log.Fatalf("rewatching tailscaled for updates after auth: %v", err) + return fmt.Errorf("rewatching tailscaled for updates after auth: %w", err) } // If tailscaled config was read from a mounted file, watch the file for updates and reload. @@ -391,7 +418,7 @@ func main() { if isL3Proxy(cfg) { nfr, err = newNetfilterRunner(log.Printf) if err != nil { - log.Fatalf("error creating new netfilter runner: %v", err) + return fmt.Errorf("error creating new netfilter runner: %w", err) } } @@ -462,9 +489,9 @@ func main() { killTailscaled() break runLoop case err := <-errChan: - log.Fatalf("failed to read from tailscaled: %v", err) + return fmt.Errorf("failed to read from tailscaled: %w", err) case err := <-cfgWatchErrChan: - log.Fatalf("failed to watch tailscaled config: %v", err) + return fmt.Errorf("failed to watch tailscaled config: %w", err) case n := <-notifyChan: if n.State != nil && *n.State != ipn.Running { // Something's gone wrong and we've left the authenticated state. @@ -472,7 +499,7 @@ func main() { // control flow required to make it work now is hard. So, just crash // the container and rely on the container runtime to restart us, // whereupon we'll go through initial auth again. - log.Fatalf("tailscaled left running state (now in state %q), exiting", *n.State) + return fmt.Errorf("tailscaled left running state (now in state %q), exiting", *n.State) } if n.NetMap != nil { addrs = n.NetMap.SelfNode.Addresses().AsSlice() @@ -490,7 +517,7 @@ func main() { deviceID := n.NetMap.SelfNode.StableID() if hasKubeStateStore(cfg) && deephash.Update(¤tDeviceID, &deviceID) { if err := kc.storeDeviceID(ctx, n.NetMap.SelfNode.StableID()); err != nil { - log.Fatalf("storing device ID in Kubernetes Secret: %v", err) + return fmt.Errorf("storing device ID in Kubernetes Secret: %w", err) } } if cfg.TailnetTargetFQDN != "" { @@ -527,12 +554,12 @@ func main() { rulesInstalled = true log.Printf("Installing forwarding rules for destination %v", ea.String()) if err := installEgressForwardingRule(ctx, ea.String(), addrs, nfr); err != nil { - log.Fatalf("installing egress proxy rules for destination %s: %v", ea.String(), err) + return fmt.Errorf("installing egress proxy rules for destination %s: %v", ea.String(), err) } } } if !rulesInstalled { - log.Fatalf("no forwarding rules for egress addresses %v, host supports IPv6: %v", egressAddrs, nfr.HasIPV6NAT()) + return fmt.Errorf("no forwarding rules for egress addresses %v, host supports IPv6: %v", egressAddrs, nfr.HasIPV6NAT()) } } currentEgressIPs = newCurentEgressIPs @@ -540,7 +567,7 @@ func main() { if cfg.ProxyTargetIP != "" && len(addrs) != 0 && ipsHaveChanged { log.Printf("Installing proxy rules") if err := installIngressForwardingRule(ctx, cfg.ProxyTargetIP, addrs, nfr); err != nil { - log.Fatalf("installing ingress proxy rules: %v", err) + return fmt.Errorf("installing ingress proxy rules: %w", err) } } if cfg.ProxyTargetDNSName != "" && len(addrs) != 0 && ipsHaveChanged { @@ -556,7 +583,7 @@ func main() { if backendsHaveChanged { log.Printf("installing ingress proxy rules for backends %v", newBackendAddrs) if err := installIngressForwardingRuleForDNSTarget(ctx, newBackendAddrs, addrs, nfr); err != nil { - log.Fatalf("error installing ingress proxy rules: %v", err) + return fmt.Errorf("error installing ingress proxy rules: %w", err) } } resetTimer(false) @@ -578,7 +605,7 @@ func main() { if cfg.TailnetTargetIP != "" && ipsHaveChanged && len(addrs) != 0 { log.Printf("Installing forwarding rules for destination %v", cfg.TailnetTargetIP) if err := installEgressForwardingRule(ctx, cfg.TailnetTargetIP, addrs, nfr); err != nil { - log.Fatalf("installing egress proxy rules: %v", err) + return fmt.Errorf("installing egress proxy rules: %w", err) } } // If this is a L7 cluster ingress proxy (set up @@ -590,7 +617,7 @@ func main() { if cfg.AllowProxyingClusterTrafficViaIngress && cfg.ServeConfigPath != "" && ipsHaveChanged && len(addrs) != 0 { log.Printf("installing rules to forward traffic for %s to node's tailnet IP", cfg.PodIP) if err := installTSForwardingRuleForDestination(ctx, cfg.PodIP, addrs, nfr); err != nil { - log.Fatalf("installing rules to forward traffic to node's tailnet IP: %v", err) + return fmt.Errorf("installing rules to forward traffic to node's tailnet IP: %w", err) } } currentIPs = newCurrentIPs @@ -609,7 +636,7 @@ func main() { deviceEndpoints := []any{n.NetMap.SelfNode.Name(), n.NetMap.SelfNode.Addresses()} if hasKubeStateStore(cfg) && deephash.Update(¤tDeviceEndpoints, &deviceEndpoints) { if err := kc.storeDeviceEndpoints(ctx, n.NetMap.SelfNode.Name(), n.NetMap.SelfNode.Addresses().AsSlice()); err != nil { - log.Fatalf("storing device IPs and FQDN in Kubernetes Secret: %v", err) + return fmt.Errorf("storing device IPs and FQDN in Kubernetes Secret: %w", err) } } @@ -700,16 +727,18 @@ func main() { if backendsHaveChanged && len(addrs) != 0 { log.Printf("Backend address change detected, installing proxy rules for backends %v", newBackendAddrs) if err := installIngressForwardingRuleForDNSTarget(ctx, newBackendAddrs, addrs, nfr); err != nil { - log.Fatalf("installing ingress proxy rules for DNS target %s: %v", cfg.ProxyTargetDNSName, err) + return fmt.Errorf("installing ingress proxy rules for DNS target %s: %v", cfg.ProxyTargetDNSName, err) } } backendAddrs = newBackendAddrs resetTimer(false) case e := <-egressSvcsErrorChan: - log.Fatalf("egress proxy failed: %v", e) + return fmt.Errorf("egress proxy failed: %v", e) } } wg.Wait() + + return nil } // ensureTunFile checks that /dev/net/tun exists, creating it if @@ -738,13 +767,13 @@ func resolveDNS(ctx context.Context, name string) ([]net.IP, error) { ip4s, err := net.DefaultResolver.LookupIP(ctx, "ip4", name) if err != nil { if e, ok := err.(*net.DNSError); !(ok && e.IsNotFound) { - return nil, fmt.Errorf("error looking up IPv4 addresses: %v", err) + return nil, fmt.Errorf("error looking up IPv4 addresses: %w", err) } } ip6s, err := net.DefaultResolver.LookupIP(ctx, "ip6", name) if err != nil { if e, ok := err.(*net.DNSError); !(ok && e.IsNotFound) { - return nil, fmt.Errorf("error looking up IPv6 addresses: %v", err) + return nil, fmt.Errorf("error looking up IPv6 addresses: %w", err) } } if len(ip4s) == 0 && len(ip6s) == 0 { @@ -757,7 +786,7 @@ func resolveDNS(ctx context.Context, name string) ([]net.IP, error) { // context that gets cancelled when a signal is received and a cancel function // that can be called to free the resources when the watch should be stopped. func contextWithExitSignalWatch() (context.Context, func()) { - closeChan := make(chan string) + closeChan := make(chan struct{}) ctx, cancel := context.WithCancel(context.Background()) signalChan := make(chan os.Signal, 1) signal.Notify(signalChan, syscall.SIGINT, syscall.SIGTERM) @@ -769,8 +798,11 @@ func contextWithExitSignalWatch() (context.Context, func()) { return } }() + closeOnce := sync.Once{} f := func() { - closeChan <- "goodbye" + closeOnce.Do(func() { + close(closeChan) + }) } return ctx, f } @@ -823,7 +855,11 @@ func runHTTPServer(mux *http.ServeMux, addr string) (close func() error) { go func() { if err := srv.Serve(ln); err != nil { - log.Fatalf("failed running server: %v", err) + if err != http.ErrServerClosed { + log.Fatalf("failed running server: %v", err) + } else { + log.Printf("HTTP server at %s closed", addr) + } } }() diff --git a/cmd/containerboot/main_test.go b/cmd/containerboot/main_test.go index c8066f2c1..bc158dac5 100644 --- a/cmd/containerboot/main_test.go +++ b/cmd/containerboot/main_test.go @@ -25,6 +25,7 @@ "strconv" "strings" "sync" + "syscall" "testing" "time" @@ -50,9 +51,7 @@ func TestContainerBoot(t *testing.T) { defer lapi.Close() kube := kubeServer{FSRoot: d} - if err := kube.Start(); err != nil { - t.Fatal(err) - } + kube.Start(t) defer kube.Close() tailscaledConf := &ipn.ConfigVAlpha{AuthKey: ptr.To("foo"), Version: "alpha0"} @@ -138,15 +137,29 @@ type phase struct { // WantCmds is the commands that containerboot should run in this phase. WantCmds []string + // WantKubeSecret is the secret keys/values that should exist in the // kube secret. WantKubeSecret map[string]string + + // Update the kube secret with these keys/values at the beginning of the + // phase (simulates our fake tailscaled doing it). + UpdateKubeSecret map[string]string + // WantFiles files that should exist in the container and their // contents. WantFiles map[string]string - // WantFatalLog is the fatal log message we expect from containerboot. - // If set for a phase, the test will finish on that phase. - WantFatalLog string + + // WantLog is a log message we expect from containerboot. + WantLog string + + // If set for a phase, the test will expect containerboot to exit with + // this error code, and the test will finish on that phase without + // waiting for the successful startup log message. + WantExitCode *int + + // The signal to send to containerboot at the start of the phase. + Signal *syscall.Signal EndpointStatuses map[string]int } @@ -434,7 +447,8 @@ type phase struct { }, }, }, - WantFatalLog: "no forwarding rules for egress addresses [::1/128], host supports IPv6: false", + WantLog: "no forwarding rules for egress addresses [::1/128], host supports IPv6: false", + WantExitCode: ptr.To(1), }, }, }, @@ -936,7 +950,64 @@ type phase struct { }, Phases: []phase{ { - WantFatalLog: "TS_EGRESS_PROXIES_CONFIG_PATH is only supported for Tailscale running on Kubernetes", + WantLog: "TS_EGRESS_PROXIES_CONFIG_PATH is only supported for Tailscale running on Kubernetes", + WantExitCode: ptr.To(1), + }, + }, + }, + { + Name: "kube_shutdown_during_state_write", + Env: map[string]string{ + "KUBERNETES_SERVICE_HOST": kube.Host, + "KUBERNETES_SERVICE_PORT_HTTPS": kube.Port, + "TS_ENABLE_HEALTH_CHECK": "true", + }, + KubeSecret: map[string]string{ + "authkey": "tskey-key", + }, + Phases: []phase{ + { + // Normal startup. + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=kube:tailscale --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + }, + WantKubeSecret: map[string]string{ + "authkey": "tskey-key", + }, + }, + { + // SIGTERM before state is finished writing, should wait for + // consistent state before propagating SIGTERM to tailscaled. + Signal: ptr.To(unix.SIGTERM), + UpdateKubeSecret: map[string]string{ + "_machinekey": "foo", + "_profiles": "foo", + "profile-baff": "foo", + // Missing "_current-profile" key. + }, + WantKubeSecret: map[string]string{ + "authkey": "tskey-key", + "_machinekey": "foo", + "_profiles": "foo", + "profile-baff": "foo", + }, + WantLog: "Waiting for tailscaled to finish writing state to Secret \"tailscale\"", + }, + { + // tailscaled has finished writing state, should propagate SIGTERM. + UpdateKubeSecret: map[string]string{ + "_current-profile": "foo", + }, + WantKubeSecret: map[string]string{ + "authkey": "tskey-key", + "_machinekey": "foo", + "_profiles": "foo", + "profile-baff": "foo", + "_current-profile": "foo", + }, + WantLog: "HTTP server at [::]:9002 closed", + WantExitCode: ptr.To(0), }, }, }, @@ -984,26 +1055,36 @@ type phase struct { var wantCmds []string for i, p := range test.Phases { + for k, v := range p.UpdateKubeSecret { + kube.SetSecret(k, v) + } lapi.Notify(p.Notify) - if p.WantFatalLog != "" { + if p.Signal != nil { + cmd.Process.Signal(*p.Signal) + } + if p.WantLog != "" { err := tstest.WaitFor(2*time.Second, func() error { - state, err := cmd.Process.Wait() - if err != nil { - return err - } - if state.ExitCode() != 1 { - return fmt.Errorf("process exited with code %d but wanted %d", state.ExitCode(), 1) - } - waitLogLine(t, time.Second, cbOut, p.WantFatalLog) + waitLogLine(t, time.Second, cbOut, p.WantLog) return nil }) if err != nil { t.Fatal(err) } + } + + if p.WantExitCode != nil { + state, err := cmd.Process.Wait() + if err != nil { + t.Fatal(err) + } + if state.ExitCode() != *p.WantExitCode { + t.Fatalf("phase %d: want exit code %d, got %d", i, *p.WantExitCode, state.ExitCode()) + } // Early test return, we don't expect the successful startup log message. return } + wantCmds = append(wantCmds, p.WantCmds...) waitArgs(t, 2*time.Second, d, argFile, strings.Join(wantCmds, "\n")) err := tstest.WaitFor(2*time.Second, func() error { @@ -1059,6 +1140,9 @@ type phase struct { } } waitLogLine(t, 2*time.Second, cbOut, "Startup complete, waiting for shutdown signal") + if cmd.ProcessState != nil { + t.Fatalf("containerboot should be running but exited with exit code %d", cmd.ProcessState.ExitCode()) + } }) } } @@ -1290,18 +1374,18 @@ func (k *kubeServer) Reset() { k.secret = map[string]string{} } -func (k *kubeServer) Start() error { +func (k *kubeServer) Start(t *testing.T) { root := filepath.Join(k.FSRoot, "var/run/secrets/kubernetes.io/serviceaccount") if err := os.MkdirAll(root, 0700); err != nil { - return err + t.Fatal(err) } if err := os.WriteFile(filepath.Join(root, "namespace"), []byte("default"), 0600); err != nil { - return err + t.Fatal(err) } if err := os.WriteFile(filepath.Join(root, "token"), []byte("bearer_token"), 0600); err != nil { - return err + t.Fatal(err) } k.srv = httptest.NewTLSServer(k) @@ -1310,13 +1394,11 @@ func (k *kubeServer) Start() error { var cert bytes.Buffer if err := pem.Encode(&cert, &pem.Block{Type: "CERTIFICATE", Bytes: k.srv.Certificate().Raw}); err != nil { - return err + t.Fatal(err) } if err := os.WriteFile(filepath.Join(root, "ca.crt"), cert.Bytes(), 0600); err != nil { - return err + t.Fatal(err) } - - return nil } func (k *kubeServer) Close() { @@ -1365,6 +1447,7 @@ func (k *kubeServer) serveSecret(w http.ResponseWriter, r *http.Request) { http.Error(w, fmt.Sprintf("reading request body: %v", err), http.StatusInternalServerError) return } + defer r.Body.Close() switch r.Method { case "GET": @@ -1397,12 +1480,13 @@ func (k *kubeServer) serveSecret(w http.ResponseWriter, r *http.Request) { panic(fmt.Sprintf("json decode failed: %v. Body:\n\n%s", err, string(bs))) } for _, op := range req { - if op.Op == "remove" { + switch op.Op { + case "remove": if !strings.HasPrefix(op.Path, "/data/") { panic(fmt.Sprintf("unsupported json-patch path %q", op.Path)) } delete(k.secret, strings.TrimPrefix(op.Path, "/data/")) - } else if op.Op == "replace" { + case "replace": path, ok := strings.CutPrefix(op.Path, "/data/") if !ok { panic(fmt.Sprintf("unsupported json-patch path %q", op.Path)) @@ -1419,7 +1503,7 @@ func (k *kubeServer) serveSecret(w http.ResponseWriter, r *http.Request) { } k.secret[path] = val } - } else { + default: panic(fmt.Sprintf("unsupported json-patch op %q", op.Op)) } } @@ -1437,7 +1521,7 @@ func (k *kubeServer) serveSecret(w http.ResponseWriter, r *http.Request) { panic(fmt.Sprintf("unknown content type %q", r.Header.Get("Content-Type"))) } default: - panic(fmt.Sprintf("unhandled HTTP method %q", r.Method)) + panic(fmt.Sprintf("unhandled HTTP request %s %s", r.Method, r.URL)) } } diff --git a/cmd/containerboot/tailscaled.go b/cmd/containerboot/tailscaled.go index fc2092477..1ff068b97 100644 --- a/cmd/containerboot/tailscaled.go +++ b/cmd/containerboot/tailscaled.go @@ -42,14 +42,14 @@ func startTailscaled(ctx context.Context, cfg *settings) (*tailscale.LocalClient log.Printf("Waiting for tailscaled socket") for { if ctx.Err() != nil { - log.Fatalf("Timed out waiting for tailscaled socket") + return nil, nil, errors.New("timed out waiting for tailscaled socket") } _, err := os.Stat(cfg.Socket) if errors.Is(err, fs.ErrNotExist) { time.Sleep(100 * time.Millisecond) continue } else if err != nil { - log.Fatalf("Waiting for tailscaled socket: %v", err) + return nil, nil, fmt.Errorf("error waiting for tailscaled socket: %w", err) } break }