github.com/andrewsun2898/u-root@v6.0.1-0.20200616011413-4b2895c1b815+incompatible/pkg/uroot/initramfs/files_test.go (about) 1 // Copyright 2018 the u-root Authors. All rights reserved 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package initramfs 6 7 import ( 8 "fmt" 9 "io" 10 "io/ioutil" 11 "os" 12 "path/filepath" 13 "reflect" 14 "strings" 15 "testing" 16 17 "github.com/u-root/u-root/pkg/cpio" 18 "github.com/u-root/u-root/pkg/uio" 19 ) 20 21 func TestFilesAddFileNoFollow(t *testing.T) { 22 regularFile, err := ioutil.TempFile("", "archive-files-add-file") 23 if err != nil { 24 t.Error(err) 25 } 26 defer os.RemoveAll(regularFile.Name()) 27 28 dir, err := ioutil.TempDir("", "archive-add-files") 29 if err != nil { 30 t.Error(err) 31 } 32 defer os.RemoveAll(dir) 33 34 dir2, err := ioutil.TempDir("", "archive-add-files") 35 if err != nil { 36 t.Error(err) 37 } 38 defer os.RemoveAll(dir2) 39 40 os.Create(filepath.Join(dir, "foo2")) 41 os.Symlink(filepath.Join(dir, "foo2"), filepath.Join(dir2, "foo3")) 42 43 for i, tt := range []struct { 44 name string 45 af *Files 46 src string 47 dest string 48 result *Files 49 errContains string 50 }{ 51 { 52 name: "just add a file", 53 af: NewFiles(), 54 55 src: regularFile.Name(), 56 dest: "bar/foo", 57 58 result: &Files{ 59 Files: map[string]string{ 60 "bar/foo": regularFile.Name(), 61 }, 62 Records: map[string]cpio.Record{}, 63 }, 64 }, 65 { 66 name: "add symlinked file, NOT following", 67 af: NewFiles(), 68 src: filepath.Join(dir2, "foo3"), 69 dest: "bar/foo", 70 result: &Files{ 71 Files: map[string]string{ 72 "bar/foo": filepath.Join(dir2, "foo3"), 73 }, 74 Records: map[string]cpio.Record{}, 75 }, 76 }, 77 } { 78 t.Run(fmt.Sprintf("Test %02d: %s", i, tt.name), func(t *testing.T) { 79 err := tt.af.AddFileNoFollow(tt.src, tt.dest) 80 if err != nil && !strings.Contains(err.Error(), tt.errContains) { 81 t.Errorf("Error is %v, does not contain %v", err, tt.errContains) 82 } 83 if err == nil && len(tt.errContains) > 0 { 84 t.Errorf("Got no error, want %v", tt.errContains) 85 } 86 87 if tt.result != nil && !reflect.DeepEqual(tt.af, tt.result) { 88 t.Errorf("got %v, want %v", tt.af, tt.result) 89 } 90 }) 91 } 92 } 93 94 func TestFilesAddFile(t *testing.T) { 95 regularFile, err := ioutil.TempFile("", "archive-files-add-file") 96 if err != nil { 97 t.Error(err) 98 } 99 defer os.RemoveAll(regularFile.Name()) 100 101 dir, err := ioutil.TempDir("", "archive-add-files") 102 if err != nil { 103 t.Error(err) 104 } 105 defer os.RemoveAll(dir) 106 107 dir2, err := ioutil.TempDir("", "archive-add-files") 108 if err != nil { 109 t.Error(err) 110 } 111 defer os.RemoveAll(dir2) 112 113 dir3, err := ioutil.TempDir("", "archive-add-files") 114 if err != nil { 115 t.Error(err) 116 } 117 defer os.RemoveAll(dir3) 118 119 os.Create(filepath.Join(dir, "foo")) 120 os.Create(filepath.Join(dir, "foo2")) 121 os.Symlink(filepath.Join(dir, "foo2"), filepath.Join(dir2, "foo3")) 122 123 fooDir := filepath.Join(dir3, "fooDir") 124 os.Mkdir(fooDir, os.ModePerm) 125 symlinkToDir3 := filepath.Join(dir3, "fooSymDir/") 126 os.Symlink(fooDir, symlinkToDir3) 127 os.Create(filepath.Join(fooDir, "foo")) 128 os.Create(filepath.Join(fooDir, "bar")) 129 130 for i, tt := range []struct { 131 name string 132 af *Files 133 src string 134 dest string 135 result *Files 136 errContains string 137 }{ 138 { 139 name: "just add a file", 140 af: NewFiles(), 141 142 src: regularFile.Name(), 143 dest: "bar/foo", 144 145 result: &Files{ 146 Files: map[string]string{ 147 "bar/foo": regularFile.Name(), 148 }, 149 Records: map[string]cpio.Record{}, 150 }, 151 }, 152 { 153 name: "add symlinked file, following", 154 af: NewFiles(), 155 src: filepath.Join(dir2, "foo3"), 156 dest: "bar/foo", 157 result: &Files{ 158 Files: map[string]string{ 159 "bar/foo": filepath.Join(dir, "foo2"), 160 }, 161 Records: map[string]cpio.Record{}, 162 }, 163 }, 164 { 165 name: "add symlinked directory, following", 166 af: NewFiles(), 167 src: symlinkToDir3, 168 dest: "foo/", 169 result: &Files{ 170 Files: map[string]string{ 171 "foo": fooDir, 172 "foo/foo": filepath.Join(fooDir, "foo"), 173 "foo/bar": filepath.Join(fooDir, "bar"), 174 }, 175 Records: map[string]cpio.Record{}, 176 }, 177 }, 178 { 179 name: "add file that exists in Files", 180 af: &Files{ 181 Files: map[string]string{ 182 "bar/foo": "/some/other/place", 183 }, 184 }, 185 src: regularFile.Name(), 186 dest: "bar/foo", 187 result: &Files{ 188 Files: map[string]string{ 189 "bar/foo": "/some/other/place", 190 }, 191 }, 192 errContains: "already exists in archive", 193 }, 194 { 195 name: "add a file that exists in Records", 196 af: &Files{ 197 Records: map[string]cpio.Record{ 198 "bar/foo": cpio.Symlink("bar/foo", "/some/other/place"), 199 }, 200 }, 201 src: regularFile.Name(), 202 dest: "bar/foo", 203 result: &Files{ 204 Records: map[string]cpio.Record{ 205 "bar/foo": cpio.Symlink("bar/foo", "/some/other/place"), 206 }, 207 }, 208 errContains: "already exists in archive", 209 }, 210 { 211 name: "add a file that already exists in Files, but is the same one", 212 af: &Files{ 213 Files: map[string]string{ 214 "bar/foo": regularFile.Name(), 215 }, 216 }, 217 src: regularFile.Name(), 218 dest: "bar/foo", 219 result: &Files{ 220 Files: map[string]string{ 221 "bar/foo": regularFile.Name(), 222 }, 223 }, 224 }, 225 { 226 name: "absolute destination paths are made relative", 227 af: &Files{ 228 Files: map[string]string{}, 229 }, 230 src: dir, 231 dest: "/bar/foo", 232 result: &Files{ 233 Files: map[string]string{ 234 "bar/foo": dir, 235 "bar/foo/foo": filepath.Join(dir, "foo"), 236 "bar/foo/foo2": filepath.Join(dir, "foo2"), 237 }, 238 }, 239 }, 240 { 241 name: "add a directory", 242 af: &Files{ 243 Files: map[string]string{}, 244 }, 245 src: dir, 246 dest: "bar/foo", 247 result: &Files{ 248 Files: map[string]string{ 249 "bar/foo": dir, 250 "bar/foo/foo": filepath.Join(dir, "foo"), 251 "bar/foo/foo2": filepath.Join(dir, "foo2"), 252 }, 253 }, 254 }, 255 { 256 name: "add a different directory to the same destination, no overlapping children", 257 af: &Files{ 258 Files: map[string]string{ 259 "bar/foo": "/some/place/real", 260 "bar/foo/zed": "/some/place/real/zed", 261 }, 262 }, 263 src: dir, 264 dest: "bar/foo", 265 result: &Files{ 266 Files: map[string]string{ 267 "bar/foo": dir, 268 "bar/foo/foo": filepath.Join(dir, "foo"), 269 "bar/foo/foo2": filepath.Join(dir, "foo2"), 270 "bar/foo/zed": "/some/place/real/zed", 271 }, 272 }, 273 }, 274 { 275 name: "add a different directory to the same destination, overlapping children", 276 af: &Files{ 277 Files: map[string]string{ 278 "bar/foo": "/some/place/real", 279 "bar/foo/foo2": "/some/place/real/zed", 280 }, 281 }, 282 src: dir, 283 dest: "bar/foo", 284 errContains: "already exists in archive", 285 }, 286 } { 287 t.Run(fmt.Sprintf("Test %02d: %s", i, tt.name), func(t *testing.T) { 288 err := tt.af.AddFile(tt.src, tt.dest) 289 if err != nil && !strings.Contains(err.Error(), tt.errContains) { 290 t.Errorf("Error is %v, does not contain %v", err, tt.errContains) 291 } 292 if err == nil && len(tt.errContains) > 0 { 293 t.Errorf("Got no error, want %v", tt.errContains) 294 } 295 296 if tt.result != nil && !reflect.DeepEqual(tt.af, tt.result) { 297 t.Errorf("got %v, want %v", tt.af, tt.result) 298 } 299 }) 300 } 301 } 302 303 func TestFilesAddRecord(t *testing.T) { 304 for i, tt := range []struct { 305 af *Files 306 record cpio.Record 307 308 result *Files 309 errContains string 310 }{ 311 { 312 af: NewFiles(), 313 record: cpio.Symlink("bar/foo", ""), 314 result: &Files{ 315 Files: map[string]string{}, 316 Records: map[string]cpio.Record{ 317 "bar/foo": cpio.Symlink("bar/foo", ""), 318 }, 319 }, 320 }, 321 { 322 af: &Files{ 323 Files: map[string]string{ 324 "bar/foo": "/some/other/place", 325 }, 326 }, 327 record: cpio.Symlink("bar/foo", ""), 328 result: &Files{ 329 Files: map[string]string{ 330 "bar/foo": "/some/other/place", 331 }, 332 }, 333 errContains: "already exists in archive", 334 }, 335 { 336 af: &Files{ 337 Records: map[string]cpio.Record{ 338 "bar/foo": cpio.Symlink("bar/foo", "/some/other/place"), 339 }, 340 }, 341 record: cpio.Symlink("bar/foo", ""), 342 result: &Files{ 343 Records: map[string]cpio.Record{ 344 "bar/foo": cpio.Symlink("bar/foo", "/some/other/place"), 345 }, 346 }, 347 errContains: "already exists in archive", 348 }, 349 { 350 af: &Files{ 351 Records: map[string]cpio.Record{ 352 "bar/foo": cpio.Symlink("bar/foo", "/some/other/place"), 353 }, 354 }, 355 record: cpio.Symlink("bar/foo", "/some/other/place"), 356 result: &Files{ 357 Records: map[string]cpio.Record{ 358 "bar/foo": cpio.Symlink("bar/foo", "/some/other/place"), 359 }, 360 }, 361 }, 362 { 363 record: cpio.Symlink("/bar/foo", ""), 364 errContains: "must not be absolute", 365 }, 366 } { 367 t.Run(fmt.Sprintf("Test %02d", i), func(t *testing.T) { 368 err := tt.af.AddRecord(tt.record) 369 if err != nil && !strings.Contains(err.Error(), tt.errContains) { 370 t.Errorf("Error is %v, does not contain %v", err, tt.errContains) 371 } 372 if err == nil && len(tt.errContains) > 0 { 373 t.Errorf("Got no error, want %v", tt.errContains) 374 } 375 376 if !reflect.DeepEqual(tt.af, tt.result) { 377 t.Errorf("got %v, want %v", tt.af, tt.result) 378 } 379 }) 380 } 381 } 382 383 func TestFilesfillInParent(t *testing.T) { 384 for i, tt := range []struct { 385 af *Files 386 result *Files 387 }{ 388 { 389 af: &Files{ 390 Records: map[string]cpio.Record{ 391 "foo/bar": cpio.Directory("foo/bar", 0777), 392 }, 393 }, 394 result: &Files{ 395 Records: map[string]cpio.Record{ 396 "foo/bar": cpio.Directory("foo/bar", 0777), 397 "foo": cpio.Directory("foo", 0755), 398 }, 399 }, 400 }, 401 { 402 af: &Files{ 403 Files: map[string]string{ 404 "baz/baz/baz": "/somewhere", 405 }, 406 Records: map[string]cpio.Record{ 407 "foo/bar": cpio.Directory("foo/bar", 0777), 408 }, 409 }, 410 result: &Files{ 411 Files: map[string]string{ 412 "baz/baz/baz": "/somewhere", 413 }, 414 Records: map[string]cpio.Record{ 415 "foo/bar": cpio.Directory("foo/bar", 0777), 416 "foo": cpio.Directory("foo", 0755), 417 "baz": cpio.Directory("baz", 0755), 418 "baz/baz": cpio.Directory("baz/baz", 0755), 419 }, 420 }, 421 }, 422 { 423 af: &Files{}, 424 result: &Files{}, 425 }, 426 } { 427 t.Run(fmt.Sprintf("Test %02d", i), func(t *testing.T) { 428 tt.af.fillInParents() 429 if !reflect.DeepEqual(tt.af, tt.result) { 430 t.Errorf("got %v, want %v", tt.af, tt.result) 431 } 432 }) 433 } 434 } 435 436 type MockArchiver struct { 437 Records Records 438 FinishCalled bool 439 BaseArchive []cpio.Record 440 } 441 442 func (ma *MockArchiver) WriteRecord(r cpio.Record) error { 443 if _, ok := ma.Records[r.Name]; ok { 444 return fmt.Errorf("file exists") 445 } 446 ma.Records[r.Name] = r 447 return nil 448 } 449 450 func (ma *MockArchiver) Finish() error { 451 ma.FinishCalled = true 452 return nil 453 } 454 455 func (ma *MockArchiver) ReadRecord() (cpio.Record, error) { 456 if len(ma.BaseArchive) > 0 { 457 next := ma.BaseArchive[0] 458 ma.BaseArchive = ma.BaseArchive[1:] 459 return next, nil 460 } 461 return cpio.Record{}, io.EOF 462 } 463 464 type Records map[string]cpio.Record 465 466 func RecordsEqual(r1, r2 Records, recordEqual func(cpio.Record, cpio.Record) bool) bool { 467 for name, s1 := range r1 { 468 s2, ok := r2[name] 469 if !ok { 470 return false 471 } 472 if !recordEqual(s1, s2) { 473 return false 474 } 475 } 476 for name := range r2 { 477 if _, ok := r1[name]; !ok { 478 return false 479 } 480 } 481 return true 482 } 483 484 func sameNameModeContent(r1 cpio.Record, r2 cpio.Record) bool { 485 if r1.Name != r2.Name || r1.Mode != r2.Mode { 486 return false 487 } 488 return uio.ReaderAtEqual(r1.ReaderAt, r2.ReaderAt) 489 } 490 491 func TestOptsWrite(t *testing.T) { 492 for i, tt := range []struct { 493 desc string 494 opts *Opts 495 ma *MockArchiver 496 want Records 497 err error 498 }{ 499 { 500 desc: "no conflicts, just records", 501 opts: &Opts{ 502 Files: &Files{ 503 Records: map[string]cpio.Record{ 504 "foo": cpio.Symlink("foo", "elsewhere"), 505 }, 506 }, 507 }, 508 ma: &MockArchiver{ 509 Records: make(Records), 510 BaseArchive: []cpio.Record{ 511 cpio.Directory("etc", 0777), 512 cpio.Directory("etc/nginx", 0777), 513 }, 514 }, 515 want: Records{ 516 "foo": cpio.Symlink("foo", "elsewhere"), 517 "etc": cpio.Directory("etc", 0777), 518 "etc/nginx": cpio.Directory("etc/nginx", 0777), 519 }, 520 }, 521 { 522 desc: "default already exists", 523 opts: &Opts{ 524 Files: &Files{ 525 Records: map[string]cpio.Record{ 526 "etc": cpio.Symlink("etc", "whatever"), 527 }, 528 }, 529 }, 530 ma: &MockArchiver{ 531 Records: make(Records), 532 BaseArchive: []cpio.Record{ 533 cpio.Directory("etc", 0777), 534 }, 535 }, 536 want: Records{ 537 "etc": cpio.Symlink("etc", "whatever"), 538 }, 539 }, 540 { 541 desc: "no conflicts, missing parent automatically created", 542 opts: &Opts{ 543 Files: &Files{ 544 Records: map[string]cpio.Record{ 545 "foo/bar/baz": cpio.Symlink("foo/bar/baz", "elsewhere"), 546 }, 547 }, 548 }, 549 ma: &MockArchiver{ 550 Records: make(Records), 551 }, 552 want: Records{ 553 "foo": cpio.Directory("foo", 0755), 554 "foo/bar": cpio.Directory("foo/bar", 0755), 555 "foo/bar/baz": cpio.Symlink("foo/bar/baz", "elsewhere"), 556 }, 557 }, 558 { 559 desc: "parent only automatically created if not already exists", 560 opts: &Opts{ 561 Files: &Files{ 562 Records: map[string]cpio.Record{ 563 "foo/bar": cpio.Directory("foo/bar", 0444), 564 "foo/bar/baz": cpio.Symlink("foo/bar/baz", "elsewhere"), 565 }, 566 }, 567 }, 568 ma: &MockArchiver{ 569 Records: make(Records), 570 }, 571 want: Records{ 572 "foo": cpio.Directory("foo", 0755), 573 "foo/bar": cpio.Directory("foo/bar", 0444), 574 "foo/bar/baz": cpio.Symlink("foo/bar/baz", "elsewhere"), 575 }, 576 }, 577 { 578 desc: "base archive", 579 opts: &Opts{ 580 Files: &Files{ 581 Records: map[string]cpio.Record{ 582 "foo/bar": cpio.Symlink("foo/bar", "elsewhere"), 583 "exists": cpio.Directory("exists", 0777), 584 }, 585 }, 586 }, 587 ma: &MockArchiver{ 588 Records: make(Records), 589 BaseArchive: []cpio.Record{ 590 cpio.Directory("etc", 0755), 591 cpio.Directory("foo", 0444), 592 cpio.Directory("exists", 0), 593 }, 594 }, 595 want: Records{ 596 "etc": cpio.Directory("etc", 0755), 597 "exists": cpio.Directory("exists", 0777), 598 "foo": cpio.Directory("foo", 0444), 599 "foo/bar": cpio.Symlink("foo/bar", "elsewhere"), 600 }, 601 }, 602 { 603 desc: "base archive with init, no user init", 604 opts: &Opts{ 605 Files: &Files{ 606 Records: map[string]cpio.Record{}, 607 }, 608 }, 609 ma: &MockArchiver{ 610 Records: make(Records), 611 BaseArchive: []cpio.Record{ 612 cpio.StaticFile("init", "boo", 0555), 613 }, 614 }, 615 want: Records{ 616 "init": cpio.StaticFile("init", "boo", 0555), 617 }, 618 }, 619 { 620 desc: "base archive with init and user init", 621 opts: &Opts{ 622 Files: &Files{ 623 Records: map[string]cpio.Record{ 624 "init": cpio.StaticFile("init", "bar", 0444), 625 }, 626 }, 627 }, 628 ma: &MockArchiver{ 629 Records: make(Records), 630 BaseArchive: []cpio.Record{ 631 cpio.StaticFile("init", "boo", 0555), 632 }, 633 }, 634 want: Records{ 635 "init": cpio.StaticFile("init", "bar", 0444), 636 "inito": cpio.StaticFile("inito", "boo", 0555), 637 }, 638 }, 639 { 640 desc: "base archive with init, use existing init", 641 opts: &Opts{ 642 Files: &Files{ 643 Records: map[string]cpio.Record{}, 644 }, 645 UseExistingInit: true, 646 }, 647 ma: &MockArchiver{ 648 Records: make(Records), 649 BaseArchive: []cpio.Record{ 650 cpio.StaticFile("init", "boo", 0555), 651 }, 652 }, 653 want: Records{ 654 "init": cpio.StaticFile("init", "boo", 0555), 655 }, 656 }, 657 { 658 desc: "base archive with init and user init, use existing init", 659 opts: &Opts{ 660 Files: &Files{ 661 Records: map[string]cpio.Record{ 662 "init": cpio.StaticFile("init", "huh", 0111), 663 }, 664 }, 665 UseExistingInit: true, 666 }, 667 ma: &MockArchiver{ 668 Records: make(Records), 669 BaseArchive: []cpio.Record{ 670 cpio.StaticFile("init", "boo", 0555), 671 }, 672 }, 673 want: Records{ 674 "init": cpio.StaticFile("init", "boo", 0555), 675 "inito": cpio.StaticFile("inito", "huh", 0111), 676 }, 677 }, 678 } { 679 t.Run(fmt.Sprintf("Test %02d (%s)", i, tt.desc), func(t *testing.T) { 680 tt.opts.BaseArchive = tt.ma 681 tt.opts.OutputFile = tt.ma 682 683 if err := Write(tt.opts); err != tt.err { 684 t.Errorf("Write() = %v, want %v", err, tt.err) 685 } else if err == nil && !tt.ma.FinishCalled { 686 t.Errorf("Finish wasn't called on archive") 687 } 688 689 if !RecordsEqual(tt.ma.Records, tt.want, sameNameModeContent) { 690 t.Errorf("Write() = %v, want %v", tt.ma.Records, tt.want) 691 } 692 }) 693 } 694 }