github.com/ttpreport/gvisor-ligolo@v0.0.0-20240123134145-a858404967ba/runsc/cmd/util/tpu.go (about)

     1  // Copyright 2023 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package util
    16  
    17  import (
    18  	"fmt"
    19  	"os"
    20  	"path/filepath"
    21  	"regexp"
    22  	"strconv"
    23  	"strings"
    24  )
    25  
    26  const googleVendorID = 0x1AE0
    27  
    28  var tpuV4DeviceIDs = map[uint64]any{0x005E: nil, 0x0056: nil}
    29  
    30  // TODO(b/288456802): Add support for /dev/vfio controlled accelerators.
    31  // This is required for v5+ TPUs.
    32  
    33  // EnumerateHostTPUDevices returns the accelerator device minor numbers of all
    34  // TPUs on the machine.
    35  func EnumerateHostTPUDevices() ([]uint32, error) {
    36  	paths, err := filepath.Glob("/dev/accel*")
    37  	if err != nil {
    38  		return nil, fmt.Errorf("enumerating TPU device files: %w", err)
    39  	}
    40  
    41  	accelDeviceRegex := regexp.MustCompile(`^/dev/accel(\d+)$`)
    42  	var devMinors []uint32
    43  	for _, path := range paths {
    44  		if ms := accelDeviceRegex.FindStringSubmatch(path); ms != nil {
    45  			index, err := strconv.ParseUint(ms[1], 10, 32)
    46  			if err != nil {
    47  				return nil, fmt.Errorf("invalid host device file %q: %w", path, err)
    48  			}
    49  
    50  			vendor, err := readHexInt(fmt.Sprintf("/sys/class/accel/accel%d/device/vendor", index))
    51  			if err != nil {
    52  				return nil, err
    53  			}
    54  			if vendor != googleVendorID {
    55  				continue
    56  			}
    57  			deviceID, err := readHexInt(fmt.Sprintf("/sys/class/accel/accel%d/device/device", index))
    58  			if err != nil {
    59  				return nil, err
    60  			}
    61  			if _, ok := tpuV4DeviceIDs[deviceID]; !ok {
    62  				continue
    63  			}
    64  
    65  			devMinors = append(devMinors, uint32(index))
    66  		}
    67  	}
    68  	return devMinors, nil
    69  }
    70  
    71  func readHexInt(path string) (uint64, error) {
    72  	data, err := os.ReadFile(path)
    73  	if err != nil {
    74  		return 0, err
    75  	}
    76  	numStr := strings.Trim(strings.TrimSpace(strings.TrimPrefix(string(data), "0x")), "\x00")
    77  	return strconv.ParseUint(numStr, 16, 64)
    78  }