github.com/GoogleCloudPlatform/compute-image-tools/cli_tools@v0.0.0-20240516224744-de2dabc4ed1b/gce_windows_upgrade/upgrader/validators_test.go (about)

     1  //  Copyright 2020 Google Inc. All Rights Reserved.
     2  //
     3  //  Licensed under the Apache License, Version 2.0 (the "License");
     4  //  you may not use this file except in compliance with the License.
     5  //  You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  //  Unless required by applicable law or agreed to in writing, software
    10  //  distributed under the License is distributed on an "AS IS" BASIS,
    11  //  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  //  See the License for the specific language governing permissions and
    13  //  limitations under the License.
    14  
    15  package upgrader
    16  
    17  import (
    18  	"fmt"
    19  	"strings"
    20  	"testing"
    21  
    22  	"github.com/golang/mock/gomock"
    23  	"github.com/stretchr/testify/assert"
    24  	"google.golang.org/api/compute/v1"
    25  
    26  	"github.com/GoogleCloudPlatform/compute-image-tools/cli_tools/common/domain"
    27  	"github.com/GoogleCloudPlatform/compute-image-tools/cli_tools/common/utils/daisyutils"
    28  	"github.com/GoogleCloudPlatform/compute-image-tools/mocks"
    29  )
    30  
    31  func init() {
    32  	initTest()
    33  }
    34  
    35  func TestValidateParams(t *testing.T) {
    36  	type testCase struct {
    37  		testName        string
    38  		u               *upgrader
    39  		expectedError   string
    40  		expectedTimeout string
    41  	}
    42  
    43  	var u *upgrader
    44  	var tcs []testCase
    45  
    46  	tcs = append(tcs, testCase{"Normal case", newTestUpgrader().upgrader, "", DefaultTimeout})
    47  
    48  	u = newTestUpgrader().upgrader
    49  	u.ClientID = ""
    50  	tcs = append(tcs, testCase{"clientID is optional", u, "", DefaultTimeout})
    51  
    52  	u = newTestUpgrader().upgrader
    53  	u.SourceOS = "android"
    54  	tcs = append(tcs, testCase{"validateOSVersion failure", u,
    55  		fmt.Sprintf("Flag -source-os value 'android' unsupported. Please choose a supported version from {%v}.", strings.Join(SupportedVersions, ", ")), DefaultTimeout})
    56  
    57  	u = newTestUpgrader().upgrader
    58  	u.Instance = "bad/url"
    59  	tcs = append(tcs, testCase{"validateAndDeriveInstanceURI failure", u,
    60  		"Please provide the instance flag either with the name of the instance or in the form of 'projects/<project>/zones/<zone>/instances/<instance>', not bad/url", DefaultTimeout})
    61  
    62  	u = newTestUpgrader().upgrader
    63  	u.Instance = daisyutils.GetInstanceURI(testProject, testZone, testInstanceNoLicense)
    64  	tcs = append(tcs, testCase{"validateAndDeriveInstance failure", u,
    65  		"No valid Windows Server PayG license can be found. Any of the following licenses are required: [projects/windows-cloud/global/licenses/windows-server-2008-r2-dc]", DefaultTimeout})
    66  
    67  	u = newTestUpgrader().upgrader
    68  	u.Timeout = "1m"
    69  	tcs = append(tcs, testCase{"override timeout", u, "", "1m"})
    70  
    71  	for _, tc := range tcs {
    72  		u = tc.u
    73  		err := u.validateAndDeriveParams()
    74  		if tc.expectedError != "" {
    75  			assert.EqualErrorf(t, err, tc.expectedError, "[test name: %v] Unexpected error.", tc.testName)
    76  		} else {
    77  			assert.NoError(t, err, "[test name: %v] Unexpected error.", tc.testName)
    78  		}
    79  		if err != nil {
    80  			continue
    81  		}
    82  
    83  		assert.Equalf(t, tc.expectedTimeout, u.Timeout, "[test name: %v] Unexpected Timeout.", tc.testName)
    84  		assert.NotEmptyf(t, u.machineImageBackupName, "[test name: %v] Unexpected machineImageBackupName.", tc.testName)
    85  		assert.NotEmptyf(t, u.osDiskSnapshotName, "[test name: %v] Unexpected osDiskSnapshotName.", tc.testName)
    86  		assert.NotEmptyf(t, u.newOSDiskName, "[test name: %v] Unexpected newOSDiskName.", tc.testName)
    87  		assert.NotEmptyf(t, u.installMediaDiskName, "[test name: %v] Unexpected installMediaDiskName.", tc.testName)
    88  		assert.Equalf(t, testProject, *u.ProjectPtr, "[test name: %v] Unexpected ProjectPtr value.", tc.testName)
    89  	}
    90  }
    91  
    92  func TestValidateOSVersion(t *testing.T) {
    93  	type testCase struct {
    94  		testName      string
    95  		sourceOS      string
    96  		targetOS      string
    97  		expectedError string
    98  	}
    99  
   100  	tcs := []testCase{
   101  		{"Unsupported source OS", "windows-2008", "windows-2008r2", fmt.Sprintf("Flag -source-os value 'windows-2008' unsupported. Please choose a supported version from {%v}.", strings.Join(SupportedVersions, ", "))},
   102  		{"Unsupported target OS", "windows-2008r2", "windows-2012", fmt.Sprintf("Flag -target-os value 'windows-2012' unsupported. Please choose a supported version from {%v}.", strings.Join(SupportedVersions, ", "))},
   103  		{"Source OS not provided", "", versionWindows2012r2, fmt.Sprintf("Flag -source-os must be provided. Please choose a supported version from {%v}.", strings.Join(SupportedVersions, ", "))},
   104  		{"Target OS not provided", versionWindows2008r2, "", fmt.Sprintf("Flag -target-os must be provided. Please choose a supported version from {%v}.", strings.Join(SupportedVersions, ", "))},
   105  	}
   106  	for _, supportedSourceOS := range SupportedVersions {
   107  		for _, supportedTargetOS := range SupportedVersions {
   108  			expectedError := ""
   109  			if !isSupportedUpgradePath(supportedSourceOS, supportedTargetOS) {
   110  				expectedError = "Can't upgrade from"
   111  			}
   112  			tcs = append(tcs, testCase{
   113  				fmt.Sprintf("From %v to %v", supportedSourceOS, supportedTargetOS),
   114  				supportedSourceOS,
   115  				supportedTargetOS,
   116  				expectedError,
   117  			})
   118  		}
   119  	}
   120  
   121  	for _, tc := range tcs {
   122  		err := validateOSVersion(tc.sourceOS, tc.targetOS)
   123  		if tc.expectedError != "" {
   124  			assert.Truef(t, strings.HasPrefix(err.Error(), tc.expectedError), "[test name: %v] expected error: %v, actual: %v", tc.testName, tc.expectedError, err)
   125  		} else {
   126  			assert.NoError(t, err, "[test name: %v]", tc.testName)
   127  		}
   128  	}
   129  }
   130  
   131  func TestValidateInstance(t *testing.T) {
   132  	type testCase struct {
   133  		testName             string
   134  		instance             string
   135  		expectedURIError     string
   136  		expectedError        string
   137  		inputProject         string
   138  		inputZone            string
   139  		expectedProject      string
   140  		expectedZone         string
   141  		expectedInstanceName string
   142  		mgce                 domain.MetadataGCEInterface
   143  	}
   144  
   145  	mockCtrl := gomock.NewController(t)
   146  	defer mockCtrl.Finish()
   147  	mockMetadataGce := mocks.NewMockMetadataGCEInterface(mockCtrl)
   148  	mockMetadataGce.EXPECT().OnGCE().Return(true)
   149  	mockMetadataGce.EXPECT().ProjectID().Return(testProject2, nil)
   150  
   151  	mockMetadataGceFail := mocks.NewMockMetadataGCEInterface(mockCtrl)
   152  	mockMetadataGceFail.EXPECT().OnGCE().Return(false)
   153  
   154  	tcs := []testCase{
   155  		{
   156  			"Normal case without original startup script",
   157  			daisyutils.GetInstanceURI(testProject, testZone, testInstance),
   158  			"",
   159  			"",
   160  			"",
   161  			"",
   162  			testProject, testZone, testInstance,
   163  			mockMetadataGce,
   164  		},
   165  		{
   166  			"Normal case with original startup script",
   167  			daisyutils.GetInstanceURI(testProject, testZone, testInstanceWithStartupScript),
   168  			"",
   169  			"",
   170  			"",
   171  			"",
   172  			testProject, testZone, testInstanceWithStartupScript,
   173  			mockMetadataGce,
   174  		},
   175  		{
   176  			"Normal case with existing startup script backup",
   177  			daisyutils.GetInstanceURI(testProject, testZone, testInstanceWithExistingStartupScriptBackup),
   178  			"",
   179  			"",
   180  			"",
   181  			"",
   182  			testProject, testZone, testInstanceWithExistingStartupScriptBackup,
   183  			mockMetadataGce,
   184  		},
   185  		{
   186  			"No disk error",
   187  			daisyutils.GetInstanceURI(testProject, testZone, testInstanceNoDisk),
   188  			"",
   189  			"No disks attached to the instance.",
   190  			"",
   191  			"",
   192  			testProject, testZone, testInstanceNoDisk,
   193  			mockMetadataGce,
   194  		},
   195  		{
   196  			"License error",
   197  			daisyutils.GetInstanceURI(testProject, testZone, testInstanceNoLicense),
   198  			"",
   199  			"No valid Windows Server PayG license can be found. Any of the following licenses are required: [projects/windows-cloud/global/licenses/windows-server-2008-r2-dc]",
   200  			"",
   201  			"",
   202  			testProject, testZone, testInstanceNoLicense,
   203  			mockMetadataGce,
   204  		},
   205  		{
   206  			"OS disk error",
   207  			daisyutils.GetInstanceURI(testProject, testZone, testInstanceNoBootDisk),
   208  			"",
   209  			"The instance has no boot disk.",
   210  			"",
   211  			"",
   212  			testProject, testZone, testInstanceNoBootDisk,
   213  			mockMetadataGce,
   214  		},
   215  		{
   216  			"Instance doesn't exist",
   217  			daisyutils.GetInstanceURI(testProject, testZone, DNE),
   218  			"",
   219  			"Failed to get instance: googleapi: got HTTP response code 404 with body: ",
   220  			"",
   221  			"",
   222  			testProject, testZone, DNE,
   223  			mockMetadataGce,
   224  		},
   225  		{
   226  			"Bad instance flag error",
   227  			"bad/url",
   228  			"Please provide the instance flag either with the name of the instance or in the form of 'projects/<project>/zones/<zone>/instances/<instance>', not bad/url",
   229  			"",
   230  			"",
   231  			"",
   232  			testProject, testZone, "bad/url",
   233  			mockMetadataGce,
   234  		},
   235  		{
   236  			"No instance flag",
   237  			"",
   238  			"Flag -instance must be provided",
   239  			"",
   240  			"",
   241  			"",
   242  			testProject, testZone, "",
   243  			mockMetadataGce,
   244  		},
   245  		{
   246  			"Instance name without project",
   247  			testInstance,
   248  			"project cannot be determined because build is not running on GCE",
   249  			"",
   250  			"",
   251  			testZone2,
   252  			"", testZone2, testInstance,
   253  			mockMetadataGceFail,
   254  		},
   255  		{
   256  			"Instance name with fallback project (on GCE)",
   257  			testInstance,
   258  			"",
   259  			"",
   260  			"",
   261  			testZone2,
   262  			testProject2, testZone2, testInstance,
   263  			mockMetadataGce,
   264  		},
   265  		{
   266  			"Instance name without input zone",
   267  			testInstance,
   268  			"--zone must be provided when --instance is not a URI with zone info.",
   269  			"",
   270  			testProject2,
   271  			"",
   272  			testProject2, testZone2, testInstance,
   273  			mockMetadataGce,
   274  		},
   275  		{
   276  			"Instance name with input project and zone",
   277  			testInstance,
   278  			"",
   279  			"",
   280  			testProject2,
   281  			testZone2,
   282  			testProject2, testZone2, testInstance,
   283  			mockMetadataGce,
   284  		},
   285  		{
   286  			"Override input project and zone",
   287  			daisyutils.GetInstanceURI(testProject, testZone, testInstance),
   288  			"",
   289  			"",
   290  			testProject2,
   291  			testZone2,
   292  			testProject, testZone, testInstance,
   293  			mockMetadataGce,
   294  		},
   295  	}
   296  
   297  	originalMGCE := mgce
   298  	defer func() {
   299  		mgce = originalMGCE
   300  	}()
   301  
   302  	for _, tc := range tcs {
   303  		derivedVars := derivedVars{}
   304  		mgce = tc.mgce
   305  
   306  		err := validateAndDeriveInstanceURI(tc.instance, &tc.inputProject, tc.inputZone, &derivedVars)
   307  		if tc.expectedURIError != "" {
   308  			assert.EqualErrorf(t, err, tc.expectedURIError, "[test name: %v] Unexpected error from validateAndDeriveInstanceURI.", tc.testName)
   309  			continue
   310  		} else {
   311  			assert.NoErrorf(t, err, "[test name: %v] Unexpected error from validateAndDeriveInstanceURI.", tc.testName)
   312  			if err != nil {
   313  				continue
   314  			}
   315  		}
   316  		if !instanceURLRgx.Match([]byte(derivedVars.instanceURI)) {
   317  			t.Errorf("[%v]: Expect correct derivedVars.instanceURI format error but it's bad format %v.", tc.testName, derivedVars.instanceURI)
   318  			continue
   319  		}
   320  
   321  		if tc.expectedProject != derivedVars.instanceProject || tc.expectedZone != derivedVars.instanceZone || tc.expectedInstanceName != derivedVars.instanceName {
   322  			t.Errorf("[%v]: Unexpected breakdown of instance URI. Actual project, zone, instanceName are %v, %v, %v while expect %v, %v, %v.",
   323  				tc.testName, derivedVars.instanceProject, derivedVars.instanceZone, derivedVars.instanceName,
   324  				tc.expectedProject, tc.expectedZone, tc.expectedInstanceName)
   325  		}
   326  		expectedURI := daisyutils.GetInstanceURI(tc.expectedProject, tc.expectedZone, tc.expectedInstanceName)
   327  		if expectedURI != derivedVars.instanceURI {
   328  			t.Errorf("[%v]: Unexpected instance URI. Actual: %v, while expect: %v.",
   329  				tc.testName, derivedVars.instanceURI, expectedURI)
   330  		}
   331  
   332  		err = validateAndDeriveInstance(&derivedVars, testSourceOS, testTargetOS)
   333  		if tc.expectedError == "" {
   334  			if err != nil {
   335  				t.Errorf("[%v]: Unexpected error: %v", tc.testName, err)
   336  			} else {
   337  				if derivedVars.instanceName == testInstance {
   338  					assert.Nil(t, derivedVars.originalWindowsStartupScriptURL, "[test name: %v] Unexpected derivedVars.originalWindowsStartupScriptURL.", tc.testName)
   339  				} else if derivedVars.instanceName == testInstanceWithStartupScript ||
   340  					derivedVars.instanceName == testInstanceWithExistingStartupScriptBackup {
   341  					if derivedVars.originalWindowsStartupScriptURL == nil || *derivedVars.originalWindowsStartupScriptURL != testOriginalStartupScript {
   342  						t.Errorf("[%v]: Unexpected originalWindowsStartupScriptURL: %v, expect: %v", tc.testName, derivedVars.originalWindowsStartupScriptURL, testOriginalStartupScript)
   343  					}
   344  				}
   345  			}
   346  		} else {
   347  			assert.EqualErrorf(t, err, tc.expectedError, "[test name: %v]: Unexpected error from validateAndDeriveInstance.", tc.testName)
   348  		}
   349  	}
   350  }
   351  
   352  func TestValidateOSDisk(t *testing.T) {
   353  	type testCase struct {
   354  		testName      string
   355  		osDisk        *compute.AttachedDisk
   356  		expectedError string
   357  	}
   358  
   359  	tcs := []testCase{
   360  		{
   361  			"Disk exists",
   362  			&compute.AttachedDisk{Source: testDiskURI, DeviceName: testDiskDeviceName,
   363  				AutoDelete: testDiskAutoDelete, Boot: true},
   364  			"",
   365  		},
   366  		{
   367  			"Disk not exist",
   368  			&compute.AttachedDisk{Source: daisyutils.GetDiskURI(testProject, testZone, DNE),
   369  				DeviceName: testDiskDeviceName, AutoDelete: testDiskAutoDelete, Boot: true},
   370  			"Failed to get boot disk info: googleapi: got HTTP response code 404 with body: ",
   371  		},
   372  		{
   373  			"Disk not boot",
   374  			&compute.AttachedDisk{Source: testDiskURI, DeviceName: testDiskDeviceName,
   375  				AutoDelete: testDiskAutoDelete, Boot: false},
   376  			"The instance has no boot disk.",
   377  		},
   378  	}
   379  
   380  	for _, tc := range tcs {
   381  		derivedVars := derivedVars{}
   382  		err := validateAndDeriveOSDisk(tc.osDisk, &derivedVars)
   383  		if tc.expectedError == "" {
   384  			if err != nil {
   385  				t.Errorf("[%v]: Unexpected error: %v", tc.testName, err)
   386  			} else {
   387  				assert.Equalf(t, testDiskURI, derivedVars.osDiskURI, "[%v]: Unexpected derivedVars.osDiskURI", tc.testName)
   388  				assert.Equalf(t, testDiskDeviceName, derivedVars.osDiskDeviceName, "[%v]: Unexpected derivedVars.osDiskDeviceName", tc.testName)
   389  				assert.Equalf(t, testDiskAutoDelete, derivedVars.osDiskAutoDelete, "[%v]: Unexpected derivedVars.osDiskAutoDelete", tc.testName)
   390  				assert.Equalf(t, testDiskType, derivedVars.osDiskType, "[%v]: Unexpected derivedVars.osDiskType", tc.testName)
   391  			}
   392  		} else {
   393  			assert.EqualErrorf(t, err, tc.expectedError, "[test name: %v] Unexpected error.", tc.testName)
   394  		}
   395  	}
   396  }
   397  
   398  func TestValidateLicense(t *testing.T) {
   399  	type testCase struct {
   400  		testName      string
   401  		osDisk        *compute.AttachedDisk
   402  		expectedError string
   403  	}
   404  
   405  	tcs := []testCase{
   406  		{
   407  			"No license",
   408  			&compute.AttachedDisk{},
   409  			"No valid Windows Server PayG license can be found. Any of the following licenses are required: [projects/windows-cloud/global/licenses/windows-server-2008-r2-dc]",
   410  		},
   411  		{
   412  			"No expected license",
   413  			&compute.AttachedDisk{
   414  				Licenses: []string{
   415  					"random-license",
   416  				}},
   417  			"No valid Windows Server PayG license can be found. Any of the following licenses are required: [projects/windows-cloud/global/licenses/windows-server-2008-r2-dc]",
   418  		},
   419  		{
   420  			"Expected license",
   421  			&compute.AttachedDisk{
   422  				Licenses: []string{
   423  					upgradePaths[testSourceOS][testTargetOS].expectedCurrentLicense[0],
   424  				}},
   425  			"",
   426  		},
   427  		{
   428  			"Expected license with some other license",
   429  			&compute.AttachedDisk{
   430  				Licenses: []string{
   431  					"random-1",
   432  					upgradePaths[testSourceOS][testTargetOS].expectedCurrentLicense[0],
   433  					"random-2",
   434  				}},
   435  			"",
   436  		},
   437  		{
   438  			"Upgraded",
   439  			&compute.AttachedDisk{
   440  				Licenses: []string{
   441  					upgradePaths[testSourceOS][testTargetOS].expectedCurrentLicense[0],
   442  					upgradePaths[testSourceOS][testTargetOS].licenseToAdd,
   443  				}},
   444  			"The GCE instance has the projects/windows-cloud/global/licenses/windows-server-2012-r2-dc-in-place-upgrade license attached. This likely means the instance either has been upgraded or has started an upgrade in the past.",
   445  		},
   446  	}
   447  
   448  	for _, tc := range tcs {
   449  		err := validateLicense(tc.osDisk, testSourceOS, testTargetOS)
   450  		if tc.expectedError != "" {
   451  			assert.EqualErrorf(t, err, tc.expectedError, "[test name: %v]", tc.testName)
   452  		} else {
   453  			assert.NoError(t, err, "[test name: %v]", tc.testName)
   454  		}
   455  	}
   456  }