gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/sentry/devices/accel/tpu_v4.go (about)

     1  // Copyright 2023 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package accel implements proxying for hardware accelerators.
    16  package accel
    17  
    18  import (
    19  	"fmt"
    20  
    21  	"golang.org/x/sys/unix"
    22  	"gvisor.dev/gvisor/pkg/abi/gasket"
    23  	"gvisor.dev/gvisor/pkg/abi/linux"
    24  	"gvisor.dev/gvisor/pkg/context"
    25  	"gvisor.dev/gvisor/pkg/errors/linuxerr"
    26  	"gvisor.dev/gvisor/pkg/fdnotifier"
    27  	"gvisor.dev/gvisor/pkg/hostarch"
    28  	"gvisor.dev/gvisor/pkg/log"
    29  	"gvisor.dev/gvisor/pkg/sentry/arch"
    30  	"gvisor.dev/gvisor/pkg/sentry/devices/tpuproxy"
    31  	"gvisor.dev/gvisor/pkg/sentry/kernel"
    32  	"gvisor.dev/gvisor/pkg/sentry/mm"
    33  	"gvisor.dev/gvisor/pkg/sentry/vfs"
    34  	"gvisor.dev/gvisor/pkg/usermem"
    35  	"gvisor.dev/gvisor/pkg/waiter"
    36  )
    37  
    38  // tpuV4FD implements vfs.FileDescriptionImpl for /dev/accel[0-9]+.
    39  //
    40  // accelFD is not savable; we do not implement save/restore of accelerator
    41  // state.
    42  type tpuV4FD struct {
    43  	vfsfd vfs.FileDescription
    44  	vfs.FileDescriptionDefaultImpl
    45  	vfs.DentryMetadataFileDescriptionImpl
    46  	vfs.NoLockFD
    47  
    48  	hostFD     int32
    49  	device     *tpuV4Device
    50  	queue      waiter.Queue
    51  	memmapFile accelFDMemmapFile
    52  }
    53  
    54  // Release implements vfs.FileDescriptionImpl.Release.
    55  func (fd *tpuV4FD) Release(context.Context) {
    56  	fd.device.mu.Lock()
    57  	defer fd.device.mu.Unlock()
    58  	fd.device.openWriteFDs--
    59  	if fd.device.openWriteFDs == 0 {
    60  		log.Infof("openWriteFDs is zero, unpinning all sentry memory mappings")
    61  		s := &fd.device.devAddrSet
    62  		seg := s.FirstSegment()
    63  		for seg.Ok() {
    64  			r, v := seg.Range(), seg.Value()
    65  			gpti := gasket.GasketPageTableIoctl{
    66  				PageTableIndex: v.pageTableIndex,
    67  				DeviceAddress:  r.Start,
    68  				Size:           r.End - r.Start,
    69  				HostAddress:    0,
    70  			}
    71  			_, err := tpuproxy.IOCTLInvokePtrArg[gasket.Ioctl](fd.hostFD, gasket.GASKET_IOCTL_UNMAP_BUFFER, &gpti)
    72  			if err != nil {
    73  				log.Warningf("could not unmap range [%#x, %#x) (index %d) on device: %v", r.Start, r.End, v.pageTableIndex, err)
    74  			}
    75  			mm.Unpin([]mm.PinnedRange{v.pinnedRange})
    76  			gap := s.Remove(seg)
    77  			seg = gap.NextSegment()
    78  		}
    79  		fd.device.owner = nil
    80  	}
    81  	fdnotifier.RemoveFD(fd.hostFD)
    82  	unix.Close(int(fd.hostFD))
    83  }
    84  
    85  // EventRegister implements waiter.Waitable.EventRegister.
    86  func (fd *tpuV4FD) EventRegister(e *waiter.Entry) error {
    87  	fd.queue.EventRegister(e)
    88  	if err := fdnotifier.UpdateFD(fd.hostFD); err != nil {
    89  		fd.queue.EventUnregister(e)
    90  		return err
    91  	}
    92  	return nil
    93  }
    94  
    95  // EventUnregister implements waiter.Waitable.EventUnregister.
    96  func (fd *tpuV4FD) EventUnregister(e *waiter.Entry) {
    97  	fd.queue.EventUnregister(e)
    98  	if err := fdnotifier.UpdateFD(fd.hostFD); err != nil {
    99  		panic(fmt.Sprint("UpdateFD:", err))
   100  	}
   101  }
   102  
   103  // Readiness implements waiter.Waitable.Readiness.
   104  func (fd *tpuV4FD) Readiness(mask waiter.EventMask) waiter.EventMask {
   105  	return fdnotifier.NonBlockingPoll(fd.hostFD, mask)
   106  }
   107  
   108  // Epollable implements vfs.FileDescriptionImpl.Epollable.
   109  func (fd *tpuV4FD) Epollable() bool {
   110  	return true
   111  }
   112  
   113  // Ioctl implements vfs.FileDescriptionImpl.Ioctl.
   114  func (fd *tpuV4FD) Ioctl(ctx context.Context, uio usermem.IO, sysno uintptr, args arch.SyscallArguments) (uintptr, error) {
   115  	cmd := args[1].Uint()
   116  	argPtr := args[2].Pointer()
   117  	argSize := linux.IOC_SIZE(cmd)
   118  
   119  	t := kernel.TaskFromContext(ctx)
   120  	if t == nil {
   121  		panic("Ioctl should be called from a task context")
   122  	}
   123  	if err := fd.checkPermission(t); err != nil {
   124  		return 0, err
   125  	}
   126  
   127  	log.Infof("Accel ioctl %s called on fd %d with arg %v of size %d.", gasket.Ioctl(cmd), fd.hostFD, argPtr, argSize)
   128  	switch gasket.Ioctl(cmd) {
   129  	// Not yet implemented gasket ioctls.
   130  	case gasket.GASKET_IOCTL_SET_EVENTFD, gasket.GASKET_IOCTL_CLEAR_EVENTFD,
   131  		gasket.GASKET_IOCTL_NUMBER_PAGE_TABLES, gasket.GASKET_IOCTL_PAGE_TABLE_SIZE,
   132  		gasket.GASKET_IOCTL_SIMPLE_PAGE_TABLE_SIZE, gasket.GASKET_IOCTL_PARTITION_PAGE_TABLE,
   133  		gasket.GASKET_IOCTL_MAP_DMA_BUF:
   134  		return 0, linuxerr.ENOSYS
   135  	case gasket.GASKET_IOCTL_RESET:
   136  		return tpuproxy.IOCTLInvoke[gasket.Ioctl, uint64](fd.hostFD, gasket.GASKET_IOCTL_RESET, args[2].Uint64())
   137  	case gasket.GASKET_IOCTL_MAP_BUFFER:
   138  		return gasketMapBufferIoctl(ctx, t, fd.hostFD, fd, argPtr)
   139  	case gasket.GASKET_IOCTL_UNMAP_BUFFER:
   140  		return gasketUnmapBufferIoctl(ctx, t, fd.hostFD, fd, argPtr)
   141  	case gasket.GASKET_IOCTL_CLEAR_INTERRUPT_COUNTS:
   142  		return tpuproxy.IOCTLInvoke[gasket.Ioctl](fd.hostFD, gasket.GASKET_IOCTL_CLEAR_INTERRUPT_COUNTS, 0)
   143  	case gasket.GASKET_IOCTL_REGISTER_INTERRUPT:
   144  		return gasketInterruptMappingIoctl(ctx, t, fd.hostFD, argPtr, fd.device.lite)
   145  	case gasket.GASKET_IOCTL_UNREGISTER_INTERRUPT:
   146  		return tpuproxy.IOCTLInvoke[gasket.Ioctl, uint64](fd.hostFD, gasket.GASKET_IOCTL_UNREGISTER_INTERRUPT, args[2].Uint64())
   147  	default:
   148  		return 0, linuxerr.EINVAL
   149  	}
   150  }
   151  
   152  // checkPermission checks that the thread that owns this device is the only
   153  // one that can issue commands to the TPU. Other threads with access to
   154  // /dev/accel will not be able to issue commands to the device.
   155  func (fd *tpuV4FD) checkPermission(t *kernel.Task) error {
   156  	fd.device.mu.Lock()
   157  	defer fd.device.mu.Unlock()
   158  	owner := fd.device.owner
   159  	if t.ThreadGroup() != owner {
   160  		return linuxerr.EPERM
   161  	}
   162  	return nil
   163  }
   164  
   165  type pinnedAccelMem struct {
   166  	pinnedRange    mm.PinnedRange
   167  	pageTableIndex uint64
   168  }
   169  
   170  // DevAddrSet tracks device address ranges that have been mapped.
   171  type devAddrSetFuncs struct{}
   172  
   173  func (devAddrSetFuncs) MinKey() uint64 {
   174  	return 0
   175  }
   176  
   177  func (devAddrSetFuncs) MaxKey() uint64 {
   178  	return ^uint64(0)
   179  }
   180  
   181  func (devAddrSetFuncs) ClearValue(val *pinnedAccelMem) {
   182  	*val = pinnedAccelMem{}
   183  }
   184  
   185  func (devAddrSetFuncs) Merge(r1 DevAddrRange, v1 pinnedAccelMem, r2 DevAddrRange, v2 pinnedAccelMem) (pinnedAccelMem, bool) {
   186  	// Do we have the same backing file?
   187  	if v1.pinnedRange.File != v2.pinnedRange.File {
   188  		return pinnedAccelMem{}, false
   189  	}
   190  
   191  	// Do we have contiguous offsets in the backing file?
   192  	if v1.pinnedRange.Offset+uint64(v1.pinnedRange.Source.Length()) != v2.pinnedRange.Offset {
   193  		return pinnedAccelMem{}, false
   194  	}
   195  
   196  	// Are the virtual addresses contiguous?
   197  	//
   198  	// This check isn't strictly needed because 'mm.PinnedRange.Source'
   199  	// is only used to track the size of the pinned region (this is
   200  	// because the virtual address range can be unmapped or remapped
   201  	// elsewhere). Regardless we require this for simplicity.
   202  	if v1.pinnedRange.Source.End != v2.pinnedRange.Source.Start {
   203  		return pinnedAccelMem{}, false
   204  	}
   205  
   206  	// Extend v1 to account for the adjacent PinnedRange.
   207  	v1.pinnedRange.Source.End = v2.pinnedRange.Source.End
   208  	return v1, true
   209  }
   210  
   211  func (devAddrSetFuncs) Split(r DevAddrRange, val pinnedAccelMem, split uint64) (pinnedAccelMem, pinnedAccelMem) {
   212  	n := split - r.Start
   213  
   214  	left := val
   215  	left.pinnedRange.Source.End = left.pinnedRange.Source.Start + hostarch.Addr(n)
   216  
   217  	right := val
   218  	right.pinnedRange.Source.Start += hostarch.Addr(n)
   219  	right.pinnedRange.Offset += n
   220  
   221  	return left, right
   222  }