github.com/mvdan/u-root-coreutils@v0.0.0-20230122170626-c2eef2898555/pkg/vfile/vfile_test.go (about) 1 // Copyright 2020 the u-root Authors. All rights reserved 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package vfile 6 7 import ( 8 "bytes" 9 "crypto/sha256" 10 "fmt" 11 "io" 12 "math/rand" 13 "os" 14 "path/filepath" 15 "reflect" 16 "strings" 17 "syscall" 18 "testing" 19 "time" 20 21 "github.com/ProtonMail/go-crypto/openpgp" 22 "github.com/ProtonMail/go-crypto/openpgp/errors" 23 "github.com/ProtonMail/go-crypto/openpgp/packet" 24 ) 25 26 type signedFile struct { 27 signers []*openpgp.Entity 28 content string 29 } 30 31 func (s signedFile) write(path string) error { 32 f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600) 33 if err != nil { 34 return err 35 } 36 defer f.Close() 37 38 if _, err := f.Write([]byte(s.content)); err != nil { 39 return err 40 } 41 42 sigf, err := os.OpenFile(fmt.Sprintf("%s.sig", path), os.O_RDWR|os.O_CREATE, 0o600) 43 if err != nil { 44 return err 45 } 46 defer sigf.Close() 47 for _, signer := range s.signers { 48 if err := openpgp.DetachSign(sigf, signer, strings.NewReader(s.content), nil); err != nil { 49 return err 50 } 51 } 52 return nil 53 } 54 55 type normalFile struct { 56 content string 57 } 58 59 func (n normalFile) write(path string) error { 60 return os.WriteFile(path, []byte(n.content), 0o600) 61 } 62 63 func writeHashedFile(path, content string) ([]byte, error) { 64 c := []byte(content) 65 if err := os.WriteFile(path, c, 0o600); err != nil { 66 return nil, err 67 } 68 hash := sha256.Sum256(c) 69 return hash[:], nil 70 } 71 72 func TestOpenSignedFile(t *testing.T) { 73 keyFiles := []string{"key0", "key1"} 74 75 // EntityGenerate generates the entities in testdata/. The entities are 76 // cached because they take 40+ seconds to generate in arm64 QEMU. 77 t.Run("EntityGenerate", func(t *testing.T) { 78 t.Skip("uncomment this to generate the entities") 79 80 if err := os.MkdirAll("testdata", 0o777); err != nil { 81 t.Fatal(err) 82 } 83 84 for i, k := range keyFiles { 85 // You would think this Config would be sufficient to 86 // generate the same each time for the test, but it is 87 // not the case (and I don't know why). 88 var s int64 89 conf := &packet.Config{ 90 Rand: rand.New(rand.NewSource(int64(i))), 91 Time: func() time.Time { 92 s++ 93 return time.Unix(s, 0) 94 }, 95 } 96 key, err := openpgp.NewEntity("goog", "goog", "goog@goog", conf) 97 if err != nil { 98 t.Fatal(err) 99 } 100 101 f, err := os.Create(filepath.Join("testdata", k)) 102 if err != nil { 103 t.Fatal(err) 104 } 105 if err := key.SerializePrivate(f, conf); err != nil { 106 f.Close() 107 t.Fatal(err) 108 } 109 if err := f.Close(); err != nil { 110 t.Fatal(err) 111 } 112 } 113 }) 114 115 // This depends on the keys generated by EntityGenerate. 116 var keys []*openpgp.Entity 117 for _, k := range keyFiles { 118 b, err := os.ReadFile(filepath.Join("testdata", k)) 119 if err != nil { 120 t.Fatal(err) 121 } 122 key, err := openpgp.ReadEntity(packet.NewReader(bytes.NewBuffer(b))) 123 if err != nil { 124 t.Fatal(err) 125 } 126 keys = append(keys, key) 127 } 128 129 ring := openpgp.EntityList{keys[0]} 130 131 dir := t.TempDir() 132 133 signed := signedFile{ 134 signers: openpgp.EntityList{keys[0]}, 135 content: "foo", 136 } 137 signedPath := filepath.Join(dir, "signed_by_key0") 138 if err := signed.write(signedPath); err != nil { 139 t.Fatal(err) 140 } 141 142 signed2 := signedFile{ 143 signers: openpgp.EntityList{keys[1]}, 144 content: "foo", 145 } 146 signed2Path := filepath.Join(dir, "signed_by_key1") 147 if err := signed2.write(signed2Path); err != nil { 148 t.Fatal(err) 149 } 150 151 signed12 := signedFile{ 152 signers: openpgp.EntityList{keys[0], keys[1]}, 153 content: "foo", 154 } 155 signed12Path := filepath.Join(dir, "signed_by_both.sig") 156 if err := signed12.write(signed12Path); err != nil { 157 t.Fatal(err) 158 } 159 160 normalPath := filepath.Join(dir, "unsigned") 161 if err := os.WriteFile(normalPath, []byte("foo"), 0o777); err != nil { 162 t.Fatal(err) 163 } 164 165 for _, tt := range []struct { 166 desc string 167 path string 168 keyring openpgp.KeyRing 169 want error 170 isSignatureValid bool 171 }{ 172 { 173 desc: "signed file", 174 keyring: ring, 175 path: signedPath, 176 want: nil, 177 isSignatureValid: true, 178 }, 179 { 180 desc: "signed file w/ two signatures (key0 ring)", 181 keyring: ring, 182 path: signed12Path, 183 want: nil, 184 isSignatureValid: true, 185 }, 186 { 187 desc: "signed file w/ two signatures (key1 ring)", 188 keyring: openpgp.EntityList{keys[1]}, 189 path: signed12Path, 190 want: nil, 191 isSignatureValid: true, 192 }, 193 { 194 desc: "nil keyring", 195 keyring: nil, 196 path: signed2Path, 197 want: ErrUnsigned{ 198 Path: signed2Path, 199 Err: ErrNoKeyRing, 200 }, 201 isSignatureValid: false, 202 }, 203 { 204 desc: "non-nil empty keyring", 205 keyring: openpgp.EntityList{}, 206 path: signed2Path, 207 want: ErrUnsigned{ 208 Path: signed2Path, 209 Err: errors.ErrUnknownIssuer, 210 }, 211 isSignatureValid: false, 212 }, 213 { 214 desc: "signed file does not match keyring", 215 keyring: openpgp.EntityList{keys[1]}, 216 path: signedPath, 217 want: ErrUnsigned{ 218 Path: signedPath, 219 Err: errors.ErrUnknownIssuer, 220 }, 221 isSignatureValid: false, 222 }, 223 { 224 desc: "unsigned file", 225 keyring: ring, 226 path: normalPath, 227 want: ErrUnsigned{ 228 Path: normalPath, 229 Err: &os.PathError{ 230 Op: "open", 231 Path: fmt.Sprintf("%s.sig", normalPath), 232 Err: syscall.ENOENT, 233 }, 234 }, 235 isSignatureValid: false, 236 }, 237 { 238 desc: "file does not exist", 239 keyring: ring, 240 path: filepath.Join(dir, "foo"), 241 want: &os.PathError{ 242 Op: "open", 243 Path: filepath.Join(dir, "foo"), 244 Err: syscall.ENOENT, 245 }, 246 isSignatureValid: false, 247 }, 248 } { 249 t.Run(tt.desc, func(t *testing.T) { 250 f, gotErr := OpenSignedSigFile(tt.keyring, tt.path) 251 if !reflect.DeepEqual(gotErr, tt.want) { 252 t.Errorf("openSignedFile(%v, %q) = %v, want %v", tt.keyring, tt.path, gotErr, tt.want) 253 } 254 255 if isSignatureValid := (gotErr == nil); isSignatureValid != tt.isSignatureValid { 256 t.Errorf("isSignatureValid(%v) = %v, want %v", gotErr, isSignatureValid, tt.isSignatureValid) 257 } 258 259 // Make sure that the file is readable from position 0. 260 if f != nil { 261 content, err := io.ReadAll(f) 262 if err != nil { 263 t.Errorf("Could not read content: %v", err) 264 } 265 if got := string(content); got != "foo" { 266 t.Errorf("ReadAll = %v, want \"foo\"", got) 267 } 268 } 269 }) 270 } 271 } 272 273 func TestReadSignedImage(t *testing.T) { 274 for _, tt := range []struct { 275 desc string 276 path string 277 wantKeyCnt int 278 wantError bool 279 }{ 280 { 281 desc: "Correct read key0", 282 path: "testdata/key0", 283 wantError: false, 284 wantKeyCnt: 2, 285 }, 286 { 287 desc: "Correct read key1", 288 path: "testdata/key1", 289 wantError: false, 290 wantKeyCnt: 2, 291 }, 292 { 293 desc: "Read nonRSA key", 294 path: "testdata/dsakey", 295 wantError: true, 296 wantKeyCnt: 0, 297 }, 298 { 299 desc: "Multikey ring", 300 path: "testdata/keyring0+1+dsa", 301 wantError: false, 302 wantKeyCnt: 4, 303 }, 304 } { 305 t.Run(tt.desc, func(t *testing.T) { 306 ring, err := GetKeyRing(tt.path) 307 if err != nil { 308 t.Fatalf("GetKeyRing(%s) Failed with err: %v", tt.path, err) 309 } 310 gotKeys, gotErr := GetRSAKeysFromRing(ring) 311 if (gotErr == nil) == tt.wantError { 312 t.Errorf("GetRSAKeysFromRing(%s) = %v, want %v", tt.path, gotErr, tt.wantError) 313 } 314 var gotCnt int 315 if gotKeys == nil { 316 gotCnt = 0 317 } else { 318 gotCnt = len(gotKeys) 319 } 320 321 if tt.wantKeyCnt != gotCnt { 322 t.Errorf("GetRSAKeysFromRing(%s) returned %d keys, want %d", tt.path, gotCnt, tt.wantKeyCnt) 323 } 324 }) 325 } 326 } 327 328 func TestOpenHashedFile(t *testing.T) { 329 dir := t.TempDir() 330 331 hashedPath := filepath.Join(dir, "hashed") 332 hash, err := writeHashedFile(hashedPath, "foo") 333 if err != nil { 334 t.Fatal(err) 335 } 336 337 emptyPath := filepath.Join(dir, "empty") 338 emptyHash, err := writeHashedFile(emptyPath, "") 339 if err != nil { 340 t.Fatal(err) 341 } 342 343 for _, tt := range []struct { 344 desc string 345 path string 346 hash []byte 347 want error 348 isHashValid bool 349 wantContent string 350 }{ 351 { 352 desc: "correct hash", 353 path: hashedPath, 354 hash: hash, 355 want: nil, 356 isHashValid: true, 357 wantContent: "foo", 358 }, 359 { 360 desc: "wrong hash", 361 path: hashedPath, 362 hash: []byte{0x99, 0x77}, 363 want: ErrInvalidHash{ 364 Path: hashedPath, 365 Err: ErrHashMismatch{ 366 Got: hash, 367 Want: []byte{0x99, 0x77}, 368 }, 369 }, 370 isHashValid: false, 371 wantContent: "foo", 372 }, 373 { 374 desc: "no hash", 375 path: hashedPath, 376 hash: []byte{}, 377 want: ErrInvalidHash{ 378 Path: hashedPath, 379 Err: ErrNoExpectedHash, 380 }, 381 isHashValid: false, 382 wantContent: "foo", 383 }, 384 { 385 desc: "empty file", 386 path: emptyPath, 387 hash: emptyHash, 388 want: nil, 389 isHashValid: true, 390 wantContent: "", 391 }, 392 { 393 desc: "nonexistent file", 394 path: filepath.Join(dir, "doesnotexist"), 395 hash: nil, 396 want: &os.PathError{ 397 Op: "open", 398 Path: filepath.Join(dir, "doesnotexist"), 399 Err: syscall.ENOENT, 400 }, 401 isHashValid: false, 402 }, 403 } { 404 t.Run(tt.desc, func(t *testing.T) { 405 f, err := OpenHashedFile256(tt.path, tt.hash) 406 if !reflect.DeepEqual(err, tt.want) { 407 t.Errorf("OpenHashedFile256(%s, %x) = %v, want %v", tt.path, tt.hash, err, tt.want) 408 } 409 410 if isHashValid := (err == nil); isHashValid != tt.isHashValid { 411 t.Errorf("isHashValid(%v) = %v, want %v", err, isHashValid, tt.isHashValid) 412 } 413 414 // Make sure that the file is readable from position 0. 415 if f != nil { 416 content, err := io.ReadAll(f) 417 if err != nil { 418 t.Errorf("Could not read content: %v", err) 419 } 420 if got := string(content); got != tt.wantContent { 421 t.Errorf("ReadAll = %v, want %s", got, tt.wantContent) 422 } 423 } 424 }) 425 } 426 }