github.com/google/osv-scalibr@v0.4.1/veles/secrets/perplexityapikey/validator_test.go (about)

     1  // Copyright 2025 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package perplexityapikey_test
    16  
    17  import (
    18  	"context"
    19  	"net/http"
    20  	"net/http/httptest"
    21  	"net/url"
    22  	"strings"
    23  	"testing"
    24  
    25  	"github.com/google/osv-scalibr/veles"
    26  	"github.com/google/osv-scalibr/veles/secrets/perplexityapikey"
    27  )
    28  
    29  const validatorTestKey = "pplx-test123456789012345678901234567890123456789012345678"
    30  
    31  // mockTransport redirects requests to the test server
    32  type mockTransport struct {
    33  	testServer *httptest.Server
    34  }
    35  
    36  func (m *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
    37  	// Replace the original URL with our test server URL
    38  	if req.URL.Host == "api.perplexity.ai" {
    39  		testURL, _ := url.Parse(m.testServer.URL)
    40  		req.URL.Scheme = testURL.Scheme
    41  		req.URL.Host = testURL.Host
    42  	}
    43  	return http.DefaultTransport.RoundTrip(req)
    44  }
    45  
    46  // mockPerplexityServer creates a mock Perplexity API server for testing
    47  func mockPerplexityServer(t *testing.T, expectedKey string, statusCode int) *httptest.Server {
    48  	t.Helper()
    49  
    50  	return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    51  		// Check if it's a GET request to the expected endpoint
    52  		if r.Method != http.MethodGet || r.URL.Path != "/async/chat/completions" {
    53  			t.Errorf("unexpected request: %s %s, expected: GET /async/chat/completions", r.Method, r.URL.Path)
    54  			http.Error(w, "not found", http.StatusNotFound)
    55  			return
    56  		}
    57  
    58  		// Check Authorization header
    59  		authHeader := r.Header.Get("Authorization")
    60  		if !strings.HasSuffix(authHeader, expectedKey) {
    61  			t.Errorf("expected Authorization header to end with key %s, got: %s", expectedKey, authHeader)
    62  		}
    63  
    64  		// Set response
    65  		w.Header().Set("Content-Type", "application/json")
    66  		w.WriteHeader(statusCode)
    67  	}))
    68  }
    69  
    70  func TestValidator(t *testing.T) {
    71  	cases := []struct {
    72  		name        string
    73  		statusCode  int
    74  		want        veles.ValidationStatus
    75  		expectError bool
    76  	}{
    77  		{
    78  			name:       "valid_key",
    79  			statusCode: http.StatusOK,
    80  			want:       veles.ValidationValid,
    81  		},
    82  		{
    83  			name:       "invalid_key_unauthorized",
    84  			statusCode: http.StatusUnauthorized,
    85  			want:       veles.ValidationInvalid,
    86  		},
    87  		{
    88  			name:        "server_error",
    89  			statusCode:  http.StatusInternalServerError,
    90  			want:        veles.ValidationFailed,
    91  			expectError: true,
    92  		},
    93  		{
    94  			name:        "bad_gateway",
    95  			statusCode:  http.StatusBadGateway,
    96  			want:        veles.ValidationFailed,
    97  			expectError: true,
    98  		},
    99  	}
   100  
   101  	for _, tc := range cases {
   102  		t.Run(tc.name, func(t *testing.T) {
   103  			// Create mock server
   104  			server := mockPerplexityServer(t, validatorTestKey, tc.statusCode)
   105  			defer server.Close()
   106  
   107  			// Create client with custom transport
   108  			client := &http.Client{
   109  				Transport: &mockTransport{testServer: server},
   110  			}
   111  
   112  			// Create validator with mock client
   113  			validator := perplexityapikey.NewValidator()
   114  			validator.HTTPC = client
   115  
   116  			// Create test key
   117  			key := perplexityapikey.PerplexityAPIKey{Key: validatorTestKey}
   118  
   119  			// Test validation
   120  			got, err := validator.Validate(t.Context(), key)
   121  
   122  			// Check error expectation
   123  			if tc.expectError {
   124  				if err == nil {
   125  					t.Errorf("Validate() expected error, got nil")
   126  				}
   127  			} else {
   128  				if err != nil {
   129  					t.Errorf("Validate() unexpected error: %v", err)
   130  				}
   131  			}
   132  
   133  			// Check validation status
   134  			if got != tc.want {
   135  				t.Errorf("Validate() = %v, want %v", got, tc.want)
   136  			}
   137  		})
   138  	}
   139  }
   140  
   141  func TestValidator_ContextCancellation(t *testing.T) {
   142  	// Create a server that delays response
   143  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   144  		w.WriteHeader(http.StatusOK)
   145  	}))
   146  	defer server.Close()
   147  
   148  	// Create client with custom transport
   149  	client := &http.Client{
   150  		Transport: &mockTransport{testServer: server},
   151  	}
   152  
   153  	validator := perplexityapikey.NewValidator()
   154  	validator.HTTPC = client
   155  
   156  	key := perplexityapikey.PerplexityAPIKey{Key: validatorTestKey}
   157  
   158  	// Create a cancelled context
   159  	ctx, cancel := context.WithCancel(t.Context())
   160  	cancel()
   161  
   162  	// Test validation with cancelled context
   163  	got, err := validator.Validate(ctx, key)
   164  
   165  	if err == nil {
   166  		t.Errorf("Validate() expected error due to context cancellation, got nil")
   167  	}
   168  	if got != veles.ValidationFailed {
   169  		t.Errorf("Validate() = %v, want %v", got, veles.ValidationFailed)
   170  	}
   171  }
   172  
   173  func TestValidator_InvalidRequest(t *testing.T) {
   174  	// Create mock server that returns 401 Unauthorized
   175  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   176  		w.WriteHeader(http.StatusUnauthorized)
   177  	}))
   178  	defer server.Close()
   179  
   180  	// Create client with custom transport
   181  	client := &http.Client{
   182  		Transport: &mockTransport{testServer: server},
   183  	}
   184  
   185  	validator := perplexityapikey.NewValidator()
   186  	validator.HTTPC = client
   187  
   188  	testCases := []struct {
   189  		name     string
   190  		key      string
   191  		expected veles.ValidationStatus
   192  	}{
   193  		{
   194  			name:     "empty_key",
   195  			key:      "",
   196  			expected: veles.ValidationInvalid,
   197  		},
   198  		{
   199  			name:     "invalid_key_format",
   200  			key:      "invalid-key-format",
   201  			expected: veles.ValidationInvalid,
   202  		},
   203  	}
   204  
   205  	for _, tc := range testCases {
   206  		t.Run(tc.name, func(t *testing.T) {
   207  			key := perplexityapikey.PerplexityAPIKey{Key: tc.key}
   208  
   209  			got, err := validator.Validate(t.Context(), key)
   210  
   211  			if err != nil {
   212  				t.Errorf("Validate() unexpected error for %s: %v", tc.name, err)
   213  			}
   214  			if got != tc.expected {
   215  				t.Errorf("Validate() = %v, want %v for %s", got, tc.expected, tc.name)
   216  			}
   217  		})
   218  	}
   219  }