golang.org/x/oauth2@v0.18.0/transport_test.go (about) 1 package oauth2 2 3 import ( 4 "errors" 5 "io" 6 "net/http" 7 "net/http/httptest" 8 "testing" 9 "time" 10 ) 11 12 type tokenSource struct{ token *Token } 13 14 func (t *tokenSource) Token() (*Token, error) { 15 return t.token, nil 16 } 17 18 func TestTransportNilTokenSource(t *testing.T) { 19 tr := &Transport{} 20 server := newMockServer(func(w http.ResponseWriter, r *http.Request) {}) 21 defer server.Close() 22 client := &http.Client{Transport: tr} 23 resp, err := client.Get(server.URL) 24 if err == nil { 25 t.Errorf("got no errors, want an error with nil token source") 26 } 27 if resp != nil { 28 t.Errorf("Response = %v; want nil", resp) 29 } 30 } 31 32 type readCloseCounter struct { 33 CloseCount int 34 ReadErr error 35 } 36 37 func (r *readCloseCounter) Read(b []byte) (int, error) { 38 return 0, r.ReadErr 39 } 40 41 func (r *readCloseCounter) Close() error { 42 r.CloseCount++ 43 return nil 44 } 45 46 func TestTransportCloseRequestBody(t *testing.T) { 47 tr := &Transport{} 48 server := newMockServer(func(w http.ResponseWriter, r *http.Request) {}) 49 defer server.Close() 50 client := &http.Client{Transport: tr} 51 body := &readCloseCounter{ 52 ReadErr: errors.New("readCloseCounter.Read not implemented"), 53 } 54 resp, err := client.Post(server.URL, "application/json", body) 55 if err == nil { 56 t.Errorf("got no errors, want an error with nil token source") 57 } 58 if resp != nil { 59 t.Errorf("Response = %v; want nil", resp) 60 } 61 if expected := 1; body.CloseCount != expected { 62 t.Errorf("Body was closed %d times, expected %d", body.CloseCount, expected) 63 } 64 } 65 66 func TestTransportCloseRequestBodySuccess(t *testing.T) { 67 tr := &Transport{ 68 Source: StaticTokenSource(&Token{ 69 AccessToken: "abc", 70 }), 71 } 72 server := newMockServer(func(w http.ResponseWriter, r *http.Request) {}) 73 defer server.Close() 74 client := &http.Client{Transport: tr} 75 body := &readCloseCounter{ 76 ReadErr: io.EOF, 77 } 78 resp, err := client.Post(server.URL, "application/json", body) 79 if err != nil { 80 t.Errorf("got error %v; expected none", err) 81 } 82 if resp == nil { 83 t.Errorf("Response is nil; expected non-nil") 84 } 85 if expected := 1; body.CloseCount != expected { 86 t.Errorf("Body was closed %d times, expected %d", body.CloseCount, expected) 87 } 88 } 89 90 func TestTransportTokenSource(t *testing.T) { 91 ts := &tokenSource{ 92 token: &Token{ 93 AccessToken: "abc", 94 }, 95 } 96 tr := &Transport{ 97 Source: ts, 98 } 99 server := newMockServer(func(w http.ResponseWriter, r *http.Request) { 100 if got, want := r.Header.Get("Authorization"), "Bearer abc"; got != want { 101 t.Errorf("Authorization header = %q; want %q", got, want) 102 } 103 }) 104 defer server.Close() 105 client := &http.Client{Transport: tr} 106 res, err := client.Get(server.URL) 107 if err != nil { 108 t.Fatal(err) 109 } 110 res.Body.Close() 111 } 112 113 // Test for case-sensitive token types, per https://github.com/golang/oauth2/issues/113 114 func TestTransportTokenSourceTypes(t *testing.T) { 115 const val = "abc" 116 tests := []struct { 117 key string 118 val string 119 want string 120 }{ 121 {key: "bearer", val: val, want: "Bearer abc"}, 122 {key: "mac", val: val, want: "MAC abc"}, 123 {key: "basic", val: val, want: "Basic abc"}, 124 } 125 for _, tc := range tests { 126 ts := &tokenSource{ 127 token: &Token{ 128 AccessToken: tc.val, 129 TokenType: tc.key, 130 }, 131 } 132 tr := &Transport{ 133 Source: ts, 134 } 135 server := newMockServer(func(w http.ResponseWriter, r *http.Request) { 136 if got, want := r.Header.Get("Authorization"), tc.want; got != want { 137 t.Errorf("Authorization header (%q) = %q; want %q", val, got, want) 138 } 139 }) 140 defer server.Close() 141 client := &http.Client{Transport: tr} 142 res, err := client.Get(server.URL) 143 if err != nil { 144 t.Fatal(err) 145 } 146 res.Body.Close() 147 } 148 } 149 150 func TestTokenValidNoAccessToken(t *testing.T) { 151 token := &Token{} 152 if token.Valid() { 153 t.Errorf("got valid with no access token; want invalid") 154 } 155 } 156 157 func TestExpiredWithExpiry(t *testing.T) { 158 token := &Token{ 159 Expiry: time.Now().Add(-5 * time.Hour), 160 } 161 if token.Valid() { 162 t.Errorf("got valid with expired token; want invalid") 163 } 164 } 165 166 func newMockServer(handler func(w http.ResponseWriter, r *http.Request)) *httptest.Server { 167 return httptest.NewServer(http.HandlerFunc(handler)) 168 }