github.com/mvdan/u-root-coreutils@v0.0.0-20230122170626-c2eef2898555/cmds/exp/nvme_unlock/nvme_unlock.go (about)

     1  // Copyright 2022 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // The nvme_unlock command is used to unlock a NVMe drive with a
     6  // HSS-derived password, and rescan the drive to enumerate the
     7  // unlocked partitions.
     8  package main
     9  
    10  import (
    11  	"flag"
    12  	"fmt"
    13  	"log"
    14  	"os"
    15  	"regexp"
    16  	"strings"
    17  	"syscall"
    18  	"unsafe"
    19  
    20  	"github.com/mvdan/u-root-coreutils/pkg/finddrive"
    21  	"github.com/mvdan/u-root-coreutils/pkg/hsskey"
    22  	"github.com/mvdan/u-root-coreutils/pkg/mount/block"
    23  )
    24  
    25  const (
    26  	opalLockUnlockIoctl = 1092120797
    27  )
    28  
    29  type opalKey struct {
    30  	lr     byte
    31  	keyLen byte
    32  	align  [6]byte
    33  	key    [256]byte
    34  }
    35  
    36  type opalSessionInfo struct {
    37  	sum     uint32
    38  	who     uint32
    39  	opalKey opalKey
    40  }
    41  
    42  type opalLockUnlock struct {
    43  	session opalSessionInfo
    44  	lState  uint32
    45  	align   [4]byte
    46  }
    47  
    48  var (
    49  	disk               = flag.String("disk", "", "The disk to be unlocked.  If left blank, will search for a boot disk.")
    50  	verbose            = flag.Bool("d", false, "print debug output")
    51  	verboseNoSanitize  = flag.Bool("dangerously-disable-sanitize", false, "Print sensitive information - this should only be used for testing!")
    52  	noRereadPartitions = flag.Bool("no-reread-partitions", false, "Only attempt to unlock the disk, don't re-read the partition table.")
    53  	lock               = flag.Bool("lock", false, "Lock instead of unlocking")
    54  	salt               = flag.String("salt", hsskey.DefaultPasswordSalt, "Salt for password generation")
    55  )
    56  
    57  func verboseLog(msg string) {
    58  	if *verbose {
    59  		log.Print(msg)
    60  	}
    61  }
    62  
    63  func getSysfsInfo(index string, field string) (string, error) {
    64  	path := fmt.Sprintf("/sys/class/nvme/nvme%s/%s", index, field)
    65  	data, err := os.ReadFile(path)
    66  	if err != nil {
    67  		return "", fmt.Errorf("error reading sysfs info at path %s: %v", path, err)
    68  	}
    69  	return strings.TrimSpace(string(data)), nil
    70  }
    71  
    72  func run(disk string, verbose bool, verboseNoSanitize bool, noRereadPartitions bool, lock bool) error {
    73  	if disk == "" {
    74  		disks, err := finddrive.FindSlotType(finddrive.M2MKeySlotType)
    75  		if err != nil {
    76  			return fmt.Errorf("error finding boot disk: %v", err)
    77  		}
    78  		if len(disks) == 0 {
    79  			return fmt.Errorf("no boot disk found")
    80  		}
    81  		disk = disks[0]
    82  		if len(disks) > 1 {
    83  			log.Printf("Multiple boot disk candidates found, using the first from the following list: %v", disks)
    84  		} else if verbose {
    85  			log.Printf("Found boot disk %s", disk)
    86  		}
    87  	}
    88  
    89  	commandName := "unlock"
    90  	if lock {
    91  		commandName = "lock"
    92  	}
    93  
    94  	diskIDRegexp := regexp.MustCompile(`/dev/nvme(\d+)n.*`)
    95  	diskIDMatches := diskIDRegexp.FindStringSubmatch(disk)
    96  	if diskIDMatches == nil {
    97  		return fmt.Errorf("unable to parse device path %s", disk)
    98  	}
    99  	diskID := diskIDMatches[1]
   100  
   101  	serial, err := getSysfsInfo(diskID, "serial")
   102  	if err != nil {
   103  		return err
   104  	}
   105  	model, err := getSysfsInfo(diskID, "model")
   106  	if err != nil {
   107  		return err
   108  	}
   109  
   110  	if verbose {
   111  		log.Printf("Serial %s", serial)
   112  		log.Printf("Model %s", model)
   113  	}
   114  
   115  	diskFd, err := os.Open(disk)
   116  	if err != nil {
   117  		return fmt.Errorf("error opening disk: %v", err)
   118  	}
   119  	defer diskFd.Close()
   120  
   121  	hssList, err := hsskey.GetAllHss(verbose, verboseNoSanitize)
   122  	if err != nil {
   123  		return fmt.Errorf("error getting HSS: %v", err)
   124  	}
   125  
   126  	if len(hssList) == 0 {
   127  		return fmt.Errorf("no HSS found - can't unlock disk")
   128  	}
   129  
   130  	if verbose {
   131  		log.Printf("Found %d Host Secret Seeds.", len(hssList))
   132  	}
   133  
   134  	succeeded := false
   135  	for i, hss := range hssList {
   136  		password, err := hsskey.GenPassword(hss, *salt, serial, model)
   137  		if err != nil {
   138  			log.Printf("Couldn't generate password with HSS %d: %v", i, err)
   139  			continue
   140  		}
   141  
   142  		var state uint32 = 0x02 // OPAL_RW
   143  		if lock {
   144  			state = 0x04 // OPAL_LK
   145  		}
   146  		arg := opalLockUnlock{
   147  			session: opalSessionInfo{
   148  				sum: 0,
   149  				who: 0,
   150  				opalKey: opalKey{
   151  					keyLen: 32,
   152  				},
   153  			},
   154  			lState: state,
   155  		}
   156  		copy(arg.session.opalKey.key[0:32], password)
   157  
   158  		r1, _, errNo := syscall.Syscall(syscall.SYS_IOCTL, diskFd.Fd(),
   159  			uintptr(opalLockUnlockIoctl), uintptr(unsafe.Pointer(&arg)))
   160  		if errNo != 0 {
   161  			log.Printf("%s failed with errno: %v", commandName, errNo)
   162  		} else if r1 != 0 {
   163  			log.Printf("%s returned nonzero value %v, password may be incorrect", commandName, r1)
   164  		} else {
   165  			succeeded = true
   166  			break
   167  		}
   168  	}
   169  
   170  	if succeeded {
   171  		log.Printf("Successfully %sed disk %s.", commandName, disk)
   172  	} else {
   173  		log.Printf("Failed to %s disk %s with any HSS.", commandName, disk)
   174  		return fmt.Errorf("all HSS failed")
   175  	}
   176  
   177  	if noRereadPartitions {
   178  		return nil
   179  	}
   180  
   181  	// Update partitions on the on the disk.
   182  	if verbose {
   183  		log.Print("Reloading disk partitions...")
   184  	}
   185  	diskdev, err := block.Device(disk)
   186  	if err != nil {
   187  		return fmt.Errorf("Could not find %s: %v", disk, err)
   188  	}
   189  
   190  	if err := diskdev.ReadPartitionTable(); err != nil && !lock {
   191  		return fmt.Errorf("Could not re-read partition table: %v", err)
   192  	}
   193  	return nil
   194  }
   195  
   196  func main() {
   197  	flag.Parse()
   198  	if err := run(*disk, *verbose, *verboseNoSanitize, *noRereadPartitions, *lock); err != nil {
   199  		log.Fatalf("nvme_unlock: %v", err)
   200  	}
   201  }