github.com/versent/saml2aws@v2.17.0+incompatible/pkg/provider/http.go (about)

     1  package provider
     2  
     3  import (
     4  	"crypto/tls"
     5  	"fmt"
     6  	"net"
     7  	"net/http"
     8  	"os"
     9  	"runtime"
    10  	"time"
    11  
    12  	"github.com/sirupsen/logrus"
    13  	"github.com/versent/saml2aws/pkg/cookiejar"
    14  	"github.com/versent/saml2aws/pkg/dump"
    15  
    16  	"github.com/briandowns/spinner"
    17  	"github.com/mattn/go-isatty"
    18  	"github.com/pkg/errors"
    19  	"golang.org/x/net/publicsuffix"
    20  )
    21  
    22  // HTTPClient saml2aws http client which extends the existing client
    23  type HTTPClient struct {
    24  	http.Client
    25  	CheckResponseStatus func(*http.Request, *http.Response) error
    26  }
    27  
    28  // NewDefaultTransport configure a transport with the TLS skip verify option
    29  func NewDefaultTransport(skipVerify bool) *http.Transport {
    30  	return &http.Transport{
    31  		Proxy: http.ProxyFromEnvironment,
    32  		DialContext: (&net.Dialer{
    33  			Timeout:   30 * time.Second,
    34  			KeepAlive: 30 * time.Second,
    35  			DualStack: true,
    36  		}).DialContext,
    37  		MaxIdleConns:          100,
    38  		IdleConnTimeout:       90 * time.Second,
    39  		TLSHandshakeTimeout:   10 * time.Second,
    40  		ExpectContinueTimeout: 1 * time.Second,
    41  		TLSClientConfig:       &tls.Config{InsecureSkipVerify: skipVerify},
    42  	}
    43  }
    44  
    45  // NewHTTPClient configure the default http client used by the providers
    46  func NewHTTPClient(tr http.RoundTripper) (*HTTPClient, error) {
    47  
    48  	options := &cookiejar.Options{
    49  		PublicSuffixList: publicsuffix.List,
    50  	}
    51  
    52  	jar, err := cookiejar.New(options)
    53  	if err != nil {
    54  		return nil, err
    55  	}
    56  
    57  	client := http.Client{Transport: tr, Jar: jar}
    58  
    59  	return &HTTPClient{client, nil}, nil
    60  }
    61  
    62  // Do do the request
    63  func (hc *HTTPClient) Do(req *http.Request) (*http.Response, error) {
    64  
    65  	if isatty.IsTerminal(os.Stdout.Fd()) {
    66  		cs := spinner.CharSets[14]
    67  
    68  		// use a NON unicode spinner for windows
    69  		if runtime.GOOS == "windows" {
    70  			cs = spinner.CharSets[26]
    71  		}
    72  
    73  		if logrus.GetLevel() != logrus.DebugLevel {
    74  			s := spinner.New(cs, 100*time.Millisecond)
    75  			defer func() {
    76  				s.Stop()
    77  			}()
    78  			s.Start()
    79  		}
    80  	}
    81  
    82  	req.Header.Set("User-Agent", fmt.Sprintf("saml2aws/1.0 (%s %s) Versent", runtime.GOOS, runtime.GOARCH))
    83  
    84  	hc.logHTTPRequest(req)
    85  
    86  	resp, err := hc.Client.Do(req)
    87  	if err != nil {
    88  		return resp, err
    89  	}
    90  
    91  	// if a response check has been configured
    92  	if hc.CheckResponseStatus != nil {
    93  		err = hc.CheckResponseStatus(req, resp)
    94  		if err != nil {
    95  			return resp, err
    96  		}
    97  	}
    98  
    99  	hc.logHTTPResponse(resp)
   100  
   101  	return resp, err
   102  }
   103  
   104  // DisableFollowRedirect disable redirects
   105  func (hc *HTTPClient) DisableFollowRedirect() {
   106  	hc.CheckRedirect = func(req *http.Request, via []*http.Request) error {
   107  		return http.ErrUseLastResponse
   108  	}
   109  }
   110  
   111  // EnableFollowRedirect enable redirects
   112  func (hc *HTTPClient) EnableFollowRedirect() {
   113  	hc.CheckRedirect = nil
   114  }
   115  
   116  // SuccessOrRedirectResponseValidator this validates the response code is within range of 200 - 399
   117  func SuccessOrRedirectResponseValidator(req *http.Request, resp *http.Response) error {
   118  	if resp.StatusCode >= 200 && resp.StatusCode < 400 {
   119  		return nil
   120  	}
   121  
   122  	return errors.Errorf("request for url: %s failed status: %s", req.URL.String(), resp.Status)
   123  }
   124  
   125  func (hc *HTTPClient) logHTTPRequest(req *http.Request) {
   126  
   127  	if dump.ContentEnable() {
   128  		fmt.Println(dump.RequestString(req))
   129  		return
   130  	}
   131  
   132  	logrus.WithField("http", "client").WithFields(logrus.Fields{
   133  		"URL":    req.URL.String(),
   134  		"method": req.Method,
   135  	}).Debug("HTTP Req")
   136  }
   137  
   138  func (hc *HTTPClient) logHTTPResponse(resp *http.Response) {
   139  
   140  	if dump.ContentEnable() {
   141  		fmt.Println(dump.ResponseString(resp))
   142  		return
   143  	}
   144  
   145  	logrus.WithField("http", "client").WithFields(logrus.Fields{
   146  		"Status": resp.Status,
   147  	}).Debug("HTTP Res")
   148  }