github.com/google/osv-scalibr@v0.4.1/veles/secrets/grokxaiapikey/validator_test.go (about)

     1  // Copyright 2025 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package grokxaiapikey_test
    16  
    17  import (
    18  	"context"
    19  	"encoding/json"
    20  	"net/http"
    21  	"net/http/httptest"
    22  	"strings"
    23  	"testing"
    24  
    25  	"github.com/google/osv-scalibr/veles"
    26  	grokxaiapikey "github.com/google/osv-scalibr/veles/secrets/grokxaiapikey"
    27  )
    28  
    29  const validatorTestKey = "grokx-test12345678901234567890123456789012345678901234567890"
    30  
    31  // mockAPIServer creates a mock x.ai /v1/api-key endpoint for testing API validator.
    32  func mockAPIServer(t *testing.T, expectedKey string, statusCode int, body any) *httptest.Server {
    33  	t.Helper()
    34  
    35  	return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    36  		// Expect a GET to /v1/api-key
    37  		if r.Method != http.MethodGet || r.URL.Path != "/v1/api-key" {
    38  			t.Errorf("unexpected request: %s %s, expected: GET /v1/api-key", r.Method, r.URL.Path)
    39  			http.Error(w, "not found", http.StatusNotFound)
    40  			return
    41  		}
    42  
    43  		// Check Authorization header contains the key (ends with key)
    44  		authHeader := r.Header.Get("Authorization")
    45  		if !strings.HasSuffix(authHeader, expectedKey) {
    46  			t.Errorf("expected Authorization header to end with key %s, got: %s", expectedKey, authHeader)
    47  		}
    48  
    49  		w.Header().Set("Content-Type", "application/json")
    50  		w.WriteHeader(statusCode)
    51  		if body != nil {
    52  			_ = json.NewEncoder(w).Encode(body)
    53  		}
    54  	}))
    55  }
    56  
    57  // mockManagementServer creates a mock management endpoint for testing management validator.
    58  func mockManagementServer(t *testing.T, expectedKey string, statusCode int, body any) *httptest.Server {
    59  	t.Helper()
    60  
    61  	// The managementEndpoint path in the validator is:
    62  	// /auth/teams/ffffffff-ffff-ffff-ffff-ffffffffffff/api-keys
    63  	expectedPath := "/auth/teams/ffffffff-ffff-ffff-ffff-ffffffffffff/api-keys"
    64  
    65  	return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    66  		// Expect a GET to the management path
    67  		if r.Method != http.MethodGet || r.URL.Path != expectedPath {
    68  			t.Errorf("unexpected request: %s %s, expected: GET %s", r.Method, r.URL.Path, expectedPath)
    69  			http.Error(w, "not found", http.StatusNotFound)
    70  			return
    71  		}
    72  
    73  		// Check Authorization header contains the key (ends with key)
    74  		authHeader := r.Header.Get("Authorization")
    75  		if !strings.HasSuffix(authHeader, expectedKey) {
    76  			t.Errorf("expected Authorization header to end with key %s, got: %s", expectedKey, authHeader)
    77  		}
    78  
    79  		w.Header().Set("Content-Type", "application/json")
    80  		w.WriteHeader(statusCode)
    81  		if body != nil {
    82  			_ = json.NewEncoder(w).Encode(body)
    83  		}
    84  	}))
    85  }
    86  
    87  func TestValidatorAPI(t *testing.T) {
    88  	cases := []struct {
    89  		name        string
    90  		statusCode  int
    91  		body        any
    92  		want        veles.ValidationStatus
    93  		expectError bool
    94  	}{
    95  		{
    96  			name:       "valid_key",
    97  			statusCode: http.StatusOK,
    98  			body: map[string]bool{
    99  				"api_key_blocked":  false,
   100  				"api_key_disabled": false,
   101  			},
   102  			want: veles.ValidationValid,
   103  		},
   104  		{
   105  			name:       "blocked_key",
   106  			statusCode: http.StatusOK,
   107  			body: map[string]bool{
   108  				"api_key_blocked":  true,
   109  				"api_key_disabled": false,
   110  			},
   111  			want: veles.ValidationInvalid,
   112  		},
   113  		{
   114  			name:       "disabled_key",
   115  			statusCode: http.StatusOK,
   116  			body: map[string]bool{
   117  				"api_key_blocked":  false,
   118  				"api_key_disabled": true,
   119  			},
   120  			want: veles.ValidationInvalid,
   121  		},
   122  		{
   123  			name:        "unauthorized_status",
   124  			statusCode:  http.StatusUnauthorized,
   125  			body:        nil,
   126  			want:        veles.ValidationFailed,
   127  			expectError: true,
   128  		},
   129  		{
   130  			name:        "server_error",
   131  			statusCode:  http.StatusInternalServerError,
   132  			body:        nil,
   133  			want:        veles.ValidationFailed,
   134  			expectError: true,
   135  		},
   136  	}
   137  
   138  	for _, tc := range cases {
   139  		t.Run(tc.name, func(t *testing.T) {
   140  			// Create mock server
   141  			server := mockAPIServer(t, validatorTestKey, tc.statusCode, tc.body)
   142  			defer server.Close()
   143  
   144  			// Create validator with mock client
   145  			validator := grokxaiapikey.NewAPIValidator()
   146  			validator.HTTPC = server.Client()
   147  			validator.Endpoint = server.URL + "/v1/api-key"
   148  
   149  			// Create test key
   150  			key := grokxaiapikey.GrokXAIAPIKey{Key: validatorTestKey}
   151  
   152  			// Test validation
   153  			got, err := validator.Validate(t.Context(), key)
   154  
   155  			// Check error expectation
   156  			if tc.expectError {
   157  				if err == nil {
   158  					t.Errorf("Validate() expected error, got nil")
   159  				}
   160  			} else {
   161  				if err != nil {
   162  					t.Errorf("Validate() unexpected error: %v", err)
   163  				}
   164  			}
   165  
   166  			// Check validation status
   167  			if got != tc.want {
   168  				t.Errorf("Validate() = %v, want %v", got, tc.want)
   169  			}
   170  		})
   171  	}
   172  }
   173  
   174  func TestValidatorAPI_ContextCancellation(t *testing.T) {
   175  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   176  		w.WriteHeader(http.StatusOK)
   177  		_, _ = w.Write([]byte(`{"api_key_blocked": false, "api_key_disabled": false}`))
   178  	}))
   179  	defer server.Close()
   180  
   181  	validator := grokxaiapikey.NewAPIValidator()
   182  	validator.HTTPC = server.Client()
   183  	validator.Endpoint = server.URL + "/v1/api-key"
   184  
   185  	key := grokxaiapikey.GrokXAIAPIKey{Key: validatorTestKey}
   186  
   187  	// Create a cancelled context.
   188  	ctx, cancel := context.WithCancel(t.Context())
   189  	cancel()
   190  
   191  	// Test validation with cancelled context.
   192  	got, err := validator.Validate(ctx, key)
   193  
   194  	if err == nil {
   195  		t.Errorf("Validate() expected error due to context cancellation, got nil")
   196  	}
   197  	if got != veles.ValidationFailed {
   198  		t.Errorf("Validate() = %v, want %v", got, veles.ValidationFailed)
   199  	}
   200  }
   201  
   202  func TestValidatorAPI_InvalidRequest(t *testing.T) {
   203  	// For API validator, an "invalid" key is communicated via the JSON flags.
   204  	// Create mock server that returns a 200 with api_key_blocked true for empty/invalid keys.
   205  	server := mockAPIServer(t, "", http.StatusOK, map[string]bool{
   206  		"api_key_blocked":  true,
   207  		"api_key_disabled": false,
   208  	})
   209  	defer server.Close()
   210  
   211  	validator := grokxaiapikey.NewAPIValidator()
   212  	validator.HTTPC = server.Client()
   213  	validator.Endpoint = server.URL + "/v1/api-key"
   214  
   215  	testCases := []struct {
   216  		name     string
   217  		key      string
   218  		expected veles.ValidationStatus
   219  	}{
   220  		{
   221  			name:     "empty_key",
   222  			key:      "",
   223  			expected: veles.ValidationInvalid,
   224  		},
   225  		{
   226  			name:     "invalid_key_format",
   227  			key:      "invalid-key-format",
   228  			expected: veles.ValidationInvalid,
   229  		},
   230  	}
   231  
   232  	for _, tc := range testCases {
   233  		t.Run(tc.name, func(t *testing.T) {
   234  			k := grokxaiapikey.GrokXAIAPIKey{Key: tc.key}
   235  
   236  			got, err := validator.Validate(t.Context(), k)
   237  
   238  			if err != nil {
   239  				t.Errorf("Validate() unexpected error for %s: %v", tc.name, err)
   240  			}
   241  			if got != tc.expected {
   242  				t.Errorf("Validate() = %v, want %v for %s", got, tc.expected, tc.name)
   243  			}
   244  		})
   245  	}
   246  }
   247  
   248  func TestValidatorManagement(t *testing.T) {
   249  	cases := []struct {
   250  		name        string
   251  		statusCode  int
   252  		body        any
   253  		want        veles.ValidationStatus
   254  		expectError bool
   255  	}{
   256  		{
   257  			name:       "valid_key_status_ok",
   258  			statusCode: http.StatusOK,
   259  			body:       nil,
   260  			want:       veles.ValidationValid,
   261  		},
   262  		{
   263  			name:       "valid_key_team_mismatch",
   264  			statusCode: http.StatusForbidden,
   265  			body: map[string]any{
   266  				"code":    7,
   267  				"message": "team mismatch",
   268  			},
   269  			want: veles.ValidationValid,
   270  		},
   271  		{
   272  			name:       "invalid_key_unauthorized",
   273  			statusCode: http.StatusUnauthorized,
   274  			body:       nil,
   275  			want:       veles.ValidationInvalid,
   276  		},
   277  		{
   278  			name:       "forbidden_other_code",
   279  			statusCode: http.StatusForbidden,
   280  			body: map[string]any{
   281  				"code":    42,
   282  				"message": "other reason",
   283  			},
   284  			want: veles.ValidationInvalid,
   285  		},
   286  		{
   287  			name:        "server_error",
   288  			statusCode:  http.StatusInternalServerError,
   289  			body:        nil,
   290  			want:        veles.ValidationFailed,
   291  			expectError: true,
   292  		},
   293  		{
   294  			name:        "forbidden_bad_json",
   295  			statusCode:  http.StatusForbidden,
   296  			body:        "not-a-json", // this will be encoded as a string -> invalid JSON structure for decoding
   297  			expectError: true,
   298  			want:        veles.ValidationFailed,
   299  		},
   300  	}
   301  
   302  	for _, tc := range cases {
   303  		t.Run(tc.name, func(t *testing.T) {
   304  			// Create mock management server
   305  			server := mockManagementServer(t, validatorTestKey, tc.statusCode, tc.body)
   306  			defer server.Close()
   307  
   308  			// Create validator with mock client
   309  			validator := grokxaiapikey.NewManagementAPIValidator()
   310  			validator.HTTPC = server.Client()
   311  			validator.Endpoint = server.URL + "/auth/teams/ffffffff-ffff-ffff-ffff-ffffffffffff/api-keys"
   312  
   313  			// Create test key
   314  			key := grokxaiapikey.GrokXAIManagementKey{Key: validatorTestKey}
   315  
   316  			// Test validation
   317  			got, err := validator.Validate(t.Context(), key)
   318  
   319  			// Check error expectation
   320  			if tc.expectError {
   321  				if err == nil {
   322  					t.Errorf("Validate() expected error, got nil")
   323  				}
   324  			} else {
   325  				if err != nil {
   326  					t.Errorf("Validate() unexpected error: %v", err)
   327  				}
   328  			}
   329  
   330  			// Check validation status
   331  			if got != tc.want {
   332  				t.Errorf("Validate() = %v, want %v", got, tc.want)
   333  			}
   334  		})
   335  	}
   336  }
   337  
   338  func TestValidatorManagement_ContextCancellation(t *testing.T) {
   339  	// Create a server that delays response
   340  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   341  		w.WriteHeader(http.StatusForbidden)
   342  		_, _ = w.Write([]byte(`{"code":7,"message":"team mismatch"}`))
   343  	}))
   344  	defer server.Close()
   345  
   346  	validator := grokxaiapikey.NewManagementAPIValidator()
   347  	validator.HTTPC = server.Client()
   348  	validator.Endpoint = server.URL + "/auth/teams/ffffffff-ffff-ffff-ffff-ffffffffffff/api-keys"
   349  
   350  	key := grokxaiapikey.GrokXAIManagementKey{Key: validatorTestKey}
   351  
   352  	// Create a cancelled context
   353  	ctx, cancel := context.WithCancel(t.Context())
   354  	cancel()
   355  
   356  	// Test validation with cancelled context
   357  	got, err := validator.Validate(ctx, key)
   358  
   359  	if err == nil {
   360  		t.Errorf("Validate() expected error due to context cancellation, got nil")
   361  	}
   362  	if got != veles.ValidationFailed {
   363  		t.Errorf("Validate() = %v, want %v", got, veles.ValidationFailed)
   364  	}
   365  }
   366  
   367  func TestValidatorManagement_InvalidRequest(t *testing.T) {
   368  	// For management validator, a 401 indicates invalid token (no error returned).
   369  	server := mockManagementServer(t, "", http.StatusUnauthorized, nil)
   370  	defer server.Close()
   371  
   372  	validator := grokxaiapikey.NewManagementAPIValidator()
   373  	validator.HTTPC = server.Client()
   374  	validator.Endpoint = server.URL + "/auth/teams/ffffffff-ffff-ffff-ffff-ffffffffffff/api-keys"
   375  
   376  	testCases := []struct {
   377  		name     string
   378  		key      string
   379  		expected veles.ValidationStatus
   380  	}{
   381  		{
   382  			name:     "empty_key",
   383  			key:      "",
   384  			expected: veles.ValidationInvalid,
   385  		},
   386  		{
   387  			name:     "invalid_key_format",
   388  			key:      "invalid-management-key",
   389  			expected: veles.ValidationInvalid,
   390  		},
   391  	}
   392  
   393  	for _, tc := range testCases {
   394  		t.Run(tc.name, func(t *testing.T) {
   395  			k := grokxaiapikey.GrokXAIManagementKey{Key: tc.key}
   396  
   397  			got, err := validator.Validate(t.Context(), k)
   398  
   399  			if err != nil {
   400  				t.Errorf("Validate() unexpected error for %s: %v", tc.name, err)
   401  			}
   402  			if got != tc.expected {
   403  				t.Errorf("Validate() = %v, want %v for %s", got, tc.expected, tc.name)
   404  			}
   405  		})
   406  	}
   407  }