github.com/MerlinKodo/gvisor@v0.0.0-20231110090155-957f62ecf90e/pkg/sentry/devices/accel/accel.go (about)

     1  // Copyright 2023 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package accel implements proxying for hardware accelerators.
    16  package accel
    17  
    18  import (
    19  	"fmt"
    20  
    21  	"golang.org/x/sys/unix"
    22  	"github.com/MerlinKodo/gvisor/pkg/abi/gasket"
    23  	"github.com/MerlinKodo/gvisor/pkg/abi/linux"
    24  	"github.com/MerlinKodo/gvisor/pkg/context"
    25  	"github.com/MerlinKodo/gvisor/pkg/errors/linuxerr"
    26  	"github.com/MerlinKodo/gvisor/pkg/fdnotifier"
    27  	"github.com/MerlinKodo/gvisor/pkg/hostarch"
    28  	"github.com/MerlinKodo/gvisor/pkg/log"
    29  	"github.com/MerlinKodo/gvisor/pkg/sentry/arch"
    30  	"github.com/MerlinKodo/gvisor/pkg/sentry/kernel"
    31  	"github.com/MerlinKodo/gvisor/pkg/sentry/mm"
    32  	"github.com/MerlinKodo/gvisor/pkg/sentry/vfs"
    33  	"github.com/MerlinKodo/gvisor/pkg/usermem"
    34  	"github.com/MerlinKodo/gvisor/pkg/waiter"
    35  )
    36  
    37  // accelFD implements vfs.FileDescriptionImpl for /dev/accel[0-9]+.
    38  //
    39  // accelFD is not savable; we do not implement save/restore of accelerator
    40  // state.
    41  type accelFD struct {
    42  	vfsfd vfs.FileDescription
    43  	vfs.FileDescriptionDefaultImpl
    44  	vfs.DentryMetadataFileDescriptionImpl
    45  	vfs.NoLockFD
    46  
    47  	hostFD     int32
    48  	device     *accelDevice
    49  	queue      waiter.Queue
    50  	memmapFile accelFDMemmapFile
    51  }
    52  
    53  // Release implements vfs.FileDescriptionImpl.Release.
    54  func (fd *accelFD) Release(context.Context) {
    55  	fd.device.mu.Lock()
    56  	defer fd.device.mu.Unlock()
    57  	fd.device.openWriteFDs--
    58  	if fd.device.openWriteFDs == 0 {
    59  		log.Infof("openWriteFDs is zero, unpinning all sentry memory mappings")
    60  		s := &fd.device.devAddrSet
    61  		seg := s.FirstSegment()
    62  		for seg.Ok() {
    63  			r, v := seg.Range(), seg.Value()
    64  			gpti := gasket.GasketPageTableIoctl{
    65  				PageTableIndex: v.pageTableIndex,
    66  				DeviceAddress:  r.Start,
    67  				Size:           r.End - r.Start,
    68  				HostAddress:    0,
    69  			}
    70  			_, err := ioctlInvokePtrArg(fd.hostFD, gasket.GASKET_IOCTL_UNMAP_BUFFER, &gpti)
    71  			if err != nil {
    72  				log.Warningf("could not unmap range [%#x, %#x) (index %d) on device: %v", r.Start, r.End, v.pageTableIndex, err)
    73  			}
    74  			mm.Unpin([]mm.PinnedRange{v.pinnedRange})
    75  			gap := s.Remove(seg)
    76  			seg = gap.NextSegment()
    77  		}
    78  	}
    79  	fdnotifier.RemoveFD(fd.hostFD)
    80  	unix.Close(int(fd.hostFD))
    81  }
    82  
    83  // EventRegister implements waiter.Waitable.EventRegister.
    84  func (fd *accelFD) EventRegister(e *waiter.Entry) error {
    85  	fd.queue.EventRegister(e)
    86  	if err := fdnotifier.UpdateFD(fd.hostFD); err != nil {
    87  		fd.queue.EventUnregister(e)
    88  		return err
    89  	}
    90  	return nil
    91  }
    92  
    93  // EventUnregister implements waiter.Waitable.EventUnregister.
    94  func (fd *accelFD) EventUnregister(e *waiter.Entry) {
    95  	fd.queue.EventUnregister(e)
    96  	if err := fdnotifier.UpdateFD(fd.hostFD); err != nil {
    97  		panic(fmt.Sprint("UpdateFD:", err))
    98  	}
    99  }
   100  
   101  // Readiness implements waiter.Waitable.Readiness.
   102  func (fd *accelFD) Readiness(mask waiter.EventMask) waiter.EventMask {
   103  	return fdnotifier.NonBlockingPoll(fd.hostFD, mask)
   104  }
   105  
   106  // Epollable implements vfs.FileDescriptionImpl.Epollable.
   107  func (fd *accelFD) Epollable() bool {
   108  	return true
   109  }
   110  
   111  // Ioctl implements vfs.FileDescriptionImpl.Ioctl.
   112  func (fd *accelFD) Ioctl(ctx context.Context, uio usermem.IO, sysno uintptr, args arch.SyscallArguments) (uintptr, error) {
   113  	cmd := args[1].Uint()
   114  	argPtr := args[2].Pointer()
   115  	argSize := linux.IOC_SIZE(cmd)
   116  
   117  	t := kernel.TaskFromContext(ctx)
   118  	if t == nil {
   119  		panic("Ioctl should be called from a task context")
   120  	}
   121  
   122  	log.Infof("Accel ioctl %s called on fd %d with arg %v of size %d.", gasket.Ioctl(cmd), fd.hostFD, argPtr, argSize)
   123  	switch gasket.Ioctl(cmd) {
   124  	// Not yet implemented gasket ioctls.
   125  	case gasket.GASKET_IOCTL_SET_EVENTFD, gasket.GASKET_IOCTL_CLEAR_EVENTFD,
   126  		gasket.GASKET_IOCTL_NUMBER_PAGE_TABLES, gasket.GASKET_IOCTL_PAGE_TABLE_SIZE,
   127  		gasket.GASKET_IOCTL_SIMPLE_PAGE_TABLE_SIZE, gasket.GASKET_IOCTL_PARTITION_PAGE_TABLE,
   128  		gasket.GASKET_IOCTL_MAP_DMA_BUF:
   129  		return 0, linuxerr.ENOSYS
   130  	case gasket.GASKET_IOCTL_RESET:
   131  		return ioctlInvoke[uint64](fd.hostFD, gasket.GASKET_IOCTL_RESET, args[2].Uint64())
   132  	case gasket.GASKET_IOCTL_MAP_BUFFER:
   133  		return gasketMapBufferIoctl(ctx, t, fd.hostFD, fd, argPtr)
   134  	case gasket.GASKET_IOCTL_UNMAP_BUFFER:
   135  		return gasketUnmapBufferIoctl(ctx, t, fd.hostFD, fd, argPtr)
   136  	case gasket.GASKET_IOCTL_CLEAR_INTERRUPT_COUNTS:
   137  		return ioctlInvoke(fd.hostFD, gasket.GASKET_IOCTL_CLEAR_INTERRUPT_COUNTS, 0)
   138  	case gasket.GASKET_IOCTL_REGISTER_INTERRUPT:
   139  		return gasketInterruptMappingIoctl(ctx, t, fd.hostFD, argPtr)
   140  	case gasket.GASKET_IOCTL_UNREGISTER_INTERRUPT:
   141  		return ioctlInvoke[uint64](fd.hostFD, gasket.GASKET_IOCTL_UNREGISTER_INTERRUPT, args[2].Uint64())
   142  	default:
   143  		return 0, linuxerr.EINVAL
   144  	}
   145  }
   146  
   147  type pinnedAccelMem struct {
   148  	pinnedRange    mm.PinnedRange
   149  	pageTableIndex uint64
   150  }
   151  
   152  // DevAddrSet tracks device address ranges that have been mapped.
   153  type devAddrSetFuncs struct{}
   154  
   155  func (devAddrSetFuncs) MinKey() uint64 {
   156  	return 0
   157  }
   158  
   159  func (devAddrSetFuncs) MaxKey() uint64 {
   160  	return ^uint64(0)
   161  }
   162  
   163  func (devAddrSetFuncs) ClearValue(val *pinnedAccelMem) {
   164  	*val = pinnedAccelMem{}
   165  }
   166  
   167  func (devAddrSetFuncs) Merge(r1 DevAddrRange, v1 pinnedAccelMem, r2 DevAddrRange, v2 pinnedAccelMem) (pinnedAccelMem, bool) {
   168  	// Do we have the same backing file?
   169  	if v1.pinnedRange.File != v2.pinnedRange.File {
   170  		return pinnedAccelMem{}, false
   171  	}
   172  
   173  	// Do we have contiguous offsets in the backing file?
   174  	if v1.pinnedRange.Offset+uint64(v1.pinnedRange.Source.Length()) != v2.pinnedRange.Offset {
   175  		return pinnedAccelMem{}, false
   176  	}
   177  
   178  	// Are the virtual addresses contiguous?
   179  	//
   180  	// This check isn't strictly needed because 'mm.PinnedRange.Source'
   181  	// is only used to track the size of the pinned region (this is
   182  	// because the virtual address range can be unmapped or remapped
   183  	// elsewhere). Regardless we require this for simplicity.
   184  	if v1.pinnedRange.Source.End != v2.pinnedRange.Source.Start {
   185  		return pinnedAccelMem{}, false
   186  	}
   187  
   188  	// Extend v1 to account for the adjacent PinnedRange.
   189  	v1.pinnedRange.Source.End = v2.pinnedRange.Source.End
   190  	return v1, true
   191  }
   192  
   193  func (devAddrSetFuncs) Split(r DevAddrRange, val pinnedAccelMem, split uint64) (pinnedAccelMem, pinnedAccelMem) {
   194  	n := split - r.Start
   195  
   196  	left := val
   197  	left.pinnedRange.Source.End = left.pinnedRange.Source.Start + hostarch.Addr(n)
   198  
   199  	right := val
   200  	right.pinnedRange.Source.Start += hostarch.Addr(n)
   201  	right.pinnedRange.Offset += n
   202  
   203  	return left, right
   204  }