github.com/canonical/ubuntu-image@v0.0.0-20240430122802-2202fe98b290/internal/statemachine/mount_helper_test.go (about)

     1  package statemachine
     2  
     3  import (
     4  	"os"
     5  	"testing"
     6  
     7  	"github.com/google/go-cmp/cmp"
     8  
     9  	"github.com/canonical/ubuntu-image/internal/helper"
    10  )
    11  
    12  func Test_getMountCmd(t *testing.T) {
    13  	tests := []struct {
    14  		name           string
    15  		mp             mountPoint
    16  		wantMountCmds  []string
    17  		wantUmountCmds []string
    18  		expectedError  string
    19  	}{
    20  		{
    21  			name: "happy path",
    22  			mp: mountPoint{
    23  				src:      "src",
    24  				basePath: "targetDir",
    25  				relpath:  "mountpoint",
    26  				typ:      "devtmps",
    27  				opts:     []string{"nodev", "nosuid"},
    28  			},
    29  			wantMountCmds: []string{"/usr/bin/mount -t devtmps src -o nodev,nosuid targetDir/mountpoint"},
    30  			wantUmountCmds: []string{
    31  				"/usr/bin/mount --make-rprivate targetDir/mountpoint",
    32  				"/usr/bin/umount --recursive targetDir/mountpoint",
    33  			},
    34  		},
    35  		{
    36  			name: "no type",
    37  			mp: mountPoint{
    38  				src:      "src",
    39  				basePath: "targetDir",
    40  				relpath:  "mountpoint",
    41  				typ:      "",
    42  			},
    43  			wantMountCmds: []string{"/usr/bin/mount src targetDir/mountpoint"},
    44  			wantUmountCmds: []string{
    45  				"/usr/bin/mount --make-rprivate targetDir/mountpoint",
    46  				"/usr/bin/umount --recursive targetDir/mountpoint",
    47  			},
    48  		},
    49  		{
    50  			name: "bind mount",
    51  			mp: mountPoint{
    52  				src:      "src",
    53  				basePath: "targetDir",
    54  				relpath:  "mountpoint",
    55  				typ:      "",
    56  				bind:     true,
    57  			},
    58  			wantMountCmds: []string{"/usr/bin/mount --bind src targetDir/mountpoint"},
    59  			wantUmountCmds: []string{
    60  				"/usr/bin/mount --make-rprivate targetDir/mountpoint",
    61  				"/usr/bin/umount --recursive targetDir/mountpoint",
    62  			},
    63  		},
    64  		{
    65  			name: "no src",
    66  			mp: mountPoint{
    67  				src:      "",
    68  				basePath: "targetDir",
    69  				relpath:  "mountpoint",
    70  				typ:      "",
    71  				bind:     true,
    72  			},
    73  			wantMountCmds: []string{"/usr/bin/mount --bind  targetDir/mountpoint"},
    74  			wantUmountCmds: []string{
    75  				"/usr/bin/mount --make-rprivate targetDir/mountpoint",
    76  				"/usr/bin/umount --recursive targetDir/mountpoint",
    77  			},
    78  		},
    79  		{
    80  			name: "fail with bind and type",
    81  			mp: mountPoint{
    82  				src:      "src",
    83  				basePath: "targetDir",
    84  				relpath:  "mountpoint",
    85  				typ:      "devtmps",
    86  				bind:     true,
    87  			},
    88  			wantMountCmds:  []string{},
    89  			wantUmountCmds: []string{},
    90  			expectedError:  "invalid mount arguments. Cannot use --bind and -t at the same time.",
    91  		},
    92  	}
    93  	for _, tc := range tests {
    94  		t.Run(tc.name, func(t *testing.T) {
    95  			asserter := helper.Asserter{T: t}
    96  			gotMountCmds, gotUmountCmds, err := tc.mp.getMountCmd()
    97  
    98  			if len(tc.expectedError) == 0 {
    99  				asserter.AssertErrNil(err, true)
   100  			} else {
   101  				asserter.AssertErrContains(err, tc.expectedError)
   102  			}
   103  
   104  			gotMountCmdsStr := make([]string, 0)
   105  			gotUmountCmdsStr := make([]string, 0)
   106  
   107  			for _, c := range gotMountCmds {
   108  				gotMountCmdsStr = append(gotMountCmdsStr, c.String())
   109  			}
   110  
   111  			for _, c := range gotUmountCmds {
   112  				gotUmountCmdsStr = append(gotUmountCmdsStr, c.String())
   113  			}
   114  			asserter.AssertEqual(tc.wantMountCmds, gotMountCmdsStr)
   115  			asserter.AssertEqual(tc.wantUmountCmds, gotUmountCmdsStr)
   116  
   117  		})
   118  	}
   119  }
   120  
   121  func Test_getMountCmd_fail(t *testing.T) {
   122  	asserter := helper.Asserter{T: t}
   123  
   124  	// mock os.Mkdir
   125  	osMkdirAll = mockMkdirAll
   126  	t.Cleanup(func() {
   127  		osMkdirAll = os.MkdirAll
   128  	})
   129  
   130  	mp := mountPoint{
   131  		typ:      "devtmps",
   132  		basePath: "/tmp",
   133  		relpath:  "1234567",
   134  		src:      "src",
   135  	}
   136  
   137  	gotMountCmds, gotUmountCmds, err := mp.getMountCmd()
   138  	asserter.AssertErrContains(err, "Error creating mountpoint")
   139  	if gotMountCmds != nil {
   140  		asserter.Errorf("gotMountCmds should be nil but is %s", gotMountCmds)
   141  	}
   142  	if gotUmountCmds != nil {
   143  		asserter.Errorf("gotUmountCmds should be nil but is %s", gotUmountCmds)
   144  	}
   145  }
   146  
   147  var (
   148  	mp1 = mountPoint{
   149  		src:      "srcmp1",
   150  		path:     "src1basePath/src1relpath",
   151  		basePath: "src1basePath",
   152  		relpath:  "src1relpath",
   153  		typ:      "devtmpfs",
   154  	}
   155  	mp2 = mountPoint{
   156  		src:      "srcmp2",
   157  		path:     "src2basePath/src2relpath",
   158  		basePath: "src2basePath",
   159  		relpath:  "src2relpath",
   160  		typ:      "devpts",
   161  	}
   162  	mp3 = mountPoint{
   163  		src:      "srcmp3",
   164  		path:     "src3basePath/src3relpath",
   165  		basePath: "src3basePath",
   166  		relpath:  "src3relpath",
   167  		typ:      "proc",
   168  	}
   169  	mp4 = mountPoint{
   170  		src:      "srcmp4",
   171  		path:     "src4basePath/src4relpath",
   172  		basePath: "src4basePath",
   173  		relpath:  "src4relpath",
   174  		typ:      "cgroup2",
   175  	}
   176  	mp11 = mountPoint{
   177  		src:      "srcmp12",
   178  		path:     "src1basePath/src1relpath",
   179  		basePath: "src1basePath",
   180  		relpath:  "src1relpath",
   181  		typ:      "devtmpfs",
   182  	}
   183  	mp21 = mountPoint{
   184  		src:      "srcmp2",
   185  		path:     "",
   186  		basePath: "src21basePath",
   187  		relpath:  "src2relpath",
   188  		typ:      "devpts",
   189  	}
   190  	mp31 = mountPoint{
   191  		src:      "srcmp3",
   192  		path:     "",
   193  		basePath: "src3basePath",
   194  		relpath:  "src31relpath",
   195  		typ:      "proc",
   196  	}
   197  	mp41 = mountPoint{
   198  		src:      "srcmp4",
   199  		path:     "src4basePath/src4relpath",
   200  		basePath: "src4basePath",
   201  		relpath:  "src4relpath",
   202  		typ:      "anotherType",
   203  	}
   204  )
   205  
   206  func Test_diffMountPoints(t *testing.T) {
   207  	asserter := helper.Asserter{T: t}
   208  	type args struct {
   209  		olds     []*mountPoint
   210  		currents []*mountPoint
   211  	}
   212  
   213  	cmpOpts := []cmp.Option{
   214  		cmp.AllowUnexported(
   215  			mountPoint{},
   216  		),
   217  	}
   218  
   219  	tests := []struct {
   220  		name string
   221  		args args
   222  		want []*mountPoint
   223  	}{
   224  		{
   225  			name: "same mounpoints, ignoring list order",
   226  			args: args{
   227  				olds: []*mountPoint{
   228  					&mp1,
   229  					&mp2,
   230  					&mp3,
   231  					&mp4,
   232  				},
   233  				currents: []*mountPoint{
   234  					&mp4,
   235  					&mp1,
   236  					&mp3,
   237  					&mp2,
   238  				},
   239  			},
   240  			want: nil,
   241  		},
   242  		{
   243  			name: "add some",
   244  			args: args{
   245  				olds: []*mountPoint{
   246  					&mp1,
   247  					&mp2,
   248  				},
   249  				currents: []*mountPoint{
   250  					&mp3,
   251  					&mp4,
   252  				},
   253  			},
   254  			want: []*mountPoint{
   255  				&mp3,
   256  				&mp4,
   257  			},
   258  		},
   259  		{
   260  			name: "no old ones",
   261  			args: args{
   262  				olds: nil,
   263  				currents: []*mountPoint{
   264  					&mp3,
   265  					&mp4,
   266  				},
   267  			},
   268  			want: []*mountPoint{
   269  				&mp3,
   270  				&mp4,
   271  			},
   272  		},
   273  		{
   274  			name: "no current ones",
   275  			args: args{
   276  				olds: []*mountPoint{
   277  					&mp1,
   278  					&mp2,
   279  				},
   280  				currents: nil,
   281  			},
   282  			want: nil,
   283  		},
   284  		{
   285  			name: "difference in src, relpath, basepath and typ",
   286  			args: args{
   287  				olds: []*mountPoint{
   288  					&mp1,
   289  					&mp2,
   290  					&mp3,
   291  					&mp4,
   292  				},
   293  				currents: []*mountPoint{
   294  					&mp11,
   295  					&mp21,
   296  					&mp31,
   297  					&mp41,
   298  				},
   299  			},
   300  			want: []*mountPoint{
   301  				&mp11,
   302  				&mp21,
   303  				&mp31,
   304  				&mp41,
   305  			},
   306  		},
   307  	}
   308  	for _, tt := range tests {
   309  		t.Run(tt.name, func(t *testing.T) {
   310  			got := diffMountPoints(tt.args.olds, tt.args.currents)
   311  			asserter.AssertEqual(tt.want, got, cmpOpts...)
   312  		})
   313  	}
   314  }