github.com/oweisse/u-root@v0.0.0-20181109060735-d005ad25fef1/pkg/uroot/initramfs/files_test.go (about)

     1  // Copyright 2018 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package initramfs
     6  
     7  import (
     8  	"fmt"
     9  	"io"
    10  	"io/ioutil"
    11  	"os"
    12  	"path/filepath"
    13  	"reflect"
    14  	"strings"
    15  	"testing"
    16  
    17  	"github.com/u-root/u-root/pkg/cpio"
    18  )
    19  
    20  func TestFilesAddFileNoFollow(t *testing.T) {
    21  	regularFile, err := ioutil.TempFile("", "archive-files-add-file")
    22  	if err != nil {
    23  		t.Error(err)
    24  	}
    25  	defer os.RemoveAll(regularFile.Name())
    26  
    27  	dir, err := ioutil.TempDir("", "archive-add-files")
    28  	if err != nil {
    29  		t.Error(err)
    30  	}
    31  	defer os.RemoveAll(dir)
    32  
    33  	dir2, err := ioutil.TempDir("", "archive-add-files")
    34  	if err != nil {
    35  		t.Error(err)
    36  	}
    37  	defer os.RemoveAll(dir2)
    38  
    39  	os.Create(filepath.Join(dir, "foo2"))
    40  	os.Symlink(filepath.Join(dir, "foo2"), filepath.Join(dir2, "foo3"))
    41  
    42  	for i, tt := range []struct {
    43  		name        string
    44  		af          *Files
    45  		src         string
    46  		dest        string
    47  		result      *Files
    48  		errContains string
    49  	}{
    50  		{
    51  			name: "just add a file",
    52  			af:   NewFiles(),
    53  
    54  			src:  regularFile.Name(),
    55  			dest: "bar/foo",
    56  
    57  			result: &Files{
    58  				Files: map[string]string{
    59  					"bar/foo": regularFile.Name(),
    60  				},
    61  				Records: map[string]cpio.Record{},
    62  			},
    63  		},
    64  		{
    65  			name: "add symlinked file, NOT following",
    66  			af:   NewFiles(),
    67  			src:  filepath.Join(dir2, "foo3"),
    68  			dest: "bar/foo",
    69  			result: &Files{
    70  				Files: map[string]string{
    71  					"bar/foo": filepath.Join(dir2, "foo3"),
    72  				},
    73  				Records: map[string]cpio.Record{},
    74  			},
    75  		},
    76  	} {
    77  		t.Run(fmt.Sprintf("Test %02d: %s", i, tt.name), func(t *testing.T) {
    78  			err := tt.af.AddFileNoFollow(tt.src, tt.dest)
    79  			if err != nil && !strings.Contains(err.Error(), tt.errContains) {
    80  				t.Errorf("Error is %v, does not contain %v", err, tt.errContains)
    81  			}
    82  			if err == nil && len(tt.errContains) > 0 {
    83  				t.Errorf("Got no error, want %v", tt.errContains)
    84  			}
    85  
    86  			if tt.result != nil && !reflect.DeepEqual(tt.af, tt.result) {
    87  				t.Errorf("got %v, want %v", tt.af, tt.result)
    88  			}
    89  		})
    90  	}
    91  }
    92  
    93  func TestFilesAddFile(t *testing.T) {
    94  	regularFile, err := ioutil.TempFile("", "archive-files-add-file")
    95  	if err != nil {
    96  		t.Error(err)
    97  	}
    98  	defer os.RemoveAll(regularFile.Name())
    99  
   100  	dir, err := ioutil.TempDir("", "archive-add-files")
   101  	if err != nil {
   102  		t.Error(err)
   103  	}
   104  	defer os.RemoveAll(dir)
   105  
   106  	dir2, err := ioutil.TempDir("", "archive-add-files")
   107  	if err != nil {
   108  		t.Error(err)
   109  	}
   110  	defer os.RemoveAll(dir2)
   111  
   112  	os.Create(filepath.Join(dir, "foo"))
   113  	os.Create(filepath.Join(dir, "foo2"))
   114  	os.Symlink(filepath.Join(dir, "foo2"), filepath.Join(dir2, "foo3"))
   115  
   116  	for i, tt := range []struct {
   117  		name        string
   118  		af          *Files
   119  		src         string
   120  		dest        string
   121  		result      *Files
   122  		errContains string
   123  	}{
   124  		{
   125  			name: "just add a file",
   126  			af:   NewFiles(),
   127  
   128  			src:  regularFile.Name(),
   129  			dest: "bar/foo",
   130  
   131  			result: &Files{
   132  				Files: map[string]string{
   133  					"bar/foo": regularFile.Name(),
   134  				},
   135  				Records: map[string]cpio.Record{},
   136  			},
   137  		},
   138  		{
   139  			name: "add symlinked file, following",
   140  			af:   NewFiles(),
   141  			src:  filepath.Join(dir2, "foo3"),
   142  			dest: "bar/foo",
   143  			result: &Files{
   144  				Files: map[string]string{
   145  					"bar/foo": filepath.Join(dir, "foo2"),
   146  				},
   147  				Records: map[string]cpio.Record{},
   148  			},
   149  		},
   150  		{
   151  			name: "add file that exists in Files",
   152  			af: &Files{
   153  				Files: map[string]string{
   154  					"bar/foo": "/some/other/place",
   155  				},
   156  			},
   157  			src:  regularFile.Name(),
   158  			dest: "bar/foo",
   159  			result: &Files{
   160  				Files: map[string]string{
   161  					"bar/foo": "/some/other/place",
   162  				},
   163  			},
   164  			errContains: "already exists in archive",
   165  		},
   166  		{
   167  			name: "add a file that exists in Records",
   168  			af: &Files{
   169  				Records: map[string]cpio.Record{
   170  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   171  				},
   172  			},
   173  			src:  regularFile.Name(),
   174  			dest: "bar/foo",
   175  			result: &Files{
   176  				Records: map[string]cpio.Record{
   177  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   178  				},
   179  			},
   180  			errContains: "already exists in archive",
   181  		},
   182  		{
   183  			name: "add a file that already exists in Files, but is the same one",
   184  			af: &Files{
   185  				Files: map[string]string{
   186  					"bar/foo": regularFile.Name(),
   187  				},
   188  			},
   189  			src:  regularFile.Name(),
   190  			dest: "bar/foo",
   191  			result: &Files{
   192  				Files: map[string]string{
   193  					"bar/foo": regularFile.Name(),
   194  				},
   195  			},
   196  		},
   197  		{
   198  			name:        "destination path must not be absolute",
   199  			src:         regularFile.Name(),
   200  			dest:        "/bar/foo",
   201  			errContains: "must not be absolute",
   202  		},
   203  		{
   204  			name: "add a directory",
   205  			af: &Files{
   206  				Files: map[string]string{},
   207  			},
   208  			src:  dir,
   209  			dest: "bar/foo",
   210  			result: &Files{
   211  				Files: map[string]string{
   212  					"bar/foo":      dir,
   213  					"bar/foo/foo":  filepath.Join(dir, "foo"),
   214  					"bar/foo/foo2": filepath.Join(dir, "foo2"),
   215  				},
   216  			},
   217  		},
   218  		{
   219  			name: "add a different directory to the same destination, no overlapping children",
   220  			af: &Files{
   221  				Files: map[string]string{
   222  					"bar/foo":     "/some/place/real",
   223  					"bar/foo/zed": "/some/place/real/zed",
   224  				},
   225  			},
   226  			src:  dir,
   227  			dest: "bar/foo",
   228  			result: &Files{
   229  				Files: map[string]string{
   230  					"bar/foo":      dir,
   231  					"bar/foo/foo":  filepath.Join(dir, "foo"),
   232  					"bar/foo/foo2": filepath.Join(dir, "foo2"),
   233  					"bar/foo/zed":  "/some/place/real/zed",
   234  				},
   235  			},
   236  		},
   237  		{
   238  			name: "add a different directory to the same destination, overlapping children",
   239  			af: &Files{
   240  				Files: map[string]string{
   241  					"bar/foo":      "/some/place/real",
   242  					"bar/foo/foo2": "/some/place/real/zed",
   243  				},
   244  			},
   245  			src:         dir,
   246  			dest:        "bar/foo",
   247  			errContains: "already exists in archive",
   248  		},
   249  	} {
   250  		t.Run(fmt.Sprintf("Test %02d: %s", i, tt.name), func(t *testing.T) {
   251  			err := tt.af.AddFile(tt.src, tt.dest)
   252  			if err != nil && !strings.Contains(err.Error(), tt.errContains) {
   253  				t.Errorf("Error is %v, does not contain %v", err, tt.errContains)
   254  			}
   255  			if err == nil && len(tt.errContains) > 0 {
   256  				t.Errorf("Got no error, want %v", tt.errContains)
   257  			}
   258  
   259  			if tt.result != nil && !reflect.DeepEqual(tt.af, tt.result) {
   260  				t.Errorf("got %v, want %v", tt.af, tt.result)
   261  			}
   262  		})
   263  	}
   264  }
   265  
   266  func TestFilesAddRecord(t *testing.T) {
   267  	for i, tt := range []struct {
   268  		af     *Files
   269  		record cpio.Record
   270  
   271  		result      *Files
   272  		errContains string
   273  	}{
   274  		{
   275  			af:     NewFiles(),
   276  			record: cpio.Symlink("bar/foo", ""),
   277  			result: &Files{
   278  				Files: map[string]string{},
   279  				Records: map[string]cpio.Record{
   280  					"bar/foo": cpio.Symlink("bar/foo", ""),
   281  				},
   282  			},
   283  		},
   284  		{
   285  			af: &Files{
   286  				Files: map[string]string{
   287  					"bar/foo": "/some/other/place",
   288  				},
   289  			},
   290  			record: cpio.Symlink("bar/foo", ""),
   291  			result: &Files{
   292  				Files: map[string]string{
   293  					"bar/foo": "/some/other/place",
   294  				},
   295  			},
   296  			errContains: "already exists in archive",
   297  		},
   298  		{
   299  			af: &Files{
   300  				Records: map[string]cpio.Record{
   301  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   302  				},
   303  			},
   304  			record: cpio.Symlink("bar/foo", ""),
   305  			result: &Files{
   306  				Records: map[string]cpio.Record{
   307  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   308  				},
   309  			},
   310  			errContains: "already exists in archive",
   311  		},
   312  		{
   313  			af: &Files{
   314  				Records: map[string]cpio.Record{
   315  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   316  				},
   317  			},
   318  			record: cpio.Symlink("bar/foo", "/some/other/place"),
   319  			result: &Files{
   320  				Records: map[string]cpio.Record{
   321  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   322  				},
   323  			},
   324  		},
   325  		{
   326  			record:      cpio.Symlink("/bar/foo", ""),
   327  			errContains: "must not be absolute",
   328  		},
   329  	} {
   330  		t.Run(fmt.Sprintf("Test %02d", i), func(t *testing.T) {
   331  			err := tt.af.AddRecord(tt.record)
   332  			if err != nil && !strings.Contains(err.Error(), tt.errContains) {
   333  				t.Errorf("Error is %v, does not contain %v", err, tt.errContains)
   334  			}
   335  			if err == nil && len(tt.errContains) > 0 {
   336  				t.Errorf("Got no error, want %v", tt.errContains)
   337  			}
   338  
   339  			if !reflect.DeepEqual(tt.af, tt.result) {
   340  				t.Errorf("got %v, want %v", tt.af, tt.result)
   341  			}
   342  		})
   343  	}
   344  }
   345  
   346  func TestFilesfillInParent(t *testing.T) {
   347  	for i, tt := range []struct {
   348  		af     *Files
   349  		result *Files
   350  	}{
   351  		{
   352  			af: &Files{
   353  				Records: map[string]cpio.Record{
   354  					"foo/bar": cpio.Directory("foo/bar", 0777),
   355  				},
   356  			},
   357  			result: &Files{
   358  				Records: map[string]cpio.Record{
   359  					"foo/bar": cpio.Directory("foo/bar", 0777),
   360  					"foo":     cpio.Directory("foo", 0755),
   361  				},
   362  			},
   363  		},
   364  		{
   365  			af: &Files{
   366  				Files: map[string]string{
   367  					"baz/baz/baz": "/somewhere",
   368  				},
   369  				Records: map[string]cpio.Record{
   370  					"foo/bar": cpio.Directory("foo/bar", 0777),
   371  				},
   372  			},
   373  			result: &Files{
   374  				Files: map[string]string{
   375  					"baz/baz/baz": "/somewhere",
   376  				},
   377  				Records: map[string]cpio.Record{
   378  					"foo/bar": cpio.Directory("foo/bar", 0777),
   379  					"foo":     cpio.Directory("foo", 0755),
   380  					"baz":     cpio.Directory("baz", 0755),
   381  					"baz/baz": cpio.Directory("baz/baz", 0755),
   382  				},
   383  			},
   384  		},
   385  		{
   386  			af:     &Files{},
   387  			result: &Files{},
   388  		},
   389  	} {
   390  		t.Run(fmt.Sprintf("Test %02d", i), func(t *testing.T) {
   391  			tt.af.fillInParents()
   392  			if !reflect.DeepEqual(tt.af, tt.result) {
   393  				t.Errorf("got %v, want %v", tt.af, tt.result)
   394  			}
   395  		})
   396  	}
   397  }
   398  
   399  type MockArchiver struct {
   400  	Records      Records
   401  	FinishCalled bool
   402  	BaseArchive  []cpio.Record
   403  }
   404  
   405  func (ma *MockArchiver) WriteRecord(r cpio.Record) error {
   406  	if _, ok := ma.Records[r.Name]; ok {
   407  		return fmt.Errorf("file exists")
   408  	}
   409  	ma.Records[r.Name] = r
   410  	return nil
   411  }
   412  
   413  func (ma *MockArchiver) Finish() error {
   414  	ma.FinishCalled = true
   415  	return nil
   416  }
   417  
   418  func (ma *MockArchiver) ReadRecord() (cpio.Record, error) {
   419  	if len(ma.BaseArchive) > 0 {
   420  		next := ma.BaseArchive[0]
   421  		ma.BaseArchive = ma.BaseArchive[1:]
   422  		return next, nil
   423  	}
   424  	return cpio.Record{}, io.EOF
   425  }
   426  
   427  type Records map[string]cpio.Record
   428  
   429  func RecordsEqual(r1, r2 Records, recordEqual func(cpio.Record, cpio.Record) bool) bool {
   430  	for name, s1 := range r1 {
   431  		s2, ok := r2[name]
   432  		if !ok {
   433  			return false
   434  		}
   435  		if !recordEqual(s1, s2) {
   436  			return false
   437  		}
   438  	}
   439  	for name := range r2 {
   440  		if _, ok := r1[name]; !ok {
   441  			return false
   442  		}
   443  	}
   444  	return true
   445  }
   446  
   447  func sameNameModeContent(r1 cpio.Record, r2 cpio.Record) bool {
   448  	if r1.Name != r2.Name || r1.Mode != r2.Mode {
   449  		return false
   450  	}
   451  	return cpio.ReaderAtEqual(r1.ReaderAt, r2.ReaderAt)
   452  }
   453  
   454  func TestOptsWrite(t *testing.T) {
   455  	for i, tt := range []struct {
   456  		desc string
   457  		opts *Opts
   458  		ma   *MockArchiver
   459  		want Records
   460  		err  error
   461  	}{
   462  		{
   463  			desc: "no conflicts, just records",
   464  			opts: &Opts{
   465  				Files: &Files{
   466  					Records: map[string]cpio.Record{
   467  						"foo": cpio.Symlink("foo", "elsewhere"),
   468  					},
   469  				},
   470  			},
   471  			ma: &MockArchiver{
   472  				Records: make(Records),
   473  				BaseArchive: []cpio.Record{
   474  					cpio.Directory("etc", 0777),
   475  					cpio.Directory("etc/nginx", 0777),
   476  				},
   477  			},
   478  			want: Records{
   479  				"foo":       cpio.Symlink("foo", "elsewhere"),
   480  				"etc":       cpio.Directory("etc", 0777),
   481  				"etc/nginx": cpio.Directory("etc/nginx", 0777),
   482  			},
   483  		},
   484  		{
   485  			desc: "default already exists",
   486  			opts: &Opts{
   487  				Files: &Files{
   488  					Records: map[string]cpio.Record{
   489  						"etc": cpio.Symlink("etc", "whatever"),
   490  					},
   491  				},
   492  			},
   493  			ma: &MockArchiver{
   494  				Records: make(Records),
   495  				BaseArchive: []cpio.Record{
   496  					cpio.Directory("etc", 0777),
   497  				},
   498  			},
   499  			want: Records{
   500  				"etc": cpio.Symlink("etc", "whatever"),
   501  			},
   502  		},
   503  		{
   504  			desc: "no conflicts, missing parent automatically created",
   505  			opts: &Opts{
   506  				Files: &Files{
   507  					Records: map[string]cpio.Record{
   508  						"foo/bar/baz": cpio.Symlink("foo/bar/baz", "elsewhere"),
   509  					},
   510  				},
   511  			},
   512  			ma: &MockArchiver{
   513  				Records: make(Records),
   514  			},
   515  			want: Records{
   516  				"foo":         cpio.Directory("foo", 0755),
   517  				"foo/bar":     cpio.Directory("foo/bar", 0755),
   518  				"foo/bar/baz": cpio.Symlink("foo/bar/baz", "elsewhere"),
   519  			},
   520  		},
   521  		{
   522  			desc: "parent only automatically created if not already exists",
   523  			opts: &Opts{
   524  				Files: &Files{
   525  					Records: map[string]cpio.Record{
   526  						"foo/bar":     cpio.Directory("foo/bar", 0444),
   527  						"foo/bar/baz": cpio.Symlink("foo/bar/baz", "elsewhere"),
   528  					},
   529  				},
   530  			},
   531  			ma: &MockArchiver{
   532  				Records: make(Records),
   533  			},
   534  			want: Records{
   535  				"foo":         cpio.Directory("foo", 0755),
   536  				"foo/bar":     cpio.Directory("foo/bar", 0444),
   537  				"foo/bar/baz": cpio.Symlink("foo/bar/baz", "elsewhere"),
   538  			},
   539  		},
   540  		{
   541  			desc: "base archive",
   542  			opts: &Opts{
   543  				Files: &Files{
   544  					Records: map[string]cpio.Record{
   545  						"foo/bar": cpio.Symlink("foo/bar", "elsewhere"),
   546  						"exists":  cpio.Directory("exists", 0777),
   547  					},
   548  				},
   549  			},
   550  			ma: &MockArchiver{
   551  				Records: make(Records),
   552  				BaseArchive: []cpio.Record{
   553  					cpio.Directory("etc", 0755),
   554  					cpio.Directory("foo", 0444),
   555  					cpio.Directory("exists", 0),
   556  				},
   557  			},
   558  			want: Records{
   559  				"etc":     cpio.Directory("etc", 0755),
   560  				"exists":  cpio.Directory("exists", 0777),
   561  				"foo":     cpio.Directory("foo", 0444),
   562  				"foo/bar": cpio.Symlink("foo/bar", "elsewhere"),
   563  			},
   564  		},
   565  		{
   566  			desc: "base archive with init, no user init",
   567  			opts: &Opts{
   568  				Files: &Files{
   569  					Records: map[string]cpio.Record{},
   570  				},
   571  			},
   572  			ma: &MockArchiver{
   573  				Records: make(Records),
   574  				BaseArchive: []cpio.Record{
   575  					cpio.StaticFile("init", "boo", 0555),
   576  				},
   577  			},
   578  			want: Records{
   579  				"init": cpio.StaticFile("init", "boo", 0555),
   580  			},
   581  		},
   582  		{
   583  			desc: "base archive with init and user init",
   584  			opts: &Opts{
   585  				Files: &Files{
   586  					Records: map[string]cpio.Record{
   587  						"init": cpio.StaticFile("init", "bar", 0444),
   588  					},
   589  				},
   590  			},
   591  			ma: &MockArchiver{
   592  				Records: make(Records),
   593  				BaseArchive: []cpio.Record{
   594  					cpio.StaticFile("init", "boo", 0555),
   595  				},
   596  			},
   597  			want: Records{
   598  				"init":  cpio.StaticFile("init", "bar", 0444),
   599  				"inito": cpio.StaticFile("inito", "boo", 0555),
   600  			},
   601  		},
   602  		{
   603  			desc: "base archive with init, use existing init",
   604  			opts: &Opts{
   605  				Files: &Files{
   606  					Records: map[string]cpio.Record{},
   607  				},
   608  				UseExistingInit: true,
   609  			},
   610  			ma: &MockArchiver{
   611  				Records: make(Records),
   612  				BaseArchive: []cpio.Record{
   613  					cpio.StaticFile("init", "boo", 0555),
   614  				},
   615  			},
   616  			want: Records{
   617  				"init": cpio.StaticFile("init", "boo", 0555),
   618  			},
   619  		},
   620  		{
   621  			desc: "base archive with init and user init, use existing init",
   622  			opts: &Opts{
   623  				Files: &Files{
   624  					Records: map[string]cpio.Record{
   625  						"init": cpio.StaticFile("init", "huh", 0111),
   626  					},
   627  				},
   628  				UseExistingInit: true,
   629  			},
   630  			ma: &MockArchiver{
   631  				Records: make(Records),
   632  				BaseArchive: []cpio.Record{
   633  					cpio.StaticFile("init", "boo", 0555),
   634  				},
   635  			},
   636  			want: Records{
   637  				"init":  cpio.StaticFile("init", "boo", 0555),
   638  				"inito": cpio.StaticFile("inito", "huh", 0111),
   639  			},
   640  		},
   641  	} {
   642  		t.Run(fmt.Sprintf("Test %02d (%s)", i, tt.desc), func(t *testing.T) {
   643  			tt.opts.BaseArchive = tt.ma
   644  			tt.opts.OutputFile = tt.ma
   645  
   646  			if err := Write(tt.opts); err != tt.err {
   647  				t.Errorf("Write() = %v, want %v", err, tt.err)
   648  			} else if err == nil && !tt.ma.FinishCalled {
   649  				t.Errorf("Finish wasn't called on archive")
   650  			}
   651  
   652  			if !RecordsEqual(tt.ma.Records, tt.want, sameNameModeContent) {
   653  				t.Errorf("Write() = %v, want %v", tt.ma.Records, tt.want)
   654  			}
   655  		})
   656  	}
   657  }