github.com/metacubex/gvisor@v0.0.0-20240320004321-933faba989ec/pkg/sentry/devices/accel/gasket.go (about)

     1  // Copyright 2023 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package accel
    16  
    17  import (
    18  	"golang.org/x/sys/unix"
    19  	"github.com/metacubex/gvisor/pkg/abi/gasket"
    20  	"github.com/metacubex/gvisor/pkg/abi/linux"
    21  	"github.com/metacubex/gvisor/pkg/abi/tpu"
    22  	"github.com/metacubex/gvisor/pkg/cleanup"
    23  	"github.com/metacubex/gvisor/pkg/context"
    24  	"github.com/metacubex/gvisor/pkg/errors/linuxerr"
    25  	"github.com/metacubex/gvisor/pkg/hostarch"
    26  	"github.com/metacubex/gvisor/pkg/sentry/fsimpl/eventfd"
    27  	"github.com/metacubex/gvisor/pkg/sentry/kernel"
    28  	"github.com/metacubex/gvisor/pkg/sentry/memmap"
    29  	"github.com/metacubex/gvisor/pkg/sentry/mm"
    30  )
    31  
    32  func gasketMapBufferIoctl(ctx context.Context, t *kernel.Task, hostFd int32, fd *tpuV4FD, paramsAddr hostarch.Addr) (uintptr, error) {
    33  	var userIoctlParams gasket.GasketPageTableIoctl
    34  	if _, err := userIoctlParams.CopyIn(t, paramsAddr); err != nil {
    35  		return 0, err
    36  	}
    37  
    38  	numberOfPageTables := tpu.NumberOfTPUV4PageTables
    39  	if fd.device.lite {
    40  		numberOfPageTables = tpu.NumberOfTPUV4litePageTables
    41  	}
    42  	if userIoctlParams.PageTableIndex >= numberOfPageTables {
    43  		return 0, linuxerr.EFAULT
    44  	}
    45  
    46  	tmm := t.MemoryManager()
    47  	ar, ok := tmm.CheckIORange(hostarch.Addr(userIoctlParams.HostAddress), int64(userIoctlParams.Size))
    48  	if !ok {
    49  		return 0, linuxerr.EFAULT
    50  	}
    51  
    52  	if !ar.IsPageAligned() || (userIoctlParams.Size/hostarch.PageSize) == 0 {
    53  		return 0, linuxerr.EINVAL
    54  	}
    55  
    56  	devAddr := userIoctlParams.DeviceAddress
    57  	// The kernel driver does not enforce page alignment on the device
    58  	// address although it will be implicitly rounded down to a page
    59  	// boundary. We do it explicitly because it simplifies tracking
    60  	// of allocated ranges in 'devAddrSet'.
    61  	devAddr &^= (hostarch.PageSize - 1)
    62  
    63  	// Make sure that the device address range can be mapped.
    64  	devar := DevAddrRange{
    65  		devAddr,
    66  		devAddr + userIoctlParams.Size,
    67  	}
    68  	if !devar.WellFormed() {
    69  		return 0, linuxerr.EINVAL
    70  	}
    71  
    72  	// Reserve a range in our address space.
    73  	m, _, errno := unix.RawSyscall6(unix.SYS_MMAP, 0 /* addr */, uintptr(ar.Length()), unix.PROT_NONE, unix.MAP_PRIVATE|unix.MAP_ANONYMOUS, ^uintptr(0) /* fd */, 0 /* offset */)
    74  	if errno != 0 {
    75  		return 0, errno
    76  	}
    77  	cu := cleanup.Make(func() {
    78  		unix.RawSyscall(unix.SYS_MUNMAP, m, uintptr(ar.Length()), 0)
    79  	})
    80  	defer cu.Clean()
    81  	// Mirror application mappings into the reserved range.
    82  	prs, err := t.MemoryManager().Pin(ctx, ar, hostarch.ReadWrite, false /* ignorePermissions */)
    83  	cu.Add(func() {
    84  		mm.Unpin(prs)
    85  	})
    86  	if err != nil {
    87  		return 0, err
    88  	}
    89  	sentryAddr := uintptr(m)
    90  	for _, pr := range prs {
    91  		ims, err := pr.File.MapInternal(memmap.FileRange{pr.Offset, pr.Offset + uint64(pr.Source.Length())}, hostarch.ReadWrite)
    92  		if err != nil {
    93  			return 0, err
    94  		}
    95  		for !ims.IsEmpty() {
    96  			im := ims.Head()
    97  			if _, _, errno := unix.RawSyscall6(unix.SYS_MREMAP, im.Addr(), 0 /* old_size */, uintptr(im.Len()), linux.MREMAP_MAYMOVE|linux.MREMAP_FIXED, sentryAddr, 0); errno != 0 {
    98  				return 0, errno
    99  			}
   100  			sentryAddr += uintptr(im.Len())
   101  			ims = ims.Tail()
   102  		}
   103  	}
   104  	sentryIoctlParams := userIoctlParams
   105  	sentryIoctlParams.HostAddress = uint64(m)
   106  	n, err := ioctlInvokePtrArg(hostFd, gasket.GASKET_IOCTL_MAP_BUFFER, &sentryIoctlParams)
   107  	if err != nil {
   108  		return n, err
   109  	}
   110  	cu.Release()
   111  	// Unmap the reserved range, which is no longer required.
   112  	unix.RawSyscall(unix.SYS_MUNMAP, m, uintptr(ar.Length()), 0)
   113  
   114  	fd.device.mu.Lock()
   115  	defer fd.device.mu.Unlock()
   116  	for _, pr := range prs {
   117  		rlen := uint64(pr.Source.Length())
   118  		fd.device.devAddrSet.InsertRange(DevAddrRange{
   119  			devAddr,
   120  			devAddr + rlen,
   121  		}, pinnedAccelMem{pinnedRange: pr, pageTableIndex: userIoctlParams.PageTableIndex})
   122  		devAddr += rlen
   123  	}
   124  	return n, nil
   125  }
   126  
   127  func gasketUnmapBufferIoctl(ctx context.Context, t *kernel.Task, hostFd int32, fd *tpuV4FD, paramsAddr hostarch.Addr) (uintptr, error) {
   128  	var userIoctlParams gasket.GasketPageTableIoctl
   129  	if _, err := userIoctlParams.CopyIn(t, paramsAddr); err != nil {
   130  		return 0, err
   131  	}
   132  
   133  	numberOfPageTables := tpu.NumberOfTPUV4PageTables
   134  	if fd.device.lite {
   135  		numberOfPageTables = tpu.NumberOfTPUV4litePageTables
   136  	}
   137  	if userIoctlParams.PageTableIndex >= numberOfPageTables {
   138  		return 0, linuxerr.EFAULT
   139  	}
   140  
   141  	devAddr := userIoctlParams.DeviceAddress
   142  	devAddr &^= (hostarch.PageSize - 1)
   143  	devar := DevAddrRange{
   144  		devAddr,
   145  		devAddr + userIoctlParams.Size,
   146  	}
   147  	if !devar.WellFormed() {
   148  		return 0, linuxerr.EINVAL
   149  	}
   150  
   151  	sentryIoctlParams := userIoctlParams
   152  	sentryIoctlParams.HostAddress = 0 // clobber this value, it's unused.
   153  	n, err := ioctlInvokePtrArg(hostFd, gasket.GASKET_IOCTL_UNMAP_BUFFER, &sentryIoctlParams)
   154  	if err != nil {
   155  		return n, err
   156  	}
   157  	fd.device.mu.Lock()
   158  	defer fd.device.mu.Unlock()
   159  	s := &fd.device.devAddrSet
   160  	r := DevAddrRange{userIoctlParams.DeviceAddress, userIoctlParams.DeviceAddress + userIoctlParams.Size}
   161  	seg := s.LowerBoundSegment(r.Start)
   162  	for seg.Ok() && seg.Start() < r.End {
   163  		seg = s.Isolate(seg, r)
   164  		v := seg.Value()
   165  		mm.Unpin([]mm.PinnedRange{v.pinnedRange})
   166  		gap := s.Remove(seg)
   167  		seg = gap.NextSegment()
   168  	}
   169  	return n, nil
   170  }
   171  
   172  func gasketInterruptMappingIoctl(ctx context.Context, t *kernel.Task, hostFd int32, paramsAddr hostarch.Addr, lite bool) (uintptr, error) {
   173  	var userIoctlParams gasket.GasketInterruptMapping
   174  	if _, err := userIoctlParams.CopyIn(t, paramsAddr); err != nil {
   175  		return 0, err
   176  	}
   177  
   178  	sizeOfInterruptList := tpu.SizeOfTPUV4InterruptList
   179  	interruptMap := tpu.TPUV4InterruptsMap
   180  	if lite {
   181  		sizeOfInterruptList = tpu.SizeOfTPUV4liteInterruptList
   182  		interruptMap = tpu.TPUV4liteInterruptsMap
   183  	}
   184  	if userIoctlParams.Interrupt >= sizeOfInterruptList {
   185  		return 0, linuxerr.EINVAL
   186  	}
   187  	barRegMap, ok := interruptMap[userIoctlParams.BarIndex]
   188  	if !ok {
   189  		return 0, linuxerr.EINVAL
   190  	}
   191  	if _, ok := barRegMap[userIoctlParams.RegOffset]; !ok {
   192  		return 0, linuxerr.EINVAL
   193  	}
   194  
   195  	// Check that 'userEventFD.Eventfd' is an eventfd.
   196  	eventFileGeneric, _ := t.FDTable().Get(int32(userIoctlParams.EventFD))
   197  	if eventFileGeneric == nil {
   198  		return 0, linuxerr.EBADF
   199  	}
   200  	defer eventFileGeneric.DecRef(ctx)
   201  	eventFile, ok := eventFileGeneric.Impl().(*eventfd.EventFileDescription)
   202  	if !ok {
   203  		return 0, linuxerr.EINVAL
   204  	}
   205  
   206  	eventfd, err := eventFile.HostFD()
   207  	if err != nil {
   208  		return 0, err
   209  	}
   210  
   211  	sentryIoctlParams := userIoctlParams
   212  	sentryIoctlParams.EventFD = uint64(eventfd)
   213  	n, err := ioctlInvokePtrArg(hostFd, gasket.GASKET_IOCTL_REGISTER_INTERRUPT, &sentryIoctlParams)
   214  	if err != nil {
   215  		return n, err
   216  	}
   217  
   218  	outIoctlParams := sentryIoctlParams
   219  	outIoctlParams.EventFD = userIoctlParams.EventFD
   220  	if _, err := outIoctlParams.CopyOut(t, paramsAddr); err != nil {
   221  		return n, err
   222  	}
   223  	return n, nil
   224  }