gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/sentry/devices/tpuproxy/tpu.go (about)

     1  // Copyright 2023 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package tpuproxy implements proxying for TPU devices.
    16  package tpuproxy
    17  
    18  import (
    19  	"fmt"
    20  
    21  	"golang.org/x/sys/unix"
    22  	"gvisor.dev/gvisor/pkg/abi/linux"
    23  	"gvisor.dev/gvisor/pkg/context"
    24  	"gvisor.dev/gvisor/pkg/errors/linuxerr"
    25  	"gvisor.dev/gvisor/pkg/fdnotifier"
    26  	"gvisor.dev/gvisor/pkg/hostarch"
    27  	"gvisor.dev/gvisor/pkg/marshal/primitive"
    28  	"gvisor.dev/gvisor/pkg/sentry/arch"
    29  	"gvisor.dev/gvisor/pkg/sentry/fsimpl/eventfd"
    30  	"gvisor.dev/gvisor/pkg/sentry/kernel"
    31  	"gvisor.dev/gvisor/pkg/sentry/mm"
    32  	"gvisor.dev/gvisor/pkg/sentry/vfs"
    33  	"gvisor.dev/gvisor/pkg/usermem"
    34  	"gvisor.dev/gvisor/pkg/waiter"
    35  )
    36  
    37  const (
    38  	// A value of -1 can be used to either de-assign interrupts if already
    39  	// assigned or skip un-assigned interrupts.
    40  	disableInterrupt = -1
    41  )
    42  
    43  var (
    44  	// vfioDeviceInfoFlags contains all available flags for
    45  	// IOCTL command VFIO_DEVICE_GET_INFO.
    46  	vfioDeviceInfoFlags uint32 = linux.VFIO_DEVICE_FLAGS_RESET | linux.VFIO_DEVICE_FLAGS_PCI |
    47  		linux.VFIO_DEVICE_FLAGS_PLATFORM | linux.VFIO_DEVICE_FLAGS_AMBA |
    48  		linux.VFIO_DEVICE_FLAGS_CCW | linux.VFIO_DEVICE_FLAGS_AP | linux.VFIO_DEVICE_FLAGS_FSL_MC |
    49  		linux.VFIO_DEVICE_FLAGS_CAPS | linux.VFIO_DEVICE_FLAGS_CDX
    50  	// vfioIrqSetFlags includes all available flags for IOCTL comamnd VFIO_DEVICE_SET_IRQS
    51  	vfioIrqSetFlags uint32 = linux.VFIO_IRQ_SET_DATA_TYPE_MASK | linux.VFIO_IRQ_SET_ACTION_TYPE_MASK
    52  )
    53  
    54  // tpuFD implements vfs.FileDescriptionImpl for /dev/vfio/[0-9]+
    55  //
    56  // tpuFD is not savable until TPU save/restore is needed.
    57  type tpuFD struct {
    58  	vfsfd vfs.FileDescription
    59  	vfs.FileDescriptionDefaultImpl
    60  	vfs.DentryMetadataFileDescriptionImpl
    61  	vfs.NoLockFD
    62  
    63  	hostFD     int32
    64  	device     *tpuDevice
    65  	queue      waiter.Queue
    66  	memmapFile tpuFDMemmapFile
    67  }
    68  
    69  // Release implements vfs.FileDescriptionImpl.Release.
    70  func (fd *tpuFD) Release(context.Context) {
    71  	fdnotifier.RemoveFD(fd.hostFD)
    72  	fd.queue.Notify(waiter.EventHUp)
    73  	unix.Close(int(fd.hostFD))
    74  }
    75  
    76  // EventRegister implements waiter.Waitable.EventRegister.
    77  func (fd *tpuFD) EventRegister(e *waiter.Entry) error {
    78  	fd.queue.EventRegister(e)
    79  	if err := fdnotifier.UpdateFD(fd.hostFD); err != nil {
    80  		fd.queue.EventUnregister(e)
    81  		return err
    82  	}
    83  	return nil
    84  }
    85  
    86  // EventUnregister implements waiter.Waitable.EventUnregister.
    87  func (fd *tpuFD) EventUnregister(e *waiter.Entry) {
    88  	fd.queue.EventUnregister(e)
    89  	if err := fdnotifier.UpdateFD(fd.hostFD); err != nil {
    90  		panic(fmt.Sprint("UpdateFD:", err))
    91  	}
    92  }
    93  
    94  // Readiness implements waiter.Waitable.Readiness.
    95  func (fd *tpuFD) Readiness(mask waiter.EventMask) waiter.EventMask {
    96  	return fdnotifier.NonBlockingPoll(fd.hostFD, mask)
    97  }
    98  
    99  // Epollable implements vfs.FileDescriptionImpl.Epollable.
   100  func (fd *tpuFD) Epollable() bool {
   101  	return true
   102  }
   103  
   104  // Ioctl implements vfs.FileDescriptionImpl.Ioctl.
   105  func (fd *tpuFD) Ioctl(ctx context.Context, uio usermem.IO, sysno uintptr, args arch.SyscallArguments) (uintptr, error) {
   106  	cmd := args[1].Uint()
   107  
   108  	t := kernel.TaskFromContext(ctx)
   109  	if t == nil {
   110  		panic("Ioctl should be called from a task context")
   111  	}
   112  	switch cmd {
   113  	case linux.VFIO_GROUP_SET_CONTAINER:
   114  		return fd.setContainer(ctx, t, args[2].Pointer())
   115  	case linux.VFIO_GROUP_GET_DEVICE_FD:
   116  		ret, cleanup, err := fd.getPciDeviceFd(t, args[2].Pointer())
   117  		defer cleanup()
   118  		return ret, err
   119  	}
   120  	return 0, linuxerr.ENOSYS
   121  }
   122  
   123  func (fd *tpuFD) setContainer(ctx context.Context, t *kernel.Task, arg hostarch.Addr) (uintptr, error) {
   124  	var vfioContainerFD int32
   125  	if _, err := primitive.CopyInt32In(t, arg, &vfioContainerFD); err != nil {
   126  		return 0, err
   127  	}
   128  	vfioContainerFile, _ := t.FDTable().Get(vfioContainerFD)
   129  	if vfioContainerFile == nil {
   130  		return 0, linuxerr.EBADF
   131  	}
   132  	defer vfioContainerFile.DecRef(ctx)
   133  	vfioContainer, ok := vfioContainerFile.Impl().(*vfioFD)
   134  	if !ok {
   135  		return 0, linuxerr.EINVAL
   136  	}
   137  	return IOCTLInvokePtrArg[uint32](fd.hostFD, linux.VFIO_GROUP_SET_CONTAINER, &vfioContainer.hostFD)
   138  }
   139  
   140  // It will be the caller's responsibility to call the returned cleanup function.
   141  func (fd *tpuFD) getPciDeviceFd(t *kernel.Task, arg hostarch.Addr) (uintptr, func(), error) {
   142  	pciAddress, err := t.CopyInString(arg, hostarch.PageSize)
   143  	if err != nil {
   144  		return 0, func() {}, err
   145  	}
   146  	// Build a NUL-terminated slice of bytes containing the PCI address.
   147  	pciAddressBytes, err := unix.ByteSliceFromString(pciAddress)
   148  	if err != nil {
   149  		return 0, func() {}, err
   150  	}
   151  	// Pass the address of the PCI address' first byte which can be
   152  	// recognized by the IOCTL syscall.
   153  	hostFD, err := IOCTLInvokePtrArg[uint32](fd.hostFD, linux.VFIO_GROUP_GET_DEVICE_FD, &pciAddressBytes[0])
   154  	if err != nil {
   155  		return 0, func() {}, err
   156  	}
   157  	pciDevFD := &pciDeviceFD{
   158  		hostFD: int32(hostFD),
   159  	}
   160  	cleanup := func() {
   161  		unix.Close(int(hostFD))
   162  	}
   163  	// See drivers/vfio/group.c:vfio_device_open_file(), the PCI device
   164  	// is accessed for both reads and writes.
   165  	vd := t.Kernel().VFS().NewAnonVirtualDentry("[vfio-device]")
   166  	if err := pciDevFD.vfsfd.Init(pciDevFD, linux.O_RDWR, vd.Mount(), vd.Dentry(), &vfs.FileDescriptionOptions{
   167  		UseDentryMetadata: true,
   168  	}); err != nil {
   169  		return 0, cleanup, err
   170  	}
   171  	if err := fdnotifier.AddFD(int32(hostFD), &fd.queue); err != nil {
   172  		return 0, cleanup, err
   173  	}
   174  	newFD, err := t.NewFDFrom(0, &pciDevFD.vfsfd, kernel.FDFlags{})
   175  	if err != nil {
   176  		return 0, cleanup, err
   177  	}
   178  	// Initialize a mapping that is backed by a host FD.
   179  	pciDevFD.memmapFile.fd = pciDevFD
   180  	return uintptr(newFD), func() {}, nil
   181  }
   182  
   183  // pciDeviceFD implements vfs.FileDescriptionImpl for TPU's PCI device.
   184  type pciDeviceFD struct {
   185  	vfsfd vfs.FileDescription
   186  	vfs.FileDescriptionDefaultImpl
   187  	vfs.DentryMetadataFileDescriptionImpl
   188  	vfs.NoLockFD
   189  
   190  	hostFD     int32
   191  	queue      waiter.Queue
   192  	memmapFile pciDeviceFdMemmapFile
   193  }
   194  
   195  // Release implements vfs.FileDescriptionImpl.Release.
   196  func (fd *pciDeviceFD) Release(context.Context) {
   197  	fdnotifier.RemoveFD(fd.hostFD)
   198  	fd.queue.Notify(waiter.EventHUp)
   199  	unix.Close(int(fd.hostFD))
   200  }
   201  
   202  // EventRegister implements waiter.Waitable.EventRegister.
   203  func (fd *pciDeviceFD) EventRegister(e *waiter.Entry) error {
   204  	fd.queue.EventRegister(e)
   205  	if err := fdnotifier.UpdateFD(fd.hostFD); err != nil {
   206  		fd.queue.EventUnregister(e)
   207  		return err
   208  	}
   209  	return nil
   210  }
   211  
   212  // EventUnregister implements waiter.Waitable.EventUnregister.
   213  func (fd *pciDeviceFD) EventUnregister(e *waiter.Entry) {
   214  	fd.queue.EventUnregister(e)
   215  	if err := fdnotifier.UpdateFD(fd.hostFD); err != nil {
   216  		panic(fmt.Sprint("UpdateFD:", err))
   217  	}
   218  }
   219  
   220  // Readiness implements waiter.Waitable.Readiness.
   221  func (fd *pciDeviceFD) Readiness(mask waiter.EventMask) waiter.EventMask {
   222  	return fdnotifier.NonBlockingPoll(fd.hostFD, mask)
   223  }
   224  
   225  // Epollable implements vfs.FileDescriptionImpl.Epollable.
   226  func (fd *pciDeviceFD) Epollable() bool {
   227  	return true
   228  }
   229  
   230  // Ioctl implements vfs.FileDescriptionImpl.Ioctl.
   231  func (fd *pciDeviceFD) Ioctl(ctx context.Context, uio usermem.IO, sysno uintptr, args arch.SyscallArguments) (uintptr, error) {
   232  	cmd := args[1].Uint()
   233  
   234  	t := kernel.TaskFromContext(ctx)
   235  	if t == nil {
   236  		panic("Ioctl should be called from a task context")
   237  	}
   238  	switch cmd {
   239  	// TODO(b/299303493): consider making VFIO's GET_INFO commands more generic.
   240  	case linux.VFIO_DEVICE_GET_INFO:
   241  		return fd.vfioDeviceInfo(ctx, t, args[2].Pointer())
   242  	case linux.VFIO_DEVICE_GET_REGION_INFO:
   243  		return fd.vfioRegionInfo(ctx, t, args[2].Pointer())
   244  	case linux.VFIO_DEVICE_GET_IRQ_INFO:
   245  		return fd.vfioIrqInfo(ctx, t, args[2].Pointer())
   246  	case linux.VFIO_DEVICE_SET_IRQS:
   247  		return fd.vfioSetIrqs(ctx, t, args[2].Pointer())
   248  	case linux.VFIO_DEVICE_RESET:
   249  		// VFIO_DEVICE_RESET is just a simple IOCTL command that carries no data.
   250  		return IOCTLInvoke[uint32, uintptr](fd.hostFD, linux.VFIO_DEVICE_RESET, 0)
   251  	}
   252  	return 0, linuxerr.ENOSYS
   253  }
   254  
   255  // Retrieve the host TPU device's region information, which could be used by
   256  // vfio driver to setup mappings.
   257  func (fd *pciDeviceFD) vfioRegionInfo(ctx context.Context, t *kernel.Task, arg hostarch.Addr) (uintptr, error) {
   258  	var regionInfo linux.VFIORegionInfo
   259  	if _, err := regionInfo.CopyIn(t, arg); err != nil {
   260  		return 0, err
   261  	}
   262  	if regionInfo.Argsz == 0 {
   263  		return 0, linuxerr.EINVAL
   264  	}
   265  	ret, err := IOCTLInvokePtrArg[uint32](fd.hostFD, linux.VFIO_DEVICE_GET_REGION_INFO, &regionInfo)
   266  	if err != nil {
   267  		return 0, err
   268  	}
   269  	if _, err := regionInfo.CopyOut(t, arg); err != nil {
   270  		return 0, err
   271  	}
   272  	return ret, nil
   273  }
   274  
   275  // Retrieve the host TPU device's information.
   276  func (fd *pciDeviceFD) vfioDeviceInfo(ctx context.Context, t *kernel.Task, arg hostarch.Addr) (uintptr, error) {
   277  	var deviceInfo linux.VFIODeviceInfo
   278  	if _, err := deviceInfo.CopyIn(t, arg); err != nil {
   279  		return 0, err
   280  	}
   281  	// Callers must set VFIODeviceInfo.Argsz.
   282  	if deviceInfo.Argsz == 0 {
   283  		return 0, linuxerr.EINVAL
   284  	}
   285  	if deviceInfo.Flags&^vfioDeviceInfoFlags != 0 {
   286  		return 0, linuxerr.EINVAL
   287  	}
   288  	ret, err := IOCTLInvokePtrArg[uint32](fd.hostFD, linux.VFIO_DEVICE_GET_INFO, &deviceInfo)
   289  	if err != nil {
   290  		return 0, err
   291  	}
   292  	// gVisor is not supposed to change any device information that is
   293  	// returned from the host since gVisor doesn't own the device.
   294  	// Passing the device info back to the caller will be just fine.
   295  	if _, err := deviceInfo.CopyOut(t, arg); err != nil {
   296  		return 0, err
   297  	}
   298  	return ret, nil
   299  }
   300  
   301  // Retrieve the device's interrupt information.
   302  func (fd *pciDeviceFD) vfioIrqInfo(ctx context.Context, t *kernel.Task, arg hostarch.Addr) (uintptr, error) {
   303  	var irqInfo linux.VFIOIrqInfo
   304  	if _, err := irqInfo.CopyIn(t, arg); err != nil {
   305  		return 0, err
   306  	}
   307  	// Callers must set the payload's size.
   308  	if irqInfo.Argsz == 0 {
   309  		return 0, linuxerr.EINVAL
   310  	}
   311  	ret, err := IOCTLInvokePtrArg[uint32](fd.hostFD, linux.VFIO_DEVICE_GET_IRQ_INFO, &irqInfo)
   312  	if err != nil {
   313  		return 0, err
   314  	}
   315  	if _, err := irqInfo.CopyOut(t, arg); err != nil {
   316  		return 0, err
   317  	}
   318  	return ret, nil
   319  }
   320  
   321  func (fd *pciDeviceFD) vfioSetIrqs(ctx context.Context, t *kernel.Task, arg hostarch.Addr) (uintptr, error) {
   322  	var irqSet linux.VFIOIrqSet
   323  	if _, err := irqSet.CopyIn(t, arg); err != nil {
   324  		return 0, err
   325  	}
   326  	// Callers must set the payload's size.
   327  	if irqSet.Argsz == 0 {
   328  		return 0, linuxerr.EINVAL
   329  	}
   330  	// Invalidate unknown flags.
   331  	if irqSet.Flags&^vfioIrqSetFlags != 0 {
   332  		return 0, linuxerr.EINVAL
   333  	}
   334  	// See drivers/vfio/vfio_main.c:vfio_set_irqs_validate_and_prepare,
   335  	// VFIO uses the data type at the request's flags to determine
   336  	// the memory layout of data field.
   337  	//
   338  	// The struct vfio_irq_set includes a flexible array member, it
   339  	// allocates an array for a continuous trunk of memory to back
   340  	// a vfio_irq_set object. In order to mirror that behavior, gVisor
   341  	// would allocate a slice to store the underlying bytes
   342  	// and pass that through to its host.
   343  	switch irqSet.Flags & linux.VFIO_IRQ_SET_DATA_TYPE_MASK {
   344  	// VFIO_IRQ_SET_DATA_NONE indicates there is no data field for
   345  	// the IOCTL command.
   346  	// It works with VFIO_IRQ_SET_ACTION_MASK, VFIO_IRQ_SET_ACTION_UNMASK,
   347  	// or VFIO_IRQ_SET_ACTION_TRIGGER to mask an interrupt, unmask an
   348  	// interrupt,  and trigger an interrupt unconditionally.
   349  	case linux.VFIO_IRQ_SET_DATA_NONE:
   350  		// When there is no data, passing through the given payload
   351  		// works just fine.
   352  		return IOCTLInvokePtrArg[uint32](fd.hostFD, linux.VFIO_DEVICE_SET_IRQS, &irqSet)
   353  	// VFIO_IRQ_SET_DATA_BOOL indicates that the data field is an array of uint8.
   354  	// The action will be performed if the corresponding boolean is true.
   355  	case linux.VFIO_IRQ_SET_DATA_BOOL:
   356  		payloadSize := uint32(irqSet.Size()) + irqSet.Count
   357  		payload := make([]uint8, payloadSize)
   358  		if _, err := primitive.CopyUint8SliceIn(t, arg, payload); err != nil {
   359  			return 0, err
   360  		}
   361  		return IOCTLInvokePtrArg[uint32](fd.hostFD, linux.VFIO_DEVICE_SET_IRQS, &payload[0])
   362  	// VFIO_IRQ_SET_DATA_EVENTFD indicates that the data field is an array
   363  	// of int32 (or event file descriptors). These descriptors will be
   364  	// signalled when an action in the flags happens.
   365  	case linux.VFIO_IRQ_SET_DATA_EVENTFD:
   366  		payloadSize := uint32(irqSet.Size())/4 + irqSet.Count
   367  		payload := make([]int32, payloadSize)
   368  		if _, err := primitive.CopyInt32SliceIn(t, arg, payload); err != nil {
   369  			return 0, err
   370  		}
   371  		// Transform the input FDs to host FDs.
   372  		for i := 0; i < int(irqSet.Count); i++ {
   373  			index := len(payload) - 1 - i
   374  			fd := payload[index]
   375  			// Skip non-event FD.
   376  			if fd == disableInterrupt {
   377  				continue
   378  			}
   379  			eventFileGeneric, _ := t.FDTable().Get(fd)
   380  			if eventFileGeneric == nil {
   381  				return 0, linuxerr.EBADF
   382  			}
   383  			defer eventFileGeneric.DecRef(ctx)
   384  			eventFile, ok := eventFileGeneric.Impl().(*eventfd.EventFileDescription)
   385  			if !ok {
   386  				return 0, linuxerr.EINVAL
   387  			}
   388  			eventfd, err := eventFile.HostFD()
   389  			if err != nil {
   390  				return 0, err
   391  			}
   392  			payload[index] = int32(eventfd)
   393  		}
   394  		return IOCTLInvokePtrArg[uint32](fd.hostFD, linux.VFIO_DEVICE_SET_IRQS, &payload[0])
   395  	}
   396  	// No data type is specified or multiple data types are specified.
   397  	return 0, linuxerr.EINVAL
   398  }
   399  
   400  // PRead implements vfs.FileDescriptionImpl.PRead.
   401  func (fd *pciDeviceFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
   402  	if offset < 0 {
   403  		return 0, linuxerr.EINVAL
   404  	}
   405  	buf := make([]byte, dst.NumBytes())
   406  	_, err := unix.Pread(int(fd.hostFD), buf, offset)
   407  	if err != nil {
   408  		return 0, err
   409  	}
   410  	n, err := dst.CopyOut(ctx, buf)
   411  	return int64(n), err
   412  }
   413  
   414  // PWrite implements vfs.FileDescriptionImpl.PWrite.
   415  func (fd *pciDeviceFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
   416  	if offset < 0 {
   417  		return 0, linuxerr.EINVAL
   418  	}
   419  	buf := make([]byte, src.NumBytes())
   420  	_, err := src.CopyIn(ctx, buf)
   421  	if err != nil {
   422  		return 0, err
   423  	}
   424  	n, err := unix.Pwrite(int(fd.hostFD), buf, offset)
   425  	return int64(n), err
   426  }
   427  
   428  // DevAddrSet tracks device address ranges that have been mapped.
   429  type devAddrSetFuncs struct{}
   430  
   431  func (devAddrSetFuncs) MinKey() uint64 {
   432  	return 0
   433  }
   434  
   435  func (devAddrSetFuncs) MaxKey() uint64 {
   436  	return ^uint64(0)
   437  }
   438  
   439  func (devAddrSetFuncs) ClearValue(val *mm.PinnedRange) {
   440  	*val = mm.PinnedRange{}
   441  }
   442  
   443  func (devAddrSetFuncs) Merge(r1 DevAddrRange, v1 mm.PinnedRange, r2 DevAddrRange, v2 mm.PinnedRange) (mm.PinnedRange, bool) {
   444  	// Do we have the same backing file?
   445  	if v1.File != v2.File {
   446  		return mm.PinnedRange{}, false
   447  	}
   448  
   449  	// Do we have contiguous offsets in the backing file?
   450  	if v1.Offset+uint64(v1.Source.Length()) != v2.Offset {
   451  		return mm.PinnedRange{}, false
   452  	}
   453  
   454  	// Are the virtual addresses contiguous?
   455  	//
   456  	// This check isn't strictly needed because 'mm.PinnedRange.Source'
   457  	// is only used to track the size of the pinned region (this is
   458  	// because the virtual address range can be unmapped or remapped
   459  	// elsewhere). Regardless we require this for simplicity.
   460  	if v1.Source.End != v2.Source.Start {
   461  		return mm.PinnedRange{}, false
   462  	}
   463  
   464  	// Extend v1 to account for the adjacent PinnedRange.
   465  	v1.Source.End = v2.Source.End
   466  	return v1, true
   467  }
   468  
   469  func (devAddrSetFuncs) Split(r DevAddrRange, val mm.PinnedRange, split uint64) (mm.PinnedRange, mm.PinnedRange) {
   470  	n := split - r.Start
   471  
   472  	left := val
   473  	left.Source.End = left.Source.Start + hostarch.Addr(n)
   474  
   475  	right := val
   476  	right.Source.Start += hostarch.Addr(n)
   477  	right.Offset += n
   478  
   479  	return left, right
   480  }