github.com/google/osv-scalibr@v0.4.1/veles/secrets/huggingfaceapikey/validator_test.go (about) 1 // Copyright 2025 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package huggingfaceapikey_test 16 17 import ( 18 "context" 19 "net/http" 20 "net/http/httptest" 21 "net/url" 22 "strings" 23 "testing" 24 25 "github.com/google/go-cmp/cmp" 26 "github.com/google/go-cmp/cmp/cmpopts" 27 "github.com/google/osv-scalibr/veles" 28 "github.com/google/osv-scalibr/veles/secrets/huggingfaceapikey" 29 ) 30 31 const validatorTestKey = "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" 32 33 // mockTransport redirects requests to the test server 34 type mockTransport struct { 35 testServer *httptest.Server 36 } 37 38 func (m *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) { 39 // Replace the original URL with our test server URL 40 if req.URL.Host == "huggingface.co" { 41 testURL, _ := url.Parse(m.testServer.URL) 42 req.URL.Scheme = testURL.Scheme 43 req.URL.Host = testURL.Host 44 } 45 return http.DefaultTransport.RoundTrip(req) 46 } 47 48 // mockHuggingfaceServer creates a mock Huggingface API server for testing 49 func mockHuggingfaceServer(t *testing.T, expectedKey string, statusCode int) *httptest.Server { 50 t.Helper() 51 52 return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 53 // Check if it's a GET request to the expected endpoint 54 if r.Method != http.MethodGet || r.URL.Path != "/api/whoami-v2" { 55 t.Errorf("unexpected request: %s %s, expected: GET /api/whoami-v2", r.Method, r.URL.Path) 56 http.Error(w, "not found", http.StatusNotFound) 57 return 58 } 59 60 // Check Authorization header 61 authHeader := r.Header.Get("Authorization") 62 if !strings.HasSuffix(authHeader, expectedKey) { 63 t.Errorf("expected Authorization header to end with key %s, got: %s", expectedKey, authHeader) 64 } 65 66 // Set response 67 w.Header().Set("Content-Type", "application/json") 68 w.WriteHeader(statusCode) 69 })) 70 } 71 72 func TestValidator(t *testing.T) { 73 cases := []struct { 74 name string 75 statusCode int 76 want veles.ValidationStatus 77 wantErr error 78 }{ 79 { 80 name: "valid_key", 81 statusCode: http.StatusOK, 82 want: veles.ValidationValid, 83 }, 84 { 85 name: "invalid_key_unauthorized", 86 statusCode: http.StatusUnauthorized, 87 want: veles.ValidationInvalid, 88 }, 89 { 90 name: "server_error", 91 statusCode: http.StatusInternalServerError, 92 want: veles.ValidationFailed, 93 wantErr: cmpopts.AnyError, 94 }, 95 { 96 name: "bad_gateway", 97 statusCode: http.StatusBadGateway, 98 want: veles.ValidationFailed, 99 wantErr: cmpopts.AnyError, 100 }, 101 } 102 103 for _, tc := range cases { 104 t.Run(tc.name, func(t *testing.T) { 105 // Create a mock server 106 server := mockHuggingfaceServer(t, validatorTestKey, tc.statusCode) 107 defer server.Close() 108 109 // Create a client with custom transport 110 client := &http.Client{ 111 Transport: &mockTransport{testServer: server}, 112 } 113 114 // Create a validator with a mock client 115 validator := huggingfaceapikey.NewValidator() 116 validator.HTTPC = client 117 118 // Create a test key 119 key := huggingfaceapikey.HuggingfaceAPIKey{Key: validatorTestKey} 120 121 // Test validation 122 got, err := validator.Validate(t.Context(), key) 123 124 // Check error expectation 125 if !cmp.Equal(err, tc.wantErr, cmpopts.EquateErrors()) { 126 t.Fatalf("Validate() error: got %v, want %v\n", err, tc.wantErr) 127 } 128 129 // Check validation status 130 if got != tc.want { 131 t.Errorf("Validate() = %v, want %v", got, tc.want) 132 } 133 }) 134 } 135 } 136 137 func TestValidator_ContextCancellation(t *testing.T) { 138 // Create a server that delays response 139 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 140 w.WriteHeader(http.StatusOK) 141 })) 142 defer server.Close() 143 // Create a client with custom transport 144 client := &http.Client{ 145 Transport: &mockTransport{testServer: server}, 146 } 147 148 validator := huggingfaceapikey.NewValidator() 149 validator.HTTPC = client 150 ctx, cancel := context.WithCancel(t.Context()) 151 cancel() 152 153 key := huggingfaceapikey.HuggingfaceAPIKey{Key: validatorTestKey} 154 155 // Test validation with cancelled context 156 got, err := validator.Validate(ctx, key) 157 158 if err == nil { 159 t.Errorf("Validate() expected error due to context cancellation, got nil") 160 } 161 if got != veles.ValidationFailed { 162 t.Errorf("Validate() = %v, want %v", got, veles.ValidationFailed) 163 } 164 } 165 166 func TestValidator_InvalidRequest(t *testing.T) { 167 // Create a mock server that returns 401 Unauthorized 168 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 169 w.WriteHeader(http.StatusUnauthorized) 170 })) 171 defer server.Close() 172 173 // Create a client with custom transport 174 client := &http.Client{ 175 Transport: &mockTransport{testServer: server}, 176 } 177 178 validator := huggingfaceapikey.NewValidator() 179 validator.HTTPC = client 180 181 testCases := []struct { 182 name string 183 key string 184 expected veles.ValidationStatus 185 }{ 186 { 187 name: "empty_key", 188 key: "", 189 expected: veles.ValidationInvalid, 190 }, 191 { 192 name: "invalid_key_format", 193 key: "invalid-key-format", 194 expected: veles.ValidationInvalid, 195 }, 196 } 197 198 for _, tc := range testCases { 199 t.Run(tc.name, func(t *testing.T) { 200 key := huggingfaceapikey.HuggingfaceAPIKey{Key: tc.key} 201 202 got, err := validator.Validate(t.Context(), key) 203 204 if err != nil { 205 t.Errorf("Validate() unexpected error for %s: %v", tc.name, err) 206 } 207 if got != tc.expected { 208 t.Errorf("Validate() = %v, want %v for %s", got, tc.expected, tc.name) 209 } 210 }) 211 } 212 }