github.com/timstclair/heapster@v0.20.0-alpha1/Godeps/_workspace/src/golang.org/x/oauth2/transport.go (about)

     1  // Copyright 2014 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package oauth2
     6  
     7  import (
     8  	"errors"
     9  	"io"
    10  	"net/http"
    11  	"sync"
    12  )
    13  
    14  // Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests,
    15  // wrapping a base RoundTripper and adding an Authorization header
    16  // with a token from the supplied Sources.
    17  //
    18  // Transport is a low-level mechanism. Most code will use the
    19  // higher-level Config.Client method instead.
    20  type Transport struct {
    21  	// Source supplies the token to add to outgoing requests'
    22  	// Authorization headers.
    23  	Source TokenSource
    24  
    25  	// Base is the base RoundTripper used to make HTTP requests.
    26  	// If nil, http.DefaultTransport is used.
    27  	Base http.RoundTripper
    28  
    29  	mu     sync.Mutex                      // guards modReq
    30  	modReq map[*http.Request]*http.Request // original -> modified
    31  }
    32  
    33  // RoundTrip authorizes and authenticates the request with an
    34  // access token. If no token exists or token is expired,
    35  // tries to refresh/fetch a new token.
    36  func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
    37  	if t.Source == nil {
    38  		return nil, errors.New("oauth2: Transport's Source is nil")
    39  	}
    40  	token, err := t.Source.Token()
    41  	if err != nil {
    42  		return nil, err
    43  	}
    44  
    45  	req2 := cloneRequest(req) // per RoundTripper contract
    46  	token.SetAuthHeader(req2)
    47  	t.setModReq(req, req2)
    48  	res, err := t.base().RoundTrip(req2)
    49  	if err != nil {
    50  		t.setModReq(req, nil)
    51  		return nil, err
    52  	}
    53  	res.Body = &onEOFReader{
    54  		rc: res.Body,
    55  		fn: func() { t.setModReq(req, nil) },
    56  	}
    57  	return res, nil
    58  }
    59  
    60  // CancelRequest cancels an in-flight request by closing its connection.
    61  func (t *Transport) CancelRequest(req *http.Request) {
    62  	type canceler interface {
    63  		CancelRequest(*http.Request)
    64  	}
    65  	if cr, ok := t.base().(canceler); ok {
    66  		t.mu.Lock()
    67  		modReq := t.modReq[req]
    68  		delete(t.modReq, req)
    69  		t.mu.Unlock()
    70  		cr.CancelRequest(modReq)
    71  	}
    72  }
    73  
    74  func (t *Transport) base() http.RoundTripper {
    75  	if t.Base != nil {
    76  		return t.Base
    77  	}
    78  	return http.DefaultTransport
    79  }
    80  
    81  func (t *Transport) setModReq(orig, mod *http.Request) {
    82  	t.mu.Lock()
    83  	defer t.mu.Unlock()
    84  	if t.modReq == nil {
    85  		t.modReq = make(map[*http.Request]*http.Request)
    86  	}
    87  	if mod == nil {
    88  		delete(t.modReq, orig)
    89  	} else {
    90  		t.modReq[orig] = mod
    91  	}
    92  }
    93  
    94  // cloneRequest returns a clone of the provided *http.Request.
    95  // The clone is a shallow copy of the struct and its Header map.
    96  func cloneRequest(r *http.Request) *http.Request {
    97  	// shallow copy of the struct
    98  	r2 := new(http.Request)
    99  	*r2 = *r
   100  	// deep copy of the Header
   101  	r2.Header = make(http.Header, len(r.Header))
   102  	for k, s := range r.Header {
   103  		r2.Header[k] = append([]string(nil), s...)
   104  	}
   105  	return r2
   106  }
   107  
   108  type onEOFReader struct {
   109  	rc io.ReadCloser
   110  	fn func()
   111  }
   112  
   113  func (r *onEOFReader) Read(p []byte) (n int, err error) {
   114  	n, err = r.rc.Read(p)
   115  	if err == io.EOF {
   116  		r.runFunc()
   117  	}
   118  	return
   119  }
   120  
   121  func (r *onEOFReader) Close() error {
   122  	err := r.rc.Close()
   123  	r.runFunc()
   124  	return err
   125  }
   126  
   127  func (r *onEOFReader) runFunc() {
   128  	if fn := r.fn; fn != nil {
   129  		fn()
   130  		r.fn = nil
   131  	}
   132  }