github.com/cli/cli@v1.14.1-0.20210902173923-1af6a669e342/api/client.go (about)

     1  package api
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	"net/http"
    10  	"net/url"
    11  	"regexp"
    12  	"strings"
    13  
    14  	"github.com/cli/cli/internal/ghinstance"
    15  	"github.com/henvic/httpretty"
    16  	"github.com/shurcooL/graphql"
    17  )
    18  
    19  // ClientOption represents an argument to NewClient
    20  type ClientOption = func(http.RoundTripper) http.RoundTripper
    21  
    22  // NewHTTPClient initializes an http.Client
    23  func NewHTTPClient(opts ...ClientOption) *http.Client {
    24  	tr := http.DefaultTransport
    25  	for _, opt := range opts {
    26  		tr = opt(tr)
    27  	}
    28  	return &http.Client{Transport: tr}
    29  }
    30  
    31  // NewClient initializes a Client
    32  func NewClient(opts ...ClientOption) *Client {
    33  	client := &Client{http: NewHTTPClient(opts...)}
    34  	return client
    35  }
    36  
    37  // NewClientFromHTTP takes in an http.Client instance
    38  func NewClientFromHTTP(httpClient *http.Client) *Client {
    39  	client := &Client{http: httpClient}
    40  	return client
    41  }
    42  
    43  // AddHeader turns a RoundTripper into one that adds a request header
    44  func AddHeader(name, value string) ClientOption {
    45  	return func(tr http.RoundTripper) http.RoundTripper {
    46  		return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) {
    47  			if req.Header.Get(name) == "" {
    48  				req.Header.Add(name, value)
    49  			}
    50  			return tr.RoundTrip(req)
    51  		}}
    52  	}
    53  }
    54  
    55  // AddHeaderFunc is an AddHeader that gets the string value from a function
    56  func AddHeaderFunc(name string, getValue func(*http.Request) (string, error)) ClientOption {
    57  	return func(tr http.RoundTripper) http.RoundTripper {
    58  		return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) {
    59  			if req.Header.Get(name) != "" {
    60  				return tr.RoundTrip(req)
    61  			}
    62  			value, err := getValue(req)
    63  			if err != nil {
    64  				return nil, err
    65  			}
    66  			if value != "" {
    67  				req.Header.Add(name, value)
    68  			}
    69  			return tr.RoundTrip(req)
    70  		}}
    71  	}
    72  }
    73  
    74  // VerboseLog enables request/response logging within a RoundTripper
    75  func VerboseLog(out io.Writer, logTraffic bool, colorize bool) ClientOption {
    76  	logger := &httpretty.Logger{
    77  		Time:            true,
    78  		TLS:             false,
    79  		Colors:          colorize,
    80  		RequestHeader:   logTraffic,
    81  		RequestBody:     logTraffic,
    82  		ResponseHeader:  logTraffic,
    83  		ResponseBody:    logTraffic,
    84  		Formatters:      []httpretty.Formatter{&httpretty.JSONFormatter{}},
    85  		MaxResponseBody: 10000,
    86  	}
    87  	logger.SetOutput(out)
    88  	logger.SetBodyFilter(func(h http.Header) (skip bool, err error) {
    89  		return !inspectableMIMEType(h.Get("Content-Type")), nil
    90  	})
    91  	return logger.RoundTripper
    92  }
    93  
    94  // ReplaceTripper substitutes the underlying RoundTripper with a custom one
    95  func ReplaceTripper(tr http.RoundTripper) ClientOption {
    96  	return func(http.RoundTripper) http.RoundTripper {
    97  		return tr
    98  	}
    99  }
   100  
   101  type funcTripper struct {
   102  	roundTrip func(*http.Request) (*http.Response, error)
   103  }
   104  
   105  func (tr funcTripper) RoundTrip(req *http.Request) (*http.Response, error) {
   106  	return tr.roundTrip(req)
   107  }
   108  
   109  // Client facilitates making HTTP requests to the GitHub API
   110  type Client struct {
   111  	http *http.Client
   112  }
   113  
   114  func (c *Client) HTTP() *http.Client {
   115  	return c.http
   116  }
   117  
   118  type graphQLResponse struct {
   119  	Data   interface{}
   120  	Errors []GraphQLError
   121  }
   122  
   123  // GraphQLError is a single error returned in a GraphQL response
   124  type GraphQLError struct {
   125  	Type    string
   126  	Message string
   127  	// Path []interface // mixed strings and numbers
   128  }
   129  
   130  // GraphQLErrorResponse contains errors returned in a GraphQL response
   131  type GraphQLErrorResponse struct {
   132  	Errors []GraphQLError
   133  }
   134  
   135  func (gr GraphQLErrorResponse) Error() string {
   136  	errorMessages := make([]string, 0, len(gr.Errors))
   137  	for _, e := range gr.Errors {
   138  		errorMessages = append(errorMessages, e.Message)
   139  	}
   140  	return fmt.Sprintf("GraphQL error: %s", strings.Join(errorMessages, "\n"))
   141  }
   142  
   143  // HTTPError is an error returned by a failed API call
   144  type HTTPError struct {
   145  	StatusCode  int
   146  	RequestURL  *url.URL
   147  	Message     string
   148  	OAuthScopes string
   149  	Errors      []HTTPErrorItem
   150  }
   151  
   152  type HTTPErrorItem struct {
   153  	Message  string
   154  	Resource string
   155  	Field    string
   156  	Code     string
   157  }
   158  
   159  func (err HTTPError) Error() string {
   160  	if msgs := strings.SplitN(err.Message, "\n", 2); len(msgs) > 1 {
   161  		return fmt.Sprintf("HTTP %d: %s (%s)\n%s", err.StatusCode, msgs[0], err.RequestURL, msgs[1])
   162  	} else if err.Message != "" {
   163  		return fmt.Sprintf("HTTP %d: %s (%s)", err.StatusCode, err.Message, err.RequestURL)
   164  	}
   165  	return fmt.Sprintf("HTTP %d (%s)", err.StatusCode, err.RequestURL)
   166  }
   167  
   168  // GraphQL performs a GraphQL request and parses the response
   169  func (c Client) GraphQL(hostname string, query string, variables map[string]interface{}, data interface{}) error {
   170  	reqBody, err := json.Marshal(map[string]interface{}{"query": query, "variables": variables})
   171  	if err != nil {
   172  		return err
   173  	}
   174  
   175  	req, err := http.NewRequest("POST", ghinstance.GraphQLEndpoint(hostname), bytes.NewBuffer(reqBody))
   176  	if err != nil {
   177  		return err
   178  	}
   179  
   180  	req.Header.Set("Content-Type", "application/json; charset=utf-8")
   181  
   182  	resp, err := c.http.Do(req)
   183  	if err != nil {
   184  		return err
   185  	}
   186  	defer resp.Body.Close()
   187  
   188  	return handleResponse(resp, data)
   189  }
   190  
   191  func graphQLClient(h *http.Client, hostname string) *graphql.Client {
   192  	return graphql.NewClient(ghinstance.GraphQLEndpoint(hostname), h)
   193  }
   194  
   195  // REST performs a REST request and parses the response.
   196  func (c Client) REST(hostname string, method string, p string, body io.Reader, data interface{}) error {
   197  	req, err := http.NewRequest(method, restURL(hostname, p), body)
   198  	if err != nil {
   199  		return err
   200  	}
   201  
   202  	req.Header.Set("Content-Type", "application/json; charset=utf-8")
   203  
   204  	resp, err := c.http.Do(req)
   205  	if err != nil {
   206  		return err
   207  	}
   208  	defer resp.Body.Close()
   209  
   210  	success := resp.StatusCode >= 200 && resp.StatusCode < 300
   211  	if !success {
   212  		return HandleHTTPError(resp)
   213  	}
   214  
   215  	if resp.StatusCode == http.StatusNoContent {
   216  		return nil
   217  	}
   218  
   219  	b, err := ioutil.ReadAll(resp.Body)
   220  	if err != nil {
   221  		return err
   222  	}
   223  	err = json.Unmarshal(b, &data)
   224  	if err != nil {
   225  		return err
   226  	}
   227  
   228  	return nil
   229  }
   230  
   231  func restURL(hostname string, pathOrURL string) string {
   232  	if strings.HasPrefix(pathOrURL, "https://") || strings.HasPrefix(pathOrURL, "http://") {
   233  		return pathOrURL
   234  	}
   235  	return ghinstance.RESTPrefix(hostname) + pathOrURL
   236  }
   237  
   238  func handleResponse(resp *http.Response, data interface{}) error {
   239  	success := resp.StatusCode >= 200 && resp.StatusCode < 300
   240  
   241  	if !success {
   242  		return HandleHTTPError(resp)
   243  	}
   244  
   245  	body, err := ioutil.ReadAll(resp.Body)
   246  	if err != nil {
   247  		return err
   248  	}
   249  
   250  	gr := &graphQLResponse{Data: data}
   251  	err = json.Unmarshal(body, &gr)
   252  	if err != nil {
   253  		return err
   254  	}
   255  
   256  	if len(gr.Errors) > 0 {
   257  		return &GraphQLErrorResponse{Errors: gr.Errors}
   258  	}
   259  	return nil
   260  }
   261  
   262  func HandleHTTPError(resp *http.Response) error {
   263  	httpError := HTTPError{
   264  		StatusCode:  resp.StatusCode,
   265  		RequestURL:  resp.Request.URL,
   266  		OAuthScopes: resp.Header.Get("X-Oauth-Scopes"),
   267  	}
   268  
   269  	if !jsonTypeRE.MatchString(resp.Header.Get("Content-Type")) {
   270  		httpError.Message = resp.Status
   271  		return httpError
   272  	}
   273  
   274  	body, err := ioutil.ReadAll(resp.Body)
   275  	if err != nil {
   276  		httpError.Message = err.Error()
   277  		return httpError
   278  	}
   279  
   280  	var parsedBody struct {
   281  		Message string `json:"message"`
   282  		Errors  []json.RawMessage
   283  	}
   284  	if err := json.Unmarshal(body, &parsedBody); err != nil {
   285  		return httpError
   286  	}
   287  
   288  	var messages []string
   289  	if parsedBody.Message != "" {
   290  		messages = append(messages, parsedBody.Message)
   291  	}
   292  	for _, raw := range parsedBody.Errors {
   293  		switch raw[0] {
   294  		case '"':
   295  			var errString string
   296  			_ = json.Unmarshal(raw, &errString)
   297  			messages = append(messages, errString)
   298  			httpError.Errors = append(httpError.Errors, HTTPErrorItem{Message: errString})
   299  		case '{':
   300  			var errInfo HTTPErrorItem
   301  			_ = json.Unmarshal(raw, &errInfo)
   302  			msg := errInfo.Message
   303  			if errInfo.Code != "" && errInfo.Code != "custom" {
   304  				msg = fmt.Sprintf("%s.%s %s", errInfo.Resource, errInfo.Field, errorCodeToMessage(errInfo.Code))
   305  			}
   306  			if msg != "" {
   307  				messages = append(messages, msg)
   308  			}
   309  			httpError.Errors = append(httpError.Errors, errInfo)
   310  		}
   311  	}
   312  	httpError.Message = strings.Join(messages, "\n")
   313  
   314  	return httpError
   315  }
   316  
   317  func errorCodeToMessage(code string) string {
   318  	// https://docs.github.com/en/rest/overview/resources-in-the-rest-api#client-errors
   319  	switch code {
   320  	case "missing", "missing_field":
   321  		return "is missing"
   322  	case "invalid", "unprocessable":
   323  		return "is invalid"
   324  	case "already_exists":
   325  		return "already exists"
   326  	default:
   327  		return code
   328  	}
   329  }
   330  
   331  var jsonTypeRE = regexp.MustCompile(`[/+]json($|;)`)
   332  
   333  func inspectableMIMEType(t string) bool {
   334  	return strings.HasPrefix(t, "text/") || jsonTypeRE.MatchString(t)
   335  }