github.com/hugelgupf/u-root@v0.0.0-20191023214958-4807c632154c/pkg/uroot/initramfs/files_test.go (about)

     1  // Copyright 2018 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package initramfs
     6  
     7  import (
     8  	"fmt"
     9  	"io"
    10  	"io/ioutil"
    11  	"os"
    12  	"path/filepath"
    13  	"reflect"
    14  	"strings"
    15  	"testing"
    16  
    17  	"github.com/u-root/u-root/pkg/cpio"
    18  	"github.com/u-root/u-root/pkg/uio"
    19  )
    20  
    21  func TestFilesAddFileNoFollow(t *testing.T) {
    22  	regularFile, err := ioutil.TempFile("", "archive-files-add-file")
    23  	if err != nil {
    24  		t.Error(err)
    25  	}
    26  	defer os.RemoveAll(regularFile.Name())
    27  
    28  	dir, err := ioutil.TempDir("", "archive-add-files")
    29  	if err != nil {
    30  		t.Error(err)
    31  	}
    32  	defer os.RemoveAll(dir)
    33  
    34  	dir2, err := ioutil.TempDir("", "archive-add-files")
    35  	if err != nil {
    36  		t.Error(err)
    37  	}
    38  	defer os.RemoveAll(dir2)
    39  
    40  	os.Create(filepath.Join(dir, "foo2"))
    41  	os.Symlink(filepath.Join(dir, "foo2"), filepath.Join(dir2, "foo3"))
    42  
    43  	for i, tt := range []struct {
    44  		name        string
    45  		af          *Files
    46  		src         string
    47  		dest        string
    48  		result      *Files
    49  		errContains string
    50  	}{
    51  		{
    52  			name: "just add a file",
    53  			af:   NewFiles(),
    54  
    55  			src:  regularFile.Name(),
    56  			dest: "bar/foo",
    57  
    58  			result: &Files{
    59  				Files: map[string]string{
    60  					"bar/foo": regularFile.Name(),
    61  				},
    62  				Records: map[string]cpio.Record{},
    63  			},
    64  		},
    65  		{
    66  			name: "add symlinked file, NOT following",
    67  			af:   NewFiles(),
    68  			src:  filepath.Join(dir2, "foo3"),
    69  			dest: "bar/foo",
    70  			result: &Files{
    71  				Files: map[string]string{
    72  					"bar/foo": filepath.Join(dir2, "foo3"),
    73  				},
    74  				Records: map[string]cpio.Record{},
    75  			},
    76  		},
    77  	} {
    78  		t.Run(fmt.Sprintf("Test %02d: %s", i, tt.name), func(t *testing.T) {
    79  			err := tt.af.AddFileNoFollow(tt.src, tt.dest)
    80  			if err != nil && !strings.Contains(err.Error(), tt.errContains) {
    81  				t.Errorf("Error is %v, does not contain %v", err, tt.errContains)
    82  			}
    83  			if err == nil && len(tt.errContains) > 0 {
    84  				t.Errorf("Got no error, want %v", tt.errContains)
    85  			}
    86  
    87  			if tt.result != nil && !reflect.DeepEqual(tt.af, tt.result) {
    88  				t.Errorf("got %v, want %v", tt.af, tt.result)
    89  			}
    90  		})
    91  	}
    92  }
    93  
    94  func TestFilesAddFile(t *testing.T) {
    95  	regularFile, err := ioutil.TempFile("", "archive-files-add-file")
    96  	if err != nil {
    97  		t.Error(err)
    98  	}
    99  	defer os.RemoveAll(regularFile.Name())
   100  
   101  	dir, err := ioutil.TempDir("", "archive-add-files")
   102  	if err != nil {
   103  		t.Error(err)
   104  	}
   105  	defer os.RemoveAll(dir)
   106  
   107  	dir2, err := ioutil.TempDir("", "archive-add-files")
   108  	if err != nil {
   109  		t.Error(err)
   110  	}
   111  	defer os.RemoveAll(dir2)
   112  
   113  	os.Create(filepath.Join(dir, "foo"))
   114  	os.Create(filepath.Join(dir, "foo2"))
   115  	os.Symlink(filepath.Join(dir, "foo2"), filepath.Join(dir2, "foo3"))
   116  
   117  	for i, tt := range []struct {
   118  		name        string
   119  		af          *Files
   120  		src         string
   121  		dest        string
   122  		result      *Files
   123  		errContains string
   124  	}{
   125  		{
   126  			name: "just add a file",
   127  			af:   NewFiles(),
   128  
   129  			src:  regularFile.Name(),
   130  			dest: "bar/foo",
   131  
   132  			result: &Files{
   133  				Files: map[string]string{
   134  					"bar/foo": regularFile.Name(),
   135  				},
   136  				Records: map[string]cpio.Record{},
   137  			},
   138  		},
   139  		{
   140  			name: "add symlinked file, following",
   141  			af:   NewFiles(),
   142  			src:  filepath.Join(dir2, "foo3"),
   143  			dest: "bar/foo",
   144  			result: &Files{
   145  				Files: map[string]string{
   146  					"bar/foo": filepath.Join(dir, "foo2"),
   147  				},
   148  				Records: map[string]cpio.Record{},
   149  			},
   150  		},
   151  		{
   152  			name: "add file that exists in Files",
   153  			af: &Files{
   154  				Files: map[string]string{
   155  					"bar/foo": "/some/other/place",
   156  				},
   157  			},
   158  			src:  regularFile.Name(),
   159  			dest: "bar/foo",
   160  			result: &Files{
   161  				Files: map[string]string{
   162  					"bar/foo": "/some/other/place",
   163  				},
   164  			},
   165  			errContains: "already exists in archive",
   166  		},
   167  		{
   168  			name: "add a file that exists in Records",
   169  			af: &Files{
   170  				Records: map[string]cpio.Record{
   171  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   172  				},
   173  			},
   174  			src:  regularFile.Name(),
   175  			dest: "bar/foo",
   176  			result: &Files{
   177  				Records: map[string]cpio.Record{
   178  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   179  				},
   180  			},
   181  			errContains: "already exists in archive",
   182  		},
   183  		{
   184  			name: "add a file that already exists in Files, but is the same one",
   185  			af: &Files{
   186  				Files: map[string]string{
   187  					"bar/foo": regularFile.Name(),
   188  				},
   189  			},
   190  			src:  regularFile.Name(),
   191  			dest: "bar/foo",
   192  			result: &Files{
   193  				Files: map[string]string{
   194  					"bar/foo": regularFile.Name(),
   195  				},
   196  			},
   197  		},
   198  		{
   199  			name: "absolute destination paths are made relative",
   200  			af: &Files{
   201  				Files: map[string]string{},
   202  			},
   203  			src:  dir,
   204  			dest: "/bar/foo",
   205  			result: &Files{
   206  				Files: map[string]string{
   207  					"bar/foo":      dir,
   208  					"bar/foo/foo":  filepath.Join(dir, "foo"),
   209  					"bar/foo/foo2": filepath.Join(dir, "foo2"),
   210  				},
   211  			},
   212  		},
   213  		{
   214  			name: "add a directory",
   215  			af: &Files{
   216  				Files: map[string]string{},
   217  			},
   218  			src:  dir,
   219  			dest: "bar/foo",
   220  			result: &Files{
   221  				Files: map[string]string{
   222  					"bar/foo":      dir,
   223  					"bar/foo/foo":  filepath.Join(dir, "foo"),
   224  					"bar/foo/foo2": filepath.Join(dir, "foo2"),
   225  				},
   226  			},
   227  		},
   228  		{
   229  			name: "add a different directory to the same destination, no overlapping children",
   230  			af: &Files{
   231  				Files: map[string]string{
   232  					"bar/foo":     "/some/place/real",
   233  					"bar/foo/zed": "/some/place/real/zed",
   234  				},
   235  			},
   236  			src:  dir,
   237  			dest: "bar/foo",
   238  			result: &Files{
   239  				Files: map[string]string{
   240  					"bar/foo":      dir,
   241  					"bar/foo/foo":  filepath.Join(dir, "foo"),
   242  					"bar/foo/foo2": filepath.Join(dir, "foo2"),
   243  					"bar/foo/zed":  "/some/place/real/zed",
   244  				},
   245  			},
   246  		},
   247  		{
   248  			name: "add a different directory to the same destination, overlapping children",
   249  			af: &Files{
   250  				Files: map[string]string{
   251  					"bar/foo":      "/some/place/real",
   252  					"bar/foo/foo2": "/some/place/real/zed",
   253  				},
   254  			},
   255  			src:         dir,
   256  			dest:        "bar/foo",
   257  			errContains: "already exists in archive",
   258  		},
   259  	} {
   260  		t.Run(fmt.Sprintf("Test %02d: %s", i, tt.name), func(t *testing.T) {
   261  			err := tt.af.AddFile(tt.src, tt.dest)
   262  			if err != nil && !strings.Contains(err.Error(), tt.errContains) {
   263  				t.Errorf("Error is %v, does not contain %v", err, tt.errContains)
   264  			}
   265  			if err == nil && len(tt.errContains) > 0 {
   266  				t.Errorf("Got no error, want %v", tt.errContains)
   267  			}
   268  
   269  			if tt.result != nil && !reflect.DeepEqual(tt.af, tt.result) {
   270  				t.Errorf("got %v, want %v", tt.af, tt.result)
   271  			}
   272  		})
   273  	}
   274  }
   275  
   276  func TestFilesAddRecord(t *testing.T) {
   277  	for i, tt := range []struct {
   278  		af     *Files
   279  		record cpio.Record
   280  
   281  		result      *Files
   282  		errContains string
   283  	}{
   284  		{
   285  			af:     NewFiles(),
   286  			record: cpio.Symlink("bar/foo", ""),
   287  			result: &Files{
   288  				Files: map[string]string{},
   289  				Records: map[string]cpio.Record{
   290  					"bar/foo": cpio.Symlink("bar/foo", ""),
   291  				},
   292  			},
   293  		},
   294  		{
   295  			af: &Files{
   296  				Files: map[string]string{
   297  					"bar/foo": "/some/other/place",
   298  				},
   299  			},
   300  			record: cpio.Symlink("bar/foo", ""),
   301  			result: &Files{
   302  				Files: map[string]string{
   303  					"bar/foo": "/some/other/place",
   304  				},
   305  			},
   306  			errContains: "already exists in archive",
   307  		},
   308  		{
   309  			af: &Files{
   310  				Records: map[string]cpio.Record{
   311  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   312  				},
   313  			},
   314  			record: cpio.Symlink("bar/foo", ""),
   315  			result: &Files{
   316  				Records: map[string]cpio.Record{
   317  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   318  				},
   319  			},
   320  			errContains: "already exists in archive",
   321  		},
   322  		{
   323  			af: &Files{
   324  				Records: map[string]cpio.Record{
   325  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   326  				},
   327  			},
   328  			record: cpio.Symlink("bar/foo", "/some/other/place"),
   329  			result: &Files{
   330  				Records: map[string]cpio.Record{
   331  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   332  				},
   333  			},
   334  		},
   335  		{
   336  			record:      cpio.Symlink("/bar/foo", ""),
   337  			errContains: "must not be absolute",
   338  		},
   339  	} {
   340  		t.Run(fmt.Sprintf("Test %02d", i), func(t *testing.T) {
   341  			err := tt.af.AddRecord(tt.record)
   342  			if err != nil && !strings.Contains(err.Error(), tt.errContains) {
   343  				t.Errorf("Error is %v, does not contain %v", err, tt.errContains)
   344  			}
   345  			if err == nil && len(tt.errContains) > 0 {
   346  				t.Errorf("Got no error, want %v", tt.errContains)
   347  			}
   348  
   349  			if !reflect.DeepEqual(tt.af, tt.result) {
   350  				t.Errorf("got %v, want %v", tt.af, tt.result)
   351  			}
   352  		})
   353  	}
   354  }
   355  
   356  func TestFilesfillInParent(t *testing.T) {
   357  	for i, tt := range []struct {
   358  		af     *Files
   359  		result *Files
   360  	}{
   361  		{
   362  			af: &Files{
   363  				Records: map[string]cpio.Record{
   364  					"foo/bar": cpio.Directory("foo/bar", 0777),
   365  				},
   366  			},
   367  			result: &Files{
   368  				Records: map[string]cpio.Record{
   369  					"foo/bar": cpio.Directory("foo/bar", 0777),
   370  					"foo":     cpio.Directory("foo", 0755),
   371  				},
   372  			},
   373  		},
   374  		{
   375  			af: &Files{
   376  				Files: map[string]string{
   377  					"baz/baz/baz": "/somewhere",
   378  				},
   379  				Records: map[string]cpio.Record{
   380  					"foo/bar": cpio.Directory("foo/bar", 0777),
   381  				},
   382  			},
   383  			result: &Files{
   384  				Files: map[string]string{
   385  					"baz/baz/baz": "/somewhere",
   386  				},
   387  				Records: map[string]cpio.Record{
   388  					"foo/bar": cpio.Directory("foo/bar", 0777),
   389  					"foo":     cpio.Directory("foo", 0755),
   390  					"baz":     cpio.Directory("baz", 0755),
   391  					"baz/baz": cpio.Directory("baz/baz", 0755),
   392  				},
   393  			},
   394  		},
   395  		{
   396  			af:     &Files{},
   397  			result: &Files{},
   398  		},
   399  	} {
   400  		t.Run(fmt.Sprintf("Test %02d", i), func(t *testing.T) {
   401  			tt.af.fillInParents()
   402  			if !reflect.DeepEqual(tt.af, tt.result) {
   403  				t.Errorf("got %v, want %v", tt.af, tt.result)
   404  			}
   405  		})
   406  	}
   407  }
   408  
   409  type MockArchiver struct {
   410  	Records      Records
   411  	FinishCalled bool
   412  	BaseArchive  []cpio.Record
   413  }
   414  
   415  func (ma *MockArchiver) WriteRecord(r cpio.Record) error {
   416  	if _, ok := ma.Records[r.Name]; ok {
   417  		return fmt.Errorf("file exists")
   418  	}
   419  	ma.Records[r.Name] = r
   420  	return nil
   421  }
   422  
   423  func (ma *MockArchiver) Finish() error {
   424  	ma.FinishCalled = true
   425  	return nil
   426  }
   427  
   428  func (ma *MockArchiver) ReadRecord() (cpio.Record, error) {
   429  	if len(ma.BaseArchive) > 0 {
   430  		next := ma.BaseArchive[0]
   431  		ma.BaseArchive = ma.BaseArchive[1:]
   432  		return next, nil
   433  	}
   434  	return cpio.Record{}, io.EOF
   435  }
   436  
   437  type Records map[string]cpio.Record
   438  
   439  func RecordsEqual(r1, r2 Records, recordEqual func(cpio.Record, cpio.Record) bool) bool {
   440  	for name, s1 := range r1 {
   441  		s2, ok := r2[name]
   442  		if !ok {
   443  			return false
   444  		}
   445  		if !recordEqual(s1, s2) {
   446  			return false
   447  		}
   448  	}
   449  	for name := range r2 {
   450  		if _, ok := r1[name]; !ok {
   451  			return false
   452  		}
   453  	}
   454  	return true
   455  }
   456  
   457  func sameNameModeContent(r1 cpio.Record, r2 cpio.Record) bool {
   458  	if r1.Name != r2.Name || r1.Mode != r2.Mode {
   459  		return false
   460  	}
   461  	return uio.ReaderAtEqual(r1.ReaderAt, r2.ReaderAt)
   462  }
   463  
   464  func TestOptsWrite(t *testing.T) {
   465  	for i, tt := range []struct {
   466  		desc string
   467  		opts *Opts
   468  		ma   *MockArchiver
   469  		want Records
   470  		err  error
   471  	}{
   472  		{
   473  			desc: "no conflicts, just records",
   474  			opts: &Opts{
   475  				Files: &Files{
   476  					Records: map[string]cpio.Record{
   477  						"foo": cpio.Symlink("foo", "elsewhere"),
   478  					},
   479  				},
   480  			},
   481  			ma: &MockArchiver{
   482  				Records: make(Records),
   483  				BaseArchive: []cpio.Record{
   484  					cpio.Directory("etc", 0777),
   485  					cpio.Directory("etc/nginx", 0777),
   486  				},
   487  			},
   488  			want: Records{
   489  				"foo":       cpio.Symlink("foo", "elsewhere"),
   490  				"etc":       cpio.Directory("etc", 0777),
   491  				"etc/nginx": cpio.Directory("etc/nginx", 0777),
   492  			},
   493  		},
   494  		{
   495  			desc: "default already exists",
   496  			opts: &Opts{
   497  				Files: &Files{
   498  					Records: map[string]cpio.Record{
   499  						"etc": cpio.Symlink("etc", "whatever"),
   500  					},
   501  				},
   502  			},
   503  			ma: &MockArchiver{
   504  				Records: make(Records),
   505  				BaseArchive: []cpio.Record{
   506  					cpio.Directory("etc", 0777),
   507  				},
   508  			},
   509  			want: Records{
   510  				"etc": cpio.Symlink("etc", "whatever"),
   511  			},
   512  		},
   513  		{
   514  			desc: "no conflicts, missing parent automatically created",
   515  			opts: &Opts{
   516  				Files: &Files{
   517  					Records: map[string]cpio.Record{
   518  						"foo/bar/baz": cpio.Symlink("foo/bar/baz", "elsewhere"),
   519  					},
   520  				},
   521  			},
   522  			ma: &MockArchiver{
   523  				Records: make(Records),
   524  			},
   525  			want: Records{
   526  				"foo":         cpio.Directory("foo", 0755),
   527  				"foo/bar":     cpio.Directory("foo/bar", 0755),
   528  				"foo/bar/baz": cpio.Symlink("foo/bar/baz", "elsewhere"),
   529  			},
   530  		},
   531  		{
   532  			desc: "parent only automatically created if not already exists",
   533  			opts: &Opts{
   534  				Files: &Files{
   535  					Records: map[string]cpio.Record{
   536  						"foo/bar":     cpio.Directory("foo/bar", 0444),
   537  						"foo/bar/baz": cpio.Symlink("foo/bar/baz", "elsewhere"),
   538  					},
   539  				},
   540  			},
   541  			ma: &MockArchiver{
   542  				Records: make(Records),
   543  			},
   544  			want: Records{
   545  				"foo":         cpio.Directory("foo", 0755),
   546  				"foo/bar":     cpio.Directory("foo/bar", 0444),
   547  				"foo/bar/baz": cpio.Symlink("foo/bar/baz", "elsewhere"),
   548  			},
   549  		},
   550  		{
   551  			desc: "base archive",
   552  			opts: &Opts{
   553  				Files: &Files{
   554  					Records: map[string]cpio.Record{
   555  						"foo/bar": cpio.Symlink("foo/bar", "elsewhere"),
   556  						"exists":  cpio.Directory("exists", 0777),
   557  					},
   558  				},
   559  			},
   560  			ma: &MockArchiver{
   561  				Records: make(Records),
   562  				BaseArchive: []cpio.Record{
   563  					cpio.Directory("etc", 0755),
   564  					cpio.Directory("foo", 0444),
   565  					cpio.Directory("exists", 0),
   566  				},
   567  			},
   568  			want: Records{
   569  				"etc":     cpio.Directory("etc", 0755),
   570  				"exists":  cpio.Directory("exists", 0777),
   571  				"foo":     cpio.Directory("foo", 0444),
   572  				"foo/bar": cpio.Symlink("foo/bar", "elsewhere"),
   573  			},
   574  		},
   575  		{
   576  			desc: "base archive with init, no user init",
   577  			opts: &Opts{
   578  				Files: &Files{
   579  					Records: map[string]cpio.Record{},
   580  				},
   581  			},
   582  			ma: &MockArchiver{
   583  				Records: make(Records),
   584  				BaseArchive: []cpio.Record{
   585  					cpio.StaticFile("init", "boo", 0555),
   586  				},
   587  			},
   588  			want: Records{
   589  				"init": cpio.StaticFile("init", "boo", 0555),
   590  			},
   591  		},
   592  		{
   593  			desc: "base archive with init and user init",
   594  			opts: &Opts{
   595  				Files: &Files{
   596  					Records: map[string]cpio.Record{
   597  						"init": cpio.StaticFile("init", "bar", 0444),
   598  					},
   599  				},
   600  			},
   601  			ma: &MockArchiver{
   602  				Records: make(Records),
   603  				BaseArchive: []cpio.Record{
   604  					cpio.StaticFile("init", "boo", 0555),
   605  				},
   606  			},
   607  			want: Records{
   608  				"init":  cpio.StaticFile("init", "bar", 0444),
   609  				"inito": cpio.StaticFile("inito", "boo", 0555),
   610  			},
   611  		},
   612  		{
   613  			desc: "base archive with init, use existing init",
   614  			opts: &Opts{
   615  				Files: &Files{
   616  					Records: map[string]cpio.Record{},
   617  				},
   618  				UseExistingInit: true,
   619  			},
   620  			ma: &MockArchiver{
   621  				Records: make(Records),
   622  				BaseArchive: []cpio.Record{
   623  					cpio.StaticFile("init", "boo", 0555),
   624  				},
   625  			},
   626  			want: Records{
   627  				"init": cpio.StaticFile("init", "boo", 0555),
   628  			},
   629  		},
   630  		{
   631  			desc: "base archive with init and user init, use existing init",
   632  			opts: &Opts{
   633  				Files: &Files{
   634  					Records: map[string]cpio.Record{
   635  						"init": cpio.StaticFile("init", "huh", 0111),
   636  					},
   637  				},
   638  				UseExistingInit: true,
   639  			},
   640  			ma: &MockArchiver{
   641  				Records: make(Records),
   642  				BaseArchive: []cpio.Record{
   643  					cpio.StaticFile("init", "boo", 0555),
   644  				},
   645  			},
   646  			want: Records{
   647  				"init":  cpio.StaticFile("init", "boo", 0555),
   648  				"inito": cpio.StaticFile("inito", "huh", 0111),
   649  			},
   650  		},
   651  	} {
   652  		t.Run(fmt.Sprintf("Test %02d (%s)", i, tt.desc), func(t *testing.T) {
   653  			tt.opts.BaseArchive = tt.ma
   654  			tt.opts.OutputFile = tt.ma
   655  
   656  			if err := Write(tt.opts); err != tt.err {
   657  				t.Errorf("Write() = %v, want %v", err, tt.err)
   658  			} else if err == nil && !tt.ma.FinishCalled {
   659  				t.Errorf("Finish wasn't called on archive")
   660  			}
   661  
   662  			if !RecordsEqual(tt.ma.Records, tt.want, sameNameModeContent) {
   663  				t.Errorf("Write() = %v, want %v", tt.ma.Records, tt.want)
   664  			}
   665  		})
   666  	}
   667  }