diff --git a/cmd/derper/depaware.txt b/cmd/derper/depaware.txt index 780e9f2ee..fd2de6e8c 100644 --- a/cmd/derper/depaware.txt +++ b/cmd/derper/depaware.txt @@ -175,6 +175,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa golang.org/x/crypto/nacl/box from tailscale.com/types/key golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+ + W golang.org/x/exp/constraints from tailscale.com/util/winutil L golang.org/x/net/bpf from github.com/mdlayher/netlink+ golang.org/x/net/dns/dnsmessage from net+ golang.org/x/net/http/httpguts from net/http diff --git a/cmd/tailscale/depaware.txt b/cmd/tailscale/depaware.txt index eff138ac6..c0b626f13 100644 --- a/cmd/tailscale/depaware.txt +++ b/cmd/tailscale/depaware.txt @@ -182,7 +182,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box golang.org/x/crypto/pbkdf2 from software.sslmate.com/src/go-pkcs12 golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+ - W golang.org/x/exp/constraints from github.com/dblohm7/wingoes/pe + W golang.org/x/exp/constraints from github.com/dblohm7/wingoes/pe+ golang.org/x/exp/maps from tailscale.com/cmd/tailscale/cli golang.org/x/net/bpf from github.com/mdlayher/netlink+ golang.org/x/net/dns/dnsmessage from net+ diff --git a/util/winutil/winutil_windows.go b/util/winutil/winutil_windows.go index 30602d1de..f464d01d4 100644 --- a/util/winutil/winutil_windows.go +++ b/util/winutil/winutil_windows.go @@ -7,14 +7,17 @@ "errors" "fmt" "log" + "math" "os/exec" "os/user" + "reflect" "runtime" "strings" "syscall" "time" "unsafe" + "golang.org/x/exp/constraints" "golang.org/x/sys/windows" "golang.org/x/sys/windows/registry" ) @@ -643,3 +646,141 @@ func LogonSessionID(token windows.Token) (logonSessionID windows.LUID, err error return origin.originatingLogonSession, nil } + +// BufUnit is a type constraint for buffers passed into AllocateContiguousBuffer. +type BufUnit interface { + byte | uint16 +} + +// AllocateContiguousBuffer allocates memory to satisfy the Windows idiom where +// some structs contain pointers that are expected to refer to memory within the +// same buffer containing the struct itself. T is the type that contains +// the pointers. values must contain the actual data that is to be copied +// into the buffer after T. AllocateContiguousBuffer returns a pointer to the +// struct, the total length of the buffer in bytes, and a slice containing +// each value within the buffer. The caller may use slcs to populate any +// pointers in t as needed. Each element of slcs corresponds to the element of +// values in the same position. +// +// It is the responsibility of the caller to ensure that any values expected +// to contain null-terminated strings are in fact null-terminated! +// +// AllocateContiguousBuffer panics if no values are passed in, as there are +// better alternatives for allocating a struct in that case. +func AllocateContiguousBuffer[T any, BU BufUnit](values ...[]BU) (t *T, tLenBytes uint32, slcs [][]BU) { + if len(values) == 0 { + panic("len(values) must be > 0") + } + + // Get the sizes of T and BU, then compute a preferred alignment for T. + tT := reflect.TypeFor[T]() + szT := tT.Size() + szBU := int(unsafe.Sizeof(BU(0))) + alignment := max(tT.Align(), szBU) + + // Our buffers for values will start at the next szBU boundary. + tLenBytes = alignUp(uint32(szT), szBU) + firstValueOffset := tLenBytes + + // Accumulate the length of each value into tLenBytes + for _, v := range values { + tLenBytes += uint32(len(v) * szBU) + } + + // Now that we know the final length, align up to our preferred boundary. + tLenBytes = alignUp(tLenBytes, alignment) + + // Allocate the buffer. We choose a type for the slice that is appropriate + // for the desired alignment. Note that we do not have a strict requirement + // that T contain pointer fields; we could just be appending more data + // within the same buffer. + bufLen := tLenBytes / uint32(alignment) + var pt unsafe.Pointer + switch alignment { + case 1: + pt = unsafe.Pointer(unsafe.SliceData(make([]byte, bufLen))) + case 2: + pt = unsafe.Pointer(unsafe.SliceData(make([]uint16, bufLen))) + case 4: + pt = unsafe.Pointer(unsafe.SliceData(make([]uint32, bufLen))) + case 8: + pt = unsafe.Pointer(unsafe.SliceData(make([]uint64, bufLen))) + default: + panic(fmt.Sprintf("bad alignment %d", alignment)) + } + + t = (*T)(pt) + slcs = make([][]BU, 0, len(values)) + + // Use the limits of the buffer area after t to construct a slice representing the remaining buffer. + firstValuePtr := unsafe.Pointer(uintptr(pt) + uintptr(firstValueOffset)) + buf := unsafe.Slice((*BU)(firstValuePtr), (tLenBytes-firstValueOffset)/uint32(szBU)) + + // Copy each value into the buffer and record a slice describing each value's limits into slcs. + var index int + for _, v := range values { + if len(v) == 0 { + // We allow zero-length values; we simply append a nil slice. + slcs = append(slcs, nil) + continue + } + valueSlice := buf[index : index+len(v)] + copy(valueSlice, v) + slcs = append(slcs, valueSlice) + index += len(v) + } + + return t, tLenBytes, slcs +} + +// alignment must be a power of 2 +func alignUp[V constraints.Integer](v V, alignment int) V { + return v + ((-v) & (V(alignment) - 1)) +} + +// NTStr is a type constraint requiring the type to be either a +// windows.NTString or a windows.NTUnicodeString. +type NTStr interface { + windows.NTString | windows.NTUnicodeString +} + +// SetNTString sets the value of nts in-place to point to the string contained +// within buf. A nul terminator is optional in buf. +func SetNTString[NTS NTStr, BU BufUnit](nts *NTS, buf []BU) { + isEmpty := len(buf) == 0 + codeUnitSize := uint16(unsafe.Sizeof(BU(0))) + lenBytes := len(buf) * int(codeUnitSize) + if lenBytes > math.MaxUint16 { + panic("buffer length must fit into uint16") + } + lenBytes16 := uint16(lenBytes) + + switch p := any(nts).(type) { + case *windows.NTString: + if isEmpty { + *p = windows.NTString{} + break + } + p.Buffer = unsafe.SliceData(any(buf).([]byte)) + p.MaximumLength = lenBytes16 + p.Length = lenBytes16 + // account for nul terminator when present + if buf[len(buf)-1] == 0 { + p.Length -= codeUnitSize + } + case *windows.NTUnicodeString: + if isEmpty { + *p = windows.NTUnicodeString{} + break + } + p.Buffer = unsafe.SliceData(any(buf).([]uint16)) + p.MaximumLength = lenBytes16 + p.Length = lenBytes16 + // account for nul terminator when present + if buf[len(buf)-1] == 0 { + p.Length -= codeUnitSize + } + default: + panic("unknown type") + } +} diff --git a/util/winutil/winutil_windows_test.go b/util/winutil/winutil_windows_test.go index bf22d26ca..d437ffa38 100644 --- a/util/winutil/winutil_windows_test.go +++ b/util/winutil/winutil_windows_test.go @@ -4,9 +4,13 @@ package winutil import ( + "reflect" "testing" + "unsafe" ) +//lint:file-ignore U1000 Fields are unused but necessary for tests. + const ( localSystemSID = "S-1-5-18" networkSID = "S-1-5-2" @@ -28,3 +32,103 @@ func TestLookupPseudoUser(t *testing.T) { t.Errorf("LookupPseudoUser(%q) unexpectedly succeeded", networkSID) } } + +type testType interface { + byte | uint16 | uint32 | uint64 +} + +type noPointers[T testType] struct { + foo byte + bar T + baz bool +} + +type hasPointer struct { + foo byte + bar uint32 + s1 *struct{} + baz byte +} + +func checkContiguousBuffer[T any, BU BufUnit](t *testing.T, extra []BU, pt *T, ptLen uint32, slcs [][]BU) { + szBU := int(unsafe.Sizeof(BU(0))) + expectedAlign := max(reflect.TypeFor[T]().Align(), szBU) + // Check that pointer is aligned + if rem := uintptr(unsafe.Pointer(pt)) % uintptr(expectedAlign); rem != 0 { + t.Errorf("pointer alignment got %d, want 0", rem) + } + // Check that alloc length is aligned + if rem := int(ptLen) % expectedAlign; rem != 0 { + t.Errorf("allocation length alignment got %d, want 0", rem) + } + expectedLen := int(unsafe.Sizeof(*pt)) + expectedLen = alignUp(expectedLen, szBU) + expectedLen += len(extra) * szBU + expectedLen = alignUp(expectedLen, expectedAlign) + if gotLen := int(ptLen); gotLen != expectedLen { + t.Errorf("allocation length got %d, want %d", gotLen, expectedLen) + } + if l := len(slcs); l != 1 { + t.Errorf("len(slcs) got %d, want 1", l) + } + if len(extra) == 0 && slcs[0] != nil { + t.Error("slcs[0] got non-nil, want nil") + } + if len(extra) != len(slcs[0]) { + t.Errorf("len(slcs[0]) got %d, want %d", len(slcs[0]), len(extra)) + } else if rem := uintptr(unsafe.Pointer(unsafe.SliceData(slcs[0]))) % uintptr(szBU); rem != 0 { + t.Errorf("additional data alignment got %d, want 0", rem) + } +} + +func TestAllocateContiguousBuffer(t *testing.T) { + t.Run("NoValues", testNoValues) + t.Run("NoPointers", testNoPointers) + t.Run("HasPointer", testHasPointer) +} + +func testNoValues(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("expected panic but didn't get one") + } + }() + + AllocateContiguousBuffer[hasPointer, byte]() +} + +const maxTestBufLen = 8 + +func testNoPointers(t *testing.T) { + buf8 := make([]byte, maxTestBufLen) + buf16 := make([]uint16, maxTestBufLen) + for i := range maxTestBufLen { + s8, sl, slcs8 := AllocateContiguousBuffer[noPointers[byte]](buf8[:i]) + checkContiguousBuffer(t, buf8[:i], s8, sl, slcs8) + s16, sl, slcs8 := AllocateContiguousBuffer[noPointers[uint16]](buf8[:i]) + checkContiguousBuffer(t, buf8[:i], s16, sl, slcs8) + s32, sl, slcs8 := AllocateContiguousBuffer[noPointers[uint32]](buf8[:i]) + checkContiguousBuffer(t, buf8[:i], s32, sl, slcs8) + s64, sl, slcs8 := AllocateContiguousBuffer[noPointers[uint64]](buf8[:i]) + checkContiguousBuffer(t, buf8[:i], s64, sl, slcs8) + s8, sl, slcs16 := AllocateContiguousBuffer[noPointers[byte]](buf16[:i]) + checkContiguousBuffer(t, buf16[:i], s8, sl, slcs16) + s16, sl, slcs16 = AllocateContiguousBuffer[noPointers[uint16]](buf16[:i]) + checkContiguousBuffer(t, buf16[:i], s16, sl, slcs16) + s32, sl, slcs16 = AllocateContiguousBuffer[noPointers[uint32]](buf16[:i]) + checkContiguousBuffer(t, buf16[:i], s32, sl, slcs16) + s64, sl, slcs16 = AllocateContiguousBuffer[noPointers[uint64]](buf16[:i]) + checkContiguousBuffer(t, buf16[:i], s64, sl, slcs16) + } +} + +func testHasPointer(t *testing.T) { + buf8 := make([]byte, maxTestBufLen) + buf16 := make([]uint16, maxTestBufLen) + for i := range maxTestBufLen { + s, sl, slcs8 := AllocateContiguousBuffer[hasPointer](buf8[:i]) + checkContiguousBuffer(t, buf8[:i], s, sl, slcs8) + s, sl, slcs16 := AllocateContiguousBuffer[hasPointer](buf16[:i]) + checkContiguousBuffer(t, buf16[:i], s, sl, slcs16) + } +}