github.com/metacubex/gvisor@v0.0.0-20240320004321-933faba989ec/pkg/sentry/devices/tpuproxy/tpu.go (about)

     1  // Copyright 2023 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package tpuproxy implements proxying for TPU devices.
    16  package tpuproxy
    17  
    18  import (
    19  	"fmt"
    20  
    21  	"golang.org/x/sys/unix"
    22  	"github.com/metacubex/gvisor/pkg/abi/linux"
    23  	"github.com/metacubex/gvisor/pkg/context"
    24  	"github.com/metacubex/gvisor/pkg/errors/linuxerr"
    25  	"github.com/metacubex/gvisor/pkg/fdnotifier"
    26  	"github.com/metacubex/gvisor/pkg/hostarch"
    27  	"github.com/metacubex/gvisor/pkg/marshal/primitive"
    28  	"github.com/metacubex/gvisor/pkg/sentry/arch"
    29  	"github.com/metacubex/gvisor/pkg/sentry/kernel"
    30  	"github.com/metacubex/gvisor/pkg/sentry/vfs"
    31  	"github.com/metacubex/gvisor/pkg/usermem"
    32  	"github.com/metacubex/gvisor/pkg/waiter"
    33  )
    34  
    35  // tpuFD implements vfs.FileDescriptionImpl for /dev/vfio/[0-9]+
    36  //
    37  // tpuFD is not savable until TPU save/restore is needed.
    38  type tpuFD struct {
    39  	vfsfd vfs.FileDescription
    40  	vfs.FileDescriptionDefaultImpl
    41  	vfs.DentryMetadataFileDescriptionImpl
    42  	vfs.NoLockFD
    43  
    44  	hostFD     int32
    45  	device     *tpuDevice
    46  	queue      waiter.Queue
    47  	memmapFile tpuFDMemmapFile
    48  }
    49  
    50  // Release implements vfs.FileDescriptionImpl.Release.
    51  func (fd *tpuFD) Release(context.Context) {
    52  	fdnotifier.RemoveFD(fd.hostFD)
    53  	fd.queue.Notify(waiter.EventHUp)
    54  	unix.Close(int(fd.hostFD))
    55  }
    56  
    57  // EventRegister implements waiter.Waitable.EventRegister.
    58  func (fd *tpuFD) EventRegister(e *waiter.Entry) error {
    59  	fd.queue.EventRegister(e)
    60  	if err := fdnotifier.UpdateFD(fd.hostFD); err != nil {
    61  		fd.queue.EventUnregister(e)
    62  		return err
    63  	}
    64  	return nil
    65  }
    66  
    67  // EventUnregister implements waiter.Waitable.EventUnregister.
    68  func (fd *tpuFD) EventUnregister(e *waiter.Entry) {
    69  	fd.queue.EventUnregister(e)
    70  	if err := fdnotifier.UpdateFD(fd.hostFD); err != nil {
    71  		panic(fmt.Sprint("UpdateFD:", err))
    72  	}
    73  }
    74  
    75  // Readiness implements waiter.Waitable.Readiness.
    76  func (fd *tpuFD) Readiness(mask waiter.EventMask) waiter.EventMask {
    77  	return fdnotifier.NonBlockingPoll(fd.hostFD, mask)
    78  }
    79  
    80  // Epollable implements vfs.FileDescriptionImpl.Epollable.
    81  func (fd *tpuFD) Epollable() bool {
    82  	return true
    83  }
    84  
    85  // Ioctl implements vfs.FileDescriptionImpl.Ioctl.
    86  func (fd *tpuFD) Ioctl(ctx context.Context, uio usermem.IO, sysno uintptr, args arch.SyscallArguments) (uintptr, error) {
    87  	cmd := args[1].Uint()
    88  
    89  	t := kernel.TaskFromContext(ctx)
    90  	if t == nil {
    91  		panic("Ioctl should be called from a task context")
    92  	}
    93  	switch cmd {
    94  	case linux.VFIO_GROUP_SET_CONTAINER:
    95  		return fd.setContainer(ctx, t, args[2].Pointer())
    96  	}
    97  	return 0, linuxerr.ENOSYS
    98  }
    99  
   100  func (fd *tpuFD) setContainer(ctx context.Context, t *kernel.Task, arg hostarch.Addr) (uintptr, error) {
   101  	var vfioContainerFd int32
   102  	if _, err := primitive.CopyInt32In(t, arg, &vfioContainerFd); err != nil {
   103  		return 0, err
   104  	}
   105  	vfioContainerFile, _ := t.FDTable().Get(vfioContainerFd)
   106  	if vfioContainerFile == nil {
   107  		return 0, linuxerr.EBADF
   108  	}
   109  	defer vfioContainerFile.DecRef(ctx)
   110  	vfioContainer, ok := vfioContainerFile.Impl().(*vfioFd)
   111  	if !ok {
   112  		return 0, linuxerr.EINVAL
   113  	}
   114  	return ioctlInvokePtrArg(fd.hostFD, linux.VFIO_GROUP_SET_CONTAINER, &vfioContainer.hostFd)
   115  }