github.com/mem/u-root@v2.0.1-0.20181004165302-9b18b4636a33+incompatible/pkg/uroot/initramfs/files_test.go (about)

     1  // Copyright 2018 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package initramfs
     6  
     7  import (
     8  	"fmt"
     9  	"io"
    10  	"io/ioutil"
    11  	"os"
    12  	"path/filepath"
    13  	"reflect"
    14  	"strings"
    15  	"testing"
    16  
    17  	"github.com/u-root/u-root/pkg/cpio"
    18  )
    19  
    20  func TestFilesAddFile(t *testing.T) {
    21  	regularFile, err := ioutil.TempFile("", "archive-files-add-file")
    22  	if err != nil {
    23  		t.Error(err)
    24  	}
    25  	defer os.RemoveAll(regularFile.Name())
    26  
    27  	dir, err := ioutil.TempDir("", "archive-add-files")
    28  	if err != nil {
    29  		t.Error(err)
    30  	}
    31  	defer os.RemoveAll(dir)
    32  
    33  	os.Create(filepath.Join(dir, "foo"))
    34  	os.Create(filepath.Join(dir, "foo2"))
    35  
    36  	for i, tt := range []struct {
    37  		name        string
    38  		af          *Files
    39  		src         string
    40  		dest        string
    41  		result      *Files
    42  		errContains string
    43  	}{
    44  		{
    45  			name: "just add a file",
    46  			af:   NewFiles(),
    47  
    48  			src:  regularFile.Name(),
    49  			dest: "bar/foo",
    50  
    51  			result: &Files{
    52  				Files: map[string]string{
    53  					"bar/foo": regularFile.Name(),
    54  				},
    55  				Records: map[string]cpio.Record{},
    56  			},
    57  		},
    58  		{
    59  			name: "add file that exists in Files",
    60  			af: &Files{
    61  				Files: map[string]string{
    62  					"bar/foo": "/some/other/place",
    63  				},
    64  			},
    65  			src:  regularFile.Name(),
    66  			dest: "bar/foo",
    67  			result: &Files{
    68  				Files: map[string]string{
    69  					"bar/foo": "/some/other/place",
    70  				},
    71  			},
    72  			errContains: "already exists in archive",
    73  		},
    74  		{
    75  			name: "add a file that exists in Records",
    76  			af: &Files{
    77  				Records: map[string]cpio.Record{
    78  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
    79  				},
    80  			},
    81  			src:  regularFile.Name(),
    82  			dest: "bar/foo",
    83  			result: &Files{
    84  				Records: map[string]cpio.Record{
    85  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
    86  				},
    87  			},
    88  			errContains: "already exists in archive",
    89  		},
    90  		{
    91  			name: "add a file that already exists in Files, but is the same one",
    92  			af: &Files{
    93  				Files: map[string]string{
    94  					"bar/foo": regularFile.Name(),
    95  				},
    96  			},
    97  			src:  regularFile.Name(),
    98  			dest: "bar/foo",
    99  			result: &Files{
   100  				Files: map[string]string{
   101  					"bar/foo": regularFile.Name(),
   102  				},
   103  			},
   104  		},
   105  		{
   106  			name:        "destination path must not be absolute",
   107  			src:         regularFile.Name(),
   108  			dest:        "/bar/foo",
   109  			errContains: "must not be absolute",
   110  		},
   111  		{
   112  			name:        "source path must be absolute",
   113  			src:         "foo/bar",
   114  			dest:        "bar/foo",
   115  			errContains: "must be absolute",
   116  		},
   117  		{
   118  			name: "add a directory",
   119  			af: &Files{
   120  				Files: map[string]string{},
   121  			},
   122  			src:  dir,
   123  			dest: "bar/foo",
   124  			result: &Files{
   125  				Files: map[string]string{
   126  					"bar/foo":      dir,
   127  					"bar/foo/foo":  filepath.Join(dir, "foo"),
   128  					"bar/foo/foo2": filepath.Join(dir, "foo2"),
   129  				},
   130  			},
   131  		},
   132  		{
   133  			name: "add a different directory to the same destination, no overlapping children",
   134  			af: &Files{
   135  				Files: map[string]string{
   136  					"bar/foo":     "/some/place/real",
   137  					"bar/foo/zed": "/some/place/real/zed",
   138  				},
   139  			},
   140  			src:  dir,
   141  			dest: "bar/foo",
   142  			result: &Files{
   143  				Files: map[string]string{
   144  					"bar/foo":      dir,
   145  					"bar/foo/foo":  filepath.Join(dir, "foo"),
   146  					"bar/foo/foo2": filepath.Join(dir, "foo2"),
   147  					"bar/foo/zed":  "/some/place/real/zed",
   148  				},
   149  			},
   150  		},
   151  		{
   152  			name: "add a different directory to the same destination, overlapping children",
   153  			af: &Files{
   154  				Files: map[string]string{
   155  					"bar/foo":      "/some/place/real",
   156  					"bar/foo/foo2": "/some/place/real/zed",
   157  				},
   158  			},
   159  			src:         dir,
   160  			dest:        "bar/foo",
   161  			errContains: "already exists in archive",
   162  		},
   163  	} {
   164  		t.Run(fmt.Sprintf("Test %02d: %s", i, tt.name), func(t *testing.T) {
   165  			err := tt.af.AddFile(tt.src, tt.dest)
   166  			if err != nil && !strings.Contains(err.Error(), tt.errContains) {
   167  				t.Errorf("Error is %v, does not contain %v", err, tt.errContains)
   168  			}
   169  			if err == nil && len(tt.errContains) > 0 {
   170  				t.Errorf("Got no error, want %v", tt.errContains)
   171  			}
   172  
   173  			if tt.result != nil && !reflect.DeepEqual(tt.af, tt.result) {
   174  				t.Errorf("got %v, want %v", tt.af, tt.result)
   175  			}
   176  		})
   177  	}
   178  }
   179  
   180  func TestFilesAddRecord(t *testing.T) {
   181  	for i, tt := range []struct {
   182  		af     *Files
   183  		record cpio.Record
   184  
   185  		result      *Files
   186  		errContains string
   187  	}{
   188  		{
   189  			af:     NewFiles(),
   190  			record: cpio.Symlink("bar/foo", ""),
   191  			result: &Files{
   192  				Files: map[string]string{},
   193  				Records: map[string]cpio.Record{
   194  					"bar/foo": cpio.Symlink("bar/foo", ""),
   195  				},
   196  			},
   197  		},
   198  		{
   199  			af: &Files{
   200  				Files: map[string]string{
   201  					"bar/foo": "/some/other/place",
   202  				},
   203  			},
   204  			record: cpio.Symlink("bar/foo", ""),
   205  			result: &Files{
   206  				Files: map[string]string{
   207  					"bar/foo": "/some/other/place",
   208  				},
   209  			},
   210  			errContains: "already exists in archive",
   211  		},
   212  		{
   213  			af: &Files{
   214  				Records: map[string]cpio.Record{
   215  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   216  				},
   217  			},
   218  			record: cpio.Symlink("bar/foo", ""),
   219  			result: &Files{
   220  				Records: map[string]cpio.Record{
   221  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   222  				},
   223  			},
   224  			errContains: "already exists in archive",
   225  		},
   226  		{
   227  			af: &Files{
   228  				Records: map[string]cpio.Record{
   229  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   230  				},
   231  			},
   232  			record: cpio.Symlink("bar/foo", "/some/other/place"),
   233  			result: &Files{
   234  				Records: map[string]cpio.Record{
   235  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   236  				},
   237  			},
   238  		},
   239  		{
   240  			record:      cpio.Symlink("/bar/foo", ""),
   241  			errContains: "must not be absolute",
   242  		},
   243  	} {
   244  		t.Run(fmt.Sprintf("Test %02d", i), func(t *testing.T) {
   245  			err := tt.af.AddRecord(tt.record)
   246  			if err != nil && !strings.Contains(err.Error(), tt.errContains) {
   247  				t.Errorf("Error is %v, does not contain %v", err, tt.errContains)
   248  			}
   249  			if err == nil && len(tt.errContains) > 0 {
   250  				t.Errorf("Got no error, want %v", tt.errContains)
   251  			}
   252  
   253  			if !reflect.DeepEqual(tt.af, tt.result) {
   254  				t.Errorf("got %v, want %v", tt.af, tt.result)
   255  			}
   256  		})
   257  	}
   258  }
   259  
   260  func TestFilesfillInParent(t *testing.T) {
   261  	for i, tt := range []struct {
   262  		af     *Files
   263  		result *Files
   264  	}{
   265  		{
   266  			af: &Files{
   267  				Records: map[string]cpio.Record{
   268  					"foo/bar": cpio.Directory("foo/bar", 0777),
   269  				},
   270  			},
   271  			result: &Files{
   272  				Records: map[string]cpio.Record{
   273  					"foo/bar": cpio.Directory("foo/bar", 0777),
   274  					"foo":     cpio.Directory("foo", 0755),
   275  				},
   276  			},
   277  		},
   278  		{
   279  			af: &Files{
   280  				Files: map[string]string{
   281  					"baz/baz/baz": "/somewhere",
   282  				},
   283  				Records: map[string]cpio.Record{
   284  					"foo/bar": cpio.Directory("foo/bar", 0777),
   285  				},
   286  			},
   287  			result: &Files{
   288  				Files: map[string]string{
   289  					"baz/baz/baz": "/somewhere",
   290  				},
   291  				Records: map[string]cpio.Record{
   292  					"foo/bar": cpio.Directory("foo/bar", 0777),
   293  					"foo":     cpio.Directory("foo", 0755),
   294  					"baz":     cpio.Directory("baz", 0755),
   295  					"baz/baz": cpio.Directory("baz/baz", 0755),
   296  				},
   297  			},
   298  		},
   299  		{
   300  			af:     &Files{},
   301  			result: &Files{},
   302  		},
   303  	} {
   304  		t.Run(fmt.Sprintf("Test %02d", i), func(t *testing.T) {
   305  			tt.af.fillInParents()
   306  			if !reflect.DeepEqual(tt.af, tt.result) {
   307  				t.Errorf("got %v, want %v", tt.af, tt.result)
   308  			}
   309  		})
   310  	}
   311  }
   312  
   313  type MockArchiver struct {
   314  	Records      Records
   315  	FinishCalled bool
   316  	BaseArchive  []cpio.Record
   317  }
   318  
   319  func (ma *MockArchiver) WriteRecord(r cpio.Record) error {
   320  	if _, ok := ma.Records[r.Name]; ok {
   321  		return fmt.Errorf("file exists")
   322  	}
   323  	ma.Records[r.Name] = r
   324  	return nil
   325  }
   326  
   327  func (ma *MockArchiver) Finish() error {
   328  	ma.FinishCalled = true
   329  	return nil
   330  }
   331  
   332  func (ma *MockArchiver) ReadRecord() (cpio.Record, error) {
   333  	if len(ma.BaseArchive) > 0 {
   334  		next := ma.BaseArchive[0]
   335  		ma.BaseArchive = ma.BaseArchive[1:]
   336  		return next, nil
   337  	}
   338  	return cpio.Record{}, io.EOF
   339  }
   340  
   341  type Records map[string]cpio.Record
   342  
   343  func RecordsEqual(r1, r2 Records, recordEqual func(cpio.Record, cpio.Record) bool) bool {
   344  	for name, s1 := range r1 {
   345  		s2, ok := r2[name]
   346  		if !ok {
   347  			return false
   348  		}
   349  		if !recordEqual(s1, s2) {
   350  			return false
   351  		}
   352  	}
   353  	for name := range r2 {
   354  		if _, ok := r1[name]; !ok {
   355  			return false
   356  		}
   357  	}
   358  	return true
   359  }
   360  
   361  func sameNameModeContent(r1 cpio.Record, r2 cpio.Record) bool {
   362  	if r1.Name != r2.Name || r1.Mode != r2.Mode {
   363  		return false
   364  	}
   365  	return cpio.ReaderAtEqual(r1.ReaderAt, r2.ReaderAt)
   366  }
   367  
   368  func TestOptsWrite(t *testing.T) {
   369  	for i, tt := range []struct {
   370  		desc string
   371  		opts *Opts
   372  		ma   *MockArchiver
   373  		want Records
   374  		err  error
   375  	}{
   376  		{
   377  			desc: "no conflicts, just records",
   378  			opts: &Opts{
   379  				Files: &Files{
   380  					Records: map[string]cpio.Record{
   381  						"foo": cpio.Symlink("foo", "elsewhere"),
   382  					},
   383  				},
   384  			},
   385  			ma: &MockArchiver{
   386  				Records: make(Records),
   387  				BaseArchive: []cpio.Record{
   388  					cpio.Directory("etc", 0777),
   389  					cpio.Directory("etc/nginx", 0777),
   390  				},
   391  			},
   392  			want: Records{
   393  				"foo":       cpio.Symlink("foo", "elsewhere"),
   394  				"etc":       cpio.Directory("etc", 0777),
   395  				"etc/nginx": cpio.Directory("etc/nginx", 0777),
   396  			},
   397  		},
   398  		{
   399  			desc: "default already exists",
   400  			opts: &Opts{
   401  				Files: &Files{
   402  					Records: map[string]cpio.Record{
   403  						"etc": cpio.Symlink("etc", "whatever"),
   404  					},
   405  				},
   406  			},
   407  			ma: &MockArchiver{
   408  				Records: make(Records),
   409  				BaseArchive: []cpio.Record{
   410  					cpio.Directory("etc", 0777),
   411  				},
   412  			},
   413  			want: Records{
   414  				"etc": cpio.Symlink("etc", "whatever"),
   415  			},
   416  		},
   417  		{
   418  			desc: "no conflicts, missing parent automatically created",
   419  			opts: &Opts{
   420  				Files: &Files{
   421  					Records: map[string]cpio.Record{
   422  						"foo/bar/baz": cpio.Symlink("foo/bar/baz", "elsewhere"),
   423  					},
   424  				},
   425  			},
   426  			ma: &MockArchiver{
   427  				Records: make(Records),
   428  			},
   429  			want: Records{
   430  				"foo":         cpio.Directory("foo", 0755),
   431  				"foo/bar":     cpio.Directory("foo/bar", 0755),
   432  				"foo/bar/baz": cpio.Symlink("foo/bar/baz", "elsewhere"),
   433  			},
   434  		},
   435  		{
   436  			desc: "parent only automatically created if not already exists",
   437  			opts: &Opts{
   438  				Files: &Files{
   439  					Records: map[string]cpio.Record{
   440  						"foo/bar":     cpio.Directory("foo/bar", 0444),
   441  						"foo/bar/baz": cpio.Symlink("foo/bar/baz", "elsewhere"),
   442  					},
   443  				},
   444  			},
   445  			ma: &MockArchiver{
   446  				Records: make(Records),
   447  			},
   448  			want: Records{
   449  				"foo":         cpio.Directory("foo", 0755),
   450  				"foo/bar":     cpio.Directory("foo/bar", 0444),
   451  				"foo/bar/baz": cpio.Symlink("foo/bar/baz", "elsewhere"),
   452  			},
   453  		},
   454  		{
   455  			desc: "base archive",
   456  			opts: &Opts{
   457  				Files: &Files{
   458  					Records: map[string]cpio.Record{
   459  						"foo/bar": cpio.Symlink("foo/bar", "elsewhere"),
   460  						"exists":  cpio.Directory("exists", 0777),
   461  					},
   462  				},
   463  			},
   464  			ma: &MockArchiver{
   465  				Records: make(Records),
   466  				BaseArchive: []cpio.Record{
   467  					cpio.Directory("etc", 0755),
   468  					cpio.Directory("foo", 0444),
   469  					cpio.Directory("exists", 0),
   470  				},
   471  			},
   472  			want: Records{
   473  				"etc":     cpio.Directory("etc", 0755),
   474  				"exists":  cpio.Directory("exists", 0777),
   475  				"foo":     cpio.Directory("foo", 0444),
   476  				"foo/bar": cpio.Symlink("foo/bar", "elsewhere"),
   477  			},
   478  		},
   479  		{
   480  			desc: "base archive with init, no user init",
   481  			opts: &Opts{
   482  				Files: &Files{
   483  					Records: map[string]cpio.Record{},
   484  				},
   485  			},
   486  			ma: &MockArchiver{
   487  				Records: make(Records),
   488  				BaseArchive: []cpio.Record{
   489  					cpio.StaticFile("init", "boo", 0555),
   490  				},
   491  			},
   492  			want: Records{
   493  				"init": cpio.StaticFile("init", "boo", 0555),
   494  			},
   495  		},
   496  		{
   497  			desc: "base archive with init and user init",
   498  			opts: &Opts{
   499  				Files: &Files{
   500  					Records: map[string]cpio.Record{
   501  						"init": cpio.StaticFile("init", "bar", 0444),
   502  					},
   503  				},
   504  			},
   505  			ma: &MockArchiver{
   506  				Records: make(Records),
   507  				BaseArchive: []cpio.Record{
   508  					cpio.StaticFile("init", "boo", 0555),
   509  				},
   510  			},
   511  			want: Records{
   512  				"init":  cpio.StaticFile("init", "bar", 0444),
   513  				"inito": cpio.StaticFile("inito", "boo", 0555),
   514  			},
   515  		},
   516  		{
   517  			desc: "base archive with init, use existing init",
   518  			opts: &Opts{
   519  				Files: &Files{
   520  					Records: map[string]cpio.Record{},
   521  				},
   522  				UseExistingInit: true,
   523  			},
   524  			ma: &MockArchiver{
   525  				Records: make(Records),
   526  				BaseArchive: []cpio.Record{
   527  					cpio.StaticFile("init", "boo", 0555),
   528  				},
   529  			},
   530  			want: Records{
   531  				"init": cpio.StaticFile("init", "boo", 0555),
   532  			},
   533  		},
   534  		{
   535  			desc: "base archive with init and user init, use existing init",
   536  			opts: &Opts{
   537  				Files: &Files{
   538  					Records: map[string]cpio.Record{
   539  						"init": cpio.StaticFile("init", "huh", 0111),
   540  					},
   541  				},
   542  				UseExistingInit: true,
   543  			},
   544  			ma: &MockArchiver{
   545  				Records: make(Records),
   546  				BaseArchive: []cpio.Record{
   547  					cpio.StaticFile("init", "boo", 0555),
   548  				},
   549  			},
   550  			want: Records{
   551  				"init":  cpio.StaticFile("init", "boo", 0555),
   552  				"inito": cpio.StaticFile("inito", "huh", 0111),
   553  			},
   554  		},
   555  	} {
   556  		t.Run(fmt.Sprintf("Test %02d (%s)", i, tt.desc), func(t *testing.T) {
   557  			tt.opts.BaseArchive = tt.ma
   558  			tt.opts.OutputFile = tt.ma
   559  
   560  			if err := Write(tt.opts); err != tt.err {
   561  				t.Errorf("Write() = %v, want %v", err, tt.err)
   562  			} else if err == nil && !tt.ma.FinishCalled {
   563  				t.Errorf("Finish wasn't called on archive")
   564  			}
   565  
   566  			if !RecordsEqual(tt.ma.Records, tt.want, sameNameModeContent) {
   567  				t.Errorf("Write() = %v, want %v", tt.ma.Records, tt.want)
   568  			}
   569  		})
   570  	}
   571  }