github.com/shaardie/u-root@v4.0.1-0.20190127173353-f24a1c26aa2e+incompatible/pkg/uroot/initramfs/files_test.go (about)

     1  // Copyright 2018 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package initramfs
     6  
     7  import (
     8  	"fmt"
     9  	"io"
    10  	"io/ioutil"
    11  	"os"
    12  	"path/filepath"
    13  	"reflect"
    14  	"strings"
    15  	"testing"
    16  
    17  	"github.com/u-root/u-root/pkg/cpio"
    18  	"github.com/u-root/u-root/pkg/uio"
    19  )
    20  
    21  func TestFilesAddFileNoFollow(t *testing.T) {
    22  	regularFile, err := ioutil.TempFile("", "archive-files-add-file")
    23  	if err != nil {
    24  		t.Error(err)
    25  	}
    26  	defer os.RemoveAll(regularFile.Name())
    27  
    28  	dir, err := ioutil.TempDir("", "archive-add-files")
    29  	if err != nil {
    30  		t.Error(err)
    31  	}
    32  	defer os.RemoveAll(dir)
    33  
    34  	dir2, err := ioutil.TempDir("", "archive-add-files")
    35  	if err != nil {
    36  		t.Error(err)
    37  	}
    38  	defer os.RemoveAll(dir2)
    39  
    40  	os.Create(filepath.Join(dir, "foo2"))
    41  	os.Symlink(filepath.Join(dir, "foo2"), filepath.Join(dir2, "foo3"))
    42  
    43  	for i, tt := range []struct {
    44  		name        string
    45  		af          *Files
    46  		src         string
    47  		dest        string
    48  		result      *Files
    49  		errContains string
    50  	}{
    51  		{
    52  			name: "just add a file",
    53  			af:   NewFiles(),
    54  
    55  			src:  regularFile.Name(),
    56  			dest: "bar/foo",
    57  
    58  			result: &Files{
    59  				Files: map[string]string{
    60  					"bar/foo": regularFile.Name(),
    61  				},
    62  				Records: map[string]cpio.Record{},
    63  			},
    64  		},
    65  		{
    66  			name: "add symlinked file, NOT following",
    67  			af:   NewFiles(),
    68  			src:  filepath.Join(dir2, "foo3"),
    69  			dest: "bar/foo",
    70  			result: &Files{
    71  				Files: map[string]string{
    72  					"bar/foo": filepath.Join(dir2, "foo3"),
    73  				},
    74  				Records: map[string]cpio.Record{},
    75  			},
    76  		},
    77  	} {
    78  		t.Run(fmt.Sprintf("Test %02d: %s", i, tt.name), func(t *testing.T) {
    79  			err := tt.af.AddFileNoFollow(tt.src, tt.dest)
    80  			if err != nil && !strings.Contains(err.Error(), tt.errContains) {
    81  				t.Errorf("Error is %v, does not contain %v", err, tt.errContains)
    82  			}
    83  			if err == nil && len(tt.errContains) > 0 {
    84  				t.Errorf("Got no error, want %v", tt.errContains)
    85  			}
    86  
    87  			if tt.result != nil && !reflect.DeepEqual(tt.af, tt.result) {
    88  				t.Errorf("got %v, want %v", tt.af, tt.result)
    89  			}
    90  		})
    91  	}
    92  }
    93  
    94  func TestFilesAddFile(t *testing.T) {
    95  	regularFile, err := ioutil.TempFile("", "archive-files-add-file")
    96  	if err != nil {
    97  		t.Error(err)
    98  	}
    99  	defer os.RemoveAll(regularFile.Name())
   100  
   101  	dir, err := ioutil.TempDir("", "archive-add-files")
   102  	if err != nil {
   103  		t.Error(err)
   104  	}
   105  	defer os.RemoveAll(dir)
   106  
   107  	dir2, err := ioutil.TempDir("", "archive-add-files")
   108  	if err != nil {
   109  		t.Error(err)
   110  	}
   111  	defer os.RemoveAll(dir2)
   112  
   113  	os.Create(filepath.Join(dir, "foo"))
   114  	os.Create(filepath.Join(dir, "foo2"))
   115  	os.Symlink(filepath.Join(dir, "foo2"), filepath.Join(dir2, "foo3"))
   116  
   117  	for i, tt := range []struct {
   118  		name        string
   119  		af          *Files
   120  		src         string
   121  		dest        string
   122  		result      *Files
   123  		errContains string
   124  	}{
   125  		{
   126  			name: "just add a file",
   127  			af:   NewFiles(),
   128  
   129  			src:  regularFile.Name(),
   130  			dest: "bar/foo",
   131  
   132  			result: &Files{
   133  				Files: map[string]string{
   134  					"bar/foo": regularFile.Name(),
   135  				},
   136  				Records: map[string]cpio.Record{},
   137  			},
   138  		},
   139  		{
   140  			name: "add symlinked file, following",
   141  			af:   NewFiles(),
   142  			src:  filepath.Join(dir2, "foo3"),
   143  			dest: "bar/foo",
   144  			result: &Files{
   145  				Files: map[string]string{
   146  					"bar/foo": filepath.Join(dir, "foo2"),
   147  				},
   148  				Records: map[string]cpio.Record{},
   149  			},
   150  		},
   151  		{
   152  			name: "add file that exists in Files",
   153  			af: &Files{
   154  				Files: map[string]string{
   155  					"bar/foo": "/some/other/place",
   156  				},
   157  			},
   158  			src:  regularFile.Name(),
   159  			dest: "bar/foo",
   160  			result: &Files{
   161  				Files: map[string]string{
   162  					"bar/foo": "/some/other/place",
   163  				},
   164  			},
   165  			errContains: "already exists in archive",
   166  		},
   167  		{
   168  			name: "add a file that exists in Records",
   169  			af: &Files{
   170  				Records: map[string]cpio.Record{
   171  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   172  				},
   173  			},
   174  			src:  regularFile.Name(),
   175  			dest: "bar/foo",
   176  			result: &Files{
   177  				Records: map[string]cpio.Record{
   178  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   179  				},
   180  			},
   181  			errContains: "already exists in archive",
   182  		},
   183  		{
   184  			name: "add a file that already exists in Files, but is the same one",
   185  			af: &Files{
   186  				Files: map[string]string{
   187  					"bar/foo": regularFile.Name(),
   188  				},
   189  			},
   190  			src:  regularFile.Name(),
   191  			dest: "bar/foo",
   192  			result: &Files{
   193  				Files: map[string]string{
   194  					"bar/foo": regularFile.Name(),
   195  				},
   196  			},
   197  		},
   198  		{
   199  			name:        "destination path must not be absolute",
   200  			src:         regularFile.Name(),
   201  			dest:        "/bar/foo",
   202  			errContains: "must not be absolute",
   203  		},
   204  		{
   205  			name: "add a directory",
   206  			af: &Files{
   207  				Files: map[string]string{},
   208  			},
   209  			src:  dir,
   210  			dest: "bar/foo",
   211  			result: &Files{
   212  				Files: map[string]string{
   213  					"bar/foo":      dir,
   214  					"bar/foo/foo":  filepath.Join(dir, "foo"),
   215  					"bar/foo/foo2": filepath.Join(dir, "foo2"),
   216  				},
   217  			},
   218  		},
   219  		{
   220  			name: "add a different directory to the same destination, no overlapping children",
   221  			af: &Files{
   222  				Files: map[string]string{
   223  					"bar/foo":     "/some/place/real",
   224  					"bar/foo/zed": "/some/place/real/zed",
   225  				},
   226  			},
   227  			src:  dir,
   228  			dest: "bar/foo",
   229  			result: &Files{
   230  				Files: map[string]string{
   231  					"bar/foo":      dir,
   232  					"bar/foo/foo":  filepath.Join(dir, "foo"),
   233  					"bar/foo/foo2": filepath.Join(dir, "foo2"),
   234  					"bar/foo/zed":  "/some/place/real/zed",
   235  				},
   236  			},
   237  		},
   238  		{
   239  			name: "add a different directory to the same destination, overlapping children",
   240  			af: &Files{
   241  				Files: map[string]string{
   242  					"bar/foo":      "/some/place/real",
   243  					"bar/foo/foo2": "/some/place/real/zed",
   244  				},
   245  			},
   246  			src:         dir,
   247  			dest:        "bar/foo",
   248  			errContains: "already exists in archive",
   249  		},
   250  	} {
   251  		t.Run(fmt.Sprintf("Test %02d: %s", i, tt.name), func(t *testing.T) {
   252  			err := tt.af.AddFile(tt.src, tt.dest)
   253  			if err != nil && !strings.Contains(err.Error(), tt.errContains) {
   254  				t.Errorf("Error is %v, does not contain %v", err, tt.errContains)
   255  			}
   256  			if err == nil && len(tt.errContains) > 0 {
   257  				t.Errorf("Got no error, want %v", tt.errContains)
   258  			}
   259  
   260  			if tt.result != nil && !reflect.DeepEqual(tt.af, tt.result) {
   261  				t.Errorf("got %v, want %v", tt.af, tt.result)
   262  			}
   263  		})
   264  	}
   265  }
   266  
   267  func TestFilesAddRecord(t *testing.T) {
   268  	for i, tt := range []struct {
   269  		af     *Files
   270  		record cpio.Record
   271  
   272  		result      *Files
   273  		errContains string
   274  	}{
   275  		{
   276  			af:     NewFiles(),
   277  			record: cpio.Symlink("bar/foo", ""),
   278  			result: &Files{
   279  				Files: map[string]string{},
   280  				Records: map[string]cpio.Record{
   281  					"bar/foo": cpio.Symlink("bar/foo", ""),
   282  				},
   283  			},
   284  		},
   285  		{
   286  			af: &Files{
   287  				Files: map[string]string{
   288  					"bar/foo": "/some/other/place",
   289  				},
   290  			},
   291  			record: cpio.Symlink("bar/foo", ""),
   292  			result: &Files{
   293  				Files: map[string]string{
   294  					"bar/foo": "/some/other/place",
   295  				},
   296  			},
   297  			errContains: "already exists in archive",
   298  		},
   299  		{
   300  			af: &Files{
   301  				Records: map[string]cpio.Record{
   302  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   303  				},
   304  			},
   305  			record: cpio.Symlink("bar/foo", ""),
   306  			result: &Files{
   307  				Records: map[string]cpio.Record{
   308  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   309  				},
   310  			},
   311  			errContains: "already exists in archive",
   312  		},
   313  		{
   314  			af: &Files{
   315  				Records: map[string]cpio.Record{
   316  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   317  				},
   318  			},
   319  			record: cpio.Symlink("bar/foo", "/some/other/place"),
   320  			result: &Files{
   321  				Records: map[string]cpio.Record{
   322  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   323  				},
   324  			},
   325  		},
   326  		{
   327  			record:      cpio.Symlink("/bar/foo", ""),
   328  			errContains: "must not be absolute",
   329  		},
   330  	} {
   331  		t.Run(fmt.Sprintf("Test %02d", i), func(t *testing.T) {
   332  			err := tt.af.AddRecord(tt.record)
   333  			if err != nil && !strings.Contains(err.Error(), tt.errContains) {
   334  				t.Errorf("Error is %v, does not contain %v", err, tt.errContains)
   335  			}
   336  			if err == nil && len(tt.errContains) > 0 {
   337  				t.Errorf("Got no error, want %v", tt.errContains)
   338  			}
   339  
   340  			if !reflect.DeepEqual(tt.af, tt.result) {
   341  				t.Errorf("got %v, want %v", tt.af, tt.result)
   342  			}
   343  		})
   344  	}
   345  }
   346  
   347  func TestFilesfillInParent(t *testing.T) {
   348  	for i, tt := range []struct {
   349  		af     *Files
   350  		result *Files
   351  	}{
   352  		{
   353  			af: &Files{
   354  				Records: map[string]cpio.Record{
   355  					"foo/bar": cpio.Directory("foo/bar", 0777),
   356  				},
   357  			},
   358  			result: &Files{
   359  				Records: map[string]cpio.Record{
   360  					"foo/bar": cpio.Directory("foo/bar", 0777),
   361  					"foo":     cpio.Directory("foo", 0755),
   362  				},
   363  			},
   364  		},
   365  		{
   366  			af: &Files{
   367  				Files: map[string]string{
   368  					"baz/baz/baz": "/somewhere",
   369  				},
   370  				Records: map[string]cpio.Record{
   371  					"foo/bar": cpio.Directory("foo/bar", 0777),
   372  				},
   373  			},
   374  			result: &Files{
   375  				Files: map[string]string{
   376  					"baz/baz/baz": "/somewhere",
   377  				},
   378  				Records: map[string]cpio.Record{
   379  					"foo/bar": cpio.Directory("foo/bar", 0777),
   380  					"foo":     cpio.Directory("foo", 0755),
   381  					"baz":     cpio.Directory("baz", 0755),
   382  					"baz/baz": cpio.Directory("baz/baz", 0755),
   383  				},
   384  			},
   385  		},
   386  		{
   387  			af:     &Files{},
   388  			result: &Files{},
   389  		},
   390  	} {
   391  		t.Run(fmt.Sprintf("Test %02d", i), func(t *testing.T) {
   392  			tt.af.fillInParents()
   393  			if !reflect.DeepEqual(tt.af, tt.result) {
   394  				t.Errorf("got %v, want %v", tt.af, tt.result)
   395  			}
   396  		})
   397  	}
   398  }
   399  
   400  type MockArchiver struct {
   401  	Records      Records
   402  	FinishCalled bool
   403  	BaseArchive  []cpio.Record
   404  }
   405  
   406  func (ma *MockArchiver) WriteRecord(r cpio.Record) error {
   407  	if _, ok := ma.Records[r.Name]; ok {
   408  		return fmt.Errorf("file exists")
   409  	}
   410  	ma.Records[r.Name] = r
   411  	return nil
   412  }
   413  
   414  func (ma *MockArchiver) Finish() error {
   415  	ma.FinishCalled = true
   416  	return nil
   417  }
   418  
   419  func (ma *MockArchiver) ReadRecord() (cpio.Record, error) {
   420  	if len(ma.BaseArchive) > 0 {
   421  		next := ma.BaseArchive[0]
   422  		ma.BaseArchive = ma.BaseArchive[1:]
   423  		return next, nil
   424  	}
   425  	return cpio.Record{}, io.EOF
   426  }
   427  
   428  type Records map[string]cpio.Record
   429  
   430  func RecordsEqual(r1, r2 Records, recordEqual func(cpio.Record, cpio.Record) bool) bool {
   431  	for name, s1 := range r1 {
   432  		s2, ok := r2[name]
   433  		if !ok {
   434  			return false
   435  		}
   436  		if !recordEqual(s1, s2) {
   437  			return false
   438  		}
   439  	}
   440  	for name := range r2 {
   441  		if _, ok := r1[name]; !ok {
   442  			return false
   443  		}
   444  	}
   445  	return true
   446  }
   447  
   448  func sameNameModeContent(r1 cpio.Record, r2 cpio.Record) bool {
   449  	if r1.Name != r2.Name || r1.Mode != r2.Mode {
   450  		return false
   451  	}
   452  	return uio.ReaderAtEqual(r1.ReaderAt, r2.ReaderAt)
   453  }
   454  
   455  func TestOptsWrite(t *testing.T) {
   456  	for i, tt := range []struct {
   457  		desc string
   458  		opts *Opts
   459  		ma   *MockArchiver
   460  		want Records
   461  		err  error
   462  	}{
   463  		{
   464  			desc: "no conflicts, just records",
   465  			opts: &Opts{
   466  				Files: &Files{
   467  					Records: map[string]cpio.Record{
   468  						"foo": cpio.Symlink("foo", "elsewhere"),
   469  					},
   470  				},
   471  			},
   472  			ma: &MockArchiver{
   473  				Records: make(Records),
   474  				BaseArchive: []cpio.Record{
   475  					cpio.Directory("etc", 0777),
   476  					cpio.Directory("etc/nginx", 0777),
   477  				},
   478  			},
   479  			want: Records{
   480  				"foo":       cpio.Symlink("foo", "elsewhere"),
   481  				"etc":       cpio.Directory("etc", 0777),
   482  				"etc/nginx": cpio.Directory("etc/nginx", 0777),
   483  			},
   484  		},
   485  		{
   486  			desc: "default already exists",
   487  			opts: &Opts{
   488  				Files: &Files{
   489  					Records: map[string]cpio.Record{
   490  						"etc": cpio.Symlink("etc", "whatever"),
   491  					},
   492  				},
   493  			},
   494  			ma: &MockArchiver{
   495  				Records: make(Records),
   496  				BaseArchive: []cpio.Record{
   497  					cpio.Directory("etc", 0777),
   498  				},
   499  			},
   500  			want: Records{
   501  				"etc": cpio.Symlink("etc", "whatever"),
   502  			},
   503  		},
   504  		{
   505  			desc: "no conflicts, missing parent automatically created",
   506  			opts: &Opts{
   507  				Files: &Files{
   508  					Records: map[string]cpio.Record{
   509  						"foo/bar/baz": cpio.Symlink("foo/bar/baz", "elsewhere"),
   510  					},
   511  				},
   512  			},
   513  			ma: &MockArchiver{
   514  				Records: make(Records),
   515  			},
   516  			want: Records{
   517  				"foo":         cpio.Directory("foo", 0755),
   518  				"foo/bar":     cpio.Directory("foo/bar", 0755),
   519  				"foo/bar/baz": cpio.Symlink("foo/bar/baz", "elsewhere"),
   520  			},
   521  		},
   522  		{
   523  			desc: "parent only automatically created if not already exists",
   524  			opts: &Opts{
   525  				Files: &Files{
   526  					Records: map[string]cpio.Record{
   527  						"foo/bar":     cpio.Directory("foo/bar", 0444),
   528  						"foo/bar/baz": cpio.Symlink("foo/bar/baz", "elsewhere"),
   529  					},
   530  				},
   531  			},
   532  			ma: &MockArchiver{
   533  				Records: make(Records),
   534  			},
   535  			want: Records{
   536  				"foo":         cpio.Directory("foo", 0755),
   537  				"foo/bar":     cpio.Directory("foo/bar", 0444),
   538  				"foo/bar/baz": cpio.Symlink("foo/bar/baz", "elsewhere"),
   539  			},
   540  		},
   541  		{
   542  			desc: "base archive",
   543  			opts: &Opts{
   544  				Files: &Files{
   545  					Records: map[string]cpio.Record{
   546  						"foo/bar": cpio.Symlink("foo/bar", "elsewhere"),
   547  						"exists":  cpio.Directory("exists", 0777),
   548  					},
   549  				},
   550  			},
   551  			ma: &MockArchiver{
   552  				Records: make(Records),
   553  				BaseArchive: []cpio.Record{
   554  					cpio.Directory("etc", 0755),
   555  					cpio.Directory("foo", 0444),
   556  					cpio.Directory("exists", 0),
   557  				},
   558  			},
   559  			want: Records{
   560  				"etc":     cpio.Directory("etc", 0755),
   561  				"exists":  cpio.Directory("exists", 0777),
   562  				"foo":     cpio.Directory("foo", 0444),
   563  				"foo/bar": cpio.Symlink("foo/bar", "elsewhere"),
   564  			},
   565  		},
   566  		{
   567  			desc: "base archive with init, no user init",
   568  			opts: &Opts{
   569  				Files: &Files{
   570  					Records: map[string]cpio.Record{},
   571  				},
   572  			},
   573  			ma: &MockArchiver{
   574  				Records: make(Records),
   575  				BaseArchive: []cpio.Record{
   576  					cpio.StaticFile("init", "boo", 0555),
   577  				},
   578  			},
   579  			want: Records{
   580  				"init": cpio.StaticFile("init", "boo", 0555),
   581  			},
   582  		},
   583  		{
   584  			desc: "base archive with init and user init",
   585  			opts: &Opts{
   586  				Files: &Files{
   587  					Records: map[string]cpio.Record{
   588  						"init": cpio.StaticFile("init", "bar", 0444),
   589  					},
   590  				},
   591  			},
   592  			ma: &MockArchiver{
   593  				Records: make(Records),
   594  				BaseArchive: []cpio.Record{
   595  					cpio.StaticFile("init", "boo", 0555),
   596  				},
   597  			},
   598  			want: Records{
   599  				"init":  cpio.StaticFile("init", "bar", 0444),
   600  				"inito": cpio.StaticFile("inito", "boo", 0555),
   601  			},
   602  		},
   603  		{
   604  			desc: "base archive with init, use existing init",
   605  			opts: &Opts{
   606  				Files: &Files{
   607  					Records: map[string]cpio.Record{},
   608  				},
   609  				UseExistingInit: true,
   610  			},
   611  			ma: &MockArchiver{
   612  				Records: make(Records),
   613  				BaseArchive: []cpio.Record{
   614  					cpio.StaticFile("init", "boo", 0555),
   615  				},
   616  			},
   617  			want: Records{
   618  				"init": cpio.StaticFile("init", "boo", 0555),
   619  			},
   620  		},
   621  		{
   622  			desc: "base archive with init and user init, use existing init",
   623  			opts: &Opts{
   624  				Files: &Files{
   625  					Records: map[string]cpio.Record{
   626  						"init": cpio.StaticFile("init", "huh", 0111),
   627  					},
   628  				},
   629  				UseExistingInit: true,
   630  			},
   631  			ma: &MockArchiver{
   632  				Records: make(Records),
   633  				BaseArchive: []cpio.Record{
   634  					cpio.StaticFile("init", "boo", 0555),
   635  				},
   636  			},
   637  			want: Records{
   638  				"init":  cpio.StaticFile("init", "boo", 0555),
   639  				"inito": cpio.StaticFile("inito", "huh", 0111),
   640  			},
   641  		},
   642  	} {
   643  		t.Run(fmt.Sprintf("Test %02d (%s)", i, tt.desc), func(t *testing.T) {
   644  			tt.opts.BaseArchive = tt.ma
   645  			tt.opts.OutputFile = tt.ma
   646  
   647  			if err := Write(tt.opts); err != tt.err {
   648  				t.Errorf("Write() = %v, want %v", err, tt.err)
   649  			} else if err == nil && !tt.ma.FinishCalled {
   650  				t.Errorf("Finish wasn't called on archive")
   651  			}
   652  
   653  			if !RecordsEqual(tt.ma.Records, tt.want, sameNameModeContent) {
   654  				t.Errorf("Write() = %v, want %v", tt.ma.Records, tt.want)
   655  			}
   656  		})
   657  	}
   658  }