github.com/google/osv-scalibr@v0.4.1/veles/secrets/huggingfaceapikey/validator_test.go (about)

     1  // Copyright 2025 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package huggingfaceapikey_test
    16  
    17  import (
    18  	"context"
    19  	"net/http"
    20  	"net/http/httptest"
    21  	"net/url"
    22  	"strings"
    23  	"testing"
    24  
    25  	"github.com/google/go-cmp/cmp"
    26  	"github.com/google/go-cmp/cmp/cmpopts"
    27  	"github.com/google/osv-scalibr/veles"
    28  	"github.com/google/osv-scalibr/veles/secrets/huggingfaceapikey"
    29  )
    30  
    31  const validatorTestKey = "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"
    32  
    33  // mockTransport redirects requests to the test server
    34  type mockTransport struct {
    35  	testServer *httptest.Server
    36  }
    37  
    38  func (m *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
    39  	// Replace the original URL with our test server URL
    40  	if req.URL.Host == "huggingface.co" {
    41  		testURL, _ := url.Parse(m.testServer.URL)
    42  		req.URL.Scheme = testURL.Scheme
    43  		req.URL.Host = testURL.Host
    44  	}
    45  	return http.DefaultTransport.RoundTrip(req)
    46  }
    47  
    48  // mockHuggingfaceServer creates a mock Huggingface API server for testing
    49  func mockHuggingfaceServer(t *testing.T, expectedKey string, statusCode int) *httptest.Server {
    50  	t.Helper()
    51  
    52  	return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    53  		// Check if it's a GET request to the expected endpoint
    54  		if r.Method != http.MethodGet || r.URL.Path != "/api/whoami-v2" {
    55  			t.Errorf("unexpected request: %s %s, expected: GET /api/whoami-v2", r.Method, r.URL.Path)
    56  			http.Error(w, "not found", http.StatusNotFound)
    57  			return
    58  		}
    59  
    60  		// Check Authorization header
    61  		authHeader := r.Header.Get("Authorization")
    62  		if !strings.HasSuffix(authHeader, expectedKey) {
    63  			t.Errorf("expected Authorization header to end with key %s, got: %s", expectedKey, authHeader)
    64  		}
    65  
    66  		// Set response
    67  		w.Header().Set("Content-Type", "application/json")
    68  		w.WriteHeader(statusCode)
    69  	}))
    70  }
    71  
    72  func TestValidator(t *testing.T) {
    73  	cases := []struct {
    74  		name       string
    75  		statusCode int
    76  		want       veles.ValidationStatus
    77  		wantErr    error
    78  	}{
    79  		{
    80  			name:       "valid_key",
    81  			statusCode: http.StatusOK,
    82  			want:       veles.ValidationValid,
    83  		},
    84  		{
    85  			name:       "invalid_key_unauthorized",
    86  			statusCode: http.StatusUnauthorized,
    87  			want:       veles.ValidationInvalid,
    88  		},
    89  		{
    90  			name:       "server_error",
    91  			statusCode: http.StatusInternalServerError,
    92  			want:       veles.ValidationFailed,
    93  			wantErr:    cmpopts.AnyError,
    94  		},
    95  		{
    96  			name:       "bad_gateway",
    97  			statusCode: http.StatusBadGateway,
    98  			want:       veles.ValidationFailed,
    99  			wantErr:    cmpopts.AnyError,
   100  		},
   101  	}
   102  
   103  	for _, tc := range cases {
   104  		t.Run(tc.name, func(t *testing.T) {
   105  			// Create a mock server
   106  			server := mockHuggingfaceServer(t, validatorTestKey, tc.statusCode)
   107  			defer server.Close()
   108  
   109  			// Create a client with custom transport
   110  			client := &http.Client{
   111  				Transport: &mockTransport{testServer: server},
   112  			}
   113  
   114  			// Create a validator with a mock client
   115  			validator := huggingfaceapikey.NewValidator()
   116  			validator.HTTPC = client
   117  
   118  			// Create a test key
   119  			key := huggingfaceapikey.HuggingfaceAPIKey{Key: validatorTestKey}
   120  
   121  			// Test validation
   122  			got, err := validator.Validate(t.Context(), key)
   123  
   124  			// Check error expectation
   125  			if !cmp.Equal(err, tc.wantErr, cmpopts.EquateErrors()) {
   126  				t.Fatalf("Validate() error: got %v, want %v\n", err, tc.wantErr)
   127  			}
   128  
   129  			// Check validation status
   130  			if got != tc.want {
   131  				t.Errorf("Validate() = %v, want %v", got, tc.want)
   132  			}
   133  		})
   134  	}
   135  }
   136  
   137  func TestValidator_ContextCancellation(t *testing.T) {
   138  	// Create a server that delays response
   139  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   140  		w.WriteHeader(http.StatusOK)
   141  	}))
   142  	defer server.Close()
   143  	// Create a client with custom transport
   144  	client := &http.Client{
   145  		Transport: &mockTransport{testServer: server},
   146  	}
   147  
   148  	validator := huggingfaceapikey.NewValidator()
   149  	validator.HTTPC = client
   150  	ctx, cancel := context.WithCancel(t.Context())
   151  	cancel()
   152  
   153  	key := huggingfaceapikey.HuggingfaceAPIKey{Key: validatorTestKey}
   154  
   155  	// Test validation with cancelled context
   156  	got, err := validator.Validate(ctx, key)
   157  
   158  	if err == nil {
   159  		t.Errorf("Validate() expected error due to context cancellation, got nil")
   160  	}
   161  	if got != veles.ValidationFailed {
   162  		t.Errorf("Validate() = %v, want %v", got, veles.ValidationFailed)
   163  	}
   164  }
   165  
   166  func TestValidator_InvalidRequest(t *testing.T) {
   167  	// Create a mock server that returns 401 Unauthorized
   168  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   169  		w.WriteHeader(http.StatusUnauthorized)
   170  	}))
   171  	defer server.Close()
   172  
   173  	// Create a client with custom transport
   174  	client := &http.Client{
   175  		Transport: &mockTransport{testServer: server},
   176  	}
   177  
   178  	validator := huggingfaceapikey.NewValidator()
   179  	validator.HTTPC = client
   180  
   181  	testCases := []struct {
   182  		name     string
   183  		key      string
   184  		expected veles.ValidationStatus
   185  	}{
   186  		{
   187  			name:     "empty_key",
   188  			key:      "",
   189  			expected: veles.ValidationInvalid,
   190  		},
   191  		{
   192  			name:     "invalid_key_format",
   193  			key:      "invalid-key-format",
   194  			expected: veles.ValidationInvalid,
   195  		},
   196  	}
   197  
   198  	for _, tc := range testCases {
   199  		t.Run(tc.name, func(t *testing.T) {
   200  			key := huggingfaceapikey.HuggingfaceAPIKey{Key: tc.key}
   201  
   202  			got, err := validator.Validate(t.Context(), key)
   203  
   204  			if err != nil {
   205  				t.Errorf("Validate() unexpected error for %s: %v", tc.name, err)
   206  			}
   207  			if got != tc.expected {
   208  				t.Errorf("Validate() = %v, want %v for %s", got, tc.expected, tc.name)
   209  			}
   210  		})
   211  	}
   212  }