github.com/metacubex/gvisor@v0.0.0-20240320004321-933faba989ec/pkg/sentry/devices/accel/tpu_v4.go (about)

     1  // Copyright 2023 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package accel implements proxying for hardware accelerators.
    16  package accel
    17  
    18  import (
    19  	"fmt"
    20  
    21  	"golang.org/x/sys/unix"
    22  	"github.com/metacubex/gvisor/pkg/abi/gasket"
    23  	"github.com/metacubex/gvisor/pkg/abi/linux"
    24  	"github.com/metacubex/gvisor/pkg/context"
    25  	"github.com/metacubex/gvisor/pkg/errors/linuxerr"
    26  	"github.com/metacubex/gvisor/pkg/fdnotifier"
    27  	"github.com/metacubex/gvisor/pkg/hostarch"
    28  	"github.com/metacubex/gvisor/pkg/log"
    29  	"github.com/metacubex/gvisor/pkg/sentry/arch"
    30  	"github.com/metacubex/gvisor/pkg/sentry/kernel"
    31  	"github.com/metacubex/gvisor/pkg/sentry/mm"
    32  	"github.com/metacubex/gvisor/pkg/sentry/vfs"
    33  	"github.com/metacubex/gvisor/pkg/usermem"
    34  	"github.com/metacubex/gvisor/pkg/waiter"
    35  )
    36  
    37  // tpuV4FD implements vfs.FileDescriptionImpl for /dev/accel[0-9]+.
    38  //
    39  // accelFD is not savable; we do not implement save/restore of accelerator
    40  // state.
    41  type tpuV4FD struct {
    42  	vfsfd vfs.FileDescription
    43  	vfs.FileDescriptionDefaultImpl
    44  	vfs.DentryMetadataFileDescriptionImpl
    45  	vfs.NoLockFD
    46  
    47  	hostFD     int32
    48  	device     *tpuV4Device
    49  	queue      waiter.Queue
    50  	memmapFile accelFDMemmapFile
    51  }
    52  
    53  // Release implements vfs.FileDescriptionImpl.Release.
    54  func (fd *tpuV4FD) Release(context.Context) {
    55  	fd.device.mu.Lock()
    56  	defer fd.device.mu.Unlock()
    57  	fd.device.openWriteFDs--
    58  	if fd.device.openWriteFDs == 0 {
    59  		log.Infof("openWriteFDs is zero, unpinning all sentry memory mappings")
    60  		s := &fd.device.devAddrSet
    61  		seg := s.FirstSegment()
    62  		for seg.Ok() {
    63  			r, v := seg.Range(), seg.Value()
    64  			gpti := gasket.GasketPageTableIoctl{
    65  				PageTableIndex: v.pageTableIndex,
    66  				DeviceAddress:  r.Start,
    67  				Size:           r.End - r.Start,
    68  				HostAddress:    0,
    69  			}
    70  			_, err := ioctlInvokePtrArg(fd.hostFD, gasket.GASKET_IOCTL_UNMAP_BUFFER, &gpti)
    71  			if err != nil {
    72  				log.Warningf("could not unmap range [%#x, %#x) (index %d) on device: %v", r.Start, r.End, v.pageTableIndex, err)
    73  			}
    74  			mm.Unpin([]mm.PinnedRange{v.pinnedRange})
    75  			gap := s.Remove(seg)
    76  			seg = gap.NextSegment()
    77  		}
    78  		fd.device.owner = nil
    79  	}
    80  	fdnotifier.RemoveFD(fd.hostFD)
    81  	unix.Close(int(fd.hostFD))
    82  }
    83  
    84  // EventRegister implements waiter.Waitable.EventRegister.
    85  func (fd *tpuV4FD) EventRegister(e *waiter.Entry) error {
    86  	fd.queue.EventRegister(e)
    87  	if err := fdnotifier.UpdateFD(fd.hostFD); err != nil {
    88  		fd.queue.EventUnregister(e)
    89  		return err
    90  	}
    91  	return nil
    92  }
    93  
    94  // EventUnregister implements waiter.Waitable.EventUnregister.
    95  func (fd *tpuV4FD) EventUnregister(e *waiter.Entry) {
    96  	fd.queue.EventUnregister(e)
    97  	if err := fdnotifier.UpdateFD(fd.hostFD); err != nil {
    98  		panic(fmt.Sprint("UpdateFD:", err))
    99  	}
   100  }
   101  
   102  // Readiness implements waiter.Waitable.Readiness.
   103  func (fd *tpuV4FD) Readiness(mask waiter.EventMask) waiter.EventMask {
   104  	return fdnotifier.NonBlockingPoll(fd.hostFD, mask)
   105  }
   106  
   107  // Epollable implements vfs.FileDescriptionImpl.Epollable.
   108  func (fd *tpuV4FD) Epollable() bool {
   109  	return true
   110  }
   111  
   112  // Ioctl implements vfs.FileDescriptionImpl.Ioctl.
   113  func (fd *tpuV4FD) Ioctl(ctx context.Context, uio usermem.IO, sysno uintptr, args arch.SyscallArguments) (uintptr, error) {
   114  	cmd := args[1].Uint()
   115  	argPtr := args[2].Pointer()
   116  	argSize := linux.IOC_SIZE(cmd)
   117  
   118  	t := kernel.TaskFromContext(ctx)
   119  	if t == nil {
   120  		panic("Ioctl should be called from a task context")
   121  	}
   122  	if err := fd.checkPermission(t); err != nil {
   123  		return 0, err
   124  	}
   125  
   126  	log.Infof("Accel ioctl %s called on fd %d with arg %v of size %d.", gasket.Ioctl(cmd), fd.hostFD, argPtr, argSize)
   127  	switch gasket.Ioctl(cmd) {
   128  	// Not yet implemented gasket ioctls.
   129  	case gasket.GASKET_IOCTL_SET_EVENTFD, gasket.GASKET_IOCTL_CLEAR_EVENTFD,
   130  		gasket.GASKET_IOCTL_NUMBER_PAGE_TABLES, gasket.GASKET_IOCTL_PAGE_TABLE_SIZE,
   131  		gasket.GASKET_IOCTL_SIMPLE_PAGE_TABLE_SIZE, gasket.GASKET_IOCTL_PARTITION_PAGE_TABLE,
   132  		gasket.GASKET_IOCTL_MAP_DMA_BUF:
   133  		return 0, linuxerr.ENOSYS
   134  	case gasket.GASKET_IOCTL_RESET:
   135  		return ioctlInvoke[uint64](fd.hostFD, gasket.GASKET_IOCTL_RESET, args[2].Uint64())
   136  	case gasket.GASKET_IOCTL_MAP_BUFFER:
   137  		return gasketMapBufferIoctl(ctx, t, fd.hostFD, fd, argPtr)
   138  	case gasket.GASKET_IOCTL_UNMAP_BUFFER:
   139  		return gasketUnmapBufferIoctl(ctx, t, fd.hostFD, fd, argPtr)
   140  	case gasket.GASKET_IOCTL_CLEAR_INTERRUPT_COUNTS:
   141  		return ioctlInvoke(fd.hostFD, gasket.GASKET_IOCTL_CLEAR_INTERRUPT_COUNTS, 0)
   142  	case gasket.GASKET_IOCTL_REGISTER_INTERRUPT:
   143  		return gasketInterruptMappingIoctl(ctx, t, fd.hostFD, argPtr, fd.device.lite)
   144  	case gasket.GASKET_IOCTL_UNREGISTER_INTERRUPT:
   145  		return ioctlInvoke[uint64](fd.hostFD, gasket.GASKET_IOCTL_UNREGISTER_INTERRUPT, args[2].Uint64())
   146  	default:
   147  		return 0, linuxerr.EINVAL
   148  	}
   149  }
   150  
   151  // checkPermission checks that the thread that owns this device is the only
   152  // one that can issue commands to the TPU. Other threads with access to
   153  // /dev/accel will not be able to issue commands to the device.
   154  func (fd *tpuV4FD) checkPermission(t *kernel.Task) error {
   155  	fd.device.mu.Lock()
   156  	defer fd.device.mu.Unlock()
   157  	owner := fd.device.owner
   158  	if t.ThreadGroup() != owner {
   159  		return linuxerr.EPERM
   160  	}
   161  	return nil
   162  }
   163  
   164  type pinnedAccelMem struct {
   165  	pinnedRange    mm.PinnedRange
   166  	pageTableIndex uint64
   167  }
   168  
   169  // DevAddrSet tracks device address ranges that have been mapped.
   170  type devAddrSetFuncs struct{}
   171  
   172  func (devAddrSetFuncs) MinKey() uint64 {
   173  	return 0
   174  }
   175  
   176  func (devAddrSetFuncs) MaxKey() uint64 {
   177  	return ^uint64(0)
   178  }
   179  
   180  func (devAddrSetFuncs) ClearValue(val *pinnedAccelMem) {
   181  	*val = pinnedAccelMem{}
   182  }
   183  
   184  func (devAddrSetFuncs) Merge(r1 DevAddrRange, v1 pinnedAccelMem, r2 DevAddrRange, v2 pinnedAccelMem) (pinnedAccelMem, bool) {
   185  	// Do we have the same backing file?
   186  	if v1.pinnedRange.File != v2.pinnedRange.File {
   187  		return pinnedAccelMem{}, false
   188  	}
   189  
   190  	// Do we have contiguous offsets in the backing file?
   191  	if v1.pinnedRange.Offset+uint64(v1.pinnedRange.Source.Length()) != v2.pinnedRange.Offset {
   192  		return pinnedAccelMem{}, false
   193  	}
   194  
   195  	// Are the virtual addresses contiguous?
   196  	//
   197  	// This check isn't strictly needed because 'mm.PinnedRange.Source'
   198  	// is only used to track the size of the pinned region (this is
   199  	// because the virtual address range can be unmapped or remapped
   200  	// elsewhere). Regardless we require this for simplicity.
   201  	if v1.pinnedRange.Source.End != v2.pinnedRange.Source.Start {
   202  		return pinnedAccelMem{}, false
   203  	}
   204  
   205  	// Extend v1 to account for the adjacent PinnedRange.
   206  	v1.pinnedRange.Source.End = v2.pinnedRange.Source.End
   207  	return v1, true
   208  }
   209  
   210  func (devAddrSetFuncs) Split(r DevAddrRange, val pinnedAccelMem, split uint64) (pinnedAccelMem, pinnedAccelMem) {
   211  	n := split - r.Start
   212  
   213  	left := val
   214  	left.pinnedRange.Source.End = left.pinnedRange.Source.Start + hostarch.Addr(n)
   215  
   216  	right := val
   217  	right.pinnedRange.Source.Start += hostarch.Addr(n)
   218  	right.pinnedRange.Offset += n
   219  
   220  	return left, right
   221  }