github.com/bazelbuild/remote-apis-sdks@v0.0.0-20240425170053-8a36686a6350/go/pkg/actas/actas_test.go (about)

     1  package actas
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"testing"
    10  	"time"
    11  
    12  	log "github.com/golang/glog"
    13  	"github.com/google/go-cmp/cmp"
    14  	"github.com/google/go-cmp/cmp/cmpopts"
    15  	"golang.org/x/oauth2"
    16  	"google.golang.org/grpc/codes"
    17  	"google.golang.org/grpc/credentials"
    18  	"google.golang.org/grpc/status"
    19  )
    20  
    21  const (
    22  	account = "fake-account@fake-consumer.iam.gserviceaccount.com"
    23  	scope   = "https://www.googleapis.com/auth/cloud-platform"
    24  )
    25  
    26  // defaultHandler is a default HTTP request handler that returns a nil body.
    27  func defaultHandler(req *http.Request) (interface{}, error) {
    28  	return nil, nil
    29  }
    30  
    31  // fakeHTTP provides an HTTP server routing calls to its handler and a client connected to it.
    32  type fakeHTTP struct {
    33  	// Handler is handler for HTTP requests.
    34  	Handler func(req *http.Request) (body interface{}, err error)
    35  	// Server is the HTTP server. Uses handler for handling requests.
    36  	Server *httptest.Server
    37  	// Client is the HTTP client connected to the HTTP server.
    38  	Client *http.Client
    39  }
    40  
    41  // newFakeHTTP creates a new HTTP server and client.
    42  func newFakeHTTP() (*fakeHTTP, func() error) {
    43  	f := &fakeHTTP{}
    44  	f.Handler = defaultHandler
    45  	h := func(w http.ResponseWriter, req *http.Request) {
    46  		log.Infof("HTTP Request: %+v", req)
    47  		body, err := f.Handler(req)
    48  		if err != nil {
    49  			// This is not strictly the correct HTTP status (different gRPC codes should lead to different
    50  			// HTTP statuses), but for the purposes of this test it doesn't matter.
    51  			http.Error(w, err.Error(), http.StatusInternalServerError)
    52  			return
    53  		}
    54  		if err := json.NewEncoder(w).Encode(body); err != nil {
    55  			log.Errorf("json.NewEncoder(%v).Encode(%v) failed: %v", w, body, err)
    56  			http.Error(w, "encoding the response failed", http.StatusInternalServerError)
    57  			return
    58  		}
    59  	}
    60  
    61  	f.Server = httptest.NewServer(http.HandlerFunc(h))
    62  	f.Client = f.Server.Client()
    63  	cleanup := func() error {
    64  		f.Server.Close()
    65  		return nil
    66  	}
    67  	return f, cleanup
    68  }
    69  
    70  type stubDefaultCredentials struct {
    71  	credentials.PerRPCCredentials
    72  	err error
    73  }
    74  
    75  func (s *stubDefaultCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
    76  	return map[string]string{"key": "value"}, s.err
    77  }
    78  
    79  func TestTokenSource_Token(t *testing.T) {
    80  	ctx := context.Background()
    81  
    82  	h, cleanup := newFakeHTTP()
    83  	defer cleanup()
    84  
    85  	// Override URLs for test.
    86  	newSignJWTURL = func(string) string {
    87  		return h.Server.URL + "/sign"
    88  	}
    89  	audienceURL = h.Server.URL + "/token"
    90  	// duration of token
    91  	duration := 10 * time.Minute
    92  	token := 0
    93  	// Populate the response in the http request handler based on the request URL.
    94  	h.Handler = func(req *http.Request) (body interface{}, err error) {
    95  		log.Infof("HTTP Request: %+v", req)
    96  		if req.URL.Path == "/sign" {
    97  			return &signaturePayload{
    98  				KeyID:     "fake-key-id",
    99  				SignedJwt: "fake-signed-jwt",
   100  			}, nil
   101  		}
   102  		if req.URL.Path == "/token" {
   103  			// Each time a call is made for obtaining a token we return a new one.
   104  			// This is to check that calls are made only when needed.
   105  			token++
   106  			return &tokenPayload{
   107  				AccessToken: fmt.Sprintf("fake-access-token-%v", token),
   108  				TokenType:   "fake-token-type",
   109  				ExpiresIn:   int64(duration.Seconds()),
   110  			}, nil
   111  		}
   112  		return nil, nil
   113  	}
   114  	d := &stubDefaultCredentials{}
   115  
   116  	before := time.Now()
   117  	s := NewTokenSource(ctx, d, h.Client, account, []string{scope})
   118  
   119  	// First token.
   120  	got, err := s.Token()
   121  	if err != nil {
   122  		t.Fatalf("Token() failed: %v", err)
   123  	}
   124  	want := &oauth2.Token{
   125  		AccessToken: "fake-access-token-1",
   126  		TokenType:   "fake-token-type",
   127  	}
   128  	opts := cmp.Options{cmpopts.IgnoreUnexported(oauth2.Token{}), cmpopts.IgnoreFields(oauth2.Token{}, "Expiry")}
   129  	if diff := cmp.Diff(want, got, opts); diff != "" {
   130  		t.Errorf("Token() returned diff:\n%s\n", diff)
   131  	}
   132  
   133  	// Check if the expiry is in expected range (want-2,want+2)
   134  	w := 2 * time.Minute
   135  	wantExp := before.Add(duration)
   136  	if got.Expiry.Before(wantExp.Add(-w)) || got.Expiry.After(wantExp.Add(w)) {
   137  		t.Errorf("Token().Expiry = %+v, want in [%v,%v])", got.Expiry, wantExp.Add(-w), wantExp.Add(w))
   138  	}
   139  
   140  	// Second token.
   141  	// Expiry is after now, but actas.TokenSource should not do its own caching, so we expect a new
   142  	// token.
   143  	want.AccessToken = "fake-access-token-2"
   144  	got, err = s.Token()
   145  	if err != nil {
   146  		t.Fatalf("Token() failed: %v", err)
   147  	}
   148  	if diff := cmp.Diff(want, got, opts); diff != "" {
   149  		t.Errorf("Token() returned diff:\n%s\n", diff)
   150  	}
   151  }
   152  
   153  func TestNewTokenSource_CredGetRequestMetadataFails(t *testing.T) {
   154  	ctx := context.Background()
   155  
   156  	h, cleanup := newFakeHTTP()
   157  	defer cleanup()
   158  
   159  	// Override URLs for test.
   160  	newSignJWTURL = func(string) string {
   161  		return h.Server.URL + "/sign"
   162  	}
   163  	audienceURL = h.Server.URL + "/token"
   164  
   165  	d := &stubDefaultCredentials{err: status.Error(codes.Unknown, "some error")}
   166  
   167  	s := NewTokenSource(ctx, d, h.Client, account, []string{scope})
   168  	if _, err := s.Token(); err == nil {
   169  		t.Fatal("Token() should fail when GetRequestMetadata fails.")
   170  	}
   171  }
   172  
   173  func TestTokenSource_Token_GettingSignatureFails(t *testing.T) {
   174  	ctx := context.Background()
   175  
   176  	h, cleanup := newFakeHTTP()
   177  	defer cleanup()
   178  
   179  	// Override URLs for test.
   180  	newSignJWTURL = func(string) string {
   181  		return h.Server.URL + "/sign"
   182  	}
   183  	audienceURL = h.Server.URL + "/token"
   184  	// duration of token
   185  	duration := 10 * time.Minute
   186  	token := 0
   187  	// Populate the response in the http request handler based on the request URL.
   188  	h.Handler = func(req *http.Request) (body interface{}, err error) {
   189  		log.Infof("HTTP Request: %+v", req)
   190  		if req.URL.Path == "/sign" {
   191  			return nil, status.Error(codes.Unknown, "some error")
   192  		}
   193  		if req.URL.Path == "/token" {
   194  			token++
   195  			return &tokenPayload{
   196  				AccessToken: fmt.Sprintf("fake-access-token-%v", token),
   197  				TokenType:   "fake-token-type",
   198  				ExpiresIn:   int64(duration.Nanoseconds()),
   199  			}, nil
   200  		}
   201  		return nil, nil
   202  	}
   203  	d := &stubDefaultCredentials{}
   204  
   205  	s := NewTokenSource(ctx, d, h.Client, account, []string{scope})
   206  	if _, err := s.Token(); err == nil {
   207  		t.Fatalf("Token() should fail when cannot get signature.")
   208  	}
   209  }
   210  
   211  func TestTokenSource_Token_GettingTokenFails(t *testing.T) {
   212  	ctx := context.Background()
   213  
   214  	h, cleanup := newFakeHTTP()
   215  	defer cleanup()
   216  
   217  	// Override URLs for test.
   218  	newSignJWTURL = func(string) string {
   219  		return h.Server.URL + "/sign"
   220  	}
   221  	audienceURL = h.Server.URL + "/token"
   222  
   223  	// Populate the response in the http request handler based on the request URL.
   224  	h.Handler = func(req *http.Request) (body interface{}, err error) {
   225  		log.Infof("HTTP Request: %+v", req)
   226  		if req.URL.Path == "/sign" {
   227  			return &signaturePayload{
   228  				KeyID:     "fake-key-id",
   229  				SignedJwt: "fake-signed-jwt",
   230  			}, nil
   231  		}
   232  		if req.URL.Path == "/token" {
   233  			return nil, status.Error(codes.Unknown, "some error")
   234  		}
   235  		return nil, nil
   236  	}
   237  	d := &stubDefaultCredentials{}
   238  
   239  	s := NewTokenSource(ctx, d, h.Client, account, []string{scope})
   240  	if _, err := s.Token(); err == nil {
   241  		t.Fatalf("Token() should fail when cannot get token.")
   242  	}
   243  }
   244  
   245  func Test_isOK(t *testing.T) {
   246  	tests := []struct {
   247  		code int
   248  		want bool
   249  	}{
   250  		{
   251  			code: 199,
   252  			want: false,
   253  		},
   254  		{
   255  			code: 200,
   256  			want: true,
   257  		},
   258  		{
   259  			code: 299,
   260  			want: true,
   261  		},
   262  		{
   263  			code: 300,
   264  			want: false,
   265  		},
   266  	}
   267  	for _, tc := range tests {
   268  		if got := isOK(tc.code); got != tc.want {
   269  			t.Errorf("isOK(%v) = %v, want %v", tc.code, got, tc.want)
   270  		}
   271  	}
   272  }