golang.org/x/oauth2@v0.18.0/google/jwt_test.go (about)

     1  // Copyright 2016 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package google
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/rand"
    10  	"crypto/rsa"
    11  	"crypto/x509"
    12  	"encoding/base64"
    13  	"encoding/json"
    14  	"encoding/pem"
    15  	"strings"
    16  	"sync"
    17  	"testing"
    18  	"time"
    19  
    20  	"golang.org/x/oauth2/jws"
    21  )
    22  
    23  var (
    24  	privateKey *rsa.PrivateKey
    25  	jsonKey    []byte
    26  	once       sync.Once
    27  )
    28  
    29  func TestJWTAccessTokenSourceFromJSON(t *testing.T) {
    30  	setupDummyKey(t)
    31  
    32  	ts, err := JWTAccessTokenSourceFromJSON(jsonKey, "audience")
    33  	if err != nil {
    34  		t.Fatalf("JWTAccessTokenSourceFromJSON: %v\nJSON: %s", err, string(jsonKey))
    35  	}
    36  
    37  	tok, err := ts.Token()
    38  	if err != nil {
    39  		t.Fatalf("Token: %v", err)
    40  	}
    41  
    42  	if got, want := tok.TokenType, "Bearer"; got != want {
    43  		t.Errorf("TokenType = %q, want %q", got, want)
    44  	}
    45  	if got := tok.Expiry; tok.Expiry.Before(time.Now()) {
    46  		t.Errorf("Expiry = %v, should not be expired", got)
    47  	}
    48  
    49  	err = jws.Verify(tok.AccessToken, &privateKey.PublicKey)
    50  	if err != nil {
    51  		t.Errorf("jws.Verify on AccessToken: %v", err)
    52  	}
    53  
    54  	claim, err := jws.Decode(tok.AccessToken)
    55  	if err != nil {
    56  		t.Fatalf("jws.Decode on AccessToken: %v", err)
    57  	}
    58  
    59  	if got, want := claim.Iss, "gopher@developer.gserviceaccount.com"; got != want {
    60  		t.Errorf("Iss = %q, want %q", got, want)
    61  	}
    62  	if got, want := claim.Sub, "gopher@developer.gserviceaccount.com"; got != want {
    63  		t.Errorf("Sub = %q, want %q", got, want)
    64  	}
    65  	if got, want := claim.Aud, "audience"; got != want {
    66  		t.Errorf("Aud = %q, want %q", got, want)
    67  	}
    68  
    69  	// Finally, check the header private key.
    70  	parts := strings.Split(tok.AccessToken, ".")
    71  	hdrJSON, err := base64.RawURLEncoding.DecodeString(parts[0])
    72  	if err != nil {
    73  		t.Fatalf("base64 DecodeString: %v\nString: %q", err, parts[0])
    74  	}
    75  	var hdr jws.Header
    76  	if err := json.Unmarshal(hdrJSON, &hdr); err != nil {
    77  		t.Fatalf("json.Unmarshal: %v (%q)", err, hdrJSON)
    78  	}
    79  
    80  	if got, want := hdr.KeyID, "268f54e43a1af97cfc71731688434f45aca15c8b"; got != want {
    81  		t.Errorf("Header KeyID = %q, want %q", got, want)
    82  	}
    83  }
    84  
    85  func TestJWTAccessTokenSourceWithScope(t *testing.T) {
    86  	setupDummyKey(t)
    87  
    88  	ts, err := JWTAccessTokenSourceWithScope(jsonKey, "scope1", "scope2")
    89  	if err != nil {
    90  		t.Fatalf("JWTAccessTokenSourceWithScope: %v\nJSON: %s", err, string(jsonKey))
    91  	}
    92  
    93  	tok, err := ts.Token()
    94  	if err != nil {
    95  		t.Fatalf("Token: %v", err)
    96  	}
    97  
    98  	if got, want := tok.TokenType, "Bearer"; got != want {
    99  		t.Errorf("TokenType = %q, want %q", got, want)
   100  	}
   101  	if got := tok.Expiry; tok.Expiry.Before(time.Now()) {
   102  		t.Errorf("Expiry = %v, should not be expired", got)
   103  	}
   104  
   105  	err = jws.Verify(tok.AccessToken, &privateKey.PublicKey)
   106  	if err != nil {
   107  		t.Errorf("jws.Verify on AccessToken: %v", err)
   108  	}
   109  
   110  	claim, err := jws.Decode(tok.AccessToken)
   111  	if err != nil {
   112  		t.Fatalf("jws.Decode on AccessToken: %v", err)
   113  	}
   114  
   115  	if got, want := claim.Iss, "gopher@developer.gserviceaccount.com"; got != want {
   116  		t.Errorf("Iss = %q, want %q", got, want)
   117  	}
   118  	if got, want := claim.Sub, "gopher@developer.gserviceaccount.com"; got != want {
   119  		t.Errorf("Sub = %q, want %q", got, want)
   120  	}
   121  	if got, want := claim.Scope, "scope1 scope2"; got != want {
   122  		t.Errorf("Aud = %q, want %q", got, want)
   123  	}
   124  
   125  	// Finally, check the header private key.
   126  	parts := strings.Split(tok.AccessToken, ".")
   127  	hdrJSON, err := base64.RawURLEncoding.DecodeString(parts[0])
   128  	if err != nil {
   129  		t.Fatalf("base64 DecodeString: %v\nString: %q", err, parts[0])
   130  	}
   131  	var hdr jws.Header
   132  	if err := json.Unmarshal(hdrJSON, &hdr); err != nil {
   133  		t.Fatalf("json.Unmarshal: %v (%q)", err, hdrJSON)
   134  	}
   135  
   136  	if got, want := hdr.KeyID, "268f54e43a1af97cfc71731688434f45aca15c8b"; got != want {
   137  		t.Errorf("Header KeyID = %q, want %q", got, want)
   138  	}
   139  }
   140  
   141  func setupDummyKey(t *testing.T) {
   142  	once.Do(func() {
   143  		// Generate a key we can use in the test data.
   144  		pk, err := rsa.GenerateKey(rand.Reader, 2048)
   145  		if err != nil {
   146  			t.Fatal(err)
   147  		}
   148  		privateKey = pk
   149  		// Encode the key and substitute into our example JSON.
   150  		enc := pem.EncodeToMemory(&pem.Block{
   151  			Type:  "PRIVATE KEY",
   152  			Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
   153  		})
   154  		enc, err = json.Marshal(string(enc))
   155  		if err != nil {
   156  			t.Fatalf("json.Marshal: %v", err)
   157  		}
   158  		jsonKey = bytes.Replace(jwtJSONKey, []byte(`"super secret key"`), enc, 1)
   159  	})
   160  }