github.com/google/osv-scalibr@v0.4.1/veles/secrets/common/simplevalidate/simplevalidate_test.go (about)

     1  // Copyright 2025 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package simplevalidate_test
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"errors"
    21  	"io"
    22  	"net/http"
    23  	"net/http/httptest"
    24  	"net/url"
    25  	"testing"
    26  
    27  	"github.com/google/go-cmp/cmp"
    28  	"github.com/google/go-cmp/cmp/cmpopts"
    29  	"github.com/google/osv-scalibr/veles"
    30  	sv "github.com/google/osv-scalibr/veles/secrets/common/simplevalidate"
    31  	"github.com/google/osv-scalibr/veles/velestest"
    32  )
    33  
    34  type mockRoundTripper struct {
    35  	want           *http.Request
    36  	respStatusCode int
    37  	respBody       []byte
    38  	err            error
    39  	t              *testing.T
    40  }
    41  
    42  func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
    43  	opts := []cmp.Option{
    44  		cmpopts.IgnoreUnexported(http.Request{}),
    45  		cmpopts.IgnoreFields(http.Request{}, "Proto", "ProtoMajor", "ProtoMinor", "GetBody"),
    46  	}
    47  	if diff := cmp.Diff(m.want, req, opts...); diff != "" {
    48  		m.t.Fatalf("Received unexpected request (-want +got):\n%s", diff)
    49  	}
    50  
    51  	return &http.Response{
    52  		StatusCode: m.respStatusCode,
    53  		Body:       io.NopCloser(bytes.NewReader(m.respBody)),
    54  	}, m.err
    55  }
    56  
    57  // mustParse is used for creating URLs in a "single-value" context.
    58  // It panics if the URL is invalid.
    59  func mustParse(s string) *url.URL {
    60  	u, err := url.Parse(s)
    61  	if err != nil {
    62  		panic("url.Parse: invalid URL string: " + s)
    63  	}
    64  	return u
    65  }
    66  
    67  func TestValidate(t *testing.T) {
    68  	testSecret := "TEST-SECRET"
    69  	testURLStr := "https://test"
    70  	testHost := "test"
    71  	testURL, err := url.Parse(testURLStr)
    72  	if err != nil {
    73  		t.Fatalf("url.Parse(%q): %v", testURLStr, err)
    74  	}
    75  
    76  	tests := []struct {
    77  		desc         string
    78  		validator    *sv.Validator[velestest.FakeStringSecret]
    79  		secret       string
    80  		roundTripper *mockRoundTripper
    81  		want         veles.ValidationStatus
    82  		wantErr      error
    83  	}{
    84  		{
    85  			desc: "valid_response",
    86  			validator: &sv.Validator[velestest.FakeStringSecret]{
    87  				Endpoint:   testURLStr,
    88  				HTTPMethod: http.MethodGet,
    89  				HTTPHeaders: func(s velestest.FakeStringSecret) map[string]string {
    90  					return map[string]string{"Authorization": "Bearer " + s.Value}
    91  				},
    92  				ValidResponseCodes: []int{http.StatusOK},
    93  			},
    94  			secret: testSecret,
    95  			roundTripper: &mockRoundTripper{
    96  				want: &http.Request{
    97  					Method: http.MethodGet,
    98  					URL:    testURL,
    99  					Host:   testHost,
   100  					Header: http.Header{"Authorization": []string{"Bearer " + testSecret}},
   101  				},
   102  				respStatusCode: http.StatusOK,
   103  				t:              t,
   104  			},
   105  			want: veles.ValidationValid,
   106  		},
   107  		{
   108  			desc: "invalid_response",
   109  			validator: &sv.Validator[velestest.FakeStringSecret]{
   110  				Endpoint:   testURLStr,
   111  				HTTPMethod: http.MethodGet,
   112  				HTTPHeaders: func(s velestest.FakeStringSecret) map[string]string {
   113  					return map[string]string{"Authorization": "Bearer " + s.Value}
   114  				},
   115  				ValidResponseCodes:   []int{http.StatusOK},
   116  				InvalidResponseCodes: []int{http.StatusUnauthorized},
   117  			},
   118  			secret: testSecret,
   119  			roundTripper: &mockRoundTripper{
   120  				want: &http.Request{
   121  					Method: http.MethodGet,
   122  					URL:    testURL,
   123  					Host:   testHost,
   124  					Header: http.Header{"Authorization": []string{"Bearer " + testSecret}},
   125  				},
   126  				respStatusCode: http.StatusUnauthorized,
   127  				t:              t,
   128  			},
   129  			want: veles.ValidationInvalid,
   130  		},
   131  		{
   132  			desc: "failed_response",
   133  			validator: &sv.Validator[velestest.FakeStringSecret]{
   134  				Endpoint:   testURLStr,
   135  				HTTPMethod: http.MethodGet,
   136  				HTTPHeaders: func(s velestest.FakeStringSecret) map[string]string {
   137  					return map[string]string{"Authorization": "Bearer " + s.Value}
   138  				},
   139  				ValidResponseCodes:   []int{http.StatusOK},
   140  				InvalidResponseCodes: []int{http.StatusUnauthorized},
   141  			},
   142  			secret: testSecret,
   143  			roundTripper: &mockRoundTripper{
   144  				want: &http.Request{
   145  					Method: http.MethodGet,
   146  					URL:    testURL,
   147  					Host:   testHost,
   148  					Header: http.Header{"Authorization": []string{"Bearer " + testSecret}},
   149  				},
   150  				respStatusCode: http.StatusInternalServerError,
   151  				t:              t,
   152  			},
   153  			want:    veles.ValidationFailed,
   154  			wantErr: cmpopts.AnyError,
   155  		},
   156  		{
   157  			desc: "custom_body_parsing_valid_response",
   158  			validator: &sv.Validator[velestest.FakeStringSecret]{
   159  				Endpoint: testURLStr,
   160  				HTTPHeaders: func(s velestest.FakeStringSecret) map[string]string {
   161  					return map[string]string{"Authorization": "Bearer " + s.Value}
   162  				},
   163  				StatusFromResponseBody: func(body io.Reader) (veles.ValidationStatus, error) {
   164  					content, err := io.ReadAll(body)
   165  					if err != nil {
   166  						return veles.ValidationFailed, err
   167  					}
   168  					if string(content) == "valid_secret" {
   169  						return veles.ValidationValid, nil
   170  					}
   171  					return veles.ValidationInvalid, nil
   172  				},
   173  			},
   174  			secret: testSecret,
   175  			roundTripper: &mockRoundTripper{
   176  				want: &http.Request{
   177  					Method: http.MethodGet,
   178  					URL:    testURL,
   179  					Host:   testHost,
   180  					Header: http.Header{"Authorization": []string{"Bearer " + testSecret}},
   181  				},
   182  				respStatusCode: http.StatusOK,
   183  				respBody:       []byte("valid_secret"),
   184  				t:              t,
   185  			},
   186  			want: veles.ValidationValid,
   187  		},
   188  		{
   189  			desc: "custom_body_parsing_invalid_response",
   190  			validator: &sv.Validator[velestest.FakeStringSecret]{
   191  				Endpoint: testURLStr,
   192  				HTTPHeaders: func(s velestest.FakeStringSecret) map[string]string {
   193  					return map[string]string{"Authorization": "Bearer " + s.Value}
   194  				},
   195  				StatusFromResponseBody: func(body io.Reader) (veles.ValidationStatus, error) {
   196  					content, err := io.ReadAll(body)
   197  					if err != nil {
   198  						return veles.ValidationFailed, err
   199  					}
   200  					if string(content) == "valid_secret" {
   201  						return veles.ValidationValid, nil
   202  					}
   203  					return veles.ValidationInvalid, nil
   204  				},
   205  			},
   206  			secret: testSecret,
   207  			roundTripper: &mockRoundTripper{
   208  				want: &http.Request{
   209  					Method: http.MethodGet,
   210  					URL:    testURL,
   211  					Host:   testHost,
   212  					Header: http.Header{"Authorization": []string{"Bearer " + testSecret}},
   213  				},
   214  				respStatusCode: http.StatusOK,
   215  				respBody:       []byte("not_a_valid_secret"),
   216  				t:              t,
   217  			},
   218  			want: veles.ValidationInvalid,
   219  		},
   220  		{
   221  			desc: "valid_response_with_endpointfunc",
   222  			validator: &sv.Validator[velestest.FakeStringSecret]{
   223  				EndpointFunc: func(s velestest.FakeStringSecret) (string, error) {
   224  					return testURLStr + "?token=" + s.Value, nil
   225  				},
   226  				HTTPMethod:         http.MethodGet,
   227  				ValidResponseCodes: []int{http.StatusOK},
   228  			},
   229  			secret: testSecret,
   230  			roundTripper: &mockRoundTripper{
   231  				want: &http.Request{
   232  					Method: http.MethodGet,
   233  					URL:    mustParse(testURLStr + "?token=" + testSecret),
   234  					Host:   testHost,
   235  					Header: http.Header{},
   236  				},
   237  				respStatusCode: http.StatusOK,
   238  				t:              t,
   239  			},
   240  			want: veles.ValidationValid,
   241  		},
   242  		{
   243  			desc: "endpoint_and_endpointfunc_provided",
   244  			validator: &sv.Validator[velestest.FakeStringSecret]{
   245  				Endpoint: testURLStr,
   246  				EndpointFunc: func(s velestest.FakeStringSecret) (string, error) {
   247  					return testURLStr, nil
   248  				},
   249  				HTTPMethod: http.MethodGet,
   250  			},
   251  			secret: testSecret,
   252  			roundTripper: &mockRoundTripper{
   253  				t: t,
   254  			},
   255  			want:    veles.ValidationFailed,
   256  			wantErr: cmpopts.AnyError,
   257  		},
   258  		{
   259  			desc: "no_endpoint_or_endpointfunc_provided",
   260  			validator: &sv.Validator[velestest.FakeStringSecret]{
   261  				HTTPMethod: http.MethodGet,
   262  			},
   263  			secret: testSecret,
   264  			roundTripper: &mockRoundTripper{
   265  				t: t,
   266  			},
   267  			want:    veles.ValidationFailed,
   268  			wantErr: cmpopts.AnyError,
   269  		},
   270  		{
   271  			desc: "body_func_returns_error",
   272  			validator: &sv.Validator[velestest.FakeStringSecret]{
   273  				Endpoint:   testURLStr,
   274  				HTTPMethod: http.MethodPost,
   275  				Body: func(s velestest.FakeStringSecret) (string, error) {
   276  					return "", errors.New("body construction failed")
   277  				},
   278  			},
   279  			secret: testSecret,
   280  			roundTripper: &mockRoundTripper{
   281  				t: t,
   282  			},
   283  			want:    veles.ValidationFailed,
   284  			wantErr: cmpopts.AnyError,
   285  		},
   286  		{
   287  			desc: "endpointfunc_returns_error",
   288  			validator: &sv.Validator[velestest.FakeStringSecret]{
   289  				EndpointFunc: func(s velestest.FakeStringSecret) (string, error) {
   290  					return "", errors.New("endpoint construction failed")
   291  				},
   292  				HTTPMethod: http.MethodGet,
   293  			},
   294  			secret: testSecret,
   295  			roundTripper: &mockRoundTripper{
   296  				t: t,
   297  			},
   298  			want:    veles.ValidationFailed,
   299  			wantErr: cmpopts.AnyError,
   300  		},
   301  	}
   302  
   303  	for _, tc := range tests {
   304  		t.Run(tc.desc, func(t *testing.T) {
   305  			tc.validator.HTTPC = &http.Client{Transport: tc.roundTripper}
   306  
   307  			secret := velestest.FakeStringSecret{Value: tc.secret}
   308  			got, err := tc.validator.Validate(t.Context(), secret)
   309  			if !cmp.Equal(err, tc.wantErr, cmpopts.EquateErrors()) {
   310  				t.Fatalf("Validate() error: got %v, want %v\n", err, tc.wantErr)
   311  			}
   312  
   313  			if got != tc.want {
   314  				t.Errorf("Validate() = %q, want %q", got, tc.want)
   315  			}
   316  		})
   317  	}
   318  }
   319  
   320  func TestValidate_respectsContext(t *testing.T) {
   321  	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   322  		w.WriteHeader(http.StatusOK)
   323  	}))
   324  	defer srv.Close()
   325  
   326  	validator := &sv.Validator[velestest.FakeStringSecret]{
   327  		HTTPC:      srv.Client(),
   328  		Endpoint:   "https://test",
   329  		HTTPMethod: http.MethodGet,
   330  	}
   331  	ctx, cancel := context.WithCancel(t.Context())
   332  	cancel()
   333  
   334  	secret := velestest.FakeStringSecret{Value: "abcd"}
   335  	if _, err := validator.Validate(ctx, secret); !errors.Is(err, context.Canceled) {
   336  		t.Errorf("Validate() error: %v, want context.Canceled", err)
   337  	}
   338  }