github.com/pusher/oauth2_proxy@v3.2.0+incompatible/providers/internal_util.go (about)

     1  package providers
     2  
     3  import (
     4  	"io/ioutil"
     5  	"log"
     6  	"net/http"
     7  	"net/url"
     8  
     9  	"github.com/pusher/oauth2_proxy/api"
    10  )
    11  
    12  // stripToken is a helper function to obfuscate "access_token"
    13  // query parameters
    14  func stripToken(endpoint string) string {
    15  	return stripParam("access_token", endpoint)
    16  }
    17  
    18  // stripParam generalizes the obfuscation of a particular
    19  // query parameter - typically 'access_token' or 'client_secret'
    20  // The parameter's second half is replaced by '...' and returned
    21  // as part of the encoded query parameters.
    22  // If the target parameter isn't found, the endpoint is returned
    23  // unmodified.
    24  func stripParam(param, endpoint string) string {
    25  	u, err := url.Parse(endpoint)
    26  	if err != nil {
    27  		log.Printf("error attempting to strip %s: %s", param, err)
    28  		return endpoint
    29  	}
    30  
    31  	if u.RawQuery != "" {
    32  		values, err := url.ParseQuery(u.RawQuery)
    33  		if err != nil {
    34  			log.Printf("error attempting to strip %s: %s", param, err)
    35  			return u.String()
    36  		}
    37  
    38  		if val := values.Get(param); val != "" {
    39  			values.Set(param, val[:(len(val)/2)]+"...")
    40  			u.RawQuery = values.Encode()
    41  			return u.String()
    42  		}
    43  	}
    44  
    45  	return endpoint
    46  }
    47  
    48  // validateToken returns true if token is valid
    49  func validateToken(p Provider, accessToken string, header http.Header) bool {
    50  	if accessToken == "" || p.Data().ValidateURL == nil {
    51  		return false
    52  	}
    53  	endpoint := p.Data().ValidateURL.String()
    54  	if len(header) == 0 {
    55  		params := url.Values{"access_token": {accessToken}}
    56  		endpoint = endpoint + "?" + params.Encode()
    57  	}
    58  	resp, err := api.RequestUnparsedResponse(endpoint, header)
    59  	if err != nil {
    60  		log.Printf("GET %s", stripToken(endpoint))
    61  		log.Printf("token validation request failed: %s", err)
    62  		return false
    63  	}
    64  
    65  	body, _ := ioutil.ReadAll(resp.Body)
    66  	resp.Body.Close()
    67  	log.Printf("%d GET %s %s", resp.StatusCode, stripToken(endpoint), body)
    68  
    69  	if resp.StatusCode == 200 {
    70  		return true
    71  	}
    72  	log.Printf("token validation request failed: status %d - %s", resp.StatusCode, body)
    73  	return false
    74  }