gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/sentry/devices/tpuproxy/vfio.go (about)

     1  // Copyright 2024 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package tpuproxy
    16  
    17  import (
    18  	"fmt"
    19  
    20  	"golang.org/x/sys/unix"
    21  	"gvisor.dev/gvisor/pkg/abi/linux"
    22  	"gvisor.dev/gvisor/pkg/cleanup"
    23  	"gvisor.dev/gvisor/pkg/context"
    24  	"gvisor.dev/gvisor/pkg/errors/linuxerr"
    25  	"gvisor.dev/gvisor/pkg/fdnotifier"
    26  	"gvisor.dev/gvisor/pkg/hostarch"
    27  	"gvisor.dev/gvisor/pkg/log"
    28  	"gvisor.dev/gvisor/pkg/sentry/arch"
    29  	"gvisor.dev/gvisor/pkg/sentry/kernel"
    30  	"gvisor.dev/gvisor/pkg/sentry/memmap"
    31  	"gvisor.dev/gvisor/pkg/sentry/mm"
    32  	"gvisor.dev/gvisor/pkg/sentry/vfs"
    33  	"gvisor.dev/gvisor/pkg/usermem"
    34  	"gvisor.dev/gvisor/pkg/waiter"
    35  )
    36  
    37  // deviceFD implements vfs.FileDescriptionImpl for /dev/vfio/vfio.
    38  type vfioFD struct {
    39  	vfsfd vfs.FileDescription
    40  	vfs.FileDescriptionDefaultImpl
    41  	vfs.DentryMetadataFileDescriptionImpl
    42  	vfs.NoLockFD
    43  
    44  	hostFD     int32
    45  	device     *vfioDevice
    46  	queue      waiter.Queue
    47  	memmapFile vfioFDMemmapFile
    48  }
    49  
    50  // Release implements vfs.FileDescriptionImpl.Release.
    51  func (fd *vfioFD) Release(context.Context) {
    52  	fdnotifier.RemoveFD(fd.hostFD)
    53  	fd.queue.Notify(waiter.EventHUp)
    54  	unix.Close(int(fd.hostFD))
    55  }
    56  
    57  // EventRegister implements waiter.Waitable.EventRegister.
    58  func (fd *vfioFD) EventRegister(e *waiter.Entry) error {
    59  	fd.queue.EventRegister(e)
    60  	if err := fdnotifier.UpdateFD(fd.hostFD); err != nil {
    61  		fd.queue.EventUnregister(e)
    62  		return err
    63  	}
    64  	return nil
    65  }
    66  
    67  // EventUnregister implements waiter.Waitable.EventUnregister.
    68  func (fd *vfioFD) EventUnregister(e *waiter.Entry) {
    69  	fd.queue.EventUnregister(e)
    70  	if err := fdnotifier.UpdateFD(fd.hostFD); err != nil {
    71  		panic(fmt.Sprint("UpdateFD:", err))
    72  	}
    73  }
    74  
    75  // Readiness implements waiter.Waitable.Readiness.
    76  func (fd *vfioFD) Readiness(mask waiter.EventMask) waiter.EventMask {
    77  	return fdnotifier.NonBlockingPoll(fd.hostFD, mask)
    78  }
    79  
    80  // Epollable implements vfs.FileDescriptionImpl.Epollable.
    81  func (fd *vfioFD) Epollable() bool {
    82  	return true
    83  }
    84  
    85  // Ioctl implements vfs.FileDescriptionImpl.Ioctl.
    86  func (fd *vfioFD) Ioctl(ctx context.Context, uio usermem.IO, sysno uintptr, args arch.SyscallArguments) (uintptr, error) {
    87  	cmd := args[1].Uint()
    88  	t := kernel.TaskFromContext(ctx)
    89  	if t == nil {
    90  		panic("Ioctl should be called from a task context")
    91  	}
    92  	switch cmd {
    93  	case linux.VFIO_CHECK_EXTENSION:
    94  		return fd.checkExtension(extension(args[2].Int()))
    95  	case linux.VFIO_SET_IOMMU:
    96  		return fd.setIOMMU(extension(args[2].Int()))
    97  	case linux.VFIO_IOMMU_MAP_DMA:
    98  		return fd.iommuMapDma(ctx, t, args[2].Pointer())
    99  	case linux.VFIO_IOMMU_UNMAP_DMA:
   100  		return fd.iommuUnmapDma(ctx, t, args[2].Pointer())
   101  	}
   102  	return 0, linuxerr.ENOSYS
   103  }
   104  
   105  // checkExtension returns a positive integer when the given VFIO extension
   106  // is supported, otherwise, it returns 0.
   107  func (fd *vfioFD) checkExtension(ext extension) (uintptr, error) {
   108  	switch ext {
   109  	case linux.VFIO_TYPE1_IOMMU, linux.VFIO_SPAPR_TCE_IOMMU, linux.VFIO_TYPE1v2_IOMMU:
   110  		ret, err := IOCTLInvoke[uint32, int32](fd.hostFD, linux.VFIO_CHECK_EXTENSION, int32(ext))
   111  		if err != nil {
   112  			log.Warningf("check VFIO extension %s: %v", ext, err)
   113  			return 0, err
   114  		}
   115  		return ret, nil
   116  	}
   117  	return 0, linuxerr.EINVAL
   118  }
   119  
   120  // Set the iommu to the given type.  The type must be supported by an iommu
   121  // driver as verified by calling VFIO_CHECK_EXTENSION using the same type.
   122  func (fd *vfioFD) setIOMMU(ext extension) (uintptr, error) {
   123  	switch ext {
   124  	case linux.VFIO_TYPE1_IOMMU, linux.VFIO_SPAPR_TCE_IOMMU, linux.VFIO_TYPE1v2_IOMMU:
   125  		ret, err := IOCTLInvoke[uint32, int32](fd.hostFD, linux.VFIO_SET_IOMMU, int32(ext))
   126  		if err != nil {
   127  			log.Warningf("set the IOMMU group to %s: %v", ext, err)
   128  			return 0, err
   129  		}
   130  		return ret, nil
   131  	}
   132  	return 0, linuxerr.EINVAL
   133  }
   134  
   135  func (fd *vfioFD) iommuMapDma(ctx context.Context, t *kernel.Task, arg hostarch.Addr) (uintptr, error) {
   136  	var dmaMap linux.VFIOIommuType1DmaMap
   137  	if _, err := dmaMap.CopyIn(t, arg); err != nil {
   138  		return 0, err
   139  	}
   140  	tmm := t.MemoryManager()
   141  	ar, ok := tmm.CheckIORange(hostarch.Addr(dmaMap.Vaddr), int64(dmaMap.Size))
   142  	if !ok {
   143  		return 0, linuxerr.EFAULT
   144  	}
   145  	if !ar.IsPageAligned() || (dmaMap.Size/hostarch.PageSize) == 0 {
   146  		return 0, linuxerr.EINVAL
   147  	}
   148  	// See comments at pkg/sentry/devices/accel/gasket.go, line 57-60.
   149  	devAddr := dmaMap.IOVa
   150  	devAddr &^= (hostarch.PageSize - 1)
   151  
   152  	devar := DevAddrRange{
   153  		devAddr,
   154  		devAddr + dmaMap.Size,
   155  	}
   156  	if !devar.WellFormed() {
   157  		return 0, linuxerr.EINVAL
   158  	}
   159  	// Reserve a range in the address space.
   160  	m, _, errno := unix.RawSyscall6(unix.SYS_MMAP, 0 /* addr */, uintptr(ar.Length()), unix.PROT_NONE, unix.MAP_PRIVATE|unix.MAP_ANONYMOUS, ^uintptr(0), 0)
   161  	if errno != 0 {
   162  		return 0, errno
   163  	}
   164  	cu := cleanup.Make(func() {
   165  		unix.RawSyscall(unix.SYS_MUNMAP, m, uintptr(ar.Length()), 0)
   166  	})
   167  	defer cu.Clean()
   168  	// Mirror application mappings into the reserved range.
   169  	prs, err := t.MemoryManager().Pin(ctx, ar, hostarch.ReadWrite, false)
   170  	cu.Add(func() {
   171  		mm.Unpin(prs)
   172  	})
   173  	if err != nil {
   174  		return 0, err
   175  	}
   176  	sentryAddr := uintptr(m)
   177  	for _, pr := range prs {
   178  		ims, err := pr.File.MapInternal(memmap.FileRange{Start: pr.Offset, End: pr.Offset + uint64(pr.Source.Length())}, hostarch.ReadWrite)
   179  		if err != nil {
   180  			return 0, err
   181  		}
   182  		for !ims.IsEmpty() {
   183  			im := ims.Head()
   184  			if _, _, errno := unix.RawSyscall6(unix.SYS_MREMAP, im.Addr(), 0, uintptr(im.Len()), linux.MREMAP_MAYMOVE|linux.MREMAP_FIXED, sentryAddr, 0); errno != 0 {
   185  				return 0, errno
   186  			}
   187  			sentryAddr += uintptr(im.Len())
   188  			ims = ims.Tail()
   189  		}
   190  	}
   191  	// Replace Vaddr with the host's virtual address.
   192  	dmaMap.Vaddr = uint64(m)
   193  	n, err := IOCTLInvokePtrArg[uint32](fd.hostFD, linux.VFIO_IOMMU_MAP_DMA, &dmaMap)
   194  	if err != nil {
   195  		return n, err
   196  	}
   197  	cu.Release()
   198  	// Unmap the reserved range, which is no longer required.
   199  	unix.RawSyscall(unix.SYS_MUNMAP, m, uintptr(ar.Length()), 0)
   200  
   201  	fd.device.mu.Lock()
   202  	defer fd.device.mu.Unlock()
   203  	for _, pr := range prs {
   204  		rlen := uint64(pr.Source.Length())
   205  		fd.device.devAddrSet.InsertRange(DevAddrRange{
   206  			devAddr,
   207  			devAddr + rlen,
   208  		}, pr)
   209  		devAddr += rlen
   210  	}
   211  	return n, nil
   212  }
   213  
   214  func (fd *vfioFD) iommuUnmapDma(ctx context.Context, t *kernel.Task, arg hostarch.Addr) (uintptr, error) {
   215  	var dmaUnmap linux.VFIOIommuType1DmaUnmap
   216  	if _, err := dmaUnmap.CopyIn(t, arg); err != nil {
   217  		return 0, err
   218  	}
   219  	if dmaUnmap.Flags&linux.VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP != 0 {
   220  		// VFIO_DMA_UNMAP_FALGS_GET_DIRTY_BITMAP is not used by libtpu for
   221  		// gVisor working with TPU.
   222  		return 0, linuxerr.ENOSYS
   223  	}
   224  	n, err := IOCTLInvokePtrArg[uint32](fd.hostFD, linux.VFIO_IOMMU_MAP_DMA, &dmaUnmap)
   225  	if err != nil {
   226  		return 0, nil
   227  	}
   228  	fd.device.mu.Lock()
   229  	defer fd.device.mu.Unlock()
   230  	s := &fd.device.devAddrSet
   231  	r := DevAddrRange{Start: dmaUnmap.IOVa, End: dmaUnmap.IOVa + dmaUnmap.Size}
   232  	seg := s.LowerBoundSegment(r.Start)
   233  	for seg.Ok() && seg.Start() < r.End {
   234  		seg = s.Isolate(seg, r)
   235  		mm.Unpin([]mm.PinnedRange{seg.Value()})
   236  		gap := s.Remove(seg)
   237  		seg = gap.NextSegment()
   238  	}
   239  	return n, nil
   240  }
   241  
   242  // VFIO extension.
   243  type extension int32
   244  
   245  // String implements fmt.Stringer for VFIO extension string representation.
   246  func (e extension) String() string {
   247  	switch e {
   248  	case linux.VFIO_TYPE1_IOMMU:
   249  		return "VFIO_TYPE1_IOMMU"
   250  	case linux.VFIO_SPAPR_TCE_IOMMU:
   251  		return "VFIO_SPAPR_TCE_IOMMU"
   252  	case linux.VFIO_TYPE1v2_IOMMU:
   253  		return "VFIO_TYPE1v2_IOMMU"
   254  	}
   255  	return ""
   256  }