github.com/andrewsun2898/u-root@v6.0.1-0.20200616011413-4b2895c1b815+incompatible/pkg/uroot/initramfs/files_test.go (about)

     1  // Copyright 2018 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package initramfs
     6  
     7  import (
     8  	"fmt"
     9  	"io"
    10  	"io/ioutil"
    11  	"os"
    12  	"path/filepath"
    13  	"reflect"
    14  	"strings"
    15  	"testing"
    16  
    17  	"github.com/u-root/u-root/pkg/cpio"
    18  	"github.com/u-root/u-root/pkg/uio"
    19  )
    20  
    21  func TestFilesAddFileNoFollow(t *testing.T) {
    22  	regularFile, err := ioutil.TempFile("", "archive-files-add-file")
    23  	if err != nil {
    24  		t.Error(err)
    25  	}
    26  	defer os.RemoveAll(regularFile.Name())
    27  
    28  	dir, err := ioutil.TempDir("", "archive-add-files")
    29  	if err != nil {
    30  		t.Error(err)
    31  	}
    32  	defer os.RemoveAll(dir)
    33  
    34  	dir2, err := ioutil.TempDir("", "archive-add-files")
    35  	if err != nil {
    36  		t.Error(err)
    37  	}
    38  	defer os.RemoveAll(dir2)
    39  
    40  	os.Create(filepath.Join(dir, "foo2"))
    41  	os.Symlink(filepath.Join(dir, "foo2"), filepath.Join(dir2, "foo3"))
    42  
    43  	for i, tt := range []struct {
    44  		name        string
    45  		af          *Files
    46  		src         string
    47  		dest        string
    48  		result      *Files
    49  		errContains string
    50  	}{
    51  		{
    52  			name: "just add a file",
    53  			af:   NewFiles(),
    54  
    55  			src:  regularFile.Name(),
    56  			dest: "bar/foo",
    57  
    58  			result: &Files{
    59  				Files: map[string]string{
    60  					"bar/foo": regularFile.Name(),
    61  				},
    62  				Records: map[string]cpio.Record{},
    63  			},
    64  		},
    65  		{
    66  			name: "add symlinked file, NOT following",
    67  			af:   NewFiles(),
    68  			src:  filepath.Join(dir2, "foo3"),
    69  			dest: "bar/foo",
    70  			result: &Files{
    71  				Files: map[string]string{
    72  					"bar/foo": filepath.Join(dir2, "foo3"),
    73  				},
    74  				Records: map[string]cpio.Record{},
    75  			},
    76  		},
    77  	} {
    78  		t.Run(fmt.Sprintf("Test %02d: %s", i, tt.name), func(t *testing.T) {
    79  			err := tt.af.AddFileNoFollow(tt.src, tt.dest)
    80  			if err != nil && !strings.Contains(err.Error(), tt.errContains) {
    81  				t.Errorf("Error is %v, does not contain %v", err, tt.errContains)
    82  			}
    83  			if err == nil && len(tt.errContains) > 0 {
    84  				t.Errorf("Got no error, want %v", tt.errContains)
    85  			}
    86  
    87  			if tt.result != nil && !reflect.DeepEqual(tt.af, tt.result) {
    88  				t.Errorf("got %v, want %v", tt.af, tt.result)
    89  			}
    90  		})
    91  	}
    92  }
    93  
    94  func TestFilesAddFile(t *testing.T) {
    95  	regularFile, err := ioutil.TempFile("", "archive-files-add-file")
    96  	if err != nil {
    97  		t.Error(err)
    98  	}
    99  	defer os.RemoveAll(regularFile.Name())
   100  
   101  	dir, err := ioutil.TempDir("", "archive-add-files")
   102  	if err != nil {
   103  		t.Error(err)
   104  	}
   105  	defer os.RemoveAll(dir)
   106  
   107  	dir2, err := ioutil.TempDir("", "archive-add-files")
   108  	if err != nil {
   109  		t.Error(err)
   110  	}
   111  	defer os.RemoveAll(dir2)
   112  
   113  	dir3, err := ioutil.TempDir("", "archive-add-files")
   114  	if err != nil {
   115  		t.Error(err)
   116  	}
   117  	defer os.RemoveAll(dir3)
   118  
   119  	os.Create(filepath.Join(dir, "foo"))
   120  	os.Create(filepath.Join(dir, "foo2"))
   121  	os.Symlink(filepath.Join(dir, "foo2"), filepath.Join(dir2, "foo3"))
   122  
   123  	fooDir := filepath.Join(dir3, "fooDir")
   124  	os.Mkdir(fooDir, os.ModePerm)
   125  	symlinkToDir3 := filepath.Join(dir3, "fooSymDir/")
   126  	os.Symlink(fooDir, symlinkToDir3)
   127  	os.Create(filepath.Join(fooDir, "foo"))
   128  	os.Create(filepath.Join(fooDir, "bar"))
   129  
   130  	for i, tt := range []struct {
   131  		name        string
   132  		af          *Files
   133  		src         string
   134  		dest        string
   135  		result      *Files
   136  		errContains string
   137  	}{
   138  		{
   139  			name: "just add a file",
   140  			af:   NewFiles(),
   141  
   142  			src:  regularFile.Name(),
   143  			dest: "bar/foo",
   144  
   145  			result: &Files{
   146  				Files: map[string]string{
   147  					"bar/foo": regularFile.Name(),
   148  				},
   149  				Records: map[string]cpio.Record{},
   150  			},
   151  		},
   152  		{
   153  			name: "add symlinked file, following",
   154  			af:   NewFiles(),
   155  			src:  filepath.Join(dir2, "foo3"),
   156  			dest: "bar/foo",
   157  			result: &Files{
   158  				Files: map[string]string{
   159  					"bar/foo": filepath.Join(dir, "foo2"),
   160  				},
   161  				Records: map[string]cpio.Record{},
   162  			},
   163  		},
   164  		{
   165  			name: "add symlinked directory, following",
   166  			af:   NewFiles(),
   167  			src:  symlinkToDir3,
   168  			dest: "foo/",
   169  			result: &Files{
   170  				Files: map[string]string{
   171  					"foo":     fooDir,
   172  					"foo/foo": filepath.Join(fooDir, "foo"),
   173  					"foo/bar": filepath.Join(fooDir, "bar"),
   174  				},
   175  				Records: map[string]cpio.Record{},
   176  			},
   177  		},
   178  		{
   179  			name: "add file that exists in Files",
   180  			af: &Files{
   181  				Files: map[string]string{
   182  					"bar/foo": "/some/other/place",
   183  				},
   184  			},
   185  			src:  regularFile.Name(),
   186  			dest: "bar/foo",
   187  			result: &Files{
   188  				Files: map[string]string{
   189  					"bar/foo": "/some/other/place",
   190  				},
   191  			},
   192  			errContains: "already exists in archive",
   193  		},
   194  		{
   195  			name: "add a file that exists in Records",
   196  			af: &Files{
   197  				Records: map[string]cpio.Record{
   198  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   199  				},
   200  			},
   201  			src:  regularFile.Name(),
   202  			dest: "bar/foo",
   203  			result: &Files{
   204  				Records: map[string]cpio.Record{
   205  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   206  				},
   207  			},
   208  			errContains: "already exists in archive",
   209  		},
   210  		{
   211  			name: "add a file that already exists in Files, but is the same one",
   212  			af: &Files{
   213  				Files: map[string]string{
   214  					"bar/foo": regularFile.Name(),
   215  				},
   216  			},
   217  			src:  regularFile.Name(),
   218  			dest: "bar/foo",
   219  			result: &Files{
   220  				Files: map[string]string{
   221  					"bar/foo": regularFile.Name(),
   222  				},
   223  			},
   224  		},
   225  		{
   226  			name: "absolute destination paths are made relative",
   227  			af: &Files{
   228  				Files: map[string]string{},
   229  			},
   230  			src:  dir,
   231  			dest: "/bar/foo",
   232  			result: &Files{
   233  				Files: map[string]string{
   234  					"bar/foo":      dir,
   235  					"bar/foo/foo":  filepath.Join(dir, "foo"),
   236  					"bar/foo/foo2": filepath.Join(dir, "foo2"),
   237  				},
   238  			},
   239  		},
   240  		{
   241  			name: "add a directory",
   242  			af: &Files{
   243  				Files: map[string]string{},
   244  			},
   245  			src:  dir,
   246  			dest: "bar/foo",
   247  			result: &Files{
   248  				Files: map[string]string{
   249  					"bar/foo":      dir,
   250  					"bar/foo/foo":  filepath.Join(dir, "foo"),
   251  					"bar/foo/foo2": filepath.Join(dir, "foo2"),
   252  				},
   253  			},
   254  		},
   255  		{
   256  			name: "add a different directory to the same destination, no overlapping children",
   257  			af: &Files{
   258  				Files: map[string]string{
   259  					"bar/foo":     "/some/place/real",
   260  					"bar/foo/zed": "/some/place/real/zed",
   261  				},
   262  			},
   263  			src:  dir,
   264  			dest: "bar/foo",
   265  			result: &Files{
   266  				Files: map[string]string{
   267  					"bar/foo":      dir,
   268  					"bar/foo/foo":  filepath.Join(dir, "foo"),
   269  					"bar/foo/foo2": filepath.Join(dir, "foo2"),
   270  					"bar/foo/zed":  "/some/place/real/zed",
   271  				},
   272  			},
   273  		},
   274  		{
   275  			name: "add a different directory to the same destination, overlapping children",
   276  			af: &Files{
   277  				Files: map[string]string{
   278  					"bar/foo":      "/some/place/real",
   279  					"bar/foo/foo2": "/some/place/real/zed",
   280  				},
   281  			},
   282  			src:         dir,
   283  			dest:        "bar/foo",
   284  			errContains: "already exists in archive",
   285  		},
   286  	} {
   287  		t.Run(fmt.Sprintf("Test %02d: %s", i, tt.name), func(t *testing.T) {
   288  			err := tt.af.AddFile(tt.src, tt.dest)
   289  			if err != nil && !strings.Contains(err.Error(), tt.errContains) {
   290  				t.Errorf("Error is %v, does not contain %v", err, tt.errContains)
   291  			}
   292  			if err == nil && len(tt.errContains) > 0 {
   293  				t.Errorf("Got no error, want %v", tt.errContains)
   294  			}
   295  
   296  			if tt.result != nil && !reflect.DeepEqual(tt.af, tt.result) {
   297  				t.Errorf("got %v, want %v", tt.af, tt.result)
   298  			}
   299  		})
   300  	}
   301  }
   302  
   303  func TestFilesAddRecord(t *testing.T) {
   304  	for i, tt := range []struct {
   305  		af     *Files
   306  		record cpio.Record
   307  
   308  		result      *Files
   309  		errContains string
   310  	}{
   311  		{
   312  			af:     NewFiles(),
   313  			record: cpio.Symlink("bar/foo", ""),
   314  			result: &Files{
   315  				Files: map[string]string{},
   316  				Records: map[string]cpio.Record{
   317  					"bar/foo": cpio.Symlink("bar/foo", ""),
   318  				},
   319  			},
   320  		},
   321  		{
   322  			af: &Files{
   323  				Files: map[string]string{
   324  					"bar/foo": "/some/other/place",
   325  				},
   326  			},
   327  			record: cpio.Symlink("bar/foo", ""),
   328  			result: &Files{
   329  				Files: map[string]string{
   330  					"bar/foo": "/some/other/place",
   331  				},
   332  			},
   333  			errContains: "already exists in archive",
   334  		},
   335  		{
   336  			af: &Files{
   337  				Records: map[string]cpio.Record{
   338  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   339  				},
   340  			},
   341  			record: cpio.Symlink("bar/foo", ""),
   342  			result: &Files{
   343  				Records: map[string]cpio.Record{
   344  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   345  				},
   346  			},
   347  			errContains: "already exists in archive",
   348  		},
   349  		{
   350  			af: &Files{
   351  				Records: map[string]cpio.Record{
   352  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   353  				},
   354  			},
   355  			record: cpio.Symlink("bar/foo", "/some/other/place"),
   356  			result: &Files{
   357  				Records: map[string]cpio.Record{
   358  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   359  				},
   360  			},
   361  		},
   362  		{
   363  			record:      cpio.Symlink("/bar/foo", ""),
   364  			errContains: "must not be absolute",
   365  		},
   366  	} {
   367  		t.Run(fmt.Sprintf("Test %02d", i), func(t *testing.T) {
   368  			err := tt.af.AddRecord(tt.record)
   369  			if err != nil && !strings.Contains(err.Error(), tt.errContains) {
   370  				t.Errorf("Error is %v, does not contain %v", err, tt.errContains)
   371  			}
   372  			if err == nil && len(tt.errContains) > 0 {
   373  				t.Errorf("Got no error, want %v", tt.errContains)
   374  			}
   375  
   376  			if !reflect.DeepEqual(tt.af, tt.result) {
   377  				t.Errorf("got %v, want %v", tt.af, tt.result)
   378  			}
   379  		})
   380  	}
   381  }
   382  
   383  func TestFilesfillInParent(t *testing.T) {
   384  	for i, tt := range []struct {
   385  		af     *Files
   386  		result *Files
   387  	}{
   388  		{
   389  			af: &Files{
   390  				Records: map[string]cpio.Record{
   391  					"foo/bar": cpio.Directory("foo/bar", 0777),
   392  				},
   393  			},
   394  			result: &Files{
   395  				Records: map[string]cpio.Record{
   396  					"foo/bar": cpio.Directory("foo/bar", 0777),
   397  					"foo":     cpio.Directory("foo", 0755),
   398  				},
   399  			},
   400  		},
   401  		{
   402  			af: &Files{
   403  				Files: map[string]string{
   404  					"baz/baz/baz": "/somewhere",
   405  				},
   406  				Records: map[string]cpio.Record{
   407  					"foo/bar": cpio.Directory("foo/bar", 0777),
   408  				},
   409  			},
   410  			result: &Files{
   411  				Files: map[string]string{
   412  					"baz/baz/baz": "/somewhere",
   413  				},
   414  				Records: map[string]cpio.Record{
   415  					"foo/bar": cpio.Directory("foo/bar", 0777),
   416  					"foo":     cpio.Directory("foo", 0755),
   417  					"baz":     cpio.Directory("baz", 0755),
   418  					"baz/baz": cpio.Directory("baz/baz", 0755),
   419  				},
   420  			},
   421  		},
   422  		{
   423  			af:     &Files{},
   424  			result: &Files{},
   425  		},
   426  	} {
   427  		t.Run(fmt.Sprintf("Test %02d", i), func(t *testing.T) {
   428  			tt.af.fillInParents()
   429  			if !reflect.DeepEqual(tt.af, tt.result) {
   430  				t.Errorf("got %v, want %v", tt.af, tt.result)
   431  			}
   432  		})
   433  	}
   434  }
   435  
   436  type MockArchiver struct {
   437  	Records      Records
   438  	FinishCalled bool
   439  	BaseArchive  []cpio.Record
   440  }
   441  
   442  func (ma *MockArchiver) WriteRecord(r cpio.Record) error {
   443  	if _, ok := ma.Records[r.Name]; ok {
   444  		return fmt.Errorf("file exists")
   445  	}
   446  	ma.Records[r.Name] = r
   447  	return nil
   448  }
   449  
   450  func (ma *MockArchiver) Finish() error {
   451  	ma.FinishCalled = true
   452  	return nil
   453  }
   454  
   455  func (ma *MockArchiver) ReadRecord() (cpio.Record, error) {
   456  	if len(ma.BaseArchive) > 0 {
   457  		next := ma.BaseArchive[0]
   458  		ma.BaseArchive = ma.BaseArchive[1:]
   459  		return next, nil
   460  	}
   461  	return cpio.Record{}, io.EOF
   462  }
   463  
   464  type Records map[string]cpio.Record
   465  
   466  func RecordsEqual(r1, r2 Records, recordEqual func(cpio.Record, cpio.Record) bool) bool {
   467  	for name, s1 := range r1 {
   468  		s2, ok := r2[name]
   469  		if !ok {
   470  			return false
   471  		}
   472  		if !recordEqual(s1, s2) {
   473  			return false
   474  		}
   475  	}
   476  	for name := range r2 {
   477  		if _, ok := r1[name]; !ok {
   478  			return false
   479  		}
   480  	}
   481  	return true
   482  }
   483  
   484  func sameNameModeContent(r1 cpio.Record, r2 cpio.Record) bool {
   485  	if r1.Name != r2.Name || r1.Mode != r2.Mode {
   486  		return false
   487  	}
   488  	return uio.ReaderAtEqual(r1.ReaderAt, r2.ReaderAt)
   489  }
   490  
   491  func TestOptsWrite(t *testing.T) {
   492  	for i, tt := range []struct {
   493  		desc string
   494  		opts *Opts
   495  		ma   *MockArchiver
   496  		want Records
   497  		err  error
   498  	}{
   499  		{
   500  			desc: "no conflicts, just records",
   501  			opts: &Opts{
   502  				Files: &Files{
   503  					Records: map[string]cpio.Record{
   504  						"foo": cpio.Symlink("foo", "elsewhere"),
   505  					},
   506  				},
   507  			},
   508  			ma: &MockArchiver{
   509  				Records: make(Records),
   510  				BaseArchive: []cpio.Record{
   511  					cpio.Directory("etc", 0777),
   512  					cpio.Directory("etc/nginx", 0777),
   513  				},
   514  			},
   515  			want: Records{
   516  				"foo":       cpio.Symlink("foo", "elsewhere"),
   517  				"etc":       cpio.Directory("etc", 0777),
   518  				"etc/nginx": cpio.Directory("etc/nginx", 0777),
   519  			},
   520  		},
   521  		{
   522  			desc: "default already exists",
   523  			opts: &Opts{
   524  				Files: &Files{
   525  					Records: map[string]cpio.Record{
   526  						"etc": cpio.Symlink("etc", "whatever"),
   527  					},
   528  				},
   529  			},
   530  			ma: &MockArchiver{
   531  				Records: make(Records),
   532  				BaseArchive: []cpio.Record{
   533  					cpio.Directory("etc", 0777),
   534  				},
   535  			},
   536  			want: Records{
   537  				"etc": cpio.Symlink("etc", "whatever"),
   538  			},
   539  		},
   540  		{
   541  			desc: "no conflicts, missing parent automatically created",
   542  			opts: &Opts{
   543  				Files: &Files{
   544  					Records: map[string]cpio.Record{
   545  						"foo/bar/baz": cpio.Symlink("foo/bar/baz", "elsewhere"),
   546  					},
   547  				},
   548  			},
   549  			ma: &MockArchiver{
   550  				Records: make(Records),
   551  			},
   552  			want: Records{
   553  				"foo":         cpio.Directory("foo", 0755),
   554  				"foo/bar":     cpio.Directory("foo/bar", 0755),
   555  				"foo/bar/baz": cpio.Symlink("foo/bar/baz", "elsewhere"),
   556  			},
   557  		},
   558  		{
   559  			desc: "parent only automatically created if not already exists",
   560  			opts: &Opts{
   561  				Files: &Files{
   562  					Records: map[string]cpio.Record{
   563  						"foo/bar":     cpio.Directory("foo/bar", 0444),
   564  						"foo/bar/baz": cpio.Symlink("foo/bar/baz", "elsewhere"),
   565  					},
   566  				},
   567  			},
   568  			ma: &MockArchiver{
   569  				Records: make(Records),
   570  			},
   571  			want: Records{
   572  				"foo":         cpio.Directory("foo", 0755),
   573  				"foo/bar":     cpio.Directory("foo/bar", 0444),
   574  				"foo/bar/baz": cpio.Symlink("foo/bar/baz", "elsewhere"),
   575  			},
   576  		},
   577  		{
   578  			desc: "base archive",
   579  			opts: &Opts{
   580  				Files: &Files{
   581  					Records: map[string]cpio.Record{
   582  						"foo/bar": cpio.Symlink("foo/bar", "elsewhere"),
   583  						"exists":  cpio.Directory("exists", 0777),
   584  					},
   585  				},
   586  			},
   587  			ma: &MockArchiver{
   588  				Records: make(Records),
   589  				BaseArchive: []cpio.Record{
   590  					cpio.Directory("etc", 0755),
   591  					cpio.Directory("foo", 0444),
   592  					cpio.Directory("exists", 0),
   593  				},
   594  			},
   595  			want: Records{
   596  				"etc":     cpio.Directory("etc", 0755),
   597  				"exists":  cpio.Directory("exists", 0777),
   598  				"foo":     cpio.Directory("foo", 0444),
   599  				"foo/bar": cpio.Symlink("foo/bar", "elsewhere"),
   600  			},
   601  		},
   602  		{
   603  			desc: "base archive with init, no user init",
   604  			opts: &Opts{
   605  				Files: &Files{
   606  					Records: map[string]cpio.Record{},
   607  				},
   608  			},
   609  			ma: &MockArchiver{
   610  				Records: make(Records),
   611  				BaseArchive: []cpio.Record{
   612  					cpio.StaticFile("init", "boo", 0555),
   613  				},
   614  			},
   615  			want: Records{
   616  				"init": cpio.StaticFile("init", "boo", 0555),
   617  			},
   618  		},
   619  		{
   620  			desc: "base archive with init and user init",
   621  			opts: &Opts{
   622  				Files: &Files{
   623  					Records: map[string]cpio.Record{
   624  						"init": cpio.StaticFile("init", "bar", 0444),
   625  					},
   626  				},
   627  			},
   628  			ma: &MockArchiver{
   629  				Records: make(Records),
   630  				BaseArchive: []cpio.Record{
   631  					cpio.StaticFile("init", "boo", 0555),
   632  				},
   633  			},
   634  			want: Records{
   635  				"init":  cpio.StaticFile("init", "bar", 0444),
   636  				"inito": cpio.StaticFile("inito", "boo", 0555),
   637  			},
   638  		},
   639  		{
   640  			desc: "base archive with init, use existing init",
   641  			opts: &Opts{
   642  				Files: &Files{
   643  					Records: map[string]cpio.Record{},
   644  				},
   645  				UseExistingInit: true,
   646  			},
   647  			ma: &MockArchiver{
   648  				Records: make(Records),
   649  				BaseArchive: []cpio.Record{
   650  					cpio.StaticFile("init", "boo", 0555),
   651  				},
   652  			},
   653  			want: Records{
   654  				"init": cpio.StaticFile("init", "boo", 0555),
   655  			},
   656  		},
   657  		{
   658  			desc: "base archive with init and user init, use existing init",
   659  			opts: &Opts{
   660  				Files: &Files{
   661  					Records: map[string]cpio.Record{
   662  						"init": cpio.StaticFile("init", "huh", 0111),
   663  					},
   664  				},
   665  				UseExistingInit: true,
   666  			},
   667  			ma: &MockArchiver{
   668  				Records: make(Records),
   669  				BaseArchive: []cpio.Record{
   670  					cpio.StaticFile("init", "boo", 0555),
   671  				},
   672  			},
   673  			want: Records{
   674  				"init":  cpio.StaticFile("init", "boo", 0555),
   675  				"inito": cpio.StaticFile("inito", "huh", 0111),
   676  			},
   677  		},
   678  	} {
   679  		t.Run(fmt.Sprintf("Test %02d (%s)", i, tt.desc), func(t *testing.T) {
   680  			tt.opts.BaseArchive = tt.ma
   681  			tt.opts.OutputFile = tt.ma
   682  
   683  			if err := Write(tt.opts); err != tt.err {
   684  				t.Errorf("Write() = %v, want %v", err, tt.err)
   685  			} else if err == nil && !tt.ma.FinishCalled {
   686  				t.Errorf("Finish wasn't called on archive")
   687  			}
   688  
   689  			if !RecordsEqual(tt.ma.Records, tt.want, sameNameModeContent) {
   690  				t.Errorf("Write() = %v, want %v", tt.ma.Records, tt.want)
   691  			}
   692  		})
   693  	}
   694  }