github.com/google/osv-scalibr@v0.4.1/veles/secrets/gcpoauth2access/validator_test.go (about)

     1  // Copyright 2025 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package gcpoauth2access_test
    16  
    17  import (
    18  	"encoding/json"
    19  	"errors"
    20  	"io"
    21  	"net/http"
    22  	"net/url"
    23  	"strconv"
    24  	"strings"
    25  	"testing"
    26  	"time"
    27  
    28  	"github.com/google/go-cmp/cmp"
    29  	"github.com/google/go-cmp/cmp/cmpopts"
    30  	"github.com/google/osv-scalibr/veles"
    31  	"github.com/google/osv-scalibr/veles/secrets/gcpoauth2access"
    32  )
    33  
    34  const (
    35  	endpoint = "https://www.googleapis.com/oauth2/v3/tokeninfo"
    36  )
    37  
    38  type mockRoundTripper struct {
    39  	want *http.Request
    40  	resp *http.Response
    41  	err  error
    42  	t    *testing.T
    43  }
    44  
    45  func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
    46  	// Set empty header and host for convenience.
    47  	// We care if these fields are drastically different, but not if they are nil vs empty.
    48  	if m.want.Header == nil {
    49  		m.want.Header = make(http.Header)
    50  	}
    51  	if m.want.Host == "" {
    52  		m.want.Host = req.Host
    53  	}
    54  	opts := []cmp.Option{
    55  		cmpopts.IgnoreUnexported(http.Request{}),
    56  		cmpopts.IgnoreFields(http.Request{}, "Proto", "ProtoMajor", "ProtoMinor"),
    57  	}
    58  	if diff := cmp.Diff(m.want, req, opts...); diff != "" {
    59  		m.t.Fatalf("Received unexpected request (-want +got):\n%s", diff)
    60  	}
    61  	return m.resp, m.err
    62  }
    63  
    64  // response represents the response from Google's OAuth2 token endpoint.
    65  // https://developers.google.com/identity/protocols/oauth2
    66  type response struct {
    67  	// Expiry is the expiration time of the token in Unix time.
    68  	Expiry string `json:"exp"`
    69  	// ExpiresIn is the number of seconds until the token expires.
    70  	ExpiresIn string `json:"expires_in"`
    71  	// Scope is a space-delimited list that identify the resources that your application could access
    72  	// https://developers.google.com/identity/protocols/oauth2/scopes
    73  	Scope string `json:"scope"`
    74  }
    75  
    76  func TestValidator_Validate(t *testing.T) {
    77  	realTokenURL := mustURLWithParams(t, endpoint, map[string]string{"access_token": realToken})
    78  
    79  	tests := []struct {
    80  		name         string
    81  		roundTripper *mockRoundTripper
    82  		token        gcpoauth2access.Token
    83  		want         veles.ValidationStatus
    84  		wantErr      bool
    85  	}{
    86  		{
    87  			name: "empty",
    88  			token: gcpoauth2access.Token{
    89  				Token: "",
    90  			},
    91  			want:    veles.ValidationFailed,
    92  			wantErr: true,
    93  		},
    94  		{
    95  			name: "request_error",
    96  			roundTripper: &mockRoundTripper{
    97  				want: &http.Request{
    98  					Method: http.MethodGet,
    99  					URL:    realTokenURL,
   100  				},
   101  				err: errors.New("request error"),
   102  			},
   103  			token: gcpoauth2access.Token{
   104  				Token: realToken,
   105  			},
   106  			want:    veles.ValidationFailed,
   107  			wantErr: true,
   108  		},
   109  		{
   110  			name: "bad_request",
   111  			roundTripper: &mockRoundTripper{
   112  				want: &http.Request{
   113  					Method: http.MethodGet,
   114  					URL:    realTokenURL,
   115  				},
   116  				resp: &http.Response{
   117  					StatusCode: http.StatusBadRequest,
   118  				},
   119  			},
   120  			token: gcpoauth2access.Token{
   121  				Token: realToken,
   122  			},
   123  			want: veles.ValidationInvalid,
   124  		},
   125  		{
   126  			name: "server_error",
   127  			roundTripper: &mockRoundTripper{
   128  				want: &http.Request{
   129  					Method: http.MethodGet,
   130  					URL:    realTokenURL,
   131  				},
   132  				resp: &http.Response{
   133  					StatusCode: http.StatusInternalServerError,
   134  				},
   135  			},
   136  			token: gcpoauth2access.Token{
   137  				Token: realToken,
   138  			},
   139  			want:    veles.ValidationFailed,
   140  			wantErr: true,
   141  		},
   142  		{
   143  			name: "unexpected_json",
   144  			roundTripper: &mockRoundTripper{
   145  				want: &http.Request{
   146  					Method: http.MethodGet,
   147  					URL:    realTokenURL,
   148  				},
   149  				resp: &http.Response{
   150  					StatusCode: http.StatusOK,
   151  					Body:       io.NopCloser(strings.NewReader("unexpected json")),
   152  				},
   153  			},
   154  			token: gcpoauth2access.Token{
   155  				Token: realToken,
   156  			},
   157  			want:    veles.ValidationFailed,
   158  			wantErr: true,
   159  		},
   160  		{
   161  			name: "valid_token",
   162  			roundTripper: &mockRoundTripper{
   163  				want: &http.Request{
   164  					Method: http.MethodGet,
   165  					URL:    realTokenURL,
   166  				},
   167  				resp: &http.Response{
   168  					StatusCode: http.StatusOK,
   169  					Body: mustJSONReadCloser(t, response{
   170  						Expiry:    "1743465600",
   171  						ExpiresIn: "3600",
   172  						Scope:     "https://www.googleapis.com/auth/cloud-platform",
   173  					}),
   174  				},
   175  			},
   176  			token: gcpoauth2access.Token{
   177  				Token: realToken,
   178  			},
   179  			want: veles.ValidationValid,
   180  		},
   181  		{
   182  			name: "expired_based_on_expires_in",
   183  			roundTripper: &mockRoundTripper{
   184  				want: &http.Request{
   185  					Method: http.MethodGet,
   186  					URL:    realTokenURL,
   187  				},
   188  				resp: &http.Response{
   189  					StatusCode: http.StatusOK,
   190  					Body: mustJSONReadCloser(t, response{
   191  						Expiry:    strconv.FormatInt(time.Now().Add(time.Hour).Unix(), 10),
   192  						ExpiresIn: "0",
   193  						Scope:     "https://www.googleapis.com/auth/cloud-platform",
   194  					}),
   195  				},
   196  			},
   197  			token: gcpoauth2access.Token{
   198  				Token: realToken,
   199  			},
   200  			want: veles.ValidationInvalid,
   201  		},
   202  		{
   203  			name: "expired_based_on_expiry",
   204  			roundTripper: &mockRoundTripper{
   205  				want: &http.Request{
   206  					Method: http.MethodGet,
   207  					URL:    realTokenURL,
   208  				},
   209  				resp: &http.Response{
   210  					StatusCode: http.StatusOK,
   211  					Body: mustJSONReadCloser(t, response{
   212  						Expiry:    strconv.FormatInt(time.Now().Add(-time.Hour).Unix(), 10),
   213  						ExpiresIn: "unparsable",
   214  						Scope:     "https://www.googleapis.com/auth/cloud-platform",
   215  					}),
   216  				},
   217  			},
   218  			token: gcpoauth2access.Token{
   219  				Token: realToken,
   220  			},
   221  			want: veles.ValidationInvalid,
   222  		},
   223  	}
   224  
   225  	for _, tc := range tests {
   226  		t.Run(tc.name, func(t *testing.T) {
   227  			if tc.roundTripper != nil {
   228  				tc.roundTripper.t = t
   229  			}
   230  			v := gcpoauth2access.NewValidator()
   231  			v.HTTPC = &http.Client{Transport: tc.roundTripper}
   232  
   233  			got, err := v.Validate(t.Context(), tc.token)
   234  			if tc.wantErr {
   235  				if err == nil {
   236  					t.Errorf("Validate() error: %v, want error: %t", err, tc.wantErr)
   237  				}
   238  			} else {
   239  				if err != nil {
   240  					t.Errorf("Validate() error: %v, want nil", err)
   241  				}
   242  			}
   243  			if got != tc.want {
   244  				t.Errorf("Validate() = %q, want %q", got, tc.want)
   245  			}
   246  		})
   247  	}
   248  }
   249  
   250  func mustURLWithParams(t *testing.T, endpoint string, params map[string]string) *url.URL {
   251  	t.Helper()
   252  	endpointURL, err := url.Parse(endpoint)
   253  	if err != nil {
   254  		t.Fatalf("Failed to parse endpoint: %v", err)
   255  	}
   256  
   257  	paramsURL := url.Values{}
   258  	for k, v := range params {
   259  		paramsURL.Set(k, v)
   260  	}
   261  	endpointURL.RawQuery = paramsURL.Encode()
   262  	return endpointURL
   263  }
   264  
   265  // mustJSONReadCloser marshals a struct into JSON, converts it to an io.Reader,
   266  // and wraps it in an io.ReadCloser, failing the test if marshaling fails.
   267  func mustJSONReadCloser(t *testing.T, data any) io.ReadCloser {
   268  	t.Helper()
   269  	b, err := json.Marshal(data)
   270  	if err != nil {
   271  		t.Fatalf("Failed to marshal struct to JSON: %v", err)
   272  	}
   273  	return io.NopCloser(strings.NewReader(string(b)))
   274  }