github.com/google/osv-scalibr@v0.4.1/veles/secrets/hcp/validator_test.go (about)

     1  // Copyright 2025 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package hcp_test
    16  
    17  import (
    18  	"context"
    19  	"encoding/json"
    20  	"io"
    21  	"net/http"
    22  	"net/http/httptest"
    23  	"net/url"
    24  	"strings"
    25  	"testing"
    26  
    27  	"github.com/google/go-cmp/cmp"
    28  	"github.com/google/go-cmp/cmp/cmpopts"
    29  	"github.com/google/osv-scalibr/veles"
    30  	"github.com/google/osv-scalibr/veles/secrets/hcp"
    31  )
    32  
    33  const (
    34  	validatorTestClientID     = "53au9oDSqR8SBzIy6QJASHnyC1SMQxE3"
    35  	validatorTestClientSecret = "x2Nyv_C0NiJLEheDO5LuAmJj7v_SrY5cpWWCi38WCcmohTFzAl48zoiEFivQBU2n"
    36  	validatorTestAccessToken  = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2F1dGguaWRwLmhhc2hpY29ycC5jb20vIiwiYXVkIjpbImh0dHBzOi8vYXBpLmhhc2hpY29ycC5jbG91ZCJdLCJndHkiOiJjbGllbnQtY3JlZGVudGlhbHMiLCJodHRwczovL2Nsb3VkLmhhc2hpY29ycC5jb20vcHJpbmNpcGFsLXR5cGUiOiJzZXJ2aWNlIn0.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c"
    37  )
    38  
    39  // mockTokenServer returns a server that emulates the HCP token endpoint behavior.
    40  func mockTokenServer(t *testing.T, expectID, expectSecret string, success bool) *httptest.Server {
    41  	t.Helper()
    42  	return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    43  		if r.Method != http.MethodPost || r.URL.Path != "/oauth2/token" {
    44  			t.Errorf("unexpected request: %s %s", r.Method, r.URL.Path)
    45  			http.Error(w, "not found", http.StatusNotFound)
    46  			return
    47  		}
    48  		if ct := r.Header.Get("Content-Type"); !strings.HasPrefix(ct, "application/x-www-form-urlencoded") {
    49  			t.Errorf("unexpected content-type: %q", ct)
    50  			http.Error(w, "bad request", http.StatusBadRequest)
    51  			return
    52  		}
    53  
    54  		body, err := io.ReadAll(r.Body)
    55  		if err != nil {
    56  			t.Fatalf("unable to read request body: %v", err)
    57  		}
    58  		_ = r.Body.Close()
    59  		vals, err := url.ParseQuery(string(body))
    60  		if err != nil {
    61  			t.Fatalf("unable to parse form: %v", err)
    62  		}
    63  		if vals.Get("grant_type") != "client_credentials" {
    64  			t.Errorf("unexpected grant_type: %q", vals.Get("grant_type"))
    65  		}
    66  
    67  		if success && (vals.Get("client_id") != expectID || vals.Get("client_secret") != expectSecret) {
    68  			w.WriteHeader(http.StatusUnauthorized)
    69  			_, _ = w.Write([]byte(`{"error":"unauthorized"}`))
    70  			return
    71  		}
    72  		if !success {
    73  			w.WriteHeader(http.StatusUnauthorized)
    74  			_, _ = w.Write([]byte(`{"error":"unauthorized"}`))
    75  			return
    76  		}
    77  
    78  		w.Header().Set("Content-Type", "application/json")
    79  		w.WriteHeader(http.StatusOK)
    80  		_ = json.NewEncoder(w).Encode(map[string]any{
    81  			"access_token": validatorTestAccessToken,
    82  			"token_type":   "Bearer",
    83  			"expires_in":   3600,
    84  		})
    85  	}))
    86  }
    87  
    88  func TestClientCredentialsValidator(t *testing.T) {
    89  	cases := []struct {
    90  		name   string
    91  		id     string
    92  		secret string
    93  		ok     bool
    94  		want   veles.ValidationStatus
    95  	}{
    96  		{name: "valid_pair", id: validatorTestClientID, secret: validatorTestClientSecret, ok: true, want: veles.ValidationValid},
    97  		{name: "invalid_pair", id: validatorTestClientID, secret: "wrong_secret", ok: false, want: veles.ValidationInvalid},
    98  		{name: "missing_id", id: "", secret: validatorTestClientSecret, ok: true, want: veles.ValidationInvalid},
    99  		{name: "missing_secret", id: validatorTestClientID, secret: "", ok: true, want: veles.ValidationInvalid},
   100  	}
   101  
   102  	for _, tc := range cases {
   103  		t.Run(tc.name, func(t *testing.T) {
   104  			srv := mockTokenServer(t, validatorTestClientID, validatorTestClientSecret, tc.ok)
   105  			defer srv.Close()
   106  
   107  			v := hcp.NewClientCredentialsValidator(hcp.WithTokenURL(srv.URL + "/oauth2/token"))
   108  
   109  			got, err := v.Validate(context.Background(), hcp.ClientCredentials{ClientID: tc.id, ClientSecret: tc.secret})
   110  			if err != nil && (tc.want == veles.ValidationValid || tc.want == veles.ValidationInvalid || tc.want == veles.ValidationUnsupported) {
   111  				t.Fatalf("Validate() unexpected error: %v", err)
   112  			}
   113  			if got != tc.want {
   114  				t.Errorf("Validate() = %v, want %v", got, tc.want)
   115  			}
   116  		})
   117  	}
   118  }
   119  
   120  func TestClientCredentialsValidator_Errors(t *testing.T) {
   121  	t.Run("connection error returns failed", func(t *testing.T) {
   122  		// Start and immediately close to simulate server down
   123  		srv := mockTokenServer(t, validatorTestClientID, validatorTestClientSecret, true)
   124  		base := srv.URL
   125  		srv.Close()
   126  
   127  		v := hcp.NewClientCredentialsValidator(hcp.WithTokenURL(base + "/oauth2/token"))
   128  
   129  		got, err := v.Validate(context.Background(), hcp.ClientCredentials{ClientID: validatorTestClientID, ClientSecret: validatorTestClientSecret})
   130  		if err == nil {
   131  			t.Fatalf("expected error due to connection failure, got nil")
   132  		}
   133  		if got != veles.ValidationFailed {
   134  			t.Errorf("Status = %v, want %v", got, veles.ValidationFailed)
   135  		}
   136  	})
   137  
   138  	t.Run("server down returns failed", func(t *testing.T) {
   139  		// Create a server that always returns 500 on /oauth2/token
   140  		srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   141  			if r.Method != http.MethodPost || r.URL.Path != "/oauth2/token" {
   142  				http.Error(w, "not found", http.StatusNotFound)
   143  				return
   144  			}
   145  			w.WriteHeader(http.StatusInternalServerError)
   146  			_, _ = w.Write([]byte(`{"error":"internal"}`))
   147  		}))
   148  		defer srv.Close()
   149  
   150  		v := hcp.NewClientCredentialsValidator(hcp.WithTokenURL(srv.URL + "/oauth2/token"))
   151  
   152  		got, err := v.Validate(context.Background(), hcp.ClientCredentials{ClientID: validatorTestClientID, ClientSecret: validatorTestClientSecret})
   153  		if err == nil {
   154  			t.Fatalf("expected error for 500 response, got nil")
   155  		}
   156  		if got != veles.ValidationFailed {
   157  			t.Errorf("Status = %v, want %v", got, veles.ValidationFailed)
   158  		}
   159  	})
   160  }
   161  
   162  // mockAPIBaseServer returns a server that emulates a minimal HCP API base for token validation.
   163  func mockAPIBaseServer(t *testing.T, status int) *httptest.Server {
   164  	t.Helper()
   165  	return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   166  		if r.Method != http.MethodGet || r.URL.Path != "/iam/2019-12-10/caller-identity" {
   167  			t.Errorf("unexpected request: %s %s", r.Method, r.URL.Path)
   168  			http.Error(w, "not found", http.StatusNotFound)
   169  			return
   170  		}
   171  		if !strings.HasPrefix(r.Header.Get("Authorization"), "Bearer ") {
   172  			http.Error(w, "unauthorized", http.StatusUnauthorized)
   173  			return
   174  		}
   175  		w.WriteHeader(status)
   176  	}))
   177  }
   178  
   179  func TestAccessTokenValidator(t *testing.T) {
   180  	cases := []struct {
   181  		name  string
   182  		httpS int
   183  		want  veles.ValidationStatus
   184  	}{
   185  		{name: "ok_200", httpS: http.StatusOK, want: veles.ValidationValid},
   186  		{name: "unauthorized_401", httpS: http.StatusUnauthorized, want: veles.ValidationInvalid},
   187  	}
   188  	for _, tc := range cases {
   189  		t.Run(tc.name, func(t *testing.T) {
   190  			srv := mockAPIBaseServer(t, tc.httpS)
   191  			defer srv.Close()
   192  
   193  			v := hcp.NewAccessTokenValidator(hcp.WithAPIBase(srv.URL))
   194  
   195  			got, err := v.Validate(context.Background(), hcp.AccessToken{Token: validatorTestAccessToken})
   196  			if !cmp.Equal(err, nil, cmpopts.EquateErrors()) {
   197  				t.Fatalf("Validate() unexpected error: %v", err)
   198  			}
   199  			if got != tc.want {
   200  				t.Errorf("Validate() = %v, want %v", got, tc.want)
   201  			}
   202  		})
   203  	}
   204  }
   205  
   206  func TestAccessTokenValidator_Errors(t *testing.T) {
   207  	t.Run("server down returns failed", func(t *testing.T) {
   208  		// Start and immediately close to simulate server down
   209  		srv := mockAPIBaseServer(t, http.StatusOK)
   210  		base := srv.URL
   211  		srv.Close()
   212  
   213  		v := hcp.NewAccessTokenValidator(hcp.WithAPIBase(base))
   214  		got, err := v.Validate(context.Background(), hcp.AccessToken{Token: validatorTestAccessToken})
   215  		if err == nil {
   216  			t.Fatalf("expected error due to connection failure, got nil")
   217  		}
   218  		if got != veles.ValidationFailed {
   219  			t.Errorf("Status = %v, want %v", got, veles.ValidationFailed)
   220  		}
   221  	})
   222  
   223  	t.Run("server error returns failed", func(t *testing.T) {
   224  		srv := mockAPIBaseServer(t, http.StatusInternalServerError)
   225  		defer srv.Close()
   226  
   227  		v := hcp.NewAccessTokenValidator(hcp.WithAPIBase(srv.URL))
   228  		got, err := v.Validate(context.Background(), hcp.AccessToken{Token: validatorTestAccessToken})
   229  		if err == nil {
   230  			t.Fatalf("expected error for 500 response, got nil")
   231  		}
   232  		if got != veles.ValidationFailed {
   233  			t.Errorf("Status = %v, want %v", got, veles.ValidationFailed)
   234  		}
   235  	})
   236  }