gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/runsc/cmd/util/tpu.go (about)

     1  // Copyright 2023 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package util
    16  
    17  import (
    18  	"fmt"
    19  	"os"
    20  	"regexp"
    21  	"strconv"
    22  	"strings"
    23  
    24  	"gvisor.dev/gvisor/pkg/abi/tpu"
    25  )
    26  
    27  const (
    28  	googleVendorID       = 0x1AE0
    29  	accelDevicePathRegex = `^/dev/accel(\d+)$`
    30  	accelSysfsFormat     = "/sys/class/accel/accel%d/device/%s"
    31  	vfioDevicePathRegex  = `^/dev/vfio/(\d+)$`
    32  	vfioSysfsFormat      = "/sys/class/vfio-dev/vfio%d/device/%s"
    33  	vendorFile           = "vendor"
    34  	deviceFile           = "device"
    35  )
    36  
    37  var tpuV4DeviceIDs = map[uint64]any{tpu.TPUV4DeviceID: nil, tpu.TPUV4liteDeviceID: nil}
    38  var tpuV5DeviceIDs = map[uint64]any{tpu.TPUV5eDeviceID: nil}
    39  
    40  // ExtractTpuDeviceMinor returns the accelerator device minor number for that
    41  // the passed device path. If the passed device is not a valid TPU device, then
    42  // it returns false.
    43  func ExtractTpuDeviceMinor(path string) (uint32, bool, error) {
    44  	devNum, valid, err := tpuV4DeviceMinor(path)
    45  	if err != nil {
    46  		return 0, false, err
    47  	}
    48  	if valid {
    49  		return devNum, valid, err
    50  	}
    51  	return tpuV5DeviceMinor(path)
    52  }
    53  
    54  // tpuDeviceMinor returns the accelerator device minor number for that
    55  // the passed device path. If the passed device is not a valid TPU device, then
    56  // it returns false.
    57  func tpuDeviceMinor(devicePath, devicePathRegex, sysfsFormat string, allowedDeviceIDs map[uint64]any) (uint32, bool, error) {
    58  	deviceRegex := regexp.MustCompile(devicePathRegex)
    59  	matches := deviceRegex.FindStringSubmatch(devicePath)
    60  	if matches == nil {
    61  		return 0, false, nil
    62  	}
    63  	minor, err := strconv.ParseUint(matches[1], 10, 32)
    64  	if err != nil {
    65  		return 0, false, fmt.Errorf("invalid host device file %q: %w", devicePath, err)
    66  	}
    67  	vendor, err := readHexInt(fmt.Sprintf(sysfsFormat, minor, vendorFile))
    68  	if err != nil {
    69  		return 0, false, err
    70  	}
    71  	if vendor != googleVendorID {
    72  		return 0, false, nil
    73  	}
    74  	deviceID, err := readHexInt(fmt.Sprintf(sysfsFormat, minor, deviceFile))
    75  	if err != nil {
    76  		return 0, false, err
    77  	}
    78  	if _, ok := allowedDeviceIDs[deviceID]; !ok {
    79  		return 0, false, nil
    80  	}
    81  	return uint32(minor), true, nil
    82  }
    83  
    84  // tpuv4DeviceMinor returns v4 and v4lite TPU device minor number for the given path.
    85  // A valid v4 TPU device is defined as:
    86  // * Path is /dev/accel#.
    87  // * Vendor is googleVendorID.
    88  // * Device ID is one of tpuV4DeviceIDs.
    89  func tpuV4DeviceMinor(path string) (uint32, bool, error) {
    90  	return tpuDeviceMinor(path, accelDevicePathRegex, accelSysfsFormat, tpuV4DeviceIDs)
    91  }
    92  
    93  // tpuV5DeviceMinor returns the v5e TPU device minor number for te given path.
    94  // A valid v5 TPU device is defined as:
    95  // * Path is /dev/vfio/#.
    96  // * Vendor is googleVendorID.
    97  // * Device ID is one of tpuV5DeviceIDs.
    98  func tpuV5DeviceMinor(path string) (uint32, bool, error) {
    99  	return tpuDeviceMinor(path, vfioDevicePathRegex, vfioSysfsFormat, tpuV5DeviceIDs)
   100  }
   101  
   102  func readHexInt(path string) (uint64, error) {
   103  	data, err := os.ReadFile(path)
   104  	if err != nil {
   105  		return 0, err
   106  	}
   107  	numStr := strings.Trim(strings.TrimSpace(strings.TrimPrefix(string(data), "0x")), "\x00")
   108  	return strconv.ParseUint(numStr, 16, 64)
   109  }