gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/sentry/devices/accel/tpu_v4.go (about) 1 // Copyright 2023 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package accel implements proxying for hardware accelerators. 16 package accel 17 18 import ( 19 "fmt" 20 21 "golang.org/x/sys/unix" 22 "gvisor.dev/gvisor/pkg/abi/gasket" 23 "gvisor.dev/gvisor/pkg/abi/linux" 24 "gvisor.dev/gvisor/pkg/context" 25 "gvisor.dev/gvisor/pkg/errors/linuxerr" 26 "gvisor.dev/gvisor/pkg/fdnotifier" 27 "gvisor.dev/gvisor/pkg/hostarch" 28 "gvisor.dev/gvisor/pkg/log" 29 "gvisor.dev/gvisor/pkg/sentry/arch" 30 "gvisor.dev/gvisor/pkg/sentry/devices/tpuproxy" 31 "gvisor.dev/gvisor/pkg/sentry/kernel" 32 "gvisor.dev/gvisor/pkg/sentry/mm" 33 "gvisor.dev/gvisor/pkg/sentry/vfs" 34 "gvisor.dev/gvisor/pkg/usermem" 35 "gvisor.dev/gvisor/pkg/waiter" 36 ) 37 38 // tpuV4FD implements vfs.FileDescriptionImpl for /dev/accel[0-9]+. 39 // 40 // accelFD is not savable; we do not implement save/restore of accelerator 41 // state. 42 type tpuV4FD struct { 43 vfsfd vfs.FileDescription 44 vfs.FileDescriptionDefaultImpl 45 vfs.DentryMetadataFileDescriptionImpl 46 vfs.NoLockFD 47 48 hostFD int32 49 device *tpuV4Device 50 queue waiter.Queue 51 memmapFile accelFDMemmapFile 52 } 53 54 // Release implements vfs.FileDescriptionImpl.Release. 55 func (fd *tpuV4FD) Release(context.Context) { 56 fd.device.mu.Lock() 57 defer fd.device.mu.Unlock() 58 fd.device.openWriteFDs-- 59 if fd.device.openWriteFDs == 0 { 60 log.Infof("openWriteFDs is zero, unpinning all sentry memory mappings") 61 s := &fd.device.devAddrSet 62 seg := s.FirstSegment() 63 for seg.Ok() { 64 r, v := seg.Range(), seg.Value() 65 gpti := gasket.GasketPageTableIoctl{ 66 PageTableIndex: v.pageTableIndex, 67 DeviceAddress: r.Start, 68 Size: r.End - r.Start, 69 HostAddress: 0, 70 } 71 _, err := tpuproxy.IOCTLInvokePtrArg[gasket.Ioctl](fd.hostFD, gasket.GASKET_IOCTL_UNMAP_BUFFER, &gpti) 72 if err != nil { 73 log.Warningf("could not unmap range [%#x, %#x) (index %d) on device: %v", r.Start, r.End, v.pageTableIndex, err) 74 } 75 mm.Unpin([]mm.PinnedRange{v.pinnedRange}) 76 gap := s.Remove(seg) 77 seg = gap.NextSegment() 78 } 79 fd.device.owner = nil 80 } 81 fdnotifier.RemoveFD(fd.hostFD) 82 unix.Close(int(fd.hostFD)) 83 } 84 85 // EventRegister implements waiter.Waitable.EventRegister. 86 func (fd *tpuV4FD) EventRegister(e *waiter.Entry) error { 87 fd.queue.EventRegister(e) 88 if err := fdnotifier.UpdateFD(fd.hostFD); err != nil { 89 fd.queue.EventUnregister(e) 90 return err 91 } 92 return nil 93 } 94 95 // EventUnregister implements waiter.Waitable.EventUnregister. 96 func (fd *tpuV4FD) EventUnregister(e *waiter.Entry) { 97 fd.queue.EventUnregister(e) 98 if err := fdnotifier.UpdateFD(fd.hostFD); err != nil { 99 panic(fmt.Sprint("UpdateFD:", err)) 100 } 101 } 102 103 // Readiness implements waiter.Waitable.Readiness. 104 func (fd *tpuV4FD) Readiness(mask waiter.EventMask) waiter.EventMask { 105 return fdnotifier.NonBlockingPoll(fd.hostFD, mask) 106 } 107 108 // Epollable implements vfs.FileDescriptionImpl.Epollable. 109 func (fd *tpuV4FD) Epollable() bool { 110 return true 111 } 112 113 // Ioctl implements vfs.FileDescriptionImpl.Ioctl. 114 func (fd *tpuV4FD) Ioctl(ctx context.Context, uio usermem.IO, sysno uintptr, args arch.SyscallArguments) (uintptr, error) { 115 cmd := args[1].Uint() 116 argPtr := args[2].Pointer() 117 argSize := linux.IOC_SIZE(cmd) 118 119 t := kernel.TaskFromContext(ctx) 120 if t == nil { 121 panic("Ioctl should be called from a task context") 122 } 123 if err := fd.checkPermission(t); err != nil { 124 return 0, err 125 } 126 127 log.Infof("Accel ioctl %s called on fd %d with arg %v of size %d.", gasket.Ioctl(cmd), fd.hostFD, argPtr, argSize) 128 switch gasket.Ioctl(cmd) { 129 // Not yet implemented gasket ioctls. 130 case gasket.GASKET_IOCTL_SET_EVENTFD, gasket.GASKET_IOCTL_CLEAR_EVENTFD, 131 gasket.GASKET_IOCTL_NUMBER_PAGE_TABLES, gasket.GASKET_IOCTL_PAGE_TABLE_SIZE, 132 gasket.GASKET_IOCTL_SIMPLE_PAGE_TABLE_SIZE, gasket.GASKET_IOCTL_PARTITION_PAGE_TABLE, 133 gasket.GASKET_IOCTL_MAP_DMA_BUF: 134 return 0, linuxerr.ENOSYS 135 case gasket.GASKET_IOCTL_RESET: 136 return tpuproxy.IOCTLInvoke[gasket.Ioctl, uint64](fd.hostFD, gasket.GASKET_IOCTL_RESET, args[2].Uint64()) 137 case gasket.GASKET_IOCTL_MAP_BUFFER: 138 return gasketMapBufferIoctl(ctx, t, fd.hostFD, fd, argPtr) 139 case gasket.GASKET_IOCTL_UNMAP_BUFFER: 140 return gasketUnmapBufferIoctl(ctx, t, fd.hostFD, fd, argPtr) 141 case gasket.GASKET_IOCTL_CLEAR_INTERRUPT_COUNTS: 142 return tpuproxy.IOCTLInvoke[gasket.Ioctl](fd.hostFD, gasket.GASKET_IOCTL_CLEAR_INTERRUPT_COUNTS, 0) 143 case gasket.GASKET_IOCTL_REGISTER_INTERRUPT: 144 return gasketInterruptMappingIoctl(ctx, t, fd.hostFD, argPtr, fd.device.lite) 145 case gasket.GASKET_IOCTL_UNREGISTER_INTERRUPT: 146 return tpuproxy.IOCTLInvoke[gasket.Ioctl, uint64](fd.hostFD, gasket.GASKET_IOCTL_UNREGISTER_INTERRUPT, args[2].Uint64()) 147 default: 148 return 0, linuxerr.EINVAL 149 } 150 } 151 152 // checkPermission checks that the thread that owns this device is the only 153 // one that can issue commands to the TPU. Other threads with access to 154 // /dev/accel will not be able to issue commands to the device. 155 func (fd *tpuV4FD) checkPermission(t *kernel.Task) error { 156 fd.device.mu.Lock() 157 defer fd.device.mu.Unlock() 158 owner := fd.device.owner 159 if t.ThreadGroup() != owner { 160 return linuxerr.EPERM 161 } 162 return nil 163 } 164 165 type pinnedAccelMem struct { 166 pinnedRange mm.PinnedRange 167 pageTableIndex uint64 168 } 169 170 // DevAddrSet tracks device address ranges that have been mapped. 171 type devAddrSetFuncs struct{} 172 173 func (devAddrSetFuncs) MinKey() uint64 { 174 return 0 175 } 176 177 func (devAddrSetFuncs) MaxKey() uint64 { 178 return ^uint64(0) 179 } 180 181 func (devAddrSetFuncs) ClearValue(val *pinnedAccelMem) { 182 *val = pinnedAccelMem{} 183 } 184 185 func (devAddrSetFuncs) Merge(r1 DevAddrRange, v1 pinnedAccelMem, r2 DevAddrRange, v2 pinnedAccelMem) (pinnedAccelMem, bool) { 186 // Do we have the same backing file? 187 if v1.pinnedRange.File != v2.pinnedRange.File { 188 return pinnedAccelMem{}, false 189 } 190 191 // Do we have contiguous offsets in the backing file? 192 if v1.pinnedRange.Offset+uint64(v1.pinnedRange.Source.Length()) != v2.pinnedRange.Offset { 193 return pinnedAccelMem{}, false 194 } 195 196 // Are the virtual addresses contiguous? 197 // 198 // This check isn't strictly needed because 'mm.PinnedRange.Source' 199 // is only used to track the size of the pinned region (this is 200 // because the virtual address range can be unmapped or remapped 201 // elsewhere). Regardless we require this for simplicity. 202 if v1.pinnedRange.Source.End != v2.pinnedRange.Source.Start { 203 return pinnedAccelMem{}, false 204 } 205 206 // Extend v1 to account for the adjacent PinnedRange. 207 v1.pinnedRange.Source.End = v2.pinnedRange.Source.End 208 return v1, true 209 } 210 211 func (devAddrSetFuncs) Split(r DevAddrRange, val pinnedAccelMem, split uint64) (pinnedAccelMem, pinnedAccelMem) { 212 n := split - r.Start 213 214 left := val 215 left.pinnedRange.Source.End = left.pinnedRange.Source.Start + hostarch.Addr(n) 216 217 right := val 218 right.pinnedRange.Source.Start += hostarch.Addr(n) 219 right.pinnedRange.Offset += n 220 221 return left, right 222 }