gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/sentry/devices/accel/gasket.go (about) 1 // Copyright 2023 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package accel 16 17 import ( 18 "golang.org/x/sys/unix" 19 "gvisor.dev/gvisor/pkg/abi/gasket" 20 "gvisor.dev/gvisor/pkg/abi/linux" 21 "gvisor.dev/gvisor/pkg/abi/tpu" 22 "gvisor.dev/gvisor/pkg/cleanup" 23 "gvisor.dev/gvisor/pkg/context" 24 "gvisor.dev/gvisor/pkg/errors/linuxerr" 25 "gvisor.dev/gvisor/pkg/hostarch" 26 "gvisor.dev/gvisor/pkg/sentry/devices/tpuproxy" 27 "gvisor.dev/gvisor/pkg/sentry/fsimpl/eventfd" 28 "gvisor.dev/gvisor/pkg/sentry/kernel" 29 "gvisor.dev/gvisor/pkg/sentry/memmap" 30 "gvisor.dev/gvisor/pkg/sentry/mm" 31 ) 32 33 func gasketMapBufferIoctl(ctx context.Context, t *kernel.Task, hostFd int32, fd *tpuV4FD, paramsAddr hostarch.Addr) (uintptr, error) { 34 var userIoctlParams gasket.GasketPageTableIoctl 35 if _, err := userIoctlParams.CopyIn(t, paramsAddr); err != nil { 36 return 0, err 37 } 38 39 numberOfPageTables := tpu.NumberOfTPUV4PageTables 40 if fd.device.lite { 41 numberOfPageTables = tpu.NumberOfTPUV4litePageTables 42 } 43 if userIoctlParams.PageTableIndex >= numberOfPageTables { 44 return 0, linuxerr.EFAULT 45 } 46 47 tmm := t.MemoryManager() 48 ar, ok := tmm.CheckIORange(hostarch.Addr(userIoctlParams.HostAddress), int64(userIoctlParams.Size)) 49 if !ok { 50 return 0, linuxerr.EFAULT 51 } 52 53 if !ar.IsPageAligned() || (userIoctlParams.Size/hostarch.PageSize) == 0 { 54 return 0, linuxerr.EINVAL 55 } 56 57 devAddr := userIoctlParams.DeviceAddress 58 // The kernel driver does not enforce page alignment on the device 59 // address although it will be implicitly rounded down to a page 60 // boundary. We do it explicitly because it simplifies tracking 61 // of allocated ranges in 'devAddrSet'. 62 devAddr &^= (hostarch.PageSize - 1) 63 64 // Make sure that the device address range can be mapped. 65 devar := DevAddrRange{ 66 devAddr, 67 devAddr + userIoctlParams.Size, 68 } 69 if !devar.WellFormed() { 70 return 0, linuxerr.EINVAL 71 } 72 73 // Reserve a range in our address space. 74 m, _, errno := unix.RawSyscall6(unix.SYS_MMAP, 0 /* addr */, uintptr(ar.Length()), unix.PROT_NONE, unix.MAP_PRIVATE|unix.MAP_ANONYMOUS, ^uintptr(0) /* fd */, 0 /* offset */) 75 if errno != 0 { 76 return 0, errno 77 } 78 cu := cleanup.Make(func() { 79 unix.RawSyscall(unix.SYS_MUNMAP, m, uintptr(ar.Length()), 0) 80 }) 81 defer cu.Clean() 82 // Mirror application mappings into the reserved range. 83 prs, err := t.MemoryManager().Pin(ctx, ar, hostarch.ReadWrite, false /* ignorePermissions */) 84 cu.Add(func() { 85 mm.Unpin(prs) 86 }) 87 if err != nil { 88 return 0, err 89 } 90 sentryAddr := uintptr(m) 91 for _, pr := range prs { 92 ims, err := pr.File.MapInternal(memmap.FileRange{pr.Offset, pr.Offset + uint64(pr.Source.Length())}, hostarch.ReadWrite) 93 if err != nil { 94 return 0, err 95 } 96 for !ims.IsEmpty() { 97 im := ims.Head() 98 if _, _, errno := unix.RawSyscall6(unix.SYS_MREMAP, im.Addr(), 0 /* old_size */, uintptr(im.Len()), linux.MREMAP_MAYMOVE|linux.MREMAP_FIXED, sentryAddr, 0); errno != 0 { 99 return 0, errno 100 } 101 sentryAddr += uintptr(im.Len()) 102 ims = ims.Tail() 103 } 104 } 105 sentryIoctlParams := userIoctlParams 106 sentryIoctlParams.HostAddress = uint64(m) 107 n, err := tpuproxy.IOCTLInvokePtrArg[gasket.Ioctl](hostFd, gasket.GASKET_IOCTL_MAP_BUFFER, &sentryIoctlParams) 108 if err != nil { 109 return n, err 110 } 111 cu.Release() 112 // Unmap the reserved range, which is no longer required. 113 unix.RawSyscall(unix.SYS_MUNMAP, m, uintptr(ar.Length()), 0) 114 115 fd.device.mu.Lock() 116 defer fd.device.mu.Unlock() 117 for _, pr := range prs { 118 rlen := uint64(pr.Source.Length()) 119 fd.device.devAddrSet.InsertRange(DevAddrRange{ 120 devAddr, 121 devAddr + rlen, 122 }, pinnedAccelMem{pinnedRange: pr, pageTableIndex: userIoctlParams.PageTableIndex}) 123 devAddr += rlen 124 } 125 return n, nil 126 } 127 128 func gasketUnmapBufferIoctl(ctx context.Context, t *kernel.Task, hostFd int32, fd *tpuV4FD, paramsAddr hostarch.Addr) (uintptr, error) { 129 var userIoctlParams gasket.GasketPageTableIoctl 130 if _, err := userIoctlParams.CopyIn(t, paramsAddr); err != nil { 131 return 0, err 132 } 133 134 numberOfPageTables := tpu.NumberOfTPUV4PageTables 135 if fd.device.lite { 136 numberOfPageTables = tpu.NumberOfTPUV4litePageTables 137 } 138 if userIoctlParams.PageTableIndex >= numberOfPageTables { 139 return 0, linuxerr.EFAULT 140 } 141 142 devAddr := userIoctlParams.DeviceAddress 143 devAddr &^= (hostarch.PageSize - 1) 144 devar := DevAddrRange{ 145 devAddr, 146 devAddr + userIoctlParams.Size, 147 } 148 if !devar.WellFormed() { 149 return 0, linuxerr.EINVAL 150 } 151 152 sentryIoctlParams := userIoctlParams 153 sentryIoctlParams.HostAddress = 0 // clobber this value, it's unused. 154 n, err := tpuproxy.IOCTLInvokePtrArg[gasket.Ioctl](hostFd, gasket.GASKET_IOCTL_UNMAP_BUFFER, &sentryIoctlParams) 155 if err != nil { 156 return n, err 157 } 158 fd.device.mu.Lock() 159 defer fd.device.mu.Unlock() 160 s := &fd.device.devAddrSet 161 r := DevAddrRange{userIoctlParams.DeviceAddress, userIoctlParams.DeviceAddress + userIoctlParams.Size} 162 seg := s.LowerBoundSegment(r.Start) 163 for seg.Ok() && seg.Start() < r.End { 164 seg = s.Isolate(seg, r) 165 v := seg.Value() 166 mm.Unpin([]mm.PinnedRange{v.pinnedRange}) 167 gap := s.Remove(seg) 168 seg = gap.NextSegment() 169 } 170 return n, nil 171 } 172 173 func gasketInterruptMappingIoctl(ctx context.Context, t *kernel.Task, hostFd int32, paramsAddr hostarch.Addr, lite bool) (uintptr, error) { 174 var userIoctlParams gasket.GasketInterruptMapping 175 if _, err := userIoctlParams.CopyIn(t, paramsAddr); err != nil { 176 return 0, err 177 } 178 179 sizeOfInterruptList := tpu.SizeOfTPUV4InterruptList 180 interruptMap := tpu.TPUV4InterruptsMap 181 if lite { 182 sizeOfInterruptList = tpu.SizeOfTPUV4liteInterruptList 183 interruptMap = tpu.TPUV4liteInterruptsMap 184 } 185 if userIoctlParams.Interrupt >= sizeOfInterruptList { 186 return 0, linuxerr.EINVAL 187 } 188 barRegMap, ok := interruptMap[userIoctlParams.BarIndex] 189 if !ok { 190 return 0, linuxerr.EINVAL 191 } 192 if _, ok := barRegMap[userIoctlParams.RegOffset]; !ok { 193 return 0, linuxerr.EINVAL 194 } 195 196 // Check that 'userEventFD.Eventfd' is an eventfd. 197 eventFileGeneric, _ := t.FDTable().Get(int32(userIoctlParams.EventFD)) 198 if eventFileGeneric == nil { 199 return 0, linuxerr.EBADF 200 } 201 defer eventFileGeneric.DecRef(ctx) 202 eventFile, ok := eventFileGeneric.Impl().(*eventfd.EventFileDescription) 203 if !ok { 204 return 0, linuxerr.EINVAL 205 } 206 207 eventfd, err := eventFile.HostFD() 208 if err != nil { 209 return 0, err 210 } 211 212 sentryIoctlParams := userIoctlParams 213 sentryIoctlParams.EventFD = uint64(eventfd) 214 n, err := tpuproxy.IOCTLInvokePtrArg[gasket.Ioctl](hostFd, gasket.GASKET_IOCTL_REGISTER_INTERRUPT, &sentryIoctlParams) 215 if err != nil { 216 return n, err 217 } 218 219 outIoctlParams := sentryIoctlParams 220 outIoctlParams.EventFD = userIoctlParams.EventFD 221 if _, err := outIoctlParams.CopyOut(t, paramsAddr); err != nil { 222 return n, err 223 } 224 return n, nil 225 }