github.com/mem/u-root@v2.0.1-0.20181004165302-9b18b4636a33+incompatible/pkg/boot/measurement_test.go (about)

     1  // Copyright 2018 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package boot
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/rand"
    10  	"crypto/rsa"
    11  	"crypto/sha1"
    12  	"fmt"
    13  	"testing"
    14  
    15  	"github.com/u-root/u-root/pkg/cpio"
    16  	"github.com/u-root/u-root/pkg/uio"
    17  )
    18  
    19  func TestSigningWriterWriteFile(t *testing.T) {
    20  	m := cpio.InMemArchive()
    21  	s := NewSigningWriter(m)
    22  	digest := &bytes.Buffer{}
    23  
    24  	for _, tt := range []struct {
    25  		r       cpio.Record
    26  		err     error
    27  		measure bool
    28  	}{
    29  		{
    30  			r:   cpio.Directory("foobar", 0777),
    31  			err: nil,
    32  		},
    33  		{
    34  			r:   cpio.Directory("signature", 0777),
    35  			err: fmt.Errorf("cannot write signature or signature_algo files"),
    36  		},
    37  		{
    38  			r:   cpio.Directory("signature_algo", 0777),
    39  			err: fmt.Errorf("cannot write signature or signature_algo files"),
    40  		},
    41  		{
    42  			r:   cpio.StaticFile("signature", "foobar", 0700),
    43  			err: fmt.Errorf("cannot write signature or signature_algo files"),
    44  		},
    45  		{
    46  			r:   cpio.StaticFile("signature_algo", "foobar", 0700),
    47  			err: fmt.Errorf("cannot write signature or signature_algo files"),
    48  		},
    49  		{
    50  			r:       cpio.StaticFile("modules/foo/kernel", "barfoo", 0700),
    51  			err:     nil,
    52  			measure: true,
    53  		},
    54  	} {
    55  		if err := s.WriteRecord(tt.r); err != tt.err && err.Error() != tt.err.Error() {
    56  			t.Errorf("WriteFile(%v) = %v, want %v", tt.r.Name, err, tt.err)
    57  		} else if err == nil {
    58  			if !m.Contains(tt.r) {
    59  				t.Errorf("Archive should contain %q but doesn't", tt.r.Name)
    60  			}
    61  			if tt.measure {
    62  				digest.WriteString(tt.r.Name)
    63  				digest.ReadFrom(uio.Reader(tt.r))
    64  			}
    65  		} else if err != nil && m.Contains(tt.r) {
    66  			t.Errorf("Archive contains file %q but shouldn't", tt.r.Name)
    67  		}
    68  	}
    69  
    70  	if len(digest.Bytes()) == 0 {
    71  		t.Errorf("digest should contain something")
    72  	}
    73  	if sha1.Sum(digest.Bytes()) != s.SHA1Sum() {
    74  		t.Errorf("sha1 differs")
    75  	}
    76  }
    77  
    78  func TestWriterAndReader(t *testing.T) {
    79  	m := cpio.InMemArchive()
    80  	s := NewSigningWriter(m)
    81  	digest := &bytes.Buffer{}
    82  
    83  	records := []cpio.Record{
    84  		cpio.Directory("modules", 0700),
    85  		cpio.Directory("modules/foo", 0700),
    86  		cpio.Directory("metadata", 0700),
    87  		cpio.StaticFile("modules/foo/kernel", "foobar", 0700),
    88  		cpio.StaticFile("metadata/hahaha", "arrgh", 0700),
    89  	}
    90  	if err := cpio.WriteRecords(s, records); err != nil {
    91  		t.Errorf("WriteRecords() = %v, want nil", err)
    92  	}
    93  	digest.WriteString("modules/foo/kernel")
    94  	digest.WriteString("foobar")
    95  	digest.WriteString("metadata/hahaha")
    96  	digest.WriteString("arrgh")
    97  
    98  	if len(digest.Bytes()) == 0 {
    99  		t.Errorf("digest should contain something")
   100  	}
   101  	if sha1.Sum(digest.Bytes()) != s.SHA1Sum() {
   102  		t.Errorf("sha1 differs")
   103  	}
   104  
   105  	privateKey, err := rsa.GenerateKey(rand.Reader, 4096)
   106  	if err != nil {
   107  		t.Errorf("rsa GenerateKey() = %v", err)
   108  	}
   109  
   110  	if err := s.WriteSignature(privateKey); err != nil {
   111  		t.Errorf("WriteSignature() = %v, want nil", err)
   112  	}
   113  	if err := cpio.WriteTrailer(s); err != nil {
   114  		t.Errorf("WriteTrailer() = %v, want nil", err)
   115  	}
   116  
   117  	want := []cpio.Record{
   118  		cpio.Directory("modules", 0700),
   119  		cpio.Directory("modules/foo", 0700),
   120  		cpio.Directory("metadata", 0700),
   121  		cpio.StaticFile("modules/foo/kernel", "foobar", 0700),
   122  		cpio.StaticFile("metadata/hahaha", "arrgh", 0700),
   123  	}
   124  
   125  	r := NewMeasuringReader(m.Reader())
   126  	got, err := cpio.ReadAllRecords(r)
   127  	if err != nil {
   128  		t.Errorf("ReadAllRecords() = %v, want nil", err)
   129  	}
   130  	if !cpio.AllEqual(got, want) {
   131  		t.Errorf("ReadAllRecords() = \n%v, want \n%v", got, want)
   132  	}
   133  
   134  	if err := r.Verify(&privateKey.PublicKey); err != nil {
   135  		t.Errorf("Verify() = %v, want nil", err)
   136  	}
   137  }