github.com/abdfnx/gh-api@v0.0.0-20210414084727-f5432eec23b8/api/client.go (about)

     1  package api
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	"net/http"
    10  	"net/url"
    11  	"regexp"
    12  	"strings"
    13  
    14  	"github.com/abdfnx/gh-api/internal/ghinstance"
    15  	"github.com/henvic/httpretty"
    16  	"github.com/shurcooL/graphql"
    17  )
    18  
    19  // ClientOption represents an argument to NewClient
    20  type ClientOption = func(http.RoundTripper) http.RoundTripper
    21  
    22  // NewHTTPClient initializes an http.Client
    23  func NewHTTPClient(opts ...ClientOption) *http.Client {
    24  	tr := http.DefaultTransport
    25  	for _, opt := range opts {
    26  		tr = opt(tr)
    27  	}
    28  	return &http.Client{Transport: tr}
    29  }
    30  
    31  // NewClient initializes a Client
    32  func NewClient(opts ...ClientOption) *Client {
    33  	client := &Client{http: NewHTTPClient(opts...)}
    34  	return client
    35  }
    36  
    37  // NewClientFromHTTP takes in an http.Client instance
    38  func NewClientFromHTTP(httpClient *http.Client) *Client {
    39  	client := &Client{http: httpClient}
    40  	return client
    41  }
    42  
    43  // AddHeader turns a RoundTripper into one that adds a request header
    44  func AddHeader(name, value string) ClientOption {
    45  	return func(tr http.RoundTripper) http.RoundTripper {
    46  		return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) {
    47  			if req.Header.Get(name) == "" {
    48  				req.Header.Add(name, value)
    49  			}
    50  			return tr.RoundTrip(req)
    51  		}}
    52  	}
    53  }
    54  
    55  // AddHeaderFunc is an AddHeader that gets the string value from a function
    56  func AddHeaderFunc(name string, getValue func(*http.Request) (string, error)) ClientOption {
    57  	return func(tr http.RoundTripper) http.RoundTripper {
    58  		return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) {
    59  			if req.Header.Get(name) != "" {
    60  				return tr.RoundTrip(req)
    61  			}
    62  			value, err := getValue(req)
    63  			if err != nil {
    64  				return nil, err
    65  			}
    66  			if value != "" {
    67  				req.Header.Add(name, value)
    68  			}
    69  			return tr.RoundTrip(req)
    70  		}}
    71  	}
    72  }
    73  
    74  // VerboseLog enables request/response logging within a RoundTripper
    75  func VerboseLog(out io.Writer, logTraffic bool, colorize bool) ClientOption {
    76  	logger := &httpretty.Logger{
    77  		Time:            true,
    78  		TLS:             false,
    79  		Colors:          colorize,
    80  		RequestHeader:   logTraffic,
    81  		RequestBody:     logTraffic,
    82  		ResponseHeader:  logTraffic,
    83  		ResponseBody:    logTraffic,
    84  		Formatters:      []httpretty.Formatter{&httpretty.JSONFormatter{}},
    85  		MaxResponseBody: 10000,
    86  	}
    87  	logger.SetOutput(out)
    88  	logger.SetBodyFilter(func(h http.Header) (skip bool, err error) {
    89  		return !inspectableMIMEType(h.Get("Content-Type")), nil
    90  	})
    91  	return logger.RoundTripper
    92  }
    93  
    94  // ReplaceTripper substitutes the underlying RoundTripper with a custom one
    95  func ReplaceTripper(tr http.RoundTripper) ClientOption {
    96  	return func(http.RoundTripper) http.RoundTripper {
    97  		return tr
    98  	}
    99  }
   100  
   101  type funcTripper struct {
   102  	roundTrip func(*http.Request) (*http.Response, error)
   103  }
   104  
   105  func (tr funcTripper) RoundTrip(req *http.Request) (*http.Response, error) {
   106  	return tr.roundTrip(req)
   107  }
   108  
   109  // Client facilitates making HTTP requests to the GitHub API
   110  type Client struct {
   111  	http *http.Client
   112  }
   113  
   114  func (c *Client) HTTP() *http.Client {
   115  	return c.http
   116  }
   117  
   118  type graphQLResponse struct {
   119  	Data   interface{}
   120  	Errors []GraphQLError
   121  }
   122  
   123  // GraphQLError is a single error returned in a GraphQL response
   124  type GraphQLError struct {
   125  	Type    string
   126  	Path    []string
   127  	Message string
   128  }
   129  
   130  // GraphQLErrorResponse contains errors returned in a GraphQL response
   131  type GraphQLErrorResponse struct {
   132  	Errors []GraphQLError
   133  }
   134  
   135  func (gr GraphQLErrorResponse) Error() string {
   136  	errorMessages := make([]string, 0, len(gr.Errors))
   137  	for _, e := range gr.Errors {
   138  		errorMessages = append(errorMessages, e.Message)
   139  	}
   140  	return fmt.Sprintf("GraphQL error: %s", strings.Join(errorMessages, "\n"))
   141  }
   142  
   143  // HTTPError is an error returned by a failed API call
   144  type HTTPError struct {
   145  	StatusCode  int
   146  	RequestURL  *url.URL
   147  	Message     string
   148  	OAuthScopes string
   149  	Errors      []HTTPErrorItem
   150  }
   151  
   152  type HTTPErrorItem struct {
   153  	Message  string
   154  	Resource string
   155  	Field    string
   156  	Code     string
   157  }
   158  
   159  func (err HTTPError) Error() string {
   160  	if msgs := strings.SplitN(err.Message, "\n", 2); len(msgs) > 1 {
   161  		return fmt.Sprintf("HTTP %d: %s (%s)\n%s", err.StatusCode, msgs[0], err.RequestURL, msgs[1])
   162  	} else if err.Message != "" {
   163  		return fmt.Sprintf("HTTP %d: %s (%s)", err.StatusCode, err.Message, err.RequestURL)
   164  	}
   165  	return fmt.Sprintf("HTTP %d (%s)", err.StatusCode, err.RequestURL)
   166  }
   167  
   168  // GraphQL performs a GraphQL request and parses the response
   169  func (c Client) GraphQL(hostname string, query string, variables map[string]interface{}, data interface{}) error {
   170  	reqBody, err := json.Marshal(map[string]interface{}{"query": query, "variables": variables})
   171  	if err != nil {
   172  		return err
   173  	}
   174  
   175  	req, err := http.NewRequest("POST", ghinstance.GraphQLEndpoint(hostname), bytes.NewBuffer(reqBody))
   176  	if err != nil {
   177  		return err
   178  	}
   179  
   180  	req.Header.Set("Content-Type", "application/json; charset=utf-8")
   181  
   182  	resp, err := c.http.Do(req)
   183  	if err != nil {
   184  		return err
   185  	}
   186  	defer resp.Body.Close()
   187  
   188  	return handleResponse(resp, data)
   189  }
   190  
   191  func graphQLClient(h *http.Client, hostname string) *graphql.Client {
   192  	return graphql.NewClient(ghinstance.GraphQLEndpoint(hostname), h)
   193  }
   194  
   195  // REST performs a REST request and parses the response.
   196  func (c Client) REST(hostname string, method string, p string, body io.Reader, data interface{}) error {
   197  	url := ghinstance.RESTPrefix(hostname) + p
   198  	req, err := http.NewRequest(method, url, body)
   199  	if err != nil {
   200  		return err
   201  	}
   202  
   203  	req.Header.Set("Content-Type", "application/json; charset=utf-8")
   204  
   205  	resp, err := c.http.Do(req)
   206  	if err != nil {
   207  		return err
   208  	}
   209  	defer resp.Body.Close()
   210  
   211  	success := resp.StatusCode >= 200 && resp.StatusCode < 300
   212  	if !success {
   213  		return HandleHTTPError(resp)
   214  	}
   215  
   216  	if resp.StatusCode == http.StatusNoContent {
   217  		return nil
   218  	}
   219  
   220  	b, err := ioutil.ReadAll(resp.Body)
   221  	if err != nil {
   222  		return err
   223  	}
   224  
   225  	err = json.Unmarshal(b, &data)
   226  	if err != nil {
   227  		return err
   228  	}
   229  
   230  	return nil
   231  }
   232  
   233  func handleResponse(resp *http.Response, data interface{}) error {
   234  	success := resp.StatusCode >= 200 && resp.StatusCode < 300
   235  
   236  	if !success {
   237  		return HandleHTTPError(resp)
   238  	}
   239  
   240  	body, err := ioutil.ReadAll(resp.Body)
   241  	if err != nil {
   242  		return err
   243  	}
   244  
   245  	gr := &graphQLResponse{Data: data}
   246  	err = json.Unmarshal(body, &gr)
   247  	if err != nil {
   248  		return err
   249  	}
   250  
   251  	if len(gr.Errors) > 0 {
   252  		return &GraphQLErrorResponse{Errors: gr.Errors}
   253  	}
   254  	return nil
   255  }
   256  
   257  func HandleHTTPError(resp *http.Response) error {
   258  	httpError := HTTPError{
   259  		StatusCode:  resp.StatusCode,
   260  		RequestURL:  resp.Request.URL,
   261  		OAuthScopes: resp.Header.Get("X-Oauth-Scopes"),
   262  	}
   263  
   264  	if !jsonTypeRE.MatchString(resp.Header.Get("Content-Type")) {
   265  		httpError.Message = resp.Status
   266  		return httpError
   267  	}
   268  
   269  	body, err := ioutil.ReadAll(resp.Body)
   270  	if err != nil {
   271  		httpError.Message = err.Error()
   272  		return httpError
   273  	}
   274  
   275  	var parsedBody struct {
   276  		Message string `json:"message"`
   277  		Errors  []json.RawMessage
   278  	}
   279  	if err := json.Unmarshal(body, &parsedBody); err != nil {
   280  		return httpError
   281  	}
   282  
   283  	messages := []string{parsedBody.Message}
   284  	for _, raw := range parsedBody.Errors {
   285  		switch raw[0] {
   286  		case '"':
   287  			var errString string
   288  			_ = json.Unmarshal(raw, &errString)
   289  			messages = append(messages, errString)
   290  			httpError.Errors = append(httpError.Errors, HTTPErrorItem{Message: errString})
   291  		case '{':
   292  			var errInfo HTTPErrorItem
   293  			_ = json.Unmarshal(raw, &errInfo)
   294  			msg := errInfo.Message
   295  			if errInfo.Code != "custom" {
   296  				msg = fmt.Sprintf("%s.%s %s", errInfo.Resource, errInfo.Field, errorCodeToMessage(errInfo.Code))
   297  			}
   298  			if msg != "" {
   299  				messages = append(messages, msg)
   300  			}
   301  			httpError.Errors = append(httpError.Errors, errInfo)
   302  		}
   303  	}
   304  	httpError.Message = strings.Join(messages, "\n")
   305  
   306  	return httpError
   307  }
   308  
   309  func errorCodeToMessage(code string) string {
   310  	// https://docs.github.com/en/rest/overview/resources-in-the-rest-api#client-errors
   311  	switch code {
   312  	case "missing", "missing_field":
   313  		return "is missing"
   314  	case "invalid", "unprocessable":
   315  		return "is invalid"
   316  	case "already_exists":
   317  		return "already exists"
   318  	default:
   319  		return code
   320  	}
   321  }
   322  
   323  var jsonTypeRE = regexp.MustCompile(`[/+]json($|;)`)
   324  
   325  func inspectableMIMEType(t string) bool {
   326  	return strings.HasPrefix(t, "text/") || jsonTypeRE.MatchString(t)
   327  }