github.com/google/trillian-examples@v0.0.0-20240520080811-0d40d35cef0e/binary_transparency/firmware/devices/usbarmory/bootloader/ft.go (about)

     1  //go:build armory
     2  // +build armory
     3  
     4  package main
     5  
     6  import (
     7  	"fmt"
     8  	"io"
     9  	"log"
    10  	"time"
    11  
    12  	"github.com/google/trillian-examples/binary_transparency/firmware/internal/crypto"
    13  	"github.com/google/trillian-examples/binary_transparency/firmware/internal/verify"
    14  	"github.com/usbarmory/tamago/soc/nxp/imx6ul"
    15  	"golang.org/x/mod/sumdb/note"
    16  )
    17  
    18  const (
    19  	bundlePath                      = "/bundle.json"
    20  	firmwareMeasurementDomainPrefix = "armory_mkii"
    21  )
    22  
    23  func init() {
    24  	imx6ul.DCP.Init()
    25  }
    26  
    27  // verifyIntegrity checks the validity of the device state
    28  // against the stored proof bundle.
    29  //
    30  // This method will fail if:
    31  //   - there is no proof bundle stored in the proof partition
    32  //   - the proof bundle is not self-consistent
    33  //   - the measurement hash of the installed firmware does not
    34  //     match the value expected by the firmware manifest
    35  //   - TODO(al): check signatures.
    36  func verifyIntegrity(proof, firmware *Partition) error {
    37  	rawBundle, err := proof.ReadAll(bundlePath)
    38  	if err != nil {
    39  		return fmt.Errorf("failed to read bundle: %w", err)
    40  	}
    41  
    42  	h, err := measureFirmware(firmware)
    43  	if err != nil {
    44  		return fmt.Errorf("failed to hash firmware partition: %w\n", err)
    45  	}
    46  	fmt.Printf("firmware partition hash: 0x%x\n", h)
    47  	logSigVerifier, err := note.NewVerifier(crypto.TestFTPersonalityPub)
    48  
    49  	if err := verify.BundleForBoot(rawBundle, h, logSigVerifier); err != nil {
    50  		return fmt.Errorf("failed to verify bundle: %w", err)
    51  	}
    52  	return nil
    53  }
    54  
    55  // measureFirmware returns the firmware measurement hash for the firmware
    56  // stored on the given partition.
    57  func measureFirmware(p *Partition) ([]byte, error) {
    58  	log.Printf("Reading partition at offset %d...\n", p.Offset)
    59  	numBytes, err := p.GetExt4FilesystemSize()
    60  	if err != nil {
    61  		return nil, fmt.Errorf("failed to get partition size: %w", err)
    62  	}
    63  	log.Printf("Partition size %d bytes\n", numBytes)
    64  
    65  	if _, err := p.Seek(0, io.SeekStart); err != nil {
    66  		return nil, fmt.Errorf("failed to seek: %w", err)
    67  	}
    68  
    69  	bs := uint64(1 << 21)
    70  	rc := make(chan []byte, 5)
    71  	hc := make(chan []byte)
    72  
    73  	start := time.Now()
    74  	go func() {
    75  		h, err := imx6ul.DCP.New256()
    76  		if err != nil {
    77  			panic(fmt.Sprintf("Failed to created hasher: %q", err))
    78  		}
    79  
    80  		if _, err := h.Write([]byte(firmwareMeasurementDomainPrefix)); err != nil {
    81  			panic(fmt.Sprintf("Failed to write measurement domain prefix: %q", err))
    82  		}
    83  
    84  		for b := range rc {
    85  			if _, err := h.Write(b); err != nil {
    86  				panic(fmt.Errorf("failed to hash: %w", err))
    87  			}
    88  		}
    89  		hash, err := h.Sum(nil)
    90  		if err != nil {
    91  			panic(fmt.Sprintf("Failed to get final sum: %q", err))
    92  		}
    93  		hc <- hash
    94  	}()
    95  
    96  	if _, err := p.Seek(0, io.SeekStart); err != nil {
    97  		panic(fmt.Sprintf("Failed to seek to start of partition: %q", err))
    98  	}
    99  	for numBytes > 0 {
   100  		n := numBytes
   101  		if n > bs {
   102  			n = bs
   103  		}
   104  		b := make([]byte, n)
   105  		bc, err := p.Read(b)
   106  		if err != nil {
   107  			return nil, fmt.Errorf("failed to read: %w", err)
   108  		}
   109  		rc <- b[:bc]
   110  
   111  		numBytes -= uint64(bc)
   112  	}
   113  	close(rc)
   114  
   115  	hash := <-hc
   116  
   117  	log.Printf("Finished reading, hashing in %s\n", time.Now().Sub(start))
   118  	return hash[:], nil
   119  }