github.com/canonical/ubuntu-image@v0.0.0-20240430122802-2202fe98b290/cmd/ubuntu-image/main_test.go (about)

     1  package main
     2  
     3  import (
     4  	"errors"
     5  	"flag"
     6  	"io"
     7  	"os"
     8  	"testing"
     9  
    10  	"github.com/google/go-cmp/cmp"
    11  	"github.com/google/go-cmp/cmp/cmpopts"
    12  	"github.com/jessevdk/go-flags"
    13  	"github.com/snapcore/snapd/gadget"
    14  
    15  	"github.com/canonical/ubuntu-image/internal/commands"
    16  	"github.com/canonical/ubuntu-image/internal/helper"
    17  	"github.com/canonical/ubuntu-image/internal/statemachine"
    18  	"github.com/canonical/ubuntu-image/internal/testhelper"
    19  )
    20  
    21  var (
    22  	ErrAtSetup    = errors.New("Fail at Setup")
    23  	ErrAtRun      = errors.New("Fail at Run")
    24  	ErrAtTeardown = errors.New("Fail at Teardown")
    25  )
    26  
    27  type mockedStateMachine struct {
    28  	whenToFail string
    29  }
    30  
    31  func (mockSM *mockedStateMachine) Setup() error {
    32  	if mockSM.whenToFail == "Setup" {
    33  		return ErrAtSetup
    34  	}
    35  	return nil
    36  }
    37  
    38  func (mockSM *mockedStateMachine) Run() error {
    39  	if mockSM.whenToFail == "Run" {
    40  		return ErrAtRun
    41  	}
    42  	return nil
    43  }
    44  
    45  func (mockSM *mockedStateMachine) Teardown() error {
    46  	if mockSM.whenToFail == "Teardown" {
    47  		return ErrAtTeardown
    48  	}
    49  	return nil
    50  }
    51  
    52  func (mockSM *mockedStateMachine) SetCommonOpts(commonOpts *commands.CommonOpts, stateMachineOpts *commands.StateMachineOpts) {
    53  }
    54  
    55  // TestValidCommands tests that certain valid commands are parsed correctly
    56  func TestValidCommands(t *testing.T) {
    57  	t.Parallel()
    58  	testCases := []struct {
    59  		name    string
    60  		command string
    61  		flags   []string
    62  		field   func(*commands.UbuntuImageCommand) string
    63  		want    string
    64  	}{
    65  		{
    66  			name:    "valid_snap_command",
    67  			command: "snap",
    68  			flags:   []string{"model_assertion.yml"},
    69  			field: func(u *commands.UbuntuImageCommand) string {
    70  				return u.Snap.SnapArgsPassed.ModelAssertion
    71  			},
    72  			want: "model_assertion.yml",
    73  		},
    74  		{
    75  			name:    "valid_classic_command",
    76  			command: "classic",
    77  			flags:   []string{"image_defintion.yml"},
    78  			field: func(u *commands.UbuntuImageCommand) string {
    79  				return u.Classic.ClassicArgsPassed.ImageDefinition
    80  			},
    81  			want: "image_defintion.yml",
    82  		},
    83  		{
    84  			name:    "valid_pack_command",
    85  			command: "pack",
    86  			flags:   []string{"--artifact-type", "raw", "--gadget-dir", "./test-gadget-dir", "--rootfs-dir", "./test"},
    87  			field: func(u *commands.UbuntuImageCommand) string {
    88  				return u.Pack.PackOptsPassed.GadgetDir
    89  			},
    90  			want: "./test-gadget-dir",
    91  		},
    92  	}
    93  	for _, tc := range testCases {
    94  		t.Run(tc.name, func(t *testing.T) {
    95  			var args []string
    96  			if tc.command != "" {
    97  				args = append(args, tc.command)
    98  			}
    99  			if tc.flags != nil {
   100  				args = append(args, tc.flags...)
   101  			}
   102  
   103  			ubuntuImageCommand := &commands.UbuntuImageCommand{}
   104  			_, err := flags.ParseArgs(ubuntuImageCommand, args)
   105  			if err != nil {
   106  				t.Error("Did not expect an error but got", err)
   107  			}
   108  
   109  			got := tc.field(ubuntuImageCommand)
   110  			if tc.want != got {
   111  				t.Errorf("Unexpected parsed value \"%s\". Expected \"%s\"",
   112  					got, tc.want)
   113  			}
   114  		})
   115  	}
   116  }
   117  
   118  // TestInvalidCommands tests invalid commands argument/flag combinations
   119  func TestInvalidCommands(t *testing.T) {
   120  	t.Parallel()
   121  	testCases := []struct {
   122  		name          string
   123  		command       []string
   124  		flags         []string
   125  		expectedError string
   126  	}{
   127  		{"invalid_command", []string{"test"}, nil, "Unknown command `test'. Please specify one command of: classic or snap"},
   128  		{"no_model_assertion", []string{"snap"}, nil, "the required argument `model_assertion` was not provided"},
   129  		{"no_gadget_tree", []string{"classic"}, nil, "the required argument `image_definition` was not provided"},
   130  		{"invalid_flag", []string{"classic"}, []string{"--nonexistent"}, "unknown flag `nonexistent'"},
   131  		{"invalid_validation", []string{"snap"}, []string{"--validation=test"}, "unknown flag `validation'"},
   132  		{"invalid_sector_size", []string{"snap"}, []string{"--sector_size=123"}, "unknown flag `sector_size'"},
   133  		{"missing_one_flag", []string{"pack"}, []string{"--artifact-type=raw"}, "the required flags `--gadget-dir' and `--rootfs-dir' were not specified"},
   134  		{"missing_flags", []string{"pack"}, []string{"--artifact-type=raw", "--gadget-dir=./test"}, "the required flag `--rootfs-dir' was not specified"},
   135  	}
   136  	for _, tc := range testCases {
   137  		tc := tc // capture range variable for parallel execution
   138  		t.Run(tc.name, func(t *testing.T) {
   139  			asserter := helper.Asserter{T: t}
   140  
   141  			var args []string
   142  			if tc.command != nil {
   143  				args = append(args, tc.command...)
   144  			}
   145  			if tc.flags != nil {
   146  				args = append(args, tc.flags...)
   147  			}
   148  
   149  			// finally, execute the command and check output
   150  			ubuntuImageCommand := &commands.UbuntuImageCommand{}
   151  			_, gotErr := flags.ParseArgs(ubuntuImageCommand, args)
   152  			asserter.AssertErrContains(gotErr, tc.expectedError)
   153  
   154  		})
   155  	}
   156  }
   157  
   158  // TestExit code runs a number of commands, both valid and invalid, and ensures that the
   159  // program exits with the correct exit code
   160  func TestExitCode(t *testing.T) {
   161  	testCases := []struct {
   162  		name     string
   163  		flags    []string
   164  		expected int
   165  	}{
   166  		{"help_exit_0", []string{"--help"}, 0},
   167  		{"invalid_flag_exit_1", []string{"--help-me"}, 1},
   168  		{"bad_state_machine_args_classic", []string{"classic", "gadget_tree.yaml", "-u", "5", "-t", "6"}, 1},
   169  		{"bad_state_machine_args_snap", []string{"snap", "model_assertion.yaml", "-u", "5", "-t", "6"}, 1},
   170  		{"bad_state_machine_args_pack", []string{"pack", "--artifact-type", "raw", "--gadget-dir", "./test-gadget-dir", "--rootfs-dir", "./test", "-u", "5", "-t", "6"}, 1},
   171  		{"no_command_given", []string{}, 1},
   172  		{"resume_without_workdir", []string{"--resume"}, 1},
   173  		{"invalid_sector_size", []string{"--sector-size", "128", "--help"}, 1}, // Cheap trick with the --help to make the test work
   174  	}
   175  	for _, tc := range testCases {
   176  		t.Run(tc.name, func(t *testing.T) {
   177  			restoreCWD := testhelper.SaveCWD()
   178  			defer restoreCWD()
   179  			// Override os.Exit temporarily
   180  			oldOsExit := osExit
   181  			t.Cleanup(func() {
   182  				osExit = oldOsExit
   183  			})
   184  
   185  			var got int
   186  			tmpExit := func(code int) {
   187  				got = code
   188  			}
   189  
   190  			osExit = tmpExit
   191  
   192  			// set up the flags for the test cases
   193  			flag.CommandLine = flag.NewFlagSet(tc.name, flag.ExitOnError)
   194  			os.Args = append([]string{tc.name}, tc.flags...)
   195  
   196  			// os.Exit will be captured. Run the command with no flags to trigger an error
   197  			main()
   198  			if got != tc.expected {
   199  				t.Errorf("Expected exit code: %d, got: %d", tc.expected, got)
   200  			}
   201  			os.RemoveAll("/tmp/ubuntu-image-0615c8dd-d3af-4074-bfcb-c3d3c8392b06")
   202  		})
   203  	}
   204  }
   205  
   206  // TestVersion code runs ubuntu-image --version and checks if the resulting
   207  // version makes sense
   208  func TestVersion(t *testing.T) {
   209  	testCases := []struct {
   210  		name      string
   211  		hardcoded string
   212  		snapEnv   string
   213  		expected  string
   214  	}{
   215  		{"hardcoded_version", "2.0ubuntu1", "", "2.0ubuntu1"},
   216  		{"snap_version", "", "2.0+snap1", "2.0+snap1"},
   217  		{"both_hardcoded_and_snap", "2.0ubuntu1", "2.0+snap1", "2.0ubuntu1"},
   218  	}
   219  	for _, tc := range testCases {
   220  		t.Run(tc.name, func(t *testing.T) {
   221  			restoreCWD := testhelper.SaveCWD()
   222  			defer restoreCWD()
   223  			// Override os.Exit temporarily
   224  			oldOsExit := osExit
   225  			defer func() {
   226  				osExit = oldOsExit
   227  			}()
   228  
   229  			var got int
   230  			tmpExit := func(code int) {
   231  				got = code
   232  			}
   233  			osExit = tmpExit
   234  
   235  			// set up the flags for the test cases
   236  			flag.CommandLine = flag.NewFlagSet(tc.name, flag.ExitOnError)
   237  			os.Args = append([]string{tc.name}, "--version")
   238  
   239  			// pre-set the test-case environment
   240  			Version = tc.hardcoded
   241  			os.Setenv("SNAP_VERSION", tc.snapEnv)
   242  
   243  			main()
   244  			if got != 0 {
   245  				t.Errorf("Expected exit code: 0, got: %d", got)
   246  			}
   247  			os.Unsetenv("SNAP_VERSION")
   248  
   249  			// since we're printing the Version variable, no need to capture
   250  			// and analyze the output
   251  			if Version != tc.expected {
   252  				t.Errorf("Expected version string: '%s', got: '%s'", tc.expected, Version)
   253  			}
   254  		})
   255  	}
   256  }
   257  
   258  // TestFailedStdoutStderrCapture tests that scenarios involving failed stdout
   259  // and stderr captures and reads fail gracefully
   260  func TestFailedStdoutStderrCapture(t *testing.T) {
   261  	testCases := []struct {
   262  		name     string
   263  		stdCap   *os.File
   264  		readFrom *os.File
   265  		flags    []string
   266  	}{
   267  		{"error_capture_stdout", os.Stdout, os.Stdout, []string{}},
   268  		{"error_capture_stderr", os.Stderr, os.Stderr, []string{}},
   269  		{"error_read_stdout", os.Stdout, nil, []string{"--help"}},
   270  		{"error_read_stderr", os.Stderr, nil, []string{}},
   271  	}
   272  	for _, tc := range testCases {
   273  		t.Run(tc.name, func(t *testing.T) {
   274  			// Override os.Exit temporarily
   275  			oldOsExit := osExit
   276  			defer func() {
   277  				osExit = oldOsExit
   278  			}()
   279  
   280  			var got int
   281  			tmpExit := func(code int) {
   282  				got = code
   283  			}
   284  
   285  			osExit = tmpExit
   286  
   287  			// os.Exit will be captured. set the captureStd function
   288  			captureStd = func(toCap **os.File) (io.Reader, func(), error) {
   289  				var err error
   290  				if *toCap == tc.readFrom {
   291  					err = errors.New("Testing Error")
   292  				} else {
   293  					err = nil
   294  				}
   295  				return tc.readFrom, func() {}, err
   296  			}
   297  
   298  			// set up the flags for the test cases
   299  			flag.CommandLine = flag.NewFlagSet(tc.name, flag.ExitOnError)
   300  			os.Args = append([]string{tc.name}, tc.flags...)
   301  
   302  			// run main and check the exit code
   303  			main()
   304  			if got != 1 {
   305  				t.Errorf("Expected error code on exit, got: %d", got)
   306  			}
   307  
   308  		})
   309  	}
   310  }
   311  
   312  // TestExecuteStateMachine tests fails for all implemented functions to ensure
   313  // that main fails gracefully
   314  func TestExecuteStateMachine(t *testing.T) {
   315  	testCases := []struct {
   316  		name          string
   317  		whenToFail    string
   318  		expectedError string
   319  	}{
   320  		{
   321  			name:          "error_statemachine_setup",
   322  			whenToFail:    "Setup",
   323  			expectedError: ErrAtSetup.Error(),
   324  		},
   325  		{
   326  			name:          "error_statemachine_run",
   327  			whenToFail:    "Run",
   328  			expectedError: ErrAtRun.Error(),
   329  		},
   330  		{
   331  			name:          "error_statemachine_teardown",
   332  			whenToFail:    "Teardown",
   333  			expectedError: ErrAtTeardown.Error(),
   334  		},
   335  	}
   336  	for _, tc := range testCases {
   337  		t.Run(tc.name, func(t *testing.T) {
   338  			asserter := helper.Asserter{T: t}
   339  
   340  			flags := []string{"snap", "model_assertion"}
   341  			// set up the flags for the test cases
   342  			flag.CommandLine = flag.NewFlagSet("failed_state_machine", flag.ExitOnError)
   343  			os.Args = flags
   344  
   345  			gotErr := executeStateMachine(&mockedStateMachine{
   346  				whenToFail: tc.whenToFail,
   347  			})
   348  			asserter.AssertErrContains(gotErr, tc.expectedError)
   349  		})
   350  	}
   351  }
   352  
   353  func Test_initStateMachine(t *testing.T) {
   354  	asserter := helper.Asserter{T: t}
   355  	type args struct {
   356  		imageType          string
   357  		commonOpts         *commands.CommonOpts
   358  		stateMachineOpts   *commands.StateMachineOpts
   359  		ubuntuImageCommand *commands.UbuntuImageCommand
   360  	}
   361  
   362  	cmpOpts := []cmp.Option{
   363  		cmpopts.IgnoreUnexported(
   364  			statemachine.SnapStateMachine{},
   365  			statemachine.StateMachine{},
   366  			gadget.Info{},
   367  		),
   368  	}
   369  
   370  	tests := []struct {
   371  		name        string
   372  		args        args
   373  		want        statemachine.SmInterface
   374  		expectedErr string
   375  	}{
   376  		{
   377  			name: "init a snap state machine",
   378  			args: args{
   379  				imageType:        "snap",
   380  				commonOpts:       &commands.CommonOpts{},
   381  				stateMachineOpts: &commands.StateMachineOpts{},
   382  				ubuntuImageCommand: &commands.UbuntuImageCommand{
   383  					Snap: commands.SnapCommand{
   384  						SnapOptsPassed: commands.SnapOpts{},
   385  						SnapArgsPassed: commands.SnapArgs{},
   386  					},
   387  				},
   388  			},
   389  			want: &statemachine.SnapStateMachine{
   390  				StateMachine: statemachine.StateMachine{},
   391  				Opts:         commands.SnapOpts{},
   392  				Args:         commands.SnapArgs{},
   393  			},
   394  		},
   395  		{
   396  			name: "init a classic state machine",
   397  			args: args{
   398  				imageType:        "classic",
   399  				commonOpts:       &commands.CommonOpts{},
   400  				stateMachineOpts: &commands.StateMachineOpts{},
   401  				ubuntuImageCommand: &commands.UbuntuImageCommand{
   402  					Classic: commands.ClassicCommand{
   403  						ClassicArgsPassed: commands.ClassicArgs{},
   404  					},
   405  				},
   406  			},
   407  			want: &statemachine.ClassicStateMachine{
   408  				Args: commands.ClassicArgs{},
   409  			},
   410  		},
   411  		{
   412  			name: "fail to init an unknown statemachine",
   413  			args: args{
   414  				imageType:          "unknown",
   415  				commonOpts:         &commands.CommonOpts{},
   416  				stateMachineOpts:   &commands.StateMachineOpts{},
   417  				ubuntuImageCommand: &commands.UbuntuImageCommand{},
   418  			},
   419  			want:        nil,
   420  			expectedErr: "unsupported command",
   421  		},
   422  	}
   423  	for _, tc := range tests {
   424  		t.Run(tc.name, func(t *testing.T) {
   425  			got, err := initStateMachine(tc.args.imageType, tc.args.commonOpts, tc.args.stateMachineOpts, tc.args.ubuntuImageCommand)
   426  
   427  			if err != nil || len(tc.expectedErr) != 0 {
   428  				asserter.AssertErrContains(err, tc.expectedErr)
   429  			}
   430  
   431  			asserter.AssertEqual(tc.want, got, cmpOpts...)
   432  
   433  		})
   434  	}
   435  }