github.com/metacubex/gvisor@v0.0.0-20240320004321-933faba989ec/pkg/sentry/devices/accel/gasket.go (about) 1 // Copyright 2023 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package accel 16 17 import ( 18 "golang.org/x/sys/unix" 19 "github.com/metacubex/gvisor/pkg/abi/gasket" 20 "github.com/metacubex/gvisor/pkg/abi/linux" 21 "github.com/metacubex/gvisor/pkg/abi/tpu" 22 "github.com/metacubex/gvisor/pkg/cleanup" 23 "github.com/metacubex/gvisor/pkg/context" 24 "github.com/metacubex/gvisor/pkg/errors/linuxerr" 25 "github.com/metacubex/gvisor/pkg/hostarch" 26 "github.com/metacubex/gvisor/pkg/sentry/fsimpl/eventfd" 27 "github.com/metacubex/gvisor/pkg/sentry/kernel" 28 "github.com/metacubex/gvisor/pkg/sentry/memmap" 29 "github.com/metacubex/gvisor/pkg/sentry/mm" 30 ) 31 32 func gasketMapBufferIoctl(ctx context.Context, t *kernel.Task, hostFd int32, fd *tpuV4FD, paramsAddr hostarch.Addr) (uintptr, error) { 33 var userIoctlParams gasket.GasketPageTableIoctl 34 if _, err := userIoctlParams.CopyIn(t, paramsAddr); err != nil { 35 return 0, err 36 } 37 38 numberOfPageTables := tpu.NumberOfTPUV4PageTables 39 if fd.device.lite { 40 numberOfPageTables = tpu.NumberOfTPUV4litePageTables 41 } 42 if userIoctlParams.PageTableIndex >= numberOfPageTables { 43 return 0, linuxerr.EFAULT 44 } 45 46 tmm := t.MemoryManager() 47 ar, ok := tmm.CheckIORange(hostarch.Addr(userIoctlParams.HostAddress), int64(userIoctlParams.Size)) 48 if !ok { 49 return 0, linuxerr.EFAULT 50 } 51 52 if !ar.IsPageAligned() || (userIoctlParams.Size/hostarch.PageSize) == 0 { 53 return 0, linuxerr.EINVAL 54 } 55 56 devAddr := userIoctlParams.DeviceAddress 57 // The kernel driver does not enforce page alignment on the device 58 // address although it will be implicitly rounded down to a page 59 // boundary. We do it explicitly because it simplifies tracking 60 // of allocated ranges in 'devAddrSet'. 61 devAddr &^= (hostarch.PageSize - 1) 62 63 // Make sure that the device address range can be mapped. 64 devar := DevAddrRange{ 65 devAddr, 66 devAddr + userIoctlParams.Size, 67 } 68 if !devar.WellFormed() { 69 return 0, linuxerr.EINVAL 70 } 71 72 // Reserve a range in our address space. 73 m, _, errno := unix.RawSyscall6(unix.SYS_MMAP, 0 /* addr */, uintptr(ar.Length()), unix.PROT_NONE, unix.MAP_PRIVATE|unix.MAP_ANONYMOUS, ^uintptr(0) /* fd */, 0 /* offset */) 74 if errno != 0 { 75 return 0, errno 76 } 77 cu := cleanup.Make(func() { 78 unix.RawSyscall(unix.SYS_MUNMAP, m, uintptr(ar.Length()), 0) 79 }) 80 defer cu.Clean() 81 // Mirror application mappings into the reserved range. 82 prs, err := t.MemoryManager().Pin(ctx, ar, hostarch.ReadWrite, false /* ignorePermissions */) 83 cu.Add(func() { 84 mm.Unpin(prs) 85 }) 86 if err != nil { 87 return 0, err 88 } 89 sentryAddr := uintptr(m) 90 for _, pr := range prs { 91 ims, err := pr.File.MapInternal(memmap.FileRange{pr.Offset, pr.Offset + uint64(pr.Source.Length())}, hostarch.ReadWrite) 92 if err != nil { 93 return 0, err 94 } 95 for !ims.IsEmpty() { 96 im := ims.Head() 97 if _, _, errno := unix.RawSyscall6(unix.SYS_MREMAP, im.Addr(), 0 /* old_size */, uintptr(im.Len()), linux.MREMAP_MAYMOVE|linux.MREMAP_FIXED, sentryAddr, 0); errno != 0 { 98 return 0, errno 99 } 100 sentryAddr += uintptr(im.Len()) 101 ims = ims.Tail() 102 } 103 } 104 sentryIoctlParams := userIoctlParams 105 sentryIoctlParams.HostAddress = uint64(m) 106 n, err := ioctlInvokePtrArg(hostFd, gasket.GASKET_IOCTL_MAP_BUFFER, &sentryIoctlParams) 107 if err != nil { 108 return n, err 109 } 110 cu.Release() 111 // Unmap the reserved range, which is no longer required. 112 unix.RawSyscall(unix.SYS_MUNMAP, m, uintptr(ar.Length()), 0) 113 114 fd.device.mu.Lock() 115 defer fd.device.mu.Unlock() 116 for _, pr := range prs { 117 rlen := uint64(pr.Source.Length()) 118 fd.device.devAddrSet.InsertRange(DevAddrRange{ 119 devAddr, 120 devAddr + rlen, 121 }, pinnedAccelMem{pinnedRange: pr, pageTableIndex: userIoctlParams.PageTableIndex}) 122 devAddr += rlen 123 } 124 return n, nil 125 } 126 127 func gasketUnmapBufferIoctl(ctx context.Context, t *kernel.Task, hostFd int32, fd *tpuV4FD, paramsAddr hostarch.Addr) (uintptr, error) { 128 var userIoctlParams gasket.GasketPageTableIoctl 129 if _, err := userIoctlParams.CopyIn(t, paramsAddr); err != nil { 130 return 0, err 131 } 132 133 numberOfPageTables := tpu.NumberOfTPUV4PageTables 134 if fd.device.lite { 135 numberOfPageTables = tpu.NumberOfTPUV4litePageTables 136 } 137 if userIoctlParams.PageTableIndex >= numberOfPageTables { 138 return 0, linuxerr.EFAULT 139 } 140 141 devAddr := userIoctlParams.DeviceAddress 142 devAddr &^= (hostarch.PageSize - 1) 143 devar := DevAddrRange{ 144 devAddr, 145 devAddr + userIoctlParams.Size, 146 } 147 if !devar.WellFormed() { 148 return 0, linuxerr.EINVAL 149 } 150 151 sentryIoctlParams := userIoctlParams 152 sentryIoctlParams.HostAddress = 0 // clobber this value, it's unused. 153 n, err := ioctlInvokePtrArg(hostFd, gasket.GASKET_IOCTL_UNMAP_BUFFER, &sentryIoctlParams) 154 if err != nil { 155 return n, err 156 } 157 fd.device.mu.Lock() 158 defer fd.device.mu.Unlock() 159 s := &fd.device.devAddrSet 160 r := DevAddrRange{userIoctlParams.DeviceAddress, userIoctlParams.DeviceAddress + userIoctlParams.Size} 161 seg := s.LowerBoundSegment(r.Start) 162 for seg.Ok() && seg.Start() < r.End { 163 seg = s.Isolate(seg, r) 164 v := seg.Value() 165 mm.Unpin([]mm.PinnedRange{v.pinnedRange}) 166 gap := s.Remove(seg) 167 seg = gap.NextSegment() 168 } 169 return n, nil 170 } 171 172 func gasketInterruptMappingIoctl(ctx context.Context, t *kernel.Task, hostFd int32, paramsAddr hostarch.Addr, lite bool) (uintptr, error) { 173 var userIoctlParams gasket.GasketInterruptMapping 174 if _, err := userIoctlParams.CopyIn(t, paramsAddr); err != nil { 175 return 0, err 176 } 177 178 sizeOfInterruptList := tpu.SizeOfTPUV4InterruptList 179 interruptMap := tpu.TPUV4InterruptsMap 180 if lite { 181 sizeOfInterruptList = tpu.SizeOfTPUV4liteInterruptList 182 interruptMap = tpu.TPUV4liteInterruptsMap 183 } 184 if userIoctlParams.Interrupt >= sizeOfInterruptList { 185 return 0, linuxerr.EINVAL 186 } 187 barRegMap, ok := interruptMap[userIoctlParams.BarIndex] 188 if !ok { 189 return 0, linuxerr.EINVAL 190 } 191 if _, ok := barRegMap[userIoctlParams.RegOffset]; !ok { 192 return 0, linuxerr.EINVAL 193 } 194 195 // Check that 'userEventFD.Eventfd' is an eventfd. 196 eventFileGeneric, _ := t.FDTable().Get(int32(userIoctlParams.EventFD)) 197 if eventFileGeneric == nil { 198 return 0, linuxerr.EBADF 199 } 200 defer eventFileGeneric.DecRef(ctx) 201 eventFile, ok := eventFileGeneric.Impl().(*eventfd.EventFileDescription) 202 if !ok { 203 return 0, linuxerr.EINVAL 204 } 205 206 eventfd, err := eventFile.HostFD() 207 if err != nil { 208 return 0, err 209 } 210 211 sentryIoctlParams := userIoctlParams 212 sentryIoctlParams.EventFD = uint64(eventfd) 213 n, err := ioctlInvokePtrArg(hostFd, gasket.GASKET_IOCTL_REGISTER_INTERRUPT, &sentryIoctlParams) 214 if err != nil { 215 return n, err 216 } 217 218 outIoctlParams := sentryIoctlParams 219 outIoctlParams.EventFD = userIoctlParams.EventFD 220 if _, err := outIoctlParams.CopyOut(t, paramsAddr); err != nil { 221 return n, err 222 } 223 return n, nil 224 }