github.com/ungtb10d/cli/v2@v2.0.0-20221110210412-98537dd9d6a1/api/http_client.go (about)

     1  package api
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"net/http"
     7  	"strings"
     8  	"time"
     9  
    10  	"github.com/ungtb10d/cli/v2/internal/ghinstance"
    11  	"github.com/ungtb10d/cli/v2/utils"
    12  	"github.com/cli/go-gh"
    13  	ghAPI "github.com/cli/go-gh/pkg/api"
    14  )
    15  
    16  type tokenGetter interface {
    17  	AuthToken(string) (string, string)
    18  }
    19  
    20  type HTTPClientOptions struct {
    21  	AppVersion        string
    22  	CacheTTL          time.Duration
    23  	Config            tokenGetter
    24  	EnableCache       bool
    25  	Log               io.Writer
    26  	LogColorize       bool
    27  	SkipAcceptHeaders bool
    28  }
    29  
    30  func NewHTTPClient(opts HTTPClientOptions) (*http.Client, error) {
    31  	// Provide invalid host, and token values so gh.HTTPClient will not automatically resolve them.
    32  	// The real host and token are inserted at request time.
    33  	clientOpts := ghAPI.ClientOptions{
    34  		Host:         "none",
    35  		AuthToken:    "none",
    36  		LogIgnoreEnv: true,
    37  	}
    38  
    39  	if debugEnabled, debugValue := utils.IsDebugEnabled(); debugEnabled {
    40  		clientOpts.Log = opts.Log
    41  		clientOpts.LogColorize = opts.LogColorize
    42  		clientOpts.LogVerboseHTTP = strings.Contains(debugValue, "api")
    43  	}
    44  
    45  	headers := map[string]string{
    46  		userAgent: fmt.Sprintf("GitHub CLI %s", opts.AppVersion),
    47  	}
    48  	if opts.SkipAcceptHeaders {
    49  		headers[accept] = ""
    50  	}
    51  	clientOpts.Headers = headers
    52  
    53  	if opts.EnableCache {
    54  		clientOpts.EnableCache = opts.EnableCache
    55  		clientOpts.CacheTTL = opts.CacheTTL
    56  	}
    57  
    58  	client, err := gh.HTTPClient(&clientOpts)
    59  	if err != nil {
    60  		return nil, err
    61  	}
    62  
    63  	if opts.Config != nil {
    64  		client.Transport = AddAuthTokenHeader(client.Transport, opts.Config)
    65  	}
    66  
    67  	return client, nil
    68  }
    69  
    70  func NewCachedHTTPClient(httpClient *http.Client, ttl time.Duration) *http.Client {
    71  	newClient := *httpClient
    72  	newClient.Transport = AddCacheTTLHeader(httpClient.Transport, ttl)
    73  	return &newClient
    74  }
    75  
    76  // AddCacheTTLHeader adds an header to the request telling the cache that the request
    77  // should be cached for a specified amount of time.
    78  func AddCacheTTLHeader(rt http.RoundTripper, ttl time.Duration) http.RoundTripper {
    79  	return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) {
    80  		// If the header is already set in the request, don't overwrite it.
    81  		if req.Header.Get(cacheTTL) == "" {
    82  			req.Header.Set(cacheTTL, ttl.String())
    83  		}
    84  		return rt.RoundTrip(req)
    85  	}}
    86  }
    87  
    88  // AddAuthToken adds an authentication token header for the host specified by the request.
    89  func AddAuthTokenHeader(rt http.RoundTripper, cfg tokenGetter) http.RoundTripper {
    90  	return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) {
    91  		// If the header is already set in the request, don't overwrite it.
    92  		if req.Header.Get(authorization) == "" {
    93  			hostname := ghinstance.NormalizeHostname(getHost(req))
    94  			if token, _ := cfg.AuthToken(hostname); token != "" {
    95  				req.Header.Set(authorization, fmt.Sprintf("token %s", token))
    96  			}
    97  		}
    98  		return rt.RoundTrip(req)
    99  	}}
   100  }
   101  
   102  // ExtractHeader extracts a named header from any response received by this client and,
   103  // if non-blank, saves it to dest.
   104  func ExtractHeader(name string, dest *string) func(http.RoundTripper) http.RoundTripper {
   105  	return func(tr http.RoundTripper) http.RoundTripper {
   106  		return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) {
   107  			res, err := tr.RoundTrip(req)
   108  			if err == nil {
   109  				if value := res.Header.Get(name); value != "" {
   110  					*dest = value
   111  				}
   112  			}
   113  			return res, err
   114  		}}
   115  	}
   116  }
   117  
   118  type funcTripper struct {
   119  	roundTrip func(*http.Request) (*http.Response, error)
   120  }
   121  
   122  func (tr funcTripper) RoundTrip(req *http.Request) (*http.Response, error) {
   123  	return tr.roundTrip(req)
   124  }
   125  
   126  func getHost(r *http.Request) string {
   127  	if r.Host != "" {
   128  		return r.Host
   129  	}
   130  	return r.URL.Hostname()
   131  }