github.com/ttpreport/gvisor-ligolo@v0.0.0-20240123134145-a858404967ba/runsc/cmd/util/tpu.go (about) 1 // Copyright 2023 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package util 16 17 import ( 18 "fmt" 19 "os" 20 "path/filepath" 21 "regexp" 22 "strconv" 23 "strings" 24 ) 25 26 const googleVendorID = 0x1AE0 27 28 var tpuV4DeviceIDs = map[uint64]any{0x005E: nil, 0x0056: nil} 29 30 // TODO(b/288456802): Add support for /dev/vfio controlled accelerators. 31 // This is required for v5+ TPUs. 32 33 // EnumerateHostTPUDevices returns the accelerator device minor numbers of all 34 // TPUs on the machine. 35 func EnumerateHostTPUDevices() ([]uint32, error) { 36 paths, err := filepath.Glob("/dev/accel*") 37 if err != nil { 38 return nil, fmt.Errorf("enumerating TPU device files: %w", err) 39 } 40 41 accelDeviceRegex := regexp.MustCompile(`^/dev/accel(\d+)$`) 42 var devMinors []uint32 43 for _, path := range paths { 44 if ms := accelDeviceRegex.FindStringSubmatch(path); ms != nil { 45 index, err := strconv.ParseUint(ms[1], 10, 32) 46 if err != nil { 47 return nil, fmt.Errorf("invalid host device file %q: %w", path, err) 48 } 49 50 vendor, err := readHexInt(fmt.Sprintf("/sys/class/accel/accel%d/device/vendor", index)) 51 if err != nil { 52 return nil, err 53 } 54 if vendor != googleVendorID { 55 continue 56 } 57 deviceID, err := readHexInt(fmt.Sprintf("/sys/class/accel/accel%d/device/device", index)) 58 if err != nil { 59 return nil, err 60 } 61 if _, ok := tpuV4DeviceIDs[deviceID]; !ok { 62 continue 63 } 64 65 devMinors = append(devMinors, uint32(index)) 66 } 67 } 68 return devMinors, nil 69 } 70 71 func readHexInt(path string) (uint64, error) { 72 data, err := os.ReadFile(path) 73 if err != nil { 74 return 0, err 75 } 76 numStr := strings.Trim(strings.TrimSpace(strings.TrimPrefix(string(data), "0x")), "\x00") 77 return strconv.ParseUint(numStr, 16, 64) 78 }