gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/sentry/devices/accel/gasket.go (about)

     1  // Copyright 2023 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package accel
    16  
    17  import (
    18  	"golang.org/x/sys/unix"
    19  	"gvisor.dev/gvisor/pkg/abi/gasket"
    20  	"gvisor.dev/gvisor/pkg/abi/linux"
    21  	"gvisor.dev/gvisor/pkg/abi/tpu"
    22  	"gvisor.dev/gvisor/pkg/cleanup"
    23  	"gvisor.dev/gvisor/pkg/context"
    24  	"gvisor.dev/gvisor/pkg/errors/linuxerr"
    25  	"gvisor.dev/gvisor/pkg/hostarch"
    26  	"gvisor.dev/gvisor/pkg/sentry/devices/tpuproxy"
    27  	"gvisor.dev/gvisor/pkg/sentry/fsimpl/eventfd"
    28  	"gvisor.dev/gvisor/pkg/sentry/kernel"
    29  	"gvisor.dev/gvisor/pkg/sentry/memmap"
    30  	"gvisor.dev/gvisor/pkg/sentry/mm"
    31  )
    32  
    33  func gasketMapBufferIoctl(ctx context.Context, t *kernel.Task, hostFd int32, fd *tpuV4FD, paramsAddr hostarch.Addr) (uintptr, error) {
    34  	var userIoctlParams gasket.GasketPageTableIoctl
    35  	if _, err := userIoctlParams.CopyIn(t, paramsAddr); err != nil {
    36  		return 0, err
    37  	}
    38  
    39  	numberOfPageTables := tpu.NumberOfTPUV4PageTables
    40  	if fd.device.lite {
    41  		numberOfPageTables = tpu.NumberOfTPUV4litePageTables
    42  	}
    43  	if userIoctlParams.PageTableIndex >= numberOfPageTables {
    44  		return 0, linuxerr.EFAULT
    45  	}
    46  
    47  	tmm := t.MemoryManager()
    48  	ar, ok := tmm.CheckIORange(hostarch.Addr(userIoctlParams.HostAddress), int64(userIoctlParams.Size))
    49  	if !ok {
    50  		return 0, linuxerr.EFAULT
    51  	}
    52  
    53  	if !ar.IsPageAligned() || (userIoctlParams.Size/hostarch.PageSize) == 0 {
    54  		return 0, linuxerr.EINVAL
    55  	}
    56  
    57  	devAddr := userIoctlParams.DeviceAddress
    58  	// The kernel driver does not enforce page alignment on the device
    59  	// address although it will be implicitly rounded down to a page
    60  	// boundary. We do it explicitly because it simplifies tracking
    61  	// of allocated ranges in 'devAddrSet'.
    62  	devAddr &^= (hostarch.PageSize - 1)
    63  
    64  	// Make sure that the device address range can be mapped.
    65  	devar := DevAddrRange{
    66  		devAddr,
    67  		devAddr + userIoctlParams.Size,
    68  	}
    69  	if !devar.WellFormed() {
    70  		return 0, linuxerr.EINVAL
    71  	}
    72  
    73  	// Reserve a range in our address space.
    74  	m, _, errno := unix.RawSyscall6(unix.SYS_MMAP, 0 /* addr */, uintptr(ar.Length()), unix.PROT_NONE, unix.MAP_PRIVATE|unix.MAP_ANONYMOUS, ^uintptr(0) /* fd */, 0 /* offset */)
    75  	if errno != 0 {
    76  		return 0, errno
    77  	}
    78  	cu := cleanup.Make(func() {
    79  		unix.RawSyscall(unix.SYS_MUNMAP, m, uintptr(ar.Length()), 0)
    80  	})
    81  	defer cu.Clean()
    82  	// Mirror application mappings into the reserved range.
    83  	prs, err := t.MemoryManager().Pin(ctx, ar, hostarch.ReadWrite, false /* ignorePermissions */)
    84  	cu.Add(func() {
    85  		mm.Unpin(prs)
    86  	})
    87  	if err != nil {
    88  		return 0, err
    89  	}
    90  	sentryAddr := uintptr(m)
    91  	for _, pr := range prs {
    92  		ims, err := pr.File.MapInternal(memmap.FileRange{pr.Offset, pr.Offset + uint64(pr.Source.Length())}, hostarch.ReadWrite)
    93  		if err != nil {
    94  			return 0, err
    95  		}
    96  		for !ims.IsEmpty() {
    97  			im := ims.Head()
    98  			if _, _, errno := unix.RawSyscall6(unix.SYS_MREMAP, im.Addr(), 0 /* old_size */, uintptr(im.Len()), linux.MREMAP_MAYMOVE|linux.MREMAP_FIXED, sentryAddr, 0); errno != 0 {
    99  				return 0, errno
   100  			}
   101  			sentryAddr += uintptr(im.Len())
   102  			ims = ims.Tail()
   103  		}
   104  	}
   105  	sentryIoctlParams := userIoctlParams
   106  	sentryIoctlParams.HostAddress = uint64(m)
   107  	n, err := tpuproxy.IOCTLInvokePtrArg[gasket.Ioctl](hostFd, gasket.GASKET_IOCTL_MAP_BUFFER, &sentryIoctlParams)
   108  	if err != nil {
   109  		return n, err
   110  	}
   111  	cu.Release()
   112  	// Unmap the reserved range, which is no longer required.
   113  	unix.RawSyscall(unix.SYS_MUNMAP, m, uintptr(ar.Length()), 0)
   114  
   115  	fd.device.mu.Lock()
   116  	defer fd.device.mu.Unlock()
   117  	for _, pr := range prs {
   118  		rlen := uint64(pr.Source.Length())
   119  		fd.device.devAddrSet.InsertRange(DevAddrRange{
   120  			devAddr,
   121  			devAddr + rlen,
   122  		}, pinnedAccelMem{pinnedRange: pr, pageTableIndex: userIoctlParams.PageTableIndex})
   123  		devAddr += rlen
   124  	}
   125  	return n, nil
   126  }
   127  
   128  func gasketUnmapBufferIoctl(ctx context.Context, t *kernel.Task, hostFd int32, fd *tpuV4FD, paramsAddr hostarch.Addr) (uintptr, error) {
   129  	var userIoctlParams gasket.GasketPageTableIoctl
   130  	if _, err := userIoctlParams.CopyIn(t, paramsAddr); err != nil {
   131  		return 0, err
   132  	}
   133  
   134  	numberOfPageTables := tpu.NumberOfTPUV4PageTables
   135  	if fd.device.lite {
   136  		numberOfPageTables = tpu.NumberOfTPUV4litePageTables
   137  	}
   138  	if userIoctlParams.PageTableIndex >= numberOfPageTables {
   139  		return 0, linuxerr.EFAULT
   140  	}
   141  
   142  	devAddr := userIoctlParams.DeviceAddress
   143  	devAddr &^= (hostarch.PageSize - 1)
   144  	devar := DevAddrRange{
   145  		devAddr,
   146  		devAddr + userIoctlParams.Size,
   147  	}
   148  	if !devar.WellFormed() {
   149  		return 0, linuxerr.EINVAL
   150  	}
   151  
   152  	sentryIoctlParams := userIoctlParams
   153  	sentryIoctlParams.HostAddress = 0 // clobber this value, it's unused.
   154  	n, err := tpuproxy.IOCTLInvokePtrArg[gasket.Ioctl](hostFd, gasket.GASKET_IOCTL_UNMAP_BUFFER, &sentryIoctlParams)
   155  	if err != nil {
   156  		return n, err
   157  	}
   158  	fd.device.mu.Lock()
   159  	defer fd.device.mu.Unlock()
   160  	s := &fd.device.devAddrSet
   161  	r := DevAddrRange{userIoctlParams.DeviceAddress, userIoctlParams.DeviceAddress + userIoctlParams.Size}
   162  	seg := s.LowerBoundSegment(r.Start)
   163  	for seg.Ok() && seg.Start() < r.End {
   164  		seg = s.Isolate(seg, r)
   165  		v := seg.Value()
   166  		mm.Unpin([]mm.PinnedRange{v.pinnedRange})
   167  		gap := s.Remove(seg)
   168  		seg = gap.NextSegment()
   169  	}
   170  	return n, nil
   171  }
   172  
   173  func gasketInterruptMappingIoctl(ctx context.Context, t *kernel.Task, hostFd int32, paramsAddr hostarch.Addr, lite bool) (uintptr, error) {
   174  	var userIoctlParams gasket.GasketInterruptMapping
   175  	if _, err := userIoctlParams.CopyIn(t, paramsAddr); err != nil {
   176  		return 0, err
   177  	}
   178  
   179  	sizeOfInterruptList := tpu.SizeOfTPUV4InterruptList
   180  	interruptMap := tpu.TPUV4InterruptsMap
   181  	if lite {
   182  		sizeOfInterruptList = tpu.SizeOfTPUV4liteInterruptList
   183  		interruptMap = tpu.TPUV4liteInterruptsMap
   184  	}
   185  	if userIoctlParams.Interrupt >= sizeOfInterruptList {
   186  		return 0, linuxerr.EINVAL
   187  	}
   188  	barRegMap, ok := interruptMap[userIoctlParams.BarIndex]
   189  	if !ok {
   190  		return 0, linuxerr.EINVAL
   191  	}
   192  	if _, ok := barRegMap[userIoctlParams.RegOffset]; !ok {
   193  		return 0, linuxerr.EINVAL
   194  	}
   195  
   196  	// Check that 'userEventFD.Eventfd' is an eventfd.
   197  	eventFileGeneric, _ := t.FDTable().Get(int32(userIoctlParams.EventFD))
   198  	if eventFileGeneric == nil {
   199  		return 0, linuxerr.EBADF
   200  	}
   201  	defer eventFileGeneric.DecRef(ctx)
   202  	eventFile, ok := eventFileGeneric.Impl().(*eventfd.EventFileDescription)
   203  	if !ok {
   204  		return 0, linuxerr.EINVAL
   205  	}
   206  
   207  	eventfd, err := eventFile.HostFD()
   208  	if err != nil {
   209  		return 0, err
   210  	}
   211  
   212  	sentryIoctlParams := userIoctlParams
   213  	sentryIoctlParams.EventFD = uint64(eventfd)
   214  	n, err := tpuproxy.IOCTLInvokePtrArg[gasket.Ioctl](hostFd, gasket.GASKET_IOCTL_REGISTER_INTERRUPT, &sentryIoctlParams)
   215  	if err != nil {
   216  		return n, err
   217  	}
   218  
   219  	outIoctlParams := sentryIoctlParams
   220  	outIoctlParams.EventFD = userIoctlParams.EventFD
   221  	if _, err := outIoctlParams.CopyOut(t, paramsAddr); err != nil {
   222  		return n, err
   223  	}
   224  	return n, nil
   225  }