sigs.k8s.io/cluster-api-provider-azure@v1.14.3/azure/defaults_test.go (about)

     1  /*
     2  Copyright 2019 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package azure
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"net/http"
    23  	"net/http/httptest"
    24  	"testing"
    25  
    26  	"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
    27  	"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
    28  	"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
    29  	. "github.com/onsi/gomega"
    30  	"go.uber.org/mock/gomock"
    31  	"sigs.k8s.io/cluster-api-provider-azure/azure/mock_azure"
    32  	"sigs.k8s.io/cluster-api-provider-azure/util/tele"
    33  )
    34  
    35  // TestARMClientOptions tests the `ARMClientOptions()` factory function.
    36  func TestARMClientOptions(t *testing.T) {
    37  	tests := []struct {
    38  		name          string
    39  		cloudName     string
    40  		expectedCloud cloud.Configuration
    41  		expectError   bool
    42  	}{
    43  		{
    44  			name:          "should return default client options if cloudName is empty",
    45  			cloudName:     "",
    46  			expectedCloud: cloud.Configuration{},
    47  		},
    48  		{
    49  			name:          "should return Azure public cloud client options",
    50  			cloudName:     PublicCloudName,
    51  			expectedCloud: cloud.AzurePublic,
    52  		},
    53  		{
    54  			name:          "should return Azure China cloud client options",
    55  			cloudName:     ChinaCloudName,
    56  			expectedCloud: cloud.AzureChina,
    57  		},
    58  		{
    59  			name:          "should return Azure government cloud client options",
    60  			cloudName:     USGovernmentCloudName,
    61  			expectedCloud: cloud.AzureGovernment,
    62  		},
    63  		{
    64  			name:        "should return error if cloudName is unrecognized",
    65  			cloudName:   "AzureUnrecognizedCloud",
    66  			expectError: true,
    67  		},
    68  	}
    69  	for _, tc := range tests {
    70  		tc := tc
    71  		t.Run(tc.name, func(t *testing.T) {
    72  			t.Parallel()
    73  			g := NewWithT(t)
    74  
    75  			opts, err := ARMClientOptions(tc.cloudName)
    76  			if tc.expectError {
    77  				g.Expect(err).To(HaveOccurred())
    78  				return
    79  			}
    80  			g.Expect(err).NotTo(HaveOccurred())
    81  			g.Expect(opts.Cloud).To(Equal(tc.expectedCloud))
    82  			g.Expect(opts.Retry.MaxRetries).To(BeNumerically("==", -1))
    83  			g.Expect(opts.PerCallPolicies).To(HaveLen(2))
    84  		})
    85  	}
    86  }
    87  
    88  // TestPerCallPolicies tests the per-call policies returned by `ARMClientOptions()`.
    89  func TestPerCallPolicies(t *testing.T) {
    90  	g := NewWithT(t)
    91  
    92  	corrID := "test-1234abcd-5678efgh"
    93  	// This server will check that the correlation ID and user-agent are set correctly.
    94  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    95  		g.Expect(r.Header.Get("User-Agent")).To(ContainSubstring("cluster-api-provider-azure/"))
    96  		g.Expect(r.Header.Get(string(tele.CorrIDKeyVal))).To(Equal(corrID))
    97  		fmt.Fprintf(w, "Hello, %s", r.Proto)
    98  	}))
    99  	defer server.Close()
   100  
   101  	// Call the factory function and ensure it has both PerCallPolicies.
   102  	opts, err := ARMClientOptions("")
   103  	g.Expect(err).NotTo(HaveOccurred())
   104  	g.Expect(opts.PerCallPolicies).To(HaveLen(2))
   105  	g.Expect(opts.PerCallPolicies).To(ContainElement(BeAssignableToTypeOf(correlationIDPolicy{})))
   106  	g.Expect(opts.PerCallPolicies).To(ContainElement(BeAssignableToTypeOf(userAgentPolicy{})))
   107  
   108  	// Create a request with a correlation ID.
   109  	ctx := context.WithValue(context.Background(), tele.CorrIDKeyVal, tele.CorrID(corrID))
   110  	req, err := runtime.NewRequest(ctx, http.MethodGet, server.URL)
   111  	g.Expect(err).NotTo(HaveOccurred())
   112  
   113  	// Create a pipeline and send the request, where it will be checked by the server.
   114  	pipeline := defaultTestPipeline(opts.PerCallPolicies)
   115  	resp, err := pipeline.Do(req)
   116  	g.Expect(err).NotTo(HaveOccurred())
   117  	defer resp.Body.Close()
   118  	g.Expect(resp.StatusCode).To(Equal(http.StatusOK))
   119  }
   120  
   121  func TestCustomPutPatchHeaderPolicy(t *testing.T) {
   122  	testHeaders := map[string]string{
   123  		"X-Test-Header":  "test-value",
   124  		"X-Test-Header2": "test-value2",
   125  	}
   126  	testcases := []struct {
   127  		name     string
   128  		method   string
   129  		headers  map[string]string
   130  		expected map[string]string
   131  	}{
   132  		{
   133  			name:     "should add custom headers to PUT request",
   134  			method:   http.MethodPut,
   135  			headers:  testHeaders,
   136  			expected: testHeaders,
   137  		},
   138  		{
   139  			name:     "should add custom headers to PATCH request",
   140  			method:   http.MethodPatch,
   141  			headers:  testHeaders,
   142  			expected: testHeaders,
   143  		},
   144  		{
   145  			name:   "should skip empty custom headers for PUT request",
   146  			method: http.MethodPut,
   147  		},
   148  		{
   149  			name:   "should skip empty custom headers for PATCH request",
   150  			method: http.MethodPatch,
   151  		},
   152  		{
   153  			name:   "should skip empty custom headers for GET request",
   154  			method: http.MethodGet,
   155  		},
   156  		{
   157  			name:    "should not add custom headers to GET request",
   158  			method:  http.MethodGet,
   159  			headers: testHeaders,
   160  		},
   161  		{
   162  			name:    "should not add custom headers to POST request",
   163  			method:  http.MethodPost,
   164  			headers: testHeaders,
   165  		},
   166  	}
   167  	for _, tc := range testcases {
   168  		tc := tc
   169  		t.Run(tc.name, func(t *testing.T) {
   170  			t.Parallel()
   171  			g := NewWithT(t)
   172  
   173  			mockCtrl := gomock.NewController(t)
   174  			defer mockCtrl.Finish()
   175  
   176  			// This server will check that custom headers are set correctly.
   177  			server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   178  				for k, v := range tc.expected {
   179  					g.Expect(r.Header.Get(k)).To(Equal(v))
   180  				}
   181  				fmt.Fprintf(w, "Hello, %s", r.Proto)
   182  			}))
   183  			defer server.Close()
   184  
   185  			// Create options with a custom PUT/PATCH header per-call policy
   186  			getterMock := mock_azure.NewMockResourceSpecGetterWithHeaders(mockCtrl)
   187  			getterMock.EXPECT().CustomHeaders().Return(tc.headers).AnyTimes()
   188  			opts, err := ARMClientOptions("", CustomPutPatchHeaderPolicy{Headers: tc.headers})
   189  			g.Expect(err).NotTo(HaveOccurred())
   190  
   191  			// Create a request
   192  			req, err := runtime.NewRequest(context.Background(), tc.method, server.URL)
   193  			g.Expect(err).NotTo(HaveOccurred())
   194  
   195  			// Create a pipeline and send the request to the test server for validation.
   196  			pipeline := defaultTestPipeline(opts.PerCallPolicies)
   197  			resp, err := pipeline.Do(req)
   198  			g.Expect(err).NotTo(HaveOccurred())
   199  			defer resp.Body.Close()
   200  			g.Expect(resp.StatusCode).To(Equal(http.StatusOK))
   201  		})
   202  	}
   203  }
   204  
   205  func defaultTestPipeline(policies []policy.Policy) runtime.Pipeline {
   206  	return runtime.NewPipeline(
   207  		"testmodule",
   208  		"v0.1.0",
   209  		runtime.PipelineOptions{},
   210  		&policy.ClientOptions{PerCallPolicies: policies},
   211  	)
   212  }
   213  
   214  func TestGetBootstrappingVMExtension(t *testing.T) {
   215  	testCases := []struct {
   216  		name            string
   217  		osType          string
   218  		cloud           string
   219  		vmName          string
   220  		cpuArchitecture string
   221  		expectedVersion string
   222  		expectNil       bool
   223  	}{
   224  		{
   225  			name:            "Linux OS, Public Cloud, x64 CPU Architecture",
   226  			osType:          LinuxOS,
   227  			cloud:           PublicCloudName,
   228  			vmName:          "test-vm",
   229  			cpuArchitecture: "x64",
   230  			expectedVersion: "1.0",
   231  		},
   232  		{
   233  			name:            "Linux OS, Public Cloud, ARM64 CPU Architecture",
   234  			osType:          LinuxOS,
   235  			cloud:           PublicCloudName,
   236  			vmName:          "test-vm",
   237  			cpuArchitecture: "Arm64",
   238  			expectedVersion: "1.1",
   239  		},
   240  		{
   241  			name:            "Windows OS, Public Cloud",
   242  			osType:          WindowsOS,
   243  			cloud:           PublicCloudName,
   244  			vmName:          "test-vm",
   245  			cpuArchitecture: "x64",
   246  			expectedVersion: "1.0",
   247  		},
   248  		{
   249  			name:            "Invalid OS Type",
   250  			osType:          "invalid",
   251  			cloud:           PublicCloudName,
   252  			vmName:          "test-vm",
   253  			cpuArchitecture: "x64",
   254  			expectedVersion: "1.0",
   255  			expectNil:       true,
   256  		},
   257  		{
   258  			name:            "Invalid Cloud",
   259  			osType:          LinuxOS,
   260  			cloud:           "invalid",
   261  			vmName:          "test-vm",
   262  			cpuArchitecture: "x64",
   263  			expectedVersion: "1.0",
   264  			expectNil:       true,
   265  		},
   266  	}
   267  
   268  	for _, tc := range testCases {
   269  		t.Run(tc.name, func(t *testing.T) {
   270  			g := NewWithT(t)
   271  			actualExtension := GetBootstrappingVMExtension(tc.osType, tc.cloud, tc.vmName, tc.cpuArchitecture)
   272  			if tc.expectNil {
   273  				g.Expect(actualExtension).To(BeNil())
   274  			} else {
   275  				g.Expect(actualExtension.Version).To(Equal(tc.expectedVersion))
   276  			}
   277  		})
   278  	}
   279  }
   280  
   281  func TestNormalizeAzureName(t *testing.T) {
   282  	testCases := []struct {
   283  		name     string
   284  		input    string
   285  		expected string
   286  	}{
   287  		{
   288  			name:     "should return lower case",
   289  			input:    "Test",
   290  			expected: "test",
   291  		},
   292  		{
   293  			name:     "should return lower case with spaces replaced by hyphens",
   294  			input:    "Test Name",
   295  			expected: "test-name",
   296  		},
   297  		{
   298  			name:     "should return lower case with spaces replaced by hyphens and non-alphanumeric characters removed",
   299  			input:    "Test Name 1",
   300  			expected: "test-name-1",
   301  		},
   302  		{
   303  			name:     "should return lower case with spaces replaced by hyphens and non-alphanumeric characters removed",
   304  			input:    "Test-Name-1-",
   305  			expected: "test-name-1",
   306  		},
   307  		{
   308  			name:     "should return lower case with spaces replaced by hyphens and non-alphanumeric characters removed",
   309  			input:    "Test-Name-1-@",
   310  			expected: "test-name-1",
   311  		},
   312  		{
   313  			name:     "should return lower case with spaces replaced by hyphens and non-alphanumeric characters removed",
   314  			input:    "Test-Name-1-@-",
   315  			expected: "test-name-1",
   316  		},
   317  		{
   318  			name:     "should return lower case with spaces replaced by hyphens and non-alphanumeric characters removed",
   319  			input:    "Test-Name-1-@-@",
   320  			expected: "test-name-1",
   321  		},
   322  		{
   323  			name:     "should return lower case with underscores replaced by hyphens and non-alphanumeric characters removed",
   324  			input:    "Test_Name_1-@-@",
   325  			expected: "test-name-1",
   326  		},
   327  		{
   328  			name:     "should return lower case with underscores replaced by hyphens and non-alphanumeric characters removed",
   329  			input:    "0_Test_Name_1-@-@",
   330  			expected: "0-test-name-1",
   331  		},
   332  		{
   333  			name:     "should return lower case with underscores replaced by hyphens and non-alphanumeric characters removed",
   334  			input:    "_Test_Name_1-@-@",
   335  			expected: "test-name-1",
   336  		},
   337  		{
   338  			name:     "should return lower case with name without hyphens",
   339  			input:    "_Test_Name_1---",
   340  			expected: "test-name-1",
   341  		},
   342  		{
   343  			name:     "should not change the input since input is valid k8s name",
   344  			input:    "test-name-1",
   345  			expected: "test-name-1",
   346  		},
   347  	}
   348  	for _, tc := range testCases {
   349  		t.Run(tc.name, func(t *testing.T) {
   350  			g := NewWithT(t)
   351  			normalizedNamed := GetNormalizedKubernetesName(tc.input)
   352  			g.Expect(normalizedNamed).To(Equal(tc.expected))
   353  		})
   354  	}
   355  }