github.com/mvdan/u-root-coreutils@v0.0.0-20230122170626-c2eef2898555/pkg/uroot/initramfs/files_test.go (about)

     1  // Copyright 2018 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package initramfs
     6  
     7  import (
     8  	"fmt"
     9  	"io"
    10  	"os"
    11  	"path/filepath"
    12  	"reflect"
    13  	"strings"
    14  	"testing"
    15  
    16  	"github.com/mvdan/u-root-coreutils/pkg/cpio"
    17  	"github.com/mvdan/u-root-coreutils/pkg/uio"
    18  )
    19  
    20  func TestFilesAddFileNoFollow(t *testing.T) {
    21  	regularFile, err := os.CreateTemp("", "archive-files-add-file")
    22  	if err != nil {
    23  		t.Error(err)
    24  	}
    25  	defer os.RemoveAll(regularFile.Name())
    26  
    27  	dir := t.TempDir()
    28  	dir2 := t.TempDir()
    29  
    30  	os.Create(filepath.Join(dir, "foo2"))
    31  	os.Symlink(filepath.Join(dir, "foo2"), filepath.Join(dir2, "foo3"))
    32  
    33  	for i, tt := range []struct {
    34  		name        string
    35  		af          *Files
    36  		src         string
    37  		dest        string
    38  		result      *Files
    39  		errContains string
    40  	}{
    41  		{
    42  			name: "just add a file",
    43  			af:   NewFiles(),
    44  
    45  			src:  regularFile.Name(),
    46  			dest: "bar/foo",
    47  
    48  			result: &Files{
    49  				Files: map[string]string{
    50  					"bar/foo": regularFile.Name(),
    51  				},
    52  				Records: map[string]cpio.Record{},
    53  			},
    54  		},
    55  		{
    56  			name: "add symlinked file, NOT following",
    57  			af:   NewFiles(),
    58  			src:  filepath.Join(dir2, "foo3"),
    59  			dest: "bar/foo",
    60  			result: &Files{
    61  				Files: map[string]string{
    62  					"bar/foo": filepath.Join(dir2, "foo3"),
    63  				},
    64  				Records: map[string]cpio.Record{},
    65  			},
    66  		},
    67  	} {
    68  		t.Run(fmt.Sprintf("Test %02d: %s", i, tt.name), func(t *testing.T) {
    69  			err := tt.af.AddFileNoFollow(tt.src, tt.dest)
    70  			if err != nil && !strings.Contains(err.Error(), tt.errContains) {
    71  				t.Errorf("Error is %v, does not contain %v", err, tt.errContains)
    72  			}
    73  			if err == nil && len(tt.errContains) > 0 {
    74  				t.Errorf("Got no error, want %v", tt.errContains)
    75  			}
    76  
    77  			if tt.result != nil && !reflect.DeepEqual(tt.af, tt.result) {
    78  				t.Errorf("got %v, want %v", tt.af, tt.result)
    79  			}
    80  		})
    81  	}
    82  }
    83  
    84  func TestFilesAddFile(t *testing.T) {
    85  	regularFile, err := os.CreateTemp("", "archive-files-add-file")
    86  	if err != nil {
    87  		t.Error(err)
    88  	}
    89  	defer os.RemoveAll(regularFile.Name())
    90  
    91  	dir := t.TempDir()
    92  	dir2 := t.TempDir()
    93  	dir3 := t.TempDir()
    94  
    95  	os.Create(filepath.Join(dir, "foo"))
    96  	os.Create(filepath.Join(dir, "foo2"))
    97  	os.Symlink(filepath.Join(dir, "foo2"), filepath.Join(dir2, "foo3"))
    98  
    99  	fooDir := filepath.Join(dir3, "fooDir")
   100  	os.Mkdir(fooDir, os.ModePerm)
   101  	symlinkToDir3 := filepath.Join(dir3, "fooSymDir/")
   102  	os.Symlink(fooDir, symlinkToDir3)
   103  	os.Create(filepath.Join(fooDir, "foo"))
   104  	os.Create(filepath.Join(fooDir, "bar"))
   105  
   106  	for i, tt := range []struct {
   107  		name        string
   108  		af          *Files
   109  		src         string
   110  		dest        string
   111  		result      *Files
   112  		errContains string
   113  	}{
   114  		{
   115  			name: "just add a file",
   116  			af:   NewFiles(),
   117  
   118  			src:  regularFile.Name(),
   119  			dest: "bar/foo",
   120  
   121  			result: &Files{
   122  				Files: map[string]string{
   123  					"bar/foo": regularFile.Name(),
   124  				},
   125  				Records: map[string]cpio.Record{},
   126  			},
   127  		},
   128  		{
   129  			name: "add symlinked file, following",
   130  			af:   NewFiles(),
   131  			src:  filepath.Join(dir2, "foo3"),
   132  			dest: "bar/foo",
   133  			result: &Files{
   134  				Files: map[string]string{
   135  					"bar/foo": filepath.Join(dir, "foo2"),
   136  				},
   137  				Records: map[string]cpio.Record{},
   138  			},
   139  		},
   140  		{
   141  			name: "add symlinked directory, following",
   142  			af:   NewFiles(),
   143  			src:  symlinkToDir3,
   144  			dest: "foo/",
   145  			result: &Files{
   146  				Files: map[string]string{
   147  					"foo":     fooDir,
   148  					"foo/foo": filepath.Join(fooDir, "foo"),
   149  					"foo/bar": filepath.Join(fooDir, "bar"),
   150  				},
   151  				Records: map[string]cpio.Record{},
   152  			},
   153  		},
   154  		{
   155  			name: "add file that exists in Files",
   156  			af: &Files{
   157  				Files: map[string]string{
   158  					"bar/foo": "/some/other/place",
   159  				},
   160  			},
   161  			src:  regularFile.Name(),
   162  			dest: "bar/foo",
   163  			result: &Files{
   164  				Files: map[string]string{
   165  					"bar/foo": "/some/other/place",
   166  				},
   167  			},
   168  			errContains: "already exists in archive",
   169  		},
   170  		{
   171  			name: "add a file that exists in Records",
   172  			af: &Files{
   173  				Records: map[string]cpio.Record{
   174  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   175  				},
   176  			},
   177  			src:  regularFile.Name(),
   178  			dest: "bar/foo",
   179  			result: &Files{
   180  				Records: map[string]cpio.Record{
   181  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   182  				},
   183  			},
   184  			errContains: "already exists in archive",
   185  		},
   186  		{
   187  			name: "add a file that already exists in Files, but is the same one",
   188  			af: &Files{
   189  				Files: map[string]string{
   190  					"bar/foo": regularFile.Name(),
   191  				},
   192  			},
   193  			src:  regularFile.Name(),
   194  			dest: "bar/foo",
   195  			result: &Files{
   196  				Files: map[string]string{
   197  					"bar/foo": regularFile.Name(),
   198  				},
   199  			},
   200  		},
   201  		{
   202  			name: "absolute destination paths are made relative",
   203  			af: &Files{
   204  				Files: map[string]string{},
   205  			},
   206  			src:  dir,
   207  			dest: "/bar/foo",
   208  			result: &Files{
   209  				Files: map[string]string{
   210  					"bar/foo":      dir,
   211  					"bar/foo/foo":  filepath.Join(dir, "foo"),
   212  					"bar/foo/foo2": filepath.Join(dir, "foo2"),
   213  				},
   214  			},
   215  		},
   216  		{
   217  			name: "add a directory",
   218  			af: &Files{
   219  				Files: map[string]string{},
   220  			},
   221  			src:  dir,
   222  			dest: "bar/foo",
   223  			result: &Files{
   224  				Files: map[string]string{
   225  					"bar/foo":      dir,
   226  					"bar/foo/foo":  filepath.Join(dir, "foo"),
   227  					"bar/foo/foo2": filepath.Join(dir, "foo2"),
   228  				},
   229  			},
   230  		},
   231  		{
   232  			name: "add a different directory to the same destination, no overlapping children",
   233  			af: &Files{
   234  				Files: map[string]string{
   235  					"bar/foo":     "/some/place/real",
   236  					"bar/foo/zed": "/some/place/real/zed",
   237  				},
   238  			},
   239  			src:  dir,
   240  			dest: "bar/foo",
   241  			result: &Files{
   242  				Files: map[string]string{
   243  					"bar/foo":      dir,
   244  					"bar/foo/foo":  filepath.Join(dir, "foo"),
   245  					"bar/foo/foo2": filepath.Join(dir, "foo2"),
   246  					"bar/foo/zed":  "/some/place/real/zed",
   247  				},
   248  			},
   249  		},
   250  		{
   251  			name: "add a different directory to the same destination, overlapping children",
   252  			af: &Files{
   253  				Files: map[string]string{
   254  					"bar/foo":      "/some/place/real",
   255  					"bar/foo/foo2": "/some/place/real/zed",
   256  				},
   257  			},
   258  			src:         dir,
   259  			dest:        "bar/foo",
   260  			errContains: "already exists in archive",
   261  		},
   262  	} {
   263  		t.Run(fmt.Sprintf("Test %02d: %s", i, tt.name), func(t *testing.T) {
   264  			err := tt.af.AddFile(tt.src, tt.dest)
   265  			if err != nil && !strings.Contains(err.Error(), tt.errContains) {
   266  				t.Errorf("Error is %v, does not contain %v", err, tt.errContains)
   267  			}
   268  			if err == nil && len(tt.errContains) > 0 {
   269  				t.Errorf("Got no error, want %v", tt.errContains)
   270  			}
   271  
   272  			if tt.result != nil && !reflect.DeepEqual(tt.af, tt.result) {
   273  				t.Errorf("got %v, want %v", tt.af, tt.result)
   274  			}
   275  		})
   276  	}
   277  }
   278  
   279  func TestFilesAddRecord(t *testing.T) {
   280  	for i, tt := range []struct {
   281  		af     *Files
   282  		record cpio.Record
   283  
   284  		result      *Files
   285  		errContains string
   286  	}{
   287  		{
   288  			af:     NewFiles(),
   289  			record: cpio.Symlink("bar/foo", ""),
   290  			result: &Files{
   291  				Files: map[string]string{},
   292  				Records: map[string]cpio.Record{
   293  					"bar/foo": cpio.Symlink("bar/foo", ""),
   294  				},
   295  			},
   296  		},
   297  		{
   298  			af: &Files{
   299  				Files: map[string]string{
   300  					"bar/foo": "/some/other/place",
   301  				},
   302  			},
   303  			record: cpio.Symlink("bar/foo", ""),
   304  			result: &Files{
   305  				Files: map[string]string{
   306  					"bar/foo": "/some/other/place",
   307  				},
   308  			},
   309  			errContains: "already exists in archive",
   310  		},
   311  		{
   312  			af: &Files{
   313  				Records: map[string]cpio.Record{
   314  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   315  				},
   316  			},
   317  			record: cpio.Symlink("bar/foo", ""),
   318  			result: &Files{
   319  				Records: map[string]cpio.Record{
   320  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   321  				},
   322  			},
   323  			errContains: "already exists in archive",
   324  		},
   325  		{
   326  			af: &Files{
   327  				Records: map[string]cpio.Record{
   328  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   329  				},
   330  			},
   331  			record: cpio.Symlink("bar/foo", "/some/other/place"),
   332  			result: &Files{
   333  				Records: map[string]cpio.Record{
   334  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   335  				},
   336  			},
   337  		},
   338  		{
   339  			record:      cpio.Symlink("/bar/foo", ""),
   340  			errContains: "must not be absolute",
   341  		},
   342  	} {
   343  		t.Run(fmt.Sprintf("Test %02d", i), func(t *testing.T) {
   344  			err := tt.af.AddRecord(tt.record)
   345  			if err != nil && !strings.Contains(err.Error(), tt.errContains) {
   346  				t.Errorf("Error is %v, does not contain %v", err, tt.errContains)
   347  			}
   348  			if err == nil && len(tt.errContains) > 0 {
   349  				t.Errorf("Got no error, want %v", tt.errContains)
   350  			}
   351  
   352  			if !reflect.DeepEqual(tt.af, tt.result) {
   353  				t.Errorf("got %v, want %v", tt.af, tt.result)
   354  			}
   355  		})
   356  	}
   357  }
   358  
   359  func TestFilesfillInParent(t *testing.T) {
   360  	for i, tt := range []struct {
   361  		af     *Files
   362  		result *Files
   363  	}{
   364  		{
   365  			af: &Files{
   366  				Records: map[string]cpio.Record{
   367  					"foo/bar": cpio.Directory("foo/bar", 0o777),
   368  				},
   369  			},
   370  			result: &Files{
   371  				Records: map[string]cpio.Record{
   372  					"foo/bar": cpio.Directory("foo/bar", 0o777),
   373  					"foo":     cpio.Directory("foo", 0o755),
   374  				},
   375  			},
   376  		},
   377  		{
   378  			af: &Files{
   379  				Files: map[string]string{
   380  					"baz/baz/baz": "/somewhere",
   381  				},
   382  				Records: map[string]cpio.Record{
   383  					"foo/bar": cpio.Directory("foo/bar", 0o777),
   384  				},
   385  			},
   386  			result: &Files{
   387  				Files: map[string]string{
   388  					"baz/baz/baz": "/somewhere",
   389  				},
   390  				Records: map[string]cpio.Record{
   391  					"foo/bar": cpio.Directory("foo/bar", 0o777),
   392  					"foo":     cpio.Directory("foo", 0o755),
   393  					"baz":     cpio.Directory("baz", 0o755),
   394  					"baz/baz": cpio.Directory("baz/baz", 0o755),
   395  				},
   396  			},
   397  		},
   398  		{
   399  			af:     &Files{},
   400  			result: &Files{},
   401  		},
   402  	} {
   403  		t.Run(fmt.Sprintf("Test %02d", i), func(t *testing.T) {
   404  			tt.af.fillInParents()
   405  			if !reflect.DeepEqual(tt.af, tt.result) {
   406  				t.Errorf("got %v, want %v", tt.af, tt.result)
   407  			}
   408  		})
   409  	}
   410  }
   411  
   412  type MockArchiver struct {
   413  	Records      Records
   414  	FinishCalled bool
   415  	BaseArchive  []cpio.Record
   416  }
   417  
   418  func (ma *MockArchiver) WriteRecord(r cpio.Record) error {
   419  	if _, ok := ma.Records[r.Name]; ok {
   420  		return fmt.Errorf("file exists")
   421  	}
   422  	ma.Records[r.Name] = r
   423  	return nil
   424  }
   425  
   426  func (ma *MockArchiver) Finish() error {
   427  	ma.FinishCalled = true
   428  	return nil
   429  }
   430  
   431  func (ma *MockArchiver) ReadRecord() (cpio.Record, error) {
   432  	if len(ma.BaseArchive) > 0 {
   433  		next := ma.BaseArchive[0]
   434  		ma.BaseArchive = ma.BaseArchive[1:]
   435  		return next, nil
   436  	}
   437  	return cpio.Record{}, io.EOF
   438  }
   439  
   440  type Records map[string]cpio.Record
   441  
   442  func RecordsEqual(r1, r2 Records, recordEqual func(cpio.Record, cpio.Record) bool) bool {
   443  	for name, s1 := range r1 {
   444  		s2, ok := r2[name]
   445  		if !ok {
   446  			return false
   447  		}
   448  		if !recordEqual(s1, s2) {
   449  			return false
   450  		}
   451  	}
   452  	for name := range r2 {
   453  		if _, ok := r1[name]; !ok {
   454  			return false
   455  		}
   456  	}
   457  	return true
   458  }
   459  
   460  func sameNameModeContent(r1 cpio.Record, r2 cpio.Record) bool {
   461  	if r1.Name != r2.Name || r1.Mode != r2.Mode {
   462  		return false
   463  	}
   464  	return uio.ReaderAtEqual(r1.ReaderAt, r2.ReaderAt)
   465  }
   466  
   467  func TestOptsWrite(t *testing.T) {
   468  	for i, tt := range []struct {
   469  		desc string
   470  		opts *Opts
   471  		ma   *MockArchiver
   472  		want Records
   473  		err  error
   474  	}{
   475  		{
   476  			desc: "no conflicts, just records",
   477  			opts: &Opts{
   478  				Files: &Files{
   479  					Records: map[string]cpio.Record{
   480  						"foo": cpio.Symlink("foo", "elsewhere"),
   481  					},
   482  				},
   483  			},
   484  			ma: &MockArchiver{
   485  				Records: make(Records),
   486  				BaseArchive: []cpio.Record{
   487  					cpio.Directory("etc", 0o777),
   488  					cpio.Directory("etc/nginx", 0o777),
   489  				},
   490  			},
   491  			want: Records{
   492  				"foo":       cpio.Symlink("foo", "elsewhere"),
   493  				"etc":       cpio.Directory("etc", 0o777),
   494  				"etc/nginx": cpio.Directory("etc/nginx", 0o777),
   495  			},
   496  		},
   497  		{
   498  			desc: "default already exists",
   499  			opts: &Opts{
   500  				Files: &Files{
   501  					Records: map[string]cpio.Record{
   502  						"etc": cpio.Symlink("etc", "whatever"),
   503  					},
   504  				},
   505  			},
   506  			ma: &MockArchiver{
   507  				Records: make(Records),
   508  				BaseArchive: []cpio.Record{
   509  					cpio.Directory("etc", 0o777),
   510  				},
   511  			},
   512  			want: Records{
   513  				"etc": cpio.Symlink("etc", "whatever"),
   514  			},
   515  		},
   516  		{
   517  			desc: "no conflicts, missing parent automatically created",
   518  			opts: &Opts{
   519  				Files: &Files{
   520  					Records: map[string]cpio.Record{
   521  						"foo/bar/baz": cpio.Symlink("foo/bar/baz", "elsewhere"),
   522  					},
   523  				},
   524  			},
   525  			ma: &MockArchiver{
   526  				Records: make(Records),
   527  			},
   528  			want: Records{
   529  				"foo":         cpio.Directory("foo", 0o755),
   530  				"foo/bar":     cpio.Directory("foo/bar", 0o755),
   531  				"foo/bar/baz": cpio.Symlink("foo/bar/baz", "elsewhere"),
   532  			},
   533  		},
   534  		{
   535  			desc: "parent only automatically created if not already exists",
   536  			opts: &Opts{
   537  				Files: &Files{
   538  					Records: map[string]cpio.Record{
   539  						"foo/bar":     cpio.Directory("foo/bar", 0o444),
   540  						"foo/bar/baz": cpio.Symlink("foo/bar/baz", "elsewhere"),
   541  					},
   542  				},
   543  			},
   544  			ma: &MockArchiver{
   545  				Records: make(Records),
   546  			},
   547  			want: Records{
   548  				"foo":         cpio.Directory("foo", 0o755),
   549  				"foo/bar":     cpio.Directory("foo/bar", 0o444),
   550  				"foo/bar/baz": cpio.Symlink("foo/bar/baz", "elsewhere"),
   551  			},
   552  		},
   553  		{
   554  			desc: "base archive",
   555  			opts: &Opts{
   556  				Files: &Files{
   557  					Records: map[string]cpio.Record{
   558  						"foo/bar": cpio.Symlink("foo/bar", "elsewhere"),
   559  						"exists":  cpio.Directory("exists", 0o777),
   560  					},
   561  				},
   562  			},
   563  			ma: &MockArchiver{
   564  				Records: make(Records),
   565  				BaseArchive: []cpio.Record{
   566  					cpio.Directory("etc", 0o755),
   567  					cpio.Directory("foo", 0o444),
   568  					cpio.Directory("exists", 0),
   569  				},
   570  			},
   571  			want: Records{
   572  				"etc":     cpio.Directory("etc", 0o755),
   573  				"exists":  cpio.Directory("exists", 0o777),
   574  				"foo":     cpio.Directory("foo", 0o444),
   575  				"foo/bar": cpio.Symlink("foo/bar", "elsewhere"),
   576  			},
   577  		},
   578  		{
   579  			desc: "base archive with init, no user init",
   580  			opts: &Opts{
   581  				Files: &Files{
   582  					Records: map[string]cpio.Record{},
   583  				},
   584  			},
   585  			ma: &MockArchiver{
   586  				Records: make(Records),
   587  				BaseArchive: []cpio.Record{
   588  					cpio.StaticFile("init", "boo", 0o555),
   589  				},
   590  			},
   591  			want: Records{
   592  				"init": cpio.StaticFile("init", "boo", 0o555),
   593  			},
   594  		},
   595  		{
   596  			desc: "base archive with init and user init",
   597  			opts: &Opts{
   598  				Files: &Files{
   599  					Records: map[string]cpio.Record{
   600  						"init": cpio.StaticFile("init", "bar", 0o444),
   601  					},
   602  				},
   603  			},
   604  			ma: &MockArchiver{
   605  				Records: make(Records),
   606  				BaseArchive: []cpio.Record{
   607  					cpio.StaticFile("init", "boo", 0o555),
   608  				},
   609  			},
   610  			want: Records{
   611  				"init":  cpio.StaticFile("init", "bar", 0o444),
   612  				"inito": cpio.StaticFile("inito", "boo", 0o555),
   613  			},
   614  		},
   615  		{
   616  			desc: "base archive with init, use existing init",
   617  			opts: &Opts{
   618  				Files: &Files{
   619  					Records: map[string]cpio.Record{},
   620  				},
   621  				UseExistingInit: true,
   622  			},
   623  			ma: &MockArchiver{
   624  				Records: make(Records),
   625  				BaseArchive: []cpio.Record{
   626  					cpio.StaticFile("init", "boo", 0o555),
   627  				},
   628  			},
   629  			want: Records{
   630  				"init": cpio.StaticFile("init", "boo", 0o555),
   631  			},
   632  		},
   633  		{
   634  			desc: "base archive with init and user init, use existing init",
   635  			opts: &Opts{
   636  				Files: &Files{
   637  					Records: map[string]cpio.Record{
   638  						"init": cpio.StaticFile("init", "huh", 0o111),
   639  					},
   640  				},
   641  				UseExistingInit: true,
   642  			},
   643  			ma: &MockArchiver{
   644  				Records: make(Records),
   645  				BaseArchive: []cpio.Record{
   646  					cpio.StaticFile("init", "boo", 0o555),
   647  				},
   648  			},
   649  			want: Records{
   650  				"init":  cpio.StaticFile("init", "boo", 0o555),
   651  				"inito": cpio.StaticFile("inito", "huh", 0o111),
   652  			},
   653  		},
   654  	} {
   655  		t.Run(fmt.Sprintf("Test %02d (%s)", i, tt.desc), func(t *testing.T) {
   656  			tt.opts.BaseArchive = tt.ma
   657  			tt.opts.OutputFile = tt.ma
   658  
   659  			if err := Write(tt.opts); err != tt.err {
   660  				t.Errorf("Write() = %v, want %v", err, tt.err)
   661  			} else if err == nil && !tt.ma.FinishCalled {
   662  				t.Errorf("Finish wasn't called on archive")
   663  			}
   664  
   665  			if !RecordsEqual(tt.ma.Records, tt.want, sameNameModeContent) {
   666  				t.Errorf("Write() = %v, want %v", tt.ma.Records, tt.want)
   667  			}
   668  		})
   669  	}
   670  }