github.com/google/trillian-examples@v0.0.0-20240520080811-0d40d35cef0e/binary_transparency/firmware/cmd/flash_tool/impl/flash_tool.go (about)

     1  // Copyright 2020 Google LLC. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package impl is the implementation of a util to flash firmware update packages created by the publisher tool onto devices.
    16  //
    17  // Currently, the only device is a dummy device, which simply sorts the firmware+metadata on local disk.
    18  package impl
    19  
    20  import (
    21  	"bytes"
    22  	"context"
    23  	"crypto/sha512"
    24  	"encoding/json"
    25  	"errors"
    26  	"fmt"
    27  	"net/url"
    28  	"os"
    29  
    30  	"github.com/golang/glog"
    31  	"github.com/google/trillian-examples/binary_transparency/firmware/api"
    32  	"github.com/google/trillian-examples/binary_transparency/firmware/cmd/flash_tool/devices"
    33  	"github.com/google/trillian-examples/binary_transparency/firmware/devices/dummy"
    34  	armory_flash "github.com/google/trillian-examples/binary_transparency/firmware/devices/usbarmory/flash"
    35  	"github.com/google/trillian-examples/binary_transparency/firmware/internal/client"
    36  	"github.com/google/trillian-examples/binary_transparency/firmware/internal/verify"
    37  	"github.com/google/trillian/merkle/coniks"
    38  	"github.com/google/trillian/merkle/smt/node"
    39  	"golang.org/x/mod/sumdb/note"
    40  )
    41  
    42  // FlashOpts encapsulates flash tool parameters.
    43  type FlashOpts struct {
    44  	DeviceID       string
    45  	LogURL         string
    46  	LogSigVerifier note.Verifier
    47  	MapURL         string
    48  	WitnessURL     string
    49  	UpdateFile     string
    50  	Force          bool
    51  	DeviceStorage  string
    52  }
    53  
    54  // Main flashes the device according to the options provided.
    55  func Main(ctx context.Context, opts FlashOpts) error {
    56  	logURL, err := url.Parse(opts.LogURL)
    57  	if err != nil {
    58  		return fmt.Errorf("log_url is invalid: %w", err)
    59  	}
    60  	c := &client.ReadonlyClient{LogURL: logURL}
    61  
    62  	up, err := readUpdateFile(opts.UpdateFile)
    63  	if err != nil {
    64  		return fmt.Errorf("failed to read update package file: %w", err)
    65  	}
    66  
    67  	dev, err := getDevice(opts)
    68  	if err != nil {
    69  		return fmt.Errorf("failed to get device: %w", err)
    70  	}
    71  
    72  	pb, fwMeta, err := verifyUpdate(c, opts.LogSigVerifier, up, dev)
    73  	if err != nil {
    74  		err := fmt.Errorf("failed to validate update: %w", err)
    75  		if !opts.Force {
    76  			return err
    77  		}
    78  		glog.Warning(err)
    79  	}
    80  
    81  	if len(opts.WitnessURL) > 0 {
    82  		err := verifyWitness(c, opts.LogSigVerifier, pb, opts.WitnessURL)
    83  		if err != nil {
    84  			if !opts.Force {
    85  				return err
    86  			}
    87  			glog.Warning(err)
    88  		}
    89  	}
    90  
    91  	if len(opts.MapURL) > 0 {
    92  		err := verifyAnnotations(ctx, c, opts.LogSigVerifier, pb, fwMeta, opts.MapURL)
    93  		if err != nil {
    94  			if !opts.Force {
    95  				return fmt.Errorf("verifyAnnotations: %w", err)
    96  			}
    97  			glog.Warning(err)
    98  		}
    99  	}
   100  	glog.Info("Update verified, about to apply to device...")
   101  
   102  	if err := dev.ApplyUpdate(up); err != nil {
   103  		return fmt.Errorf("failed to apply update to device: %w", err)
   104  	}
   105  
   106  	glog.Info("Update applied.")
   107  	return nil
   108  }
   109  
   110  func getDevice(opts FlashOpts) (devices.Device, error) {
   111  	var dev devices.Device
   112  	var err error
   113  	switch opts.DeviceID {
   114  	case "armory":
   115  		dev, err = armory_flash.New(opts.DeviceStorage)
   116  	case "dummy":
   117  		dev, err = dummy.New(opts.DeviceStorage)
   118  	default:
   119  		return dev, errors.New("device must be one of: 'dummy', 'armory'")
   120  	}
   121  	if err != nil {
   122  		switch t := err.(type) {
   123  		case devices.ErrNeedsInit:
   124  			err := fmt.Errorf("device needs to be force initialised: %w", err)
   125  			if !opts.Force {
   126  				return dev, err
   127  			}
   128  			glog.Warning(err)
   129  		default:
   130  			err := fmt.Errorf("failed to open device: %w", t)
   131  			if !opts.Force {
   132  				return dev, err
   133  			}
   134  			glog.Warning(err)
   135  		}
   136  	}
   137  	return dev, nil
   138  }
   139  
   140  func readUpdateFile(path string) (api.UpdatePackage, error) {
   141  	if len(path) == 0 {
   142  		return api.UpdatePackage{}, errors.New("must specify update_file")
   143  	}
   144  
   145  	f, err := os.OpenFile(path, os.O_RDONLY, os.ModePerm)
   146  	if err != nil {
   147  		glog.Exitf("Failed to open update package file %q: %q", path, err)
   148  	}
   149  	defer func() {
   150  		if err := f.Close(); err != nil {
   151  			glog.Errorf("f.Close(): %v", err)
   152  		}
   153  	}()
   154  
   155  	var up api.UpdatePackage
   156  	if err := json.NewDecoder(f).Decode(&up); err != nil {
   157  		glog.Exitf("Failed to parse update package file: %q", err)
   158  	}
   159  	return up, nil
   160  }
   161  
   162  // getConsistencyFunc executes on a given client context and returns a
   163  // consistency function.
   164  func getConsistencyFunc(c *client.ReadonlyClient) func(from, to uint64) ([][]byte, error) {
   165  	cpFunc := func(from, to uint64) ([][]byte, error) {
   166  		var cp [][]byte
   167  		if from > 0 {
   168  			r, err := c.GetConsistencyProof(api.GetConsistencyRequest{From: from, To: to})
   169  			if err != nil {
   170  				return nil, fmt.Errorf("failed to fetch consistency proof: %w", err)
   171  			}
   172  			cp = r.Proof
   173  		}
   174  		return cp, nil
   175  	}
   176  	return cpFunc
   177  }
   178  
   179  // verifyUpdate checks that an update package is self-consistent and returns a verified proof bundle
   180  func verifyUpdate(c *client.ReadonlyClient, logSigVerifier note.Verifier, up api.UpdatePackage, dev devices.Device) (api.ProofBundle, api.FirmwareMetadata, error) {
   181  	var pb api.ProofBundle
   182  	var fwMeta api.FirmwareMetadata
   183  
   184  	// Get the consistency proof for the bundle
   185  	n, err := dev.DeviceCheckpoint()
   186  	if err != nil {
   187  		return pb, fwMeta, fmt.Errorf("failed to fetch the device checkpoint: %w", err)
   188  	}
   189  	dc, err := api.ParseCheckpoint([]byte(n), logSigVerifier)
   190  	if err != nil {
   191  		return pb, fwMeta, fmt.Errorf("failed to open the device checkpoint: %w", err)
   192  	}
   193  
   194  	cpFunc := getConsistencyFunc(c)
   195  	fwHash := sha512.Sum512(up.FirmwareImage)
   196  	pb, fwMeta, err = verify.BundleForUpdate(up.ProofBundle, fwHash[:], *dc, cpFunc, logSigVerifier)
   197  	if err != nil {
   198  		return pb, fwMeta, fmt.Errorf("failed to verify proof bundle: %w", err)
   199  	}
   200  	return pb, fwMeta, nil
   201  }
   202  
   203  func verifyWitness(c *client.ReadonlyClient, logSigVerifier note.Verifier, pb api.ProofBundle, witnessURL string) error {
   204  	wURL, err := url.Parse(witnessURL)
   205  	if err != nil {
   206  		return fmt.Errorf("witness_url is invalid: %w", err)
   207  	}
   208  	wc := client.WitnessClient{
   209  		URL:            wURL,
   210  		LogSigVerifier: logSigVerifier,
   211  	}
   212  
   213  	wcp, err := wc.GetWitnessCheckpoint()
   214  	if err != nil {
   215  		return fmt.Errorf("failed to fetch the witness checkpoint: %w", err)
   216  	}
   217  	if wcp.Size == 0 {
   218  		return fmt.Errorf("no witness checkpoint to verify")
   219  	}
   220  	if err := verify.BundleConsistency(pb, *wcp, getConsistencyFunc(c), logSigVerifier); err != nil {
   221  		return fmt.Errorf("failed to verify checkpoint consistency against witness: %w", err)
   222  	}
   223  	return nil
   224  }
   225  
   226  func verifyAnnotations(ctx context.Context, c *client.ReadonlyClient, logSigVerifier note.Verifier, pb api.ProofBundle, fwMeta api.FirmwareMetadata, mapURL string) error {
   227  	mc, err := client.NewMapClient(mapURL)
   228  	if err != nil {
   229  		return fmt.Errorf("failed to create map client: %w", err)
   230  	}
   231  	mcp, err := mc.MapCheckpoint()
   232  	if err != nil {
   233  		return fmt.Errorf("failed to get map checkpoint: %w", err)
   234  	}
   235  	// The map checkpoint should be stored in a log, and this client should follow
   236  	// the log, checking consistency proofs and maintaining a golden checkpoint.
   237  	// Without this, the client is at risk of being given a custom map root that
   238  	// nobody else in the world sees.
   239  	glog.V(1).Infof("Received map checkpoint: %s", mcp.LogCheckpoint)
   240  	var lcp api.LogCheckpoint
   241  	if err := json.Unmarshal(mcp.LogCheckpoint, &lcp); err != nil {
   242  		return fmt.Errorf("failed to unmarshal log checkpoint: %w", err)
   243  	}
   244  	// TODO(mhutchinson): check consistency with the largest checkpoint found thus far
   245  	// in order to detect a class of fork; it could be that the checkpoint in the update
   246  	// is consistent with the map and the witness, but the map and the witness aren't
   247  	// consistent with each other.
   248  	if err := verify.BundleConsistency(pb, lcp, getConsistencyFunc(c), logSigVerifier); err != nil {
   249  		return fmt.Errorf("failed to verify update with map checkpoint: %w", err)
   250  	}
   251  
   252  	// Get the aggregation and proof, and then check everything about it.
   253  	// This is a little paranoid as the inclusion proof is generated client-side.
   254  	// The pretense here is that there is a trust boundary between the code in this class,
   255  	// and everything else. In a production system, it is likely that the proof generation
   256  	// would live elsewhere (e.g. in the OTA packaging process), and the shorter proof
   257  	// bundle would be provided to the device.
   258  	preimage, ip, err := mc.Aggregation(ctx, mcp.Revision, pb.InclusionProof.LeafIndex)
   259  	if err != nil {
   260  		return fmt.Errorf("failed to get map value for %q: %w", pb.InclusionProof.LeafIndex, err)
   261  	}
   262  	// 1. Check: the proof is for the correct key.
   263  	kbs := sha512.Sum512_256([]byte(fmt.Sprintf("summary:%d", pb.InclusionProof.LeafIndex)))
   264  	if !bytes.Equal(ip.Key, kbs[:]) {
   265  		return fmt.Errorf("received inclusion proof for key %x but wanted %x", ip.Key, kbs[:])
   266  	}
   267  	// 2. Check: the proof is for the correct value.
   268  	leafID := node.NewID(string(kbs[:]), 256)
   269  	hasher := coniks.Default
   270  	expectedCommitment := hasher.HashLeaf(api.MapTreeID, leafID, preimage)
   271  	if !bytes.Equal(ip.Value, expectedCommitment) {
   272  		// This could happen if the JSON roundtripping was not stable. If we see that happen,
   273  		// then we'll need to pass out the raw bytes received from the server and parse into
   274  		// the struct at a higher level.
   275  		// It could also happen because the value returned is not actually committed to by the map.
   276  		return fmt.Errorf("received inclusion proof for value %x but wanted %x", ip.Value, expectedCommitment)
   277  	}
   278  	// 3. Check: the inclusion proof evaluates to the map root that we've obtained.
   279  	// The calculation starts from the leaf, and uses the siblings from the inclusion proof
   280  	// to generate the root, which is then compared with the map checkpoint.
   281  	calc := expectedCommitment
   282  	for pd := hasher.BitLen(); pd > 0; pd-- {
   283  		sib := ip.Proof[pd-1]
   284  		stem := leafID.Prefix(uint(pd))
   285  		if sib == nil {
   286  			sib = hasher.HashEmpty(api.MapTreeID, stem.Sibling())
   287  		}
   288  		left, right := calc, sib
   289  		if !isLeftChild(stem) {
   290  			left, right = right, left
   291  		}
   292  		calc = hasher.HashChildren(left, right)
   293  	}
   294  	if !bytes.Equal(calc, mcp.RootHash) {
   295  		return fmt.Errorf("inclusion proof calculated root %x but wanted %x", calc, mcp.RootHash)
   296  	}
   297  
   298  	var agg api.AggregatedFirmware
   299  	if err := json.Unmarshal(preimage, &agg); err != nil {
   300  		return fmt.Errorf("failed to decode aggregation: %w", err)
   301  	}
   302  	// Now we're certain that the aggregation is contained in the map, we can use the value.
   303  	if !agg.Good {
   304  		return errors.New("firmware is marked as bad")
   305  	}
   306  	return nil
   307  }
   308  
   309  // isLeftChild returns whether the given node is a left child.
   310  func isLeftChild(id node.ID) bool {
   311  	last, bits := id.LastByte()
   312  	return last&(1<<(8-bits)) == 0
   313  }