gopkg.in/hugelgupf/u-root.v2@v2.0.0-20180831055005-3f8fdb0ce09d/pkg/uroot/initramfs/files_test.go (about)

     1  // Copyright 2018 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package initramfs
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"io"
    11  	"io/ioutil"
    12  	"os"
    13  	"path/filepath"
    14  	"reflect"
    15  	"strings"
    16  	"testing"
    17  
    18  	"github.com/u-root/u-root/pkg/cpio"
    19  	"golang.org/x/sys/unix"
    20  )
    21  
    22  func TestFilesAddFile(t *testing.T) {
    23  	for i, tt := range []struct {
    24  		af          Files
    25  		src         string
    26  		dest        string
    27  		result      Files
    28  		errContains string
    29  	}{
    30  		{
    31  			af:   NewFiles(),
    32  			src:  "/foo/bar",
    33  			dest: "bar/foo",
    34  
    35  			result: Files{
    36  				Files: map[string]string{
    37  					"bar/foo": "/foo/bar",
    38  				},
    39  				Records: map[string]cpio.Record{},
    40  			},
    41  		},
    42  		{
    43  			af: Files{
    44  				Files: map[string]string{
    45  					"bar/foo": "/some/other/place",
    46  				},
    47  			},
    48  			src:  "/foo/bar",
    49  			dest: "bar/foo",
    50  			result: Files{
    51  				Files: map[string]string{
    52  					"bar/foo": "/some/other/place",
    53  				},
    54  			},
    55  			errContains: "already exists in archive",
    56  		},
    57  		{
    58  			af: Files{
    59  				Records: map[string]cpio.Record{
    60  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
    61  				},
    62  			},
    63  			src:  "/foo/bar",
    64  			dest: "bar/foo",
    65  			result: Files{
    66  				Records: map[string]cpio.Record{
    67  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
    68  				},
    69  			},
    70  			errContains: "already exists in archive",
    71  		},
    72  		{
    73  			af: Files{
    74  				Files: map[string]string{
    75  					"bar/foo": "/foo/bar",
    76  				},
    77  			},
    78  			src:  "/foo/bar",
    79  			dest: "bar/foo",
    80  			result: Files{
    81  				Files: map[string]string{
    82  					"bar/foo": "/foo/bar",
    83  				},
    84  			},
    85  		},
    86  		{
    87  			src:         "/foo/bar",
    88  			dest:        "/bar/foo",
    89  			errContains: "must not be absolute",
    90  		},
    91  		{
    92  			src:         "foo/bar",
    93  			dest:        "bar/foo",
    94  			errContains: "must be absolute",
    95  		},
    96  	} {
    97  		t.Run(fmt.Sprintf("Test %02d", i), func(t *testing.T) {
    98  			err := tt.af.AddFile(tt.src, tt.dest)
    99  			if err != nil && !strings.Contains(err.Error(), tt.errContains) {
   100  				t.Errorf("Error is %v, does not contain %v", err, tt.errContains)
   101  			}
   102  			if err == nil && len(tt.errContains) > 0 {
   103  				t.Errorf("Got no error, want %v", tt.errContains)
   104  			}
   105  
   106  			if !reflect.DeepEqual(tt.af, tt.result) {
   107  				t.Errorf("got %v, want %v", tt.af, tt.result)
   108  			}
   109  		})
   110  	}
   111  }
   112  
   113  func TestFilesAddRecord(t *testing.T) {
   114  	for i, tt := range []struct {
   115  		af     Files
   116  		record cpio.Record
   117  
   118  		result      Files
   119  		errContains string
   120  	}{
   121  		{
   122  			af:     NewFiles(),
   123  			record: cpio.Symlink("bar/foo", ""),
   124  			result: Files{
   125  				Files: map[string]string{},
   126  				Records: map[string]cpio.Record{
   127  					"bar/foo": cpio.Symlink("bar/foo", ""),
   128  				},
   129  			},
   130  		},
   131  		{
   132  			af: Files{
   133  				Files: map[string]string{
   134  					"bar/foo": "/some/other/place",
   135  				},
   136  			},
   137  			record: cpio.Symlink("bar/foo", ""),
   138  			result: Files{
   139  				Files: map[string]string{
   140  					"bar/foo": "/some/other/place",
   141  				},
   142  			},
   143  			errContains: "already exists in archive",
   144  		},
   145  		{
   146  			af: Files{
   147  				Records: map[string]cpio.Record{
   148  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   149  				},
   150  			},
   151  			record: cpio.Symlink("bar/foo", ""),
   152  			result: Files{
   153  				Records: map[string]cpio.Record{
   154  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   155  				},
   156  			},
   157  			errContains: "already exists in archive",
   158  		},
   159  		{
   160  			af: Files{
   161  				Records: map[string]cpio.Record{
   162  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   163  				},
   164  			},
   165  			record: cpio.Symlink("bar/foo", "/some/other/place"),
   166  			result: Files{
   167  				Records: map[string]cpio.Record{
   168  					"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
   169  				},
   170  			},
   171  		},
   172  		{
   173  			record:      cpio.Symlink("/bar/foo", ""),
   174  			errContains: "must not be absolute",
   175  		},
   176  	} {
   177  		t.Run(fmt.Sprintf("Test %02d", i), func(t *testing.T) {
   178  			err := tt.af.AddRecord(tt.record)
   179  			if err != nil && !strings.Contains(err.Error(), tt.errContains) {
   180  				t.Errorf("Error is %v, does not contain %v", err, tt.errContains)
   181  			}
   182  			if err == nil && len(tt.errContains) > 0 {
   183  				t.Errorf("Got no error, want %v", tt.errContains)
   184  			}
   185  
   186  			if !reflect.DeepEqual(tt.af, tt.result) {
   187  				t.Errorf("got %v, want %v", tt.af, tt.result)
   188  			}
   189  		})
   190  	}
   191  }
   192  
   193  func TestFilesfillInParent(t *testing.T) {
   194  	for i, tt := range []struct {
   195  		af     Files
   196  		result Files
   197  	}{
   198  		{
   199  			af: Files{
   200  				Records: map[string]cpio.Record{
   201  					"foo/bar": cpio.Directory("foo/bar", 0777),
   202  				},
   203  			},
   204  			result: Files{
   205  				Records: map[string]cpio.Record{
   206  					"foo/bar": cpio.Directory("foo/bar", 0777),
   207  					"foo":     cpio.Directory("foo", 0755),
   208  				},
   209  			},
   210  		},
   211  		{
   212  			af: Files{
   213  				Files: map[string]string{
   214  					"baz/baz/baz": "/somewhere",
   215  				},
   216  				Records: map[string]cpio.Record{
   217  					"foo/bar": cpio.Directory("foo/bar", 0777),
   218  				},
   219  			},
   220  			result: Files{
   221  				Files: map[string]string{
   222  					"baz/baz/baz": "/somewhere",
   223  				},
   224  				Records: map[string]cpio.Record{
   225  					"foo/bar": cpio.Directory("foo/bar", 0777),
   226  					"foo":     cpio.Directory("foo", 0755),
   227  					"baz":     cpio.Directory("baz", 0755),
   228  					"baz/baz": cpio.Directory("baz/baz", 0755),
   229  				},
   230  			},
   231  		},
   232  		{
   233  			af:     Files{},
   234  			result: Files{},
   235  		},
   236  	} {
   237  		t.Run(fmt.Sprintf("Test %02d", i), func(t *testing.T) {
   238  			tt.af.fillInParents()
   239  			if !reflect.DeepEqual(tt.af, tt.result) {
   240  				t.Errorf("got %v, want %v", tt.af, tt.result)
   241  			}
   242  		})
   243  	}
   244  }
   245  
   246  type MockArchiver struct {
   247  	Records      Records
   248  	FinishCalled bool
   249  	BaseArchive  []cpio.Record
   250  }
   251  
   252  func (ma *MockArchiver) WriteRecord(r cpio.Record) error {
   253  	if _, ok := ma.Records[r.Name]; ok {
   254  		return fmt.Errorf("file exists")
   255  	}
   256  	ma.Records[r.Name] = r
   257  	return nil
   258  }
   259  
   260  func (ma *MockArchiver) Finish() error {
   261  	ma.FinishCalled = true
   262  	return nil
   263  }
   264  
   265  func (ma *MockArchiver) ReadRecord() (cpio.Record, error) {
   266  	if len(ma.BaseArchive) > 0 {
   267  		next := ma.BaseArchive[0]
   268  		ma.BaseArchive = ma.BaseArchive[1:]
   269  		return next, nil
   270  	}
   271  	return cpio.Record{}, io.EOF
   272  }
   273  
   274  type Records map[string]cpio.Record
   275  
   276  func RecordsEqual(r1, r2 Records, recordEqual func(cpio.Record, cpio.Record) bool) bool {
   277  	for name, s1 := range r1 {
   278  		s2, ok := r2[name]
   279  		if !ok {
   280  			return false
   281  		}
   282  		if !recordEqual(s1, s2) {
   283  			return false
   284  		}
   285  	}
   286  	for name := range r2 {
   287  		if _, ok := r1[name]; !ok {
   288  			return false
   289  		}
   290  	}
   291  	return true
   292  }
   293  
   294  func sameNameModeContent(r1 cpio.Record, r2 cpio.Record) bool {
   295  	if r1.Name != r2.Name || r1.Mode != r2.Mode {
   296  		return false
   297  	}
   298  	return cpio.ReaderAtEqual(r1.ReaderAt, r2.ReaderAt)
   299  }
   300  
   301  func TestWriteFile(t *testing.T) {
   302  	unix.Umask(0)
   303  
   304  	for i, tt := range []struct {
   305  		ma   *MockArchiver
   306  		src  func() string
   307  		dest string
   308  		err  error
   309  		want Records
   310  	}{
   311  		{
   312  			ma: &MockArchiver{
   313  				Records: make(Records),
   314  			},
   315  			src: func() string {
   316  				f, err := ioutil.TempFile("", "foo")
   317  				if err != nil {
   318  					panic(err)
   319  				}
   320  				n := f.Name()
   321  				f.Close()
   322  				return n
   323  			},
   324  			dest: "foo/whatever",
   325  			want: Records{
   326  				"foo/whatever": cpio.Record{
   327  					Info: cpio.Info{
   328  						Name:  "foo/whatever",
   329  						Mode:  unix.S_IFREG | 0600,
   330  						UID:   uint64(os.Geteuid()),
   331  						GID:   uint64(os.Getegid()),
   332  						NLink: 1,
   333  						Major: 253,
   334  						Minor: 1,
   335  					},
   336  				},
   337  			},
   338  		},
   339  		{
   340  			ma: &MockArchiver{
   341  				Records: make(Records),
   342  			},
   343  			src: func() string {
   344  				f, err := ioutil.TempDir("", "foo")
   345  				if err != nil {
   346  					panic(err)
   347  				}
   348  				if err := ioutil.WriteFile(filepath.Join(f, "bla"), []byte("foo"), 0644); err != nil {
   349  					panic(err)
   350  				}
   351  				if err := ioutil.WriteFile(filepath.Join(f, "bla2"), []byte("foo2"), 0644); err != nil {
   352  					panic(err)
   353  				}
   354  				return f
   355  			},
   356  			dest: "etc",
   357  			want: Records{
   358  				"etc": cpio.Record{
   359  					Info: cpio.Info{
   360  						Name: "etc",
   361  						Mode: unix.S_IFDIR | 0700,
   362  					},
   363  				},
   364  				"etc/bla": cpio.Record{
   365  					Info: cpio.Info{
   366  						Name: "etc/bla",
   367  						Mode: unix.S_IFREG | 0644,
   368  					},
   369  					ReaderAt: bytes.NewReader([]byte("foo")),
   370  				},
   371  				"etc/bla2": cpio.Record{
   372  					Info: cpio.Info{
   373  						Name: "etc/bla2",
   374  						Mode: unix.S_IFREG | 0644,
   375  					},
   376  					ReaderAt: bytes.NewReader([]byte("foo2")),
   377  				},
   378  			},
   379  		},
   380  	} {
   381  		t.Run(fmt.Sprintf("Test %02d", i), func(t *testing.T) {
   382  			src := tt.src()
   383  			defer os.RemoveAll(src)
   384  			if err := WriteFile(tt.ma, src, tt.dest); err != tt.err {
   385  				t.Errorf("WriteFile() = %v, want %v", err, tt.err)
   386  			}
   387  			if !RecordsEqual(tt.ma.Records, tt.want, sameNameModeContent) {
   388  				t.Errorf("WriteFile() = %v, want %v", tt.ma.Records, tt.want)
   389  			}
   390  		})
   391  	}
   392  }
   393  
   394  func TestOptsWrite(t *testing.T) {
   395  	for i, tt := range []struct {
   396  		desc string
   397  		opts *Opts
   398  		ma   *MockArchiver
   399  		want Records
   400  		err  error
   401  	}{
   402  		{
   403  			desc: "no conflicts, just records",
   404  			opts: &Opts{
   405  				Files: Files{
   406  					Records: map[string]cpio.Record{
   407  						"foo": cpio.Symlink("foo", "elsewhere"),
   408  					},
   409  				},
   410  				DefaultRecords: []cpio.Record{
   411  					cpio.Directory("etc", 0777),
   412  					cpio.Directory("etc/nginx", 0777),
   413  				},
   414  			},
   415  			ma: &MockArchiver{
   416  				Records: make(Records),
   417  			},
   418  			want: Records{
   419  				"foo":       cpio.Symlink("foo", "elsewhere"),
   420  				"etc":       cpio.Directory("etc", 0777),
   421  				"etc/nginx": cpio.Directory("etc/nginx", 0777),
   422  			},
   423  		},
   424  		{
   425  			desc: "default already exists",
   426  			opts: &Opts{
   427  				Files: Files{
   428  					Records: map[string]cpio.Record{
   429  						"etc": cpio.Symlink("etc", "whatever"),
   430  					},
   431  				},
   432  				DefaultRecords: []cpio.Record{
   433  					cpio.Directory("etc", 0777),
   434  				},
   435  			},
   436  			ma: &MockArchiver{
   437  				Records: make(Records),
   438  			},
   439  			want: Records{
   440  				"etc": cpio.Symlink("etc", "whatever"),
   441  			},
   442  		},
   443  		{
   444  			desc: "no conflicts, missing parent automatically created",
   445  			opts: &Opts{
   446  				Files: Files{
   447  					Records: map[string]cpio.Record{
   448  						"foo/bar/baz": cpio.Symlink("foo/bar/baz", "elsewhere"),
   449  					},
   450  				},
   451  			},
   452  			ma: &MockArchiver{
   453  				Records: make(Records),
   454  			},
   455  			want: Records{
   456  				"foo":         cpio.Directory("foo", 0755),
   457  				"foo/bar":     cpio.Directory("foo/bar", 0755),
   458  				"foo/bar/baz": cpio.Symlink("foo/bar/baz", "elsewhere"),
   459  			},
   460  		},
   461  		{
   462  			desc: "parent only automatically created if not already exists",
   463  			opts: &Opts{
   464  				Files: Files{
   465  					Records: map[string]cpio.Record{
   466  						"foo/bar":     cpio.Directory("foo/bar", 0444),
   467  						"foo/bar/baz": cpio.Symlink("foo/bar/baz", "elsewhere"),
   468  					},
   469  				},
   470  			},
   471  			ma: &MockArchiver{
   472  				Records: make(Records),
   473  			},
   474  			want: Records{
   475  				"foo":         cpio.Directory("foo", 0755),
   476  				"foo/bar":     cpio.Directory("foo/bar", 0444),
   477  				"foo/bar/baz": cpio.Symlink("foo/bar/baz", "elsewhere"),
   478  			},
   479  		},
   480  		{
   481  			desc: "base archive",
   482  			opts: &Opts{
   483  				Files: Files{
   484  					Records: map[string]cpio.Record{
   485  						"foo/bar": cpio.Symlink("foo/bar", "elsewhere"),
   486  						"exists":  cpio.Directory("exists", 0777),
   487  					},
   488  				},
   489  			},
   490  			ma: &MockArchiver{
   491  				Records: make(Records),
   492  				BaseArchive: []cpio.Record{
   493  					cpio.Directory("etc", 0755),
   494  					cpio.Directory("foo", 0444),
   495  					cpio.Directory("exists", 0),
   496  				},
   497  			},
   498  			want: Records{
   499  				"etc":     cpio.Directory("etc", 0755),
   500  				"exists":  cpio.Directory("exists", 0777),
   501  				"foo":     cpio.Directory("foo", 0444),
   502  				"foo/bar": cpio.Symlink("foo/bar", "elsewhere"),
   503  			},
   504  		},
   505  		{
   506  			desc: "base archive with init, no user init",
   507  			opts: &Opts{
   508  				Files: Files{
   509  					Records: map[string]cpio.Record{},
   510  				},
   511  			},
   512  			ma: &MockArchiver{
   513  				Records: make(Records),
   514  				BaseArchive: []cpio.Record{
   515  					cpio.StaticFile("init", "boo", 0555),
   516  				},
   517  			},
   518  			want: Records{
   519  				"init": cpio.StaticFile("init", "boo", 0555),
   520  			},
   521  		},
   522  		{
   523  			desc: "base archive with init and user init",
   524  			opts: &Opts{
   525  				Files: Files{
   526  					Records: map[string]cpio.Record{
   527  						"init": cpio.StaticFile("init", "bar", 0444),
   528  					},
   529  				},
   530  			},
   531  			ma: &MockArchiver{
   532  				Records: make(Records),
   533  				BaseArchive: []cpio.Record{
   534  					cpio.StaticFile("init", "boo", 0555),
   535  				},
   536  			},
   537  			want: Records{
   538  				"init":  cpio.StaticFile("init", "bar", 0444),
   539  				"inito": cpio.StaticFile("inito", "boo", 0555),
   540  			},
   541  		},
   542  		{
   543  			desc: "base archive with init, use existing init",
   544  			opts: &Opts{
   545  				Files: Files{
   546  					Records: map[string]cpio.Record{},
   547  				},
   548  				UseExistingInit: true,
   549  			},
   550  			ma: &MockArchiver{
   551  				Records: make(Records),
   552  				BaseArchive: []cpio.Record{
   553  					cpio.StaticFile("init", "boo", 0555),
   554  				},
   555  			},
   556  			want: Records{
   557  				"init": cpio.StaticFile("init", "boo", 0555),
   558  			},
   559  		},
   560  		{
   561  			desc: "base archive with init and user init, use existing init",
   562  			opts: &Opts{
   563  				Files: Files{
   564  					Records: map[string]cpio.Record{
   565  						"init": cpio.StaticFile("init", "huh", 0111),
   566  					},
   567  				},
   568  				UseExistingInit: true,
   569  			},
   570  			ma: &MockArchiver{
   571  				Records: make(Records),
   572  				BaseArchive: []cpio.Record{
   573  					cpio.StaticFile("init", "boo", 0555),
   574  				},
   575  			},
   576  			want: Records{
   577  				"init":  cpio.StaticFile("init", "boo", 0555),
   578  				"inito": cpio.StaticFile("inito", "huh", 0111),
   579  			},
   580  		},
   581  	} {
   582  		t.Run(fmt.Sprintf("Test %02d (%s)", i, tt.desc), func(t *testing.T) {
   583  			tt.opts.BaseArchive = tt.ma
   584  			tt.opts.OutputFile = tt.ma
   585  
   586  			if err := Write(tt.opts); err != tt.err {
   587  				t.Errorf("Write() = %v, want %v", err, tt.err)
   588  			} else if err == nil && !tt.ma.FinishCalled {
   589  				t.Errorf("Finish wasn't called on archive")
   590  			}
   591  
   592  			if !RecordsEqual(tt.ma.Records, tt.want, sameNameModeContent) {
   593  				t.Errorf("Write() = %v, want %v", tt.ma.Records, tt.want)
   594  			}
   595  		})
   596  	}
   597  }