github.com/google/osv-scalibr@v0.4.1/veles/secrets/digitaloceanapikey/validator_test.go (about)

     1  // Copyright 2025 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package digitaloceanapikey_test
    16  
    17  import (
    18  	"context"
    19  	"net/http"
    20  	"net/http/httptest"
    21  	"net/url"
    22  	"strings"
    23  	"testing"
    24  
    25  	"github.com/google/osv-scalibr/veles"
    26  	"github.com/google/osv-scalibr/veles/secrets/digitaloceanapikey"
    27  )
    28  
    29  const validatorTestKey = "dop_v1_4c6aeb9deed0fb897e585f8ecafa555dd0a9b46087b1e354bcab59b0483edfaf"
    30  
    31  // mockTransport redirects requests to the test server
    32  type mockTransport struct {
    33  	testServer *httptest.Server
    34  }
    35  
    36  func (m *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
    37  	// Replace the original URL with our test server URL
    38  	if req.URL.Host == "api.digitalocean.com" {
    39  		testURL, _ := url.Parse(m.testServer.URL)
    40  		req.URL.Scheme = testURL.Scheme
    41  		req.URL.Host = testURL.Host
    42  	}
    43  	return http.DefaultTransport.RoundTrip(req)
    44  }
    45  
    46  // mockDigitaloceanServer creates a mock DigitalOcean API server for testing
    47  func mockDigitaloceanServer(t *testing.T, expectedKey string, serverResponseCode int) *httptest.Server {
    48  	t.Helper()
    49  
    50  	return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    51  		// Check if it's a GET request to the expected endpoint
    52  		if r.Method != http.MethodGet || r.URL.Path != "/v2/account" {
    53  			t.Errorf("unexpected request: %s %s, expected: GET /v2/account", r.Method, r.URL.Path)
    54  			http.Error(w, "not found", http.StatusNotFound)
    55  			return
    56  		}
    57  
    58  		// Check Authorization header
    59  		authHeader := r.Header.Get("Authorization")
    60  		if len(expectedKey) > 0 && !strings.Contains(authHeader, expectedKey) {
    61  			w.Header().Set("Content-Type", "application/json")
    62  			w.WriteHeader(http.StatusUnauthorized)
    63  			return
    64  		}
    65  
    66  		// Set response
    67  		w.Header().Set("Content-Type", "application/json")
    68  		w.WriteHeader(serverResponseCode)
    69  	}))
    70  }
    71  
    72  func TestValidator(t *testing.T) {
    73  	cases := []struct {
    74  		name               string
    75  		key                string
    76  		serverExpectedKey  string
    77  		serverResponseCode int
    78  		want               veles.ValidationStatus
    79  		expectError        bool
    80  	}{
    81  		{
    82  			name:               "valid_key",
    83  			key:                validatorTestKey,
    84  			serverExpectedKey:  validatorTestKey,
    85  			serverResponseCode: http.StatusOK,
    86  			want:               veles.ValidationValid,
    87  		},
    88  		{
    89  			name:               "valid_key_custom_scope",
    90  			key:                validatorTestKey,
    91  			serverExpectedKey:  validatorTestKey,
    92  			serverResponseCode: http.StatusForbidden,
    93  			want:               veles.ValidationValid,
    94  		},
    95  		{
    96  			name:               "invalid_key_unauthorized",
    97  			key:                "random_string",
    98  			serverExpectedKey:  validatorTestKey,
    99  			serverResponseCode: http.StatusUnauthorized,
   100  			want:               veles.ValidationInvalid,
   101  		},
   102  		{
   103  			name:               "server_error",
   104  			serverResponseCode: http.StatusInternalServerError,
   105  			want:               veles.ValidationFailed,
   106  			expectError:        true,
   107  		},
   108  		{
   109  			name:               "bad_gateway",
   110  			serverResponseCode: http.StatusBadGateway,
   111  			want:               veles.ValidationFailed,
   112  			expectError:        true,
   113  		},
   114  	}
   115  
   116  	for _, tc := range cases {
   117  		t.Run(tc.name, func(t *testing.T) {
   118  			// Create a mock server
   119  			server := mockDigitaloceanServer(t, tc.serverExpectedKey, tc.serverResponseCode)
   120  			defer server.Close()
   121  
   122  			// Create a client with custom transport
   123  			client := &http.Client{
   124  				Transport: &mockTransport{testServer: server},
   125  			}
   126  
   127  			// Create a validator with a mock client
   128  			validator := digitaloceanapikey.NewValidator()
   129  			validator.HTTPC = client
   130  
   131  			// Create a test key
   132  			key := digitaloceanapikey.DigitaloceanAPIToken{Key: tc.key}
   133  
   134  			// Test validation
   135  			got, err := validator.Validate(t.Context(), key)
   136  
   137  			// Check error expectation
   138  			if tc.expectError {
   139  				if err == nil {
   140  					t.Errorf("Validate() expected error, got nil")
   141  				}
   142  			} else {
   143  				if err != nil {
   144  					t.Errorf("Validate() unexpected error: %v", err)
   145  				}
   146  			}
   147  
   148  			// Check validation status
   149  			if got != tc.want {
   150  				t.Errorf("Validate() = %v, want %v", got, tc.want)
   151  			}
   152  		})
   153  	}
   154  }
   155  
   156  func TestValidator_ContextCancellation(t *testing.T) {
   157  	// Create a server that delays response
   158  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   159  		w.WriteHeader(http.StatusOK)
   160  	}))
   161  	defer server.Close()
   162  
   163  	// Create a client with custom transport
   164  	client := &http.Client{
   165  		Transport: &mockTransport{testServer: server},
   166  	}
   167  
   168  	validator := digitaloceanapikey.NewValidator()
   169  	validator.HTTPC = client
   170  
   171  	key := digitaloceanapikey.DigitaloceanAPIToken{Key: validatorTestKey}
   172  
   173  	// Create a cancelled context
   174  	ctx, cancel := context.WithCancel(t.Context())
   175  	cancel()
   176  
   177  	// Test validation with cancelled context
   178  	got, err := validator.Validate(ctx, key)
   179  
   180  	if err == nil {
   181  		t.Errorf("Validate() expected error due to context cancellation, got nil")
   182  	}
   183  	if got != veles.ValidationFailed {
   184  		t.Errorf("Validate() = %v, want %v", got, veles.ValidationFailed)
   185  	}
   186  }
   187  
   188  func TestValidator_InvalidRequest(t *testing.T) {
   189  	// Create a mock server that returns 401 Unauthorized
   190  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   191  		w.WriteHeader(http.StatusUnauthorized)
   192  	}))
   193  	defer server.Close()
   194  
   195  	// Create a client with custom transport
   196  	client := &http.Client{
   197  		Transport: &mockTransport{testServer: server},
   198  	}
   199  
   200  	validator := digitaloceanapikey.NewValidator()
   201  	validator.HTTPC = client
   202  
   203  	testCases := []struct {
   204  		name     string
   205  		key      string
   206  		expected veles.ValidationStatus
   207  	}{
   208  		{
   209  			name:     "empty_key",
   210  			key:      "",
   211  			expected: veles.ValidationInvalid,
   212  		},
   213  		{
   214  			name:     "invalid_key_format",
   215  			key:      "invalid-key-format",
   216  			expected: veles.ValidationInvalid,
   217  		},
   218  	}
   219  
   220  	for _, tc := range testCases {
   221  		t.Run(tc.name, func(t *testing.T) {
   222  			key := digitaloceanapikey.DigitaloceanAPIToken{Key: tc.key}
   223  
   224  			got, err := validator.Validate(t.Context(), key)
   225  
   226  			if err != nil {
   227  				t.Errorf("Validate() unexpected error for %s: %v", tc.name, err)
   228  			}
   229  			if got != tc.expected {
   230  				t.Errorf("Validate() = %v, want %v for %s", got, tc.expected, tc.name)
   231  			}
   232  		})
   233  	}
   234  }