gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/sentry/devices/tpuproxy/tpu.go (about) 1 // Copyright 2023 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package tpuproxy implements proxying for TPU devices. 16 package tpuproxy 17 18 import ( 19 "fmt" 20 21 "golang.org/x/sys/unix" 22 "gvisor.dev/gvisor/pkg/abi/linux" 23 "gvisor.dev/gvisor/pkg/context" 24 "gvisor.dev/gvisor/pkg/errors/linuxerr" 25 "gvisor.dev/gvisor/pkg/fdnotifier" 26 "gvisor.dev/gvisor/pkg/hostarch" 27 "gvisor.dev/gvisor/pkg/marshal/primitive" 28 "gvisor.dev/gvisor/pkg/sentry/arch" 29 "gvisor.dev/gvisor/pkg/sentry/fsimpl/eventfd" 30 "gvisor.dev/gvisor/pkg/sentry/kernel" 31 "gvisor.dev/gvisor/pkg/sentry/mm" 32 "gvisor.dev/gvisor/pkg/sentry/vfs" 33 "gvisor.dev/gvisor/pkg/usermem" 34 "gvisor.dev/gvisor/pkg/waiter" 35 ) 36 37 const ( 38 // A value of -1 can be used to either de-assign interrupts if already 39 // assigned or skip un-assigned interrupts. 40 disableInterrupt = -1 41 ) 42 43 var ( 44 // vfioDeviceInfoFlags contains all available flags for 45 // IOCTL command VFIO_DEVICE_GET_INFO. 46 vfioDeviceInfoFlags uint32 = linux.VFIO_DEVICE_FLAGS_RESET | linux.VFIO_DEVICE_FLAGS_PCI | 47 linux.VFIO_DEVICE_FLAGS_PLATFORM | linux.VFIO_DEVICE_FLAGS_AMBA | 48 linux.VFIO_DEVICE_FLAGS_CCW | linux.VFIO_DEVICE_FLAGS_AP | linux.VFIO_DEVICE_FLAGS_FSL_MC | 49 linux.VFIO_DEVICE_FLAGS_CAPS | linux.VFIO_DEVICE_FLAGS_CDX 50 // vfioIrqSetFlags includes all available flags for IOCTL comamnd VFIO_DEVICE_SET_IRQS 51 vfioIrqSetFlags uint32 = linux.VFIO_IRQ_SET_DATA_TYPE_MASK | linux.VFIO_IRQ_SET_ACTION_TYPE_MASK 52 ) 53 54 // tpuFD implements vfs.FileDescriptionImpl for /dev/vfio/[0-9]+ 55 // 56 // tpuFD is not savable until TPU save/restore is needed. 57 type tpuFD struct { 58 vfsfd vfs.FileDescription 59 vfs.FileDescriptionDefaultImpl 60 vfs.DentryMetadataFileDescriptionImpl 61 vfs.NoLockFD 62 63 hostFD int32 64 device *tpuDevice 65 queue waiter.Queue 66 memmapFile tpuFDMemmapFile 67 } 68 69 // Release implements vfs.FileDescriptionImpl.Release. 70 func (fd *tpuFD) Release(context.Context) { 71 fdnotifier.RemoveFD(fd.hostFD) 72 fd.queue.Notify(waiter.EventHUp) 73 unix.Close(int(fd.hostFD)) 74 } 75 76 // EventRegister implements waiter.Waitable.EventRegister. 77 func (fd *tpuFD) EventRegister(e *waiter.Entry) error { 78 fd.queue.EventRegister(e) 79 if err := fdnotifier.UpdateFD(fd.hostFD); err != nil { 80 fd.queue.EventUnregister(e) 81 return err 82 } 83 return nil 84 } 85 86 // EventUnregister implements waiter.Waitable.EventUnregister. 87 func (fd *tpuFD) EventUnregister(e *waiter.Entry) { 88 fd.queue.EventUnregister(e) 89 if err := fdnotifier.UpdateFD(fd.hostFD); err != nil { 90 panic(fmt.Sprint("UpdateFD:", err)) 91 } 92 } 93 94 // Readiness implements waiter.Waitable.Readiness. 95 func (fd *tpuFD) Readiness(mask waiter.EventMask) waiter.EventMask { 96 return fdnotifier.NonBlockingPoll(fd.hostFD, mask) 97 } 98 99 // Epollable implements vfs.FileDescriptionImpl.Epollable. 100 func (fd *tpuFD) Epollable() bool { 101 return true 102 } 103 104 // Ioctl implements vfs.FileDescriptionImpl.Ioctl. 105 func (fd *tpuFD) Ioctl(ctx context.Context, uio usermem.IO, sysno uintptr, args arch.SyscallArguments) (uintptr, error) { 106 cmd := args[1].Uint() 107 108 t := kernel.TaskFromContext(ctx) 109 if t == nil { 110 panic("Ioctl should be called from a task context") 111 } 112 switch cmd { 113 case linux.VFIO_GROUP_SET_CONTAINER: 114 return fd.setContainer(ctx, t, args[2].Pointer()) 115 case linux.VFIO_GROUP_GET_DEVICE_FD: 116 ret, cleanup, err := fd.getPciDeviceFd(t, args[2].Pointer()) 117 defer cleanup() 118 return ret, err 119 } 120 return 0, linuxerr.ENOSYS 121 } 122 123 func (fd *tpuFD) setContainer(ctx context.Context, t *kernel.Task, arg hostarch.Addr) (uintptr, error) { 124 var vfioContainerFD int32 125 if _, err := primitive.CopyInt32In(t, arg, &vfioContainerFD); err != nil { 126 return 0, err 127 } 128 vfioContainerFile, _ := t.FDTable().Get(vfioContainerFD) 129 if vfioContainerFile == nil { 130 return 0, linuxerr.EBADF 131 } 132 defer vfioContainerFile.DecRef(ctx) 133 vfioContainer, ok := vfioContainerFile.Impl().(*vfioFD) 134 if !ok { 135 return 0, linuxerr.EINVAL 136 } 137 return IOCTLInvokePtrArg[uint32](fd.hostFD, linux.VFIO_GROUP_SET_CONTAINER, &vfioContainer.hostFD) 138 } 139 140 // It will be the caller's responsibility to call the returned cleanup function. 141 func (fd *tpuFD) getPciDeviceFd(t *kernel.Task, arg hostarch.Addr) (uintptr, func(), error) { 142 pciAddress, err := t.CopyInString(arg, hostarch.PageSize) 143 if err != nil { 144 return 0, func() {}, err 145 } 146 // Build a NUL-terminated slice of bytes containing the PCI address. 147 pciAddressBytes, err := unix.ByteSliceFromString(pciAddress) 148 if err != nil { 149 return 0, func() {}, err 150 } 151 // Pass the address of the PCI address' first byte which can be 152 // recognized by the IOCTL syscall. 153 hostFD, err := IOCTLInvokePtrArg[uint32](fd.hostFD, linux.VFIO_GROUP_GET_DEVICE_FD, &pciAddressBytes[0]) 154 if err != nil { 155 return 0, func() {}, err 156 } 157 pciDevFD := &pciDeviceFD{ 158 hostFD: int32(hostFD), 159 } 160 cleanup := func() { 161 unix.Close(int(hostFD)) 162 } 163 // See drivers/vfio/group.c:vfio_device_open_file(), the PCI device 164 // is accessed for both reads and writes. 165 vd := t.Kernel().VFS().NewAnonVirtualDentry("[vfio-device]") 166 if err := pciDevFD.vfsfd.Init(pciDevFD, linux.O_RDWR, vd.Mount(), vd.Dentry(), &vfs.FileDescriptionOptions{ 167 UseDentryMetadata: true, 168 }); err != nil { 169 return 0, cleanup, err 170 } 171 if err := fdnotifier.AddFD(int32(hostFD), &fd.queue); err != nil { 172 return 0, cleanup, err 173 } 174 newFD, err := t.NewFDFrom(0, &pciDevFD.vfsfd, kernel.FDFlags{}) 175 if err != nil { 176 return 0, cleanup, err 177 } 178 // Initialize a mapping that is backed by a host FD. 179 pciDevFD.memmapFile.fd = pciDevFD 180 return uintptr(newFD), func() {}, nil 181 } 182 183 // pciDeviceFD implements vfs.FileDescriptionImpl for TPU's PCI device. 184 type pciDeviceFD struct { 185 vfsfd vfs.FileDescription 186 vfs.FileDescriptionDefaultImpl 187 vfs.DentryMetadataFileDescriptionImpl 188 vfs.NoLockFD 189 190 hostFD int32 191 queue waiter.Queue 192 memmapFile pciDeviceFdMemmapFile 193 } 194 195 // Release implements vfs.FileDescriptionImpl.Release. 196 func (fd *pciDeviceFD) Release(context.Context) { 197 fdnotifier.RemoveFD(fd.hostFD) 198 fd.queue.Notify(waiter.EventHUp) 199 unix.Close(int(fd.hostFD)) 200 } 201 202 // EventRegister implements waiter.Waitable.EventRegister. 203 func (fd *pciDeviceFD) EventRegister(e *waiter.Entry) error { 204 fd.queue.EventRegister(e) 205 if err := fdnotifier.UpdateFD(fd.hostFD); err != nil { 206 fd.queue.EventUnregister(e) 207 return err 208 } 209 return nil 210 } 211 212 // EventUnregister implements waiter.Waitable.EventUnregister. 213 func (fd *pciDeviceFD) EventUnregister(e *waiter.Entry) { 214 fd.queue.EventUnregister(e) 215 if err := fdnotifier.UpdateFD(fd.hostFD); err != nil { 216 panic(fmt.Sprint("UpdateFD:", err)) 217 } 218 } 219 220 // Readiness implements waiter.Waitable.Readiness. 221 func (fd *pciDeviceFD) Readiness(mask waiter.EventMask) waiter.EventMask { 222 return fdnotifier.NonBlockingPoll(fd.hostFD, mask) 223 } 224 225 // Epollable implements vfs.FileDescriptionImpl.Epollable. 226 func (fd *pciDeviceFD) Epollable() bool { 227 return true 228 } 229 230 // Ioctl implements vfs.FileDescriptionImpl.Ioctl. 231 func (fd *pciDeviceFD) Ioctl(ctx context.Context, uio usermem.IO, sysno uintptr, args arch.SyscallArguments) (uintptr, error) { 232 cmd := args[1].Uint() 233 234 t := kernel.TaskFromContext(ctx) 235 if t == nil { 236 panic("Ioctl should be called from a task context") 237 } 238 switch cmd { 239 // TODO(b/299303493): consider making VFIO's GET_INFO commands more generic. 240 case linux.VFIO_DEVICE_GET_INFO: 241 return fd.vfioDeviceInfo(ctx, t, args[2].Pointer()) 242 case linux.VFIO_DEVICE_GET_REGION_INFO: 243 return fd.vfioRegionInfo(ctx, t, args[2].Pointer()) 244 case linux.VFIO_DEVICE_GET_IRQ_INFO: 245 return fd.vfioIrqInfo(ctx, t, args[2].Pointer()) 246 case linux.VFIO_DEVICE_SET_IRQS: 247 return fd.vfioSetIrqs(ctx, t, args[2].Pointer()) 248 case linux.VFIO_DEVICE_RESET: 249 // VFIO_DEVICE_RESET is just a simple IOCTL command that carries no data. 250 return IOCTLInvoke[uint32, uintptr](fd.hostFD, linux.VFIO_DEVICE_RESET, 0) 251 } 252 return 0, linuxerr.ENOSYS 253 } 254 255 // Retrieve the host TPU device's region information, which could be used by 256 // vfio driver to setup mappings. 257 func (fd *pciDeviceFD) vfioRegionInfo(ctx context.Context, t *kernel.Task, arg hostarch.Addr) (uintptr, error) { 258 var regionInfo linux.VFIORegionInfo 259 if _, err := regionInfo.CopyIn(t, arg); err != nil { 260 return 0, err 261 } 262 if regionInfo.Argsz == 0 { 263 return 0, linuxerr.EINVAL 264 } 265 ret, err := IOCTLInvokePtrArg[uint32](fd.hostFD, linux.VFIO_DEVICE_GET_REGION_INFO, ®ionInfo) 266 if err != nil { 267 return 0, err 268 } 269 if _, err := regionInfo.CopyOut(t, arg); err != nil { 270 return 0, err 271 } 272 return ret, nil 273 } 274 275 // Retrieve the host TPU device's information. 276 func (fd *pciDeviceFD) vfioDeviceInfo(ctx context.Context, t *kernel.Task, arg hostarch.Addr) (uintptr, error) { 277 var deviceInfo linux.VFIODeviceInfo 278 if _, err := deviceInfo.CopyIn(t, arg); err != nil { 279 return 0, err 280 } 281 // Callers must set VFIODeviceInfo.Argsz. 282 if deviceInfo.Argsz == 0 { 283 return 0, linuxerr.EINVAL 284 } 285 if deviceInfo.Flags&^vfioDeviceInfoFlags != 0 { 286 return 0, linuxerr.EINVAL 287 } 288 ret, err := IOCTLInvokePtrArg[uint32](fd.hostFD, linux.VFIO_DEVICE_GET_INFO, &deviceInfo) 289 if err != nil { 290 return 0, err 291 } 292 // gVisor is not supposed to change any device information that is 293 // returned from the host since gVisor doesn't own the device. 294 // Passing the device info back to the caller will be just fine. 295 if _, err := deviceInfo.CopyOut(t, arg); err != nil { 296 return 0, err 297 } 298 return ret, nil 299 } 300 301 // Retrieve the device's interrupt information. 302 func (fd *pciDeviceFD) vfioIrqInfo(ctx context.Context, t *kernel.Task, arg hostarch.Addr) (uintptr, error) { 303 var irqInfo linux.VFIOIrqInfo 304 if _, err := irqInfo.CopyIn(t, arg); err != nil { 305 return 0, err 306 } 307 // Callers must set the payload's size. 308 if irqInfo.Argsz == 0 { 309 return 0, linuxerr.EINVAL 310 } 311 ret, err := IOCTLInvokePtrArg[uint32](fd.hostFD, linux.VFIO_DEVICE_GET_IRQ_INFO, &irqInfo) 312 if err != nil { 313 return 0, err 314 } 315 if _, err := irqInfo.CopyOut(t, arg); err != nil { 316 return 0, err 317 } 318 return ret, nil 319 } 320 321 func (fd *pciDeviceFD) vfioSetIrqs(ctx context.Context, t *kernel.Task, arg hostarch.Addr) (uintptr, error) { 322 var irqSet linux.VFIOIrqSet 323 if _, err := irqSet.CopyIn(t, arg); err != nil { 324 return 0, err 325 } 326 // Callers must set the payload's size. 327 if irqSet.Argsz == 0 { 328 return 0, linuxerr.EINVAL 329 } 330 // Invalidate unknown flags. 331 if irqSet.Flags&^vfioIrqSetFlags != 0 { 332 return 0, linuxerr.EINVAL 333 } 334 // See drivers/vfio/vfio_main.c:vfio_set_irqs_validate_and_prepare, 335 // VFIO uses the data type at the request's flags to determine 336 // the memory layout of data field. 337 // 338 // The struct vfio_irq_set includes a flexible array member, it 339 // allocates an array for a continuous trunk of memory to back 340 // a vfio_irq_set object. In order to mirror that behavior, gVisor 341 // would allocate a slice to store the underlying bytes 342 // and pass that through to its host. 343 switch irqSet.Flags & linux.VFIO_IRQ_SET_DATA_TYPE_MASK { 344 // VFIO_IRQ_SET_DATA_NONE indicates there is no data field for 345 // the IOCTL command. 346 // It works with VFIO_IRQ_SET_ACTION_MASK, VFIO_IRQ_SET_ACTION_UNMASK, 347 // or VFIO_IRQ_SET_ACTION_TRIGGER to mask an interrupt, unmask an 348 // interrupt, and trigger an interrupt unconditionally. 349 case linux.VFIO_IRQ_SET_DATA_NONE: 350 // When there is no data, passing through the given payload 351 // works just fine. 352 return IOCTLInvokePtrArg[uint32](fd.hostFD, linux.VFIO_DEVICE_SET_IRQS, &irqSet) 353 // VFIO_IRQ_SET_DATA_BOOL indicates that the data field is an array of uint8. 354 // The action will be performed if the corresponding boolean is true. 355 case linux.VFIO_IRQ_SET_DATA_BOOL: 356 payloadSize := uint32(irqSet.Size()) + irqSet.Count 357 payload := make([]uint8, payloadSize) 358 if _, err := primitive.CopyUint8SliceIn(t, arg, payload); err != nil { 359 return 0, err 360 } 361 return IOCTLInvokePtrArg[uint32](fd.hostFD, linux.VFIO_DEVICE_SET_IRQS, &payload[0]) 362 // VFIO_IRQ_SET_DATA_EVENTFD indicates that the data field is an array 363 // of int32 (or event file descriptors). These descriptors will be 364 // signalled when an action in the flags happens. 365 case linux.VFIO_IRQ_SET_DATA_EVENTFD: 366 payloadSize := uint32(irqSet.Size())/4 + irqSet.Count 367 payload := make([]int32, payloadSize) 368 if _, err := primitive.CopyInt32SliceIn(t, arg, payload); err != nil { 369 return 0, err 370 } 371 // Transform the input FDs to host FDs. 372 for i := 0; i < int(irqSet.Count); i++ { 373 index := len(payload) - 1 - i 374 fd := payload[index] 375 // Skip non-event FD. 376 if fd == disableInterrupt { 377 continue 378 } 379 eventFileGeneric, _ := t.FDTable().Get(fd) 380 if eventFileGeneric == nil { 381 return 0, linuxerr.EBADF 382 } 383 defer eventFileGeneric.DecRef(ctx) 384 eventFile, ok := eventFileGeneric.Impl().(*eventfd.EventFileDescription) 385 if !ok { 386 return 0, linuxerr.EINVAL 387 } 388 eventfd, err := eventFile.HostFD() 389 if err != nil { 390 return 0, err 391 } 392 payload[index] = int32(eventfd) 393 } 394 return IOCTLInvokePtrArg[uint32](fd.hostFD, linux.VFIO_DEVICE_SET_IRQS, &payload[0]) 395 } 396 // No data type is specified or multiple data types are specified. 397 return 0, linuxerr.EINVAL 398 } 399 400 // PRead implements vfs.FileDescriptionImpl.PRead. 401 func (fd *pciDeviceFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { 402 if offset < 0 { 403 return 0, linuxerr.EINVAL 404 } 405 buf := make([]byte, dst.NumBytes()) 406 _, err := unix.Pread(int(fd.hostFD), buf, offset) 407 if err != nil { 408 return 0, err 409 } 410 n, err := dst.CopyOut(ctx, buf) 411 return int64(n), err 412 } 413 414 // PWrite implements vfs.FileDescriptionImpl.PWrite. 415 func (fd *pciDeviceFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { 416 if offset < 0 { 417 return 0, linuxerr.EINVAL 418 } 419 buf := make([]byte, src.NumBytes()) 420 _, err := src.CopyIn(ctx, buf) 421 if err != nil { 422 return 0, err 423 } 424 n, err := unix.Pwrite(int(fd.hostFD), buf, offset) 425 return int64(n), err 426 } 427 428 // DevAddrSet tracks device address ranges that have been mapped. 429 type devAddrSetFuncs struct{} 430 431 func (devAddrSetFuncs) MinKey() uint64 { 432 return 0 433 } 434 435 func (devAddrSetFuncs) MaxKey() uint64 { 436 return ^uint64(0) 437 } 438 439 func (devAddrSetFuncs) ClearValue(val *mm.PinnedRange) { 440 *val = mm.PinnedRange{} 441 } 442 443 func (devAddrSetFuncs) Merge(r1 DevAddrRange, v1 mm.PinnedRange, r2 DevAddrRange, v2 mm.PinnedRange) (mm.PinnedRange, bool) { 444 // Do we have the same backing file? 445 if v1.File != v2.File { 446 return mm.PinnedRange{}, false 447 } 448 449 // Do we have contiguous offsets in the backing file? 450 if v1.Offset+uint64(v1.Source.Length()) != v2.Offset { 451 return mm.PinnedRange{}, false 452 } 453 454 // Are the virtual addresses contiguous? 455 // 456 // This check isn't strictly needed because 'mm.PinnedRange.Source' 457 // is only used to track the size of the pinned region (this is 458 // because the virtual address range can be unmapped or remapped 459 // elsewhere). Regardless we require this for simplicity. 460 if v1.Source.End != v2.Source.Start { 461 return mm.PinnedRange{}, false 462 } 463 464 // Extend v1 to account for the adjacent PinnedRange. 465 v1.Source.End = v2.Source.End 466 return v1, true 467 } 468 469 func (devAddrSetFuncs) Split(r DevAddrRange, val mm.PinnedRange, split uint64) (mm.PinnedRange, mm.PinnedRange) { 470 n := split - r.Start 471 472 left := val 473 left.Source.End = left.Source.Start + hostarch.Addr(n) 474 475 right := val 476 right.Source.Start += hostarch.Addr(n) 477 right.Offset += n 478 479 return left, right 480 }