github.com/metacubex/gvisor@v0.0.0-20240320004321-933faba989ec/pkg/sentry/devices/accel/tpu_v4.go (about) 1 // Copyright 2023 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package accel implements proxying for hardware accelerators. 16 package accel 17 18 import ( 19 "fmt" 20 21 "golang.org/x/sys/unix" 22 "github.com/metacubex/gvisor/pkg/abi/gasket" 23 "github.com/metacubex/gvisor/pkg/abi/linux" 24 "github.com/metacubex/gvisor/pkg/context" 25 "github.com/metacubex/gvisor/pkg/errors/linuxerr" 26 "github.com/metacubex/gvisor/pkg/fdnotifier" 27 "github.com/metacubex/gvisor/pkg/hostarch" 28 "github.com/metacubex/gvisor/pkg/log" 29 "github.com/metacubex/gvisor/pkg/sentry/arch" 30 "github.com/metacubex/gvisor/pkg/sentry/kernel" 31 "github.com/metacubex/gvisor/pkg/sentry/mm" 32 "github.com/metacubex/gvisor/pkg/sentry/vfs" 33 "github.com/metacubex/gvisor/pkg/usermem" 34 "github.com/metacubex/gvisor/pkg/waiter" 35 ) 36 37 // tpuV4FD implements vfs.FileDescriptionImpl for /dev/accel[0-9]+. 38 // 39 // accelFD is not savable; we do not implement save/restore of accelerator 40 // state. 41 type tpuV4FD struct { 42 vfsfd vfs.FileDescription 43 vfs.FileDescriptionDefaultImpl 44 vfs.DentryMetadataFileDescriptionImpl 45 vfs.NoLockFD 46 47 hostFD int32 48 device *tpuV4Device 49 queue waiter.Queue 50 memmapFile accelFDMemmapFile 51 } 52 53 // Release implements vfs.FileDescriptionImpl.Release. 54 func (fd *tpuV4FD) Release(context.Context) { 55 fd.device.mu.Lock() 56 defer fd.device.mu.Unlock() 57 fd.device.openWriteFDs-- 58 if fd.device.openWriteFDs == 0 { 59 log.Infof("openWriteFDs is zero, unpinning all sentry memory mappings") 60 s := &fd.device.devAddrSet 61 seg := s.FirstSegment() 62 for seg.Ok() { 63 r, v := seg.Range(), seg.Value() 64 gpti := gasket.GasketPageTableIoctl{ 65 PageTableIndex: v.pageTableIndex, 66 DeviceAddress: r.Start, 67 Size: r.End - r.Start, 68 HostAddress: 0, 69 } 70 _, err := ioctlInvokePtrArg(fd.hostFD, gasket.GASKET_IOCTL_UNMAP_BUFFER, &gpti) 71 if err != nil { 72 log.Warningf("could not unmap range [%#x, %#x) (index %d) on device: %v", r.Start, r.End, v.pageTableIndex, err) 73 } 74 mm.Unpin([]mm.PinnedRange{v.pinnedRange}) 75 gap := s.Remove(seg) 76 seg = gap.NextSegment() 77 } 78 fd.device.owner = nil 79 } 80 fdnotifier.RemoveFD(fd.hostFD) 81 unix.Close(int(fd.hostFD)) 82 } 83 84 // EventRegister implements waiter.Waitable.EventRegister. 85 func (fd *tpuV4FD) EventRegister(e *waiter.Entry) error { 86 fd.queue.EventRegister(e) 87 if err := fdnotifier.UpdateFD(fd.hostFD); err != nil { 88 fd.queue.EventUnregister(e) 89 return err 90 } 91 return nil 92 } 93 94 // EventUnregister implements waiter.Waitable.EventUnregister. 95 func (fd *tpuV4FD) EventUnregister(e *waiter.Entry) { 96 fd.queue.EventUnregister(e) 97 if err := fdnotifier.UpdateFD(fd.hostFD); err != nil { 98 panic(fmt.Sprint("UpdateFD:", err)) 99 } 100 } 101 102 // Readiness implements waiter.Waitable.Readiness. 103 func (fd *tpuV4FD) Readiness(mask waiter.EventMask) waiter.EventMask { 104 return fdnotifier.NonBlockingPoll(fd.hostFD, mask) 105 } 106 107 // Epollable implements vfs.FileDescriptionImpl.Epollable. 108 func (fd *tpuV4FD) Epollable() bool { 109 return true 110 } 111 112 // Ioctl implements vfs.FileDescriptionImpl.Ioctl. 113 func (fd *tpuV4FD) Ioctl(ctx context.Context, uio usermem.IO, sysno uintptr, args arch.SyscallArguments) (uintptr, error) { 114 cmd := args[1].Uint() 115 argPtr := args[2].Pointer() 116 argSize := linux.IOC_SIZE(cmd) 117 118 t := kernel.TaskFromContext(ctx) 119 if t == nil { 120 panic("Ioctl should be called from a task context") 121 } 122 if err := fd.checkPermission(t); err != nil { 123 return 0, err 124 } 125 126 log.Infof("Accel ioctl %s called on fd %d with arg %v of size %d.", gasket.Ioctl(cmd), fd.hostFD, argPtr, argSize) 127 switch gasket.Ioctl(cmd) { 128 // Not yet implemented gasket ioctls. 129 case gasket.GASKET_IOCTL_SET_EVENTFD, gasket.GASKET_IOCTL_CLEAR_EVENTFD, 130 gasket.GASKET_IOCTL_NUMBER_PAGE_TABLES, gasket.GASKET_IOCTL_PAGE_TABLE_SIZE, 131 gasket.GASKET_IOCTL_SIMPLE_PAGE_TABLE_SIZE, gasket.GASKET_IOCTL_PARTITION_PAGE_TABLE, 132 gasket.GASKET_IOCTL_MAP_DMA_BUF: 133 return 0, linuxerr.ENOSYS 134 case gasket.GASKET_IOCTL_RESET: 135 return ioctlInvoke[uint64](fd.hostFD, gasket.GASKET_IOCTL_RESET, args[2].Uint64()) 136 case gasket.GASKET_IOCTL_MAP_BUFFER: 137 return gasketMapBufferIoctl(ctx, t, fd.hostFD, fd, argPtr) 138 case gasket.GASKET_IOCTL_UNMAP_BUFFER: 139 return gasketUnmapBufferIoctl(ctx, t, fd.hostFD, fd, argPtr) 140 case gasket.GASKET_IOCTL_CLEAR_INTERRUPT_COUNTS: 141 return ioctlInvoke(fd.hostFD, gasket.GASKET_IOCTL_CLEAR_INTERRUPT_COUNTS, 0) 142 case gasket.GASKET_IOCTL_REGISTER_INTERRUPT: 143 return gasketInterruptMappingIoctl(ctx, t, fd.hostFD, argPtr, fd.device.lite) 144 case gasket.GASKET_IOCTL_UNREGISTER_INTERRUPT: 145 return ioctlInvoke[uint64](fd.hostFD, gasket.GASKET_IOCTL_UNREGISTER_INTERRUPT, args[2].Uint64()) 146 default: 147 return 0, linuxerr.EINVAL 148 } 149 } 150 151 // checkPermission checks that the thread that owns this device is the only 152 // one that can issue commands to the TPU. Other threads with access to 153 // /dev/accel will not be able to issue commands to the device. 154 func (fd *tpuV4FD) checkPermission(t *kernel.Task) error { 155 fd.device.mu.Lock() 156 defer fd.device.mu.Unlock() 157 owner := fd.device.owner 158 if t.ThreadGroup() != owner { 159 return linuxerr.EPERM 160 } 161 return nil 162 } 163 164 type pinnedAccelMem struct { 165 pinnedRange mm.PinnedRange 166 pageTableIndex uint64 167 } 168 169 // DevAddrSet tracks device address ranges that have been mapped. 170 type devAddrSetFuncs struct{} 171 172 func (devAddrSetFuncs) MinKey() uint64 { 173 return 0 174 } 175 176 func (devAddrSetFuncs) MaxKey() uint64 { 177 return ^uint64(0) 178 } 179 180 func (devAddrSetFuncs) ClearValue(val *pinnedAccelMem) { 181 *val = pinnedAccelMem{} 182 } 183 184 func (devAddrSetFuncs) Merge(r1 DevAddrRange, v1 pinnedAccelMem, r2 DevAddrRange, v2 pinnedAccelMem) (pinnedAccelMem, bool) { 185 // Do we have the same backing file? 186 if v1.pinnedRange.File != v2.pinnedRange.File { 187 return pinnedAccelMem{}, false 188 } 189 190 // Do we have contiguous offsets in the backing file? 191 if v1.pinnedRange.Offset+uint64(v1.pinnedRange.Source.Length()) != v2.pinnedRange.Offset { 192 return pinnedAccelMem{}, false 193 } 194 195 // Are the virtual addresses contiguous? 196 // 197 // This check isn't strictly needed because 'mm.PinnedRange.Source' 198 // is only used to track the size of the pinned region (this is 199 // because the virtual address range can be unmapped or remapped 200 // elsewhere). Regardless we require this for simplicity. 201 if v1.pinnedRange.Source.End != v2.pinnedRange.Source.Start { 202 return pinnedAccelMem{}, false 203 } 204 205 // Extend v1 to account for the adjacent PinnedRange. 206 v1.pinnedRange.Source.End = v2.pinnedRange.Source.End 207 return v1, true 208 } 209 210 func (devAddrSetFuncs) Split(r DevAddrRange, val pinnedAccelMem, split uint64) (pinnedAccelMem, pinnedAccelMem) { 211 n := split - r.Start 212 213 left := val 214 left.pinnedRange.Source.End = left.pinnedRange.Source.Start + hostarch.Addr(n) 215 216 right := val 217 right.pinnedRange.Source.Start += hostarch.Addr(n) 218 right.pinnedRange.Offset += n 219 220 return left, right 221 }