github.com/google/osv-scalibr@v0.4.1/veles/secrets/perplexityapikey/validator_test.go (about) 1 // Copyright 2025 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package perplexityapikey_test 16 17 import ( 18 "context" 19 "net/http" 20 "net/http/httptest" 21 "net/url" 22 "strings" 23 "testing" 24 25 "github.com/google/osv-scalibr/veles" 26 "github.com/google/osv-scalibr/veles/secrets/perplexityapikey" 27 ) 28 29 const validatorTestKey = "pplx-test123456789012345678901234567890123456789012345678" 30 31 // mockTransport redirects requests to the test server 32 type mockTransport struct { 33 testServer *httptest.Server 34 } 35 36 func (m *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) { 37 // Replace the original URL with our test server URL 38 if req.URL.Host == "api.perplexity.ai" { 39 testURL, _ := url.Parse(m.testServer.URL) 40 req.URL.Scheme = testURL.Scheme 41 req.URL.Host = testURL.Host 42 } 43 return http.DefaultTransport.RoundTrip(req) 44 } 45 46 // mockPerplexityServer creates a mock Perplexity API server for testing 47 func mockPerplexityServer(t *testing.T, expectedKey string, statusCode int) *httptest.Server { 48 t.Helper() 49 50 return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 51 // Check if it's a GET request to the expected endpoint 52 if r.Method != http.MethodGet || r.URL.Path != "/async/chat/completions" { 53 t.Errorf("unexpected request: %s %s, expected: GET /async/chat/completions", r.Method, r.URL.Path) 54 http.Error(w, "not found", http.StatusNotFound) 55 return 56 } 57 58 // Check Authorization header 59 authHeader := r.Header.Get("Authorization") 60 if !strings.HasSuffix(authHeader, expectedKey) { 61 t.Errorf("expected Authorization header to end with key %s, got: %s", expectedKey, authHeader) 62 } 63 64 // Set response 65 w.Header().Set("Content-Type", "application/json") 66 w.WriteHeader(statusCode) 67 })) 68 } 69 70 func TestValidator(t *testing.T) { 71 cases := []struct { 72 name string 73 statusCode int 74 want veles.ValidationStatus 75 expectError bool 76 }{ 77 { 78 name: "valid_key", 79 statusCode: http.StatusOK, 80 want: veles.ValidationValid, 81 }, 82 { 83 name: "invalid_key_unauthorized", 84 statusCode: http.StatusUnauthorized, 85 want: veles.ValidationInvalid, 86 }, 87 { 88 name: "server_error", 89 statusCode: http.StatusInternalServerError, 90 want: veles.ValidationFailed, 91 expectError: true, 92 }, 93 { 94 name: "bad_gateway", 95 statusCode: http.StatusBadGateway, 96 want: veles.ValidationFailed, 97 expectError: true, 98 }, 99 } 100 101 for _, tc := range cases { 102 t.Run(tc.name, func(t *testing.T) { 103 // Create mock server 104 server := mockPerplexityServer(t, validatorTestKey, tc.statusCode) 105 defer server.Close() 106 107 // Create client with custom transport 108 client := &http.Client{ 109 Transport: &mockTransport{testServer: server}, 110 } 111 112 // Create validator with mock client 113 validator := perplexityapikey.NewValidator() 114 validator.HTTPC = client 115 116 // Create test key 117 key := perplexityapikey.PerplexityAPIKey{Key: validatorTestKey} 118 119 // Test validation 120 got, err := validator.Validate(t.Context(), key) 121 122 // Check error expectation 123 if tc.expectError { 124 if err == nil { 125 t.Errorf("Validate() expected error, got nil") 126 } 127 } else { 128 if err != nil { 129 t.Errorf("Validate() unexpected error: %v", err) 130 } 131 } 132 133 // Check validation status 134 if got != tc.want { 135 t.Errorf("Validate() = %v, want %v", got, tc.want) 136 } 137 }) 138 } 139 } 140 141 func TestValidator_ContextCancellation(t *testing.T) { 142 // Create a server that delays response 143 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 144 w.WriteHeader(http.StatusOK) 145 })) 146 defer server.Close() 147 148 // Create client with custom transport 149 client := &http.Client{ 150 Transport: &mockTransport{testServer: server}, 151 } 152 153 validator := perplexityapikey.NewValidator() 154 validator.HTTPC = client 155 156 key := perplexityapikey.PerplexityAPIKey{Key: validatorTestKey} 157 158 // Create a cancelled context 159 ctx, cancel := context.WithCancel(t.Context()) 160 cancel() 161 162 // Test validation with cancelled context 163 got, err := validator.Validate(ctx, key) 164 165 if err == nil { 166 t.Errorf("Validate() expected error due to context cancellation, got nil") 167 } 168 if got != veles.ValidationFailed { 169 t.Errorf("Validate() = %v, want %v", got, veles.ValidationFailed) 170 } 171 } 172 173 func TestValidator_InvalidRequest(t *testing.T) { 174 // Create mock server that returns 401 Unauthorized 175 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 176 w.WriteHeader(http.StatusUnauthorized) 177 })) 178 defer server.Close() 179 180 // Create client with custom transport 181 client := &http.Client{ 182 Transport: &mockTransport{testServer: server}, 183 } 184 185 validator := perplexityapikey.NewValidator() 186 validator.HTTPC = client 187 188 testCases := []struct { 189 name string 190 key string 191 expected veles.ValidationStatus 192 }{ 193 { 194 name: "empty_key", 195 key: "", 196 expected: veles.ValidationInvalid, 197 }, 198 { 199 name: "invalid_key_format", 200 key: "invalid-key-format", 201 expected: veles.ValidationInvalid, 202 }, 203 } 204 205 for _, tc := range testCases { 206 t.Run(tc.name, func(t *testing.T) { 207 key := perplexityapikey.PerplexityAPIKey{Key: tc.key} 208 209 got, err := validator.Validate(t.Context(), key) 210 211 if err != nil { 212 t.Errorf("Validate() unexpected error for %s: %v", tc.name, err) 213 } 214 if got != tc.expected { 215 t.Errorf("Validate() = %v, want %v for %s", got, tc.expected, tc.name) 216 } 217 }) 218 } 219 }