github.com/mvdan/u-root-coreutils@v0.0.0-20230122170626-c2eef2898555/pkg/vfile/vfile_test.go (about)

     1  // Copyright 2020 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package vfile
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/sha256"
    10  	"fmt"
    11  	"io"
    12  	"math/rand"
    13  	"os"
    14  	"path/filepath"
    15  	"reflect"
    16  	"strings"
    17  	"syscall"
    18  	"testing"
    19  	"time"
    20  
    21  	"github.com/ProtonMail/go-crypto/openpgp"
    22  	"github.com/ProtonMail/go-crypto/openpgp/errors"
    23  	"github.com/ProtonMail/go-crypto/openpgp/packet"
    24  )
    25  
    26  type signedFile struct {
    27  	signers []*openpgp.Entity
    28  	content string
    29  }
    30  
    31  func (s signedFile) write(path string) error {
    32  	f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600)
    33  	if err != nil {
    34  		return err
    35  	}
    36  	defer f.Close()
    37  
    38  	if _, err := f.Write([]byte(s.content)); err != nil {
    39  		return err
    40  	}
    41  
    42  	sigf, err := os.OpenFile(fmt.Sprintf("%s.sig", path), os.O_RDWR|os.O_CREATE, 0o600)
    43  	if err != nil {
    44  		return err
    45  	}
    46  	defer sigf.Close()
    47  	for _, signer := range s.signers {
    48  		if err := openpgp.DetachSign(sigf, signer, strings.NewReader(s.content), nil); err != nil {
    49  			return err
    50  		}
    51  	}
    52  	return nil
    53  }
    54  
    55  type normalFile struct {
    56  	content string
    57  }
    58  
    59  func (n normalFile) write(path string) error {
    60  	return os.WriteFile(path, []byte(n.content), 0o600)
    61  }
    62  
    63  func writeHashedFile(path, content string) ([]byte, error) {
    64  	c := []byte(content)
    65  	if err := os.WriteFile(path, c, 0o600); err != nil {
    66  		return nil, err
    67  	}
    68  	hash := sha256.Sum256(c)
    69  	return hash[:], nil
    70  }
    71  
    72  func TestOpenSignedFile(t *testing.T) {
    73  	keyFiles := []string{"key0", "key1"}
    74  
    75  	// EntityGenerate generates the entities in testdata/. The entities are
    76  	// cached because they take 40+ seconds to generate in arm64 QEMU.
    77  	t.Run("EntityGenerate", func(t *testing.T) {
    78  		t.Skip("uncomment this to generate the entities")
    79  
    80  		if err := os.MkdirAll("testdata", 0o777); err != nil {
    81  			t.Fatal(err)
    82  		}
    83  
    84  		for i, k := range keyFiles {
    85  			// You would think this Config would be sufficient to
    86  			// generate the same each time for the test, but it is
    87  			// not the case (and I don't know why).
    88  			var s int64
    89  			conf := &packet.Config{
    90  				Rand: rand.New(rand.NewSource(int64(i))),
    91  				Time: func() time.Time {
    92  					s++
    93  					return time.Unix(s, 0)
    94  				},
    95  			}
    96  			key, err := openpgp.NewEntity("goog", "goog", "goog@goog", conf)
    97  			if err != nil {
    98  				t.Fatal(err)
    99  			}
   100  
   101  			f, err := os.Create(filepath.Join("testdata", k))
   102  			if err != nil {
   103  				t.Fatal(err)
   104  			}
   105  			if err := key.SerializePrivate(f, conf); err != nil {
   106  				f.Close()
   107  				t.Fatal(err)
   108  			}
   109  			if err := f.Close(); err != nil {
   110  				t.Fatal(err)
   111  			}
   112  		}
   113  	})
   114  
   115  	// This depends on the keys generated by EntityGenerate.
   116  	var keys []*openpgp.Entity
   117  	for _, k := range keyFiles {
   118  		b, err := os.ReadFile(filepath.Join("testdata", k))
   119  		if err != nil {
   120  			t.Fatal(err)
   121  		}
   122  		key, err := openpgp.ReadEntity(packet.NewReader(bytes.NewBuffer(b)))
   123  		if err != nil {
   124  			t.Fatal(err)
   125  		}
   126  		keys = append(keys, key)
   127  	}
   128  
   129  	ring := openpgp.EntityList{keys[0]}
   130  
   131  	dir := t.TempDir()
   132  
   133  	signed := signedFile{
   134  		signers: openpgp.EntityList{keys[0]},
   135  		content: "foo",
   136  	}
   137  	signedPath := filepath.Join(dir, "signed_by_key0")
   138  	if err := signed.write(signedPath); err != nil {
   139  		t.Fatal(err)
   140  	}
   141  
   142  	signed2 := signedFile{
   143  		signers: openpgp.EntityList{keys[1]},
   144  		content: "foo",
   145  	}
   146  	signed2Path := filepath.Join(dir, "signed_by_key1")
   147  	if err := signed2.write(signed2Path); err != nil {
   148  		t.Fatal(err)
   149  	}
   150  
   151  	signed12 := signedFile{
   152  		signers: openpgp.EntityList{keys[0], keys[1]},
   153  		content: "foo",
   154  	}
   155  	signed12Path := filepath.Join(dir, "signed_by_both.sig")
   156  	if err := signed12.write(signed12Path); err != nil {
   157  		t.Fatal(err)
   158  	}
   159  
   160  	normalPath := filepath.Join(dir, "unsigned")
   161  	if err := os.WriteFile(normalPath, []byte("foo"), 0o777); err != nil {
   162  		t.Fatal(err)
   163  	}
   164  
   165  	for _, tt := range []struct {
   166  		desc             string
   167  		path             string
   168  		keyring          openpgp.KeyRing
   169  		want             error
   170  		isSignatureValid bool
   171  	}{
   172  		{
   173  			desc:             "signed file",
   174  			keyring:          ring,
   175  			path:             signedPath,
   176  			want:             nil,
   177  			isSignatureValid: true,
   178  		},
   179  		{
   180  			desc:             "signed file w/ two signatures (key0 ring)",
   181  			keyring:          ring,
   182  			path:             signed12Path,
   183  			want:             nil,
   184  			isSignatureValid: true,
   185  		},
   186  		{
   187  			desc:             "signed file w/ two signatures (key1 ring)",
   188  			keyring:          openpgp.EntityList{keys[1]},
   189  			path:             signed12Path,
   190  			want:             nil,
   191  			isSignatureValid: true,
   192  		},
   193  		{
   194  			desc:    "nil keyring",
   195  			keyring: nil,
   196  			path:    signed2Path,
   197  			want: ErrUnsigned{
   198  				Path: signed2Path,
   199  				Err:  ErrNoKeyRing,
   200  			},
   201  			isSignatureValid: false,
   202  		},
   203  		{
   204  			desc:    "non-nil empty keyring",
   205  			keyring: openpgp.EntityList{},
   206  			path:    signed2Path,
   207  			want: ErrUnsigned{
   208  				Path: signed2Path,
   209  				Err:  errors.ErrUnknownIssuer,
   210  			},
   211  			isSignatureValid: false,
   212  		},
   213  		{
   214  			desc:    "signed file does not match keyring",
   215  			keyring: openpgp.EntityList{keys[1]},
   216  			path:    signedPath,
   217  			want: ErrUnsigned{
   218  				Path: signedPath,
   219  				Err:  errors.ErrUnknownIssuer,
   220  			},
   221  			isSignatureValid: false,
   222  		},
   223  		{
   224  			desc:    "unsigned file",
   225  			keyring: ring,
   226  			path:    normalPath,
   227  			want: ErrUnsigned{
   228  				Path: normalPath,
   229  				Err: &os.PathError{
   230  					Op:   "open",
   231  					Path: fmt.Sprintf("%s.sig", normalPath),
   232  					Err:  syscall.ENOENT,
   233  				},
   234  			},
   235  			isSignatureValid: false,
   236  		},
   237  		{
   238  			desc:    "file does not exist",
   239  			keyring: ring,
   240  			path:    filepath.Join(dir, "foo"),
   241  			want: &os.PathError{
   242  				Op:   "open",
   243  				Path: filepath.Join(dir, "foo"),
   244  				Err:  syscall.ENOENT,
   245  			},
   246  			isSignatureValid: false,
   247  		},
   248  	} {
   249  		t.Run(tt.desc, func(t *testing.T) {
   250  			f, gotErr := OpenSignedSigFile(tt.keyring, tt.path)
   251  			if !reflect.DeepEqual(gotErr, tt.want) {
   252  				t.Errorf("openSignedFile(%v, %q) = %v, want %v", tt.keyring, tt.path, gotErr, tt.want)
   253  			}
   254  
   255  			if isSignatureValid := (gotErr == nil); isSignatureValid != tt.isSignatureValid {
   256  				t.Errorf("isSignatureValid(%v) = %v, want %v", gotErr, isSignatureValid, tt.isSignatureValid)
   257  			}
   258  
   259  			// Make sure that the file is readable from position 0.
   260  			if f != nil {
   261  				content, err := io.ReadAll(f)
   262  				if err != nil {
   263  					t.Errorf("Could not read content: %v", err)
   264  				}
   265  				if got := string(content); got != "foo" {
   266  					t.Errorf("ReadAll = %v, want \"foo\"", got)
   267  				}
   268  			}
   269  		})
   270  	}
   271  }
   272  
   273  func TestReadSignedImage(t *testing.T) {
   274  	for _, tt := range []struct {
   275  		desc       string
   276  		path       string
   277  		wantKeyCnt int
   278  		wantError  bool
   279  	}{
   280  		{
   281  			desc:       "Correct read key0",
   282  			path:       "testdata/key0",
   283  			wantError:  false,
   284  			wantKeyCnt: 2,
   285  		},
   286  		{
   287  			desc:       "Correct read key1",
   288  			path:       "testdata/key1",
   289  			wantError:  false,
   290  			wantKeyCnt: 2,
   291  		},
   292  		{
   293  			desc:       "Read nonRSA key",
   294  			path:       "testdata/dsakey",
   295  			wantError:  true,
   296  			wantKeyCnt: 0,
   297  		},
   298  		{
   299  			desc:       "Multikey ring",
   300  			path:       "testdata/keyring0+1+dsa",
   301  			wantError:  false,
   302  			wantKeyCnt: 4,
   303  		},
   304  	} {
   305  		t.Run(tt.desc, func(t *testing.T) {
   306  			ring, err := GetKeyRing(tt.path)
   307  			if err != nil {
   308  				t.Fatalf("GetKeyRing(%s) Failed with err: %v", tt.path, err)
   309  			}
   310  			gotKeys, gotErr := GetRSAKeysFromRing(ring)
   311  			if (gotErr == nil) == tt.wantError {
   312  				t.Errorf("GetRSAKeysFromRing(%s) = %v, want %v", tt.path, gotErr, tt.wantError)
   313  			}
   314  			var gotCnt int
   315  			if gotKeys == nil {
   316  				gotCnt = 0
   317  			} else {
   318  				gotCnt = len(gotKeys)
   319  			}
   320  
   321  			if tt.wantKeyCnt != gotCnt {
   322  				t.Errorf("GetRSAKeysFromRing(%s) returned %d keys, want %d", tt.path, gotCnt, tt.wantKeyCnt)
   323  			}
   324  		})
   325  	}
   326  }
   327  
   328  func TestOpenHashedFile(t *testing.T) {
   329  	dir := t.TempDir()
   330  
   331  	hashedPath := filepath.Join(dir, "hashed")
   332  	hash, err := writeHashedFile(hashedPath, "foo")
   333  	if err != nil {
   334  		t.Fatal(err)
   335  	}
   336  
   337  	emptyPath := filepath.Join(dir, "empty")
   338  	emptyHash, err := writeHashedFile(emptyPath, "")
   339  	if err != nil {
   340  		t.Fatal(err)
   341  	}
   342  
   343  	for _, tt := range []struct {
   344  		desc        string
   345  		path        string
   346  		hash        []byte
   347  		want        error
   348  		isHashValid bool
   349  		wantContent string
   350  	}{
   351  		{
   352  			desc:        "correct hash",
   353  			path:        hashedPath,
   354  			hash:        hash,
   355  			want:        nil,
   356  			isHashValid: true,
   357  			wantContent: "foo",
   358  		},
   359  		{
   360  			desc: "wrong hash",
   361  			path: hashedPath,
   362  			hash: []byte{0x99, 0x77},
   363  			want: ErrInvalidHash{
   364  				Path: hashedPath,
   365  				Err: ErrHashMismatch{
   366  					Got:  hash,
   367  					Want: []byte{0x99, 0x77},
   368  				},
   369  			},
   370  			isHashValid: false,
   371  			wantContent: "foo",
   372  		},
   373  		{
   374  			desc: "no hash",
   375  			path: hashedPath,
   376  			hash: []byte{},
   377  			want: ErrInvalidHash{
   378  				Path: hashedPath,
   379  				Err:  ErrNoExpectedHash,
   380  			},
   381  			isHashValid: false,
   382  			wantContent: "foo",
   383  		},
   384  		{
   385  			desc:        "empty file",
   386  			path:        emptyPath,
   387  			hash:        emptyHash,
   388  			want:        nil,
   389  			isHashValid: true,
   390  			wantContent: "",
   391  		},
   392  		{
   393  			desc: "nonexistent file",
   394  			path: filepath.Join(dir, "doesnotexist"),
   395  			hash: nil,
   396  			want: &os.PathError{
   397  				Op:   "open",
   398  				Path: filepath.Join(dir, "doesnotexist"),
   399  				Err:  syscall.ENOENT,
   400  			},
   401  			isHashValid: false,
   402  		},
   403  	} {
   404  		t.Run(tt.desc, func(t *testing.T) {
   405  			f, err := OpenHashedFile256(tt.path, tt.hash)
   406  			if !reflect.DeepEqual(err, tt.want) {
   407  				t.Errorf("OpenHashedFile256(%s, %x) = %v, want %v", tt.path, tt.hash, err, tt.want)
   408  			}
   409  
   410  			if isHashValid := (err == nil); isHashValid != tt.isHashValid {
   411  				t.Errorf("isHashValid(%v) = %v, want %v", err, isHashValid, tt.isHashValid)
   412  			}
   413  
   414  			// Make sure that the file is readable from position 0.
   415  			if f != nil {
   416  				content, err := io.ReadAll(f)
   417  				if err != nil {
   418  					t.Errorf("Could not read content: %v", err)
   419  				}
   420  				if got := string(content); got != tt.wantContent {
   421  					t.Errorf("ReadAll = %v, want %s", got, tt.wantContent)
   422  				}
   423  			}
   424  		})
   425  	}
   426  }