github.com/abdfnx/gh-api@v0.0.0-20210414084727-f5432eec23b8/api/client.go (about) 1 package api 2 3 import ( 4 "bytes" 5 "encoding/json" 6 "fmt" 7 "io" 8 "io/ioutil" 9 "net/http" 10 "net/url" 11 "regexp" 12 "strings" 13 14 "github.com/abdfnx/gh-api/internal/ghinstance" 15 "github.com/henvic/httpretty" 16 "github.com/shurcooL/graphql" 17 ) 18 19 // ClientOption represents an argument to NewClient 20 type ClientOption = func(http.RoundTripper) http.RoundTripper 21 22 // NewHTTPClient initializes an http.Client 23 func NewHTTPClient(opts ...ClientOption) *http.Client { 24 tr := http.DefaultTransport 25 for _, opt := range opts { 26 tr = opt(tr) 27 } 28 return &http.Client{Transport: tr} 29 } 30 31 // NewClient initializes a Client 32 func NewClient(opts ...ClientOption) *Client { 33 client := &Client{http: NewHTTPClient(opts...)} 34 return client 35 } 36 37 // NewClientFromHTTP takes in an http.Client instance 38 func NewClientFromHTTP(httpClient *http.Client) *Client { 39 client := &Client{http: httpClient} 40 return client 41 } 42 43 // AddHeader turns a RoundTripper into one that adds a request header 44 func AddHeader(name, value string) ClientOption { 45 return func(tr http.RoundTripper) http.RoundTripper { 46 return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) { 47 if req.Header.Get(name) == "" { 48 req.Header.Add(name, value) 49 } 50 return tr.RoundTrip(req) 51 }} 52 } 53 } 54 55 // AddHeaderFunc is an AddHeader that gets the string value from a function 56 func AddHeaderFunc(name string, getValue func(*http.Request) (string, error)) ClientOption { 57 return func(tr http.RoundTripper) http.RoundTripper { 58 return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) { 59 if req.Header.Get(name) != "" { 60 return tr.RoundTrip(req) 61 } 62 value, err := getValue(req) 63 if err != nil { 64 return nil, err 65 } 66 if value != "" { 67 req.Header.Add(name, value) 68 } 69 return tr.RoundTrip(req) 70 }} 71 } 72 } 73 74 // VerboseLog enables request/response logging within a RoundTripper 75 func VerboseLog(out io.Writer, logTraffic bool, colorize bool) ClientOption { 76 logger := &httpretty.Logger{ 77 Time: true, 78 TLS: false, 79 Colors: colorize, 80 RequestHeader: logTraffic, 81 RequestBody: logTraffic, 82 ResponseHeader: logTraffic, 83 ResponseBody: logTraffic, 84 Formatters: []httpretty.Formatter{&httpretty.JSONFormatter{}}, 85 MaxResponseBody: 10000, 86 } 87 logger.SetOutput(out) 88 logger.SetBodyFilter(func(h http.Header) (skip bool, err error) { 89 return !inspectableMIMEType(h.Get("Content-Type")), nil 90 }) 91 return logger.RoundTripper 92 } 93 94 // ReplaceTripper substitutes the underlying RoundTripper with a custom one 95 func ReplaceTripper(tr http.RoundTripper) ClientOption { 96 return func(http.RoundTripper) http.RoundTripper { 97 return tr 98 } 99 } 100 101 type funcTripper struct { 102 roundTrip func(*http.Request) (*http.Response, error) 103 } 104 105 func (tr funcTripper) RoundTrip(req *http.Request) (*http.Response, error) { 106 return tr.roundTrip(req) 107 } 108 109 // Client facilitates making HTTP requests to the GitHub API 110 type Client struct { 111 http *http.Client 112 } 113 114 func (c *Client) HTTP() *http.Client { 115 return c.http 116 } 117 118 type graphQLResponse struct { 119 Data interface{} 120 Errors []GraphQLError 121 } 122 123 // GraphQLError is a single error returned in a GraphQL response 124 type GraphQLError struct { 125 Type string 126 Path []string 127 Message string 128 } 129 130 // GraphQLErrorResponse contains errors returned in a GraphQL response 131 type GraphQLErrorResponse struct { 132 Errors []GraphQLError 133 } 134 135 func (gr GraphQLErrorResponse) Error() string { 136 errorMessages := make([]string, 0, len(gr.Errors)) 137 for _, e := range gr.Errors { 138 errorMessages = append(errorMessages, e.Message) 139 } 140 return fmt.Sprintf("GraphQL error: %s", strings.Join(errorMessages, "\n")) 141 } 142 143 // HTTPError is an error returned by a failed API call 144 type HTTPError struct { 145 StatusCode int 146 RequestURL *url.URL 147 Message string 148 OAuthScopes string 149 Errors []HTTPErrorItem 150 } 151 152 type HTTPErrorItem struct { 153 Message string 154 Resource string 155 Field string 156 Code string 157 } 158 159 func (err HTTPError) Error() string { 160 if msgs := strings.SplitN(err.Message, "\n", 2); len(msgs) > 1 { 161 return fmt.Sprintf("HTTP %d: %s (%s)\n%s", err.StatusCode, msgs[0], err.RequestURL, msgs[1]) 162 } else if err.Message != "" { 163 return fmt.Sprintf("HTTP %d: %s (%s)", err.StatusCode, err.Message, err.RequestURL) 164 } 165 return fmt.Sprintf("HTTP %d (%s)", err.StatusCode, err.RequestURL) 166 } 167 168 // GraphQL performs a GraphQL request and parses the response 169 func (c Client) GraphQL(hostname string, query string, variables map[string]interface{}, data interface{}) error { 170 reqBody, err := json.Marshal(map[string]interface{}{"query": query, "variables": variables}) 171 if err != nil { 172 return err 173 } 174 175 req, err := http.NewRequest("POST", ghinstance.GraphQLEndpoint(hostname), bytes.NewBuffer(reqBody)) 176 if err != nil { 177 return err 178 } 179 180 req.Header.Set("Content-Type", "application/json; charset=utf-8") 181 182 resp, err := c.http.Do(req) 183 if err != nil { 184 return err 185 } 186 defer resp.Body.Close() 187 188 return handleResponse(resp, data) 189 } 190 191 func graphQLClient(h *http.Client, hostname string) *graphql.Client { 192 return graphql.NewClient(ghinstance.GraphQLEndpoint(hostname), h) 193 } 194 195 // REST performs a REST request and parses the response. 196 func (c Client) REST(hostname string, method string, p string, body io.Reader, data interface{}) error { 197 url := ghinstance.RESTPrefix(hostname) + p 198 req, err := http.NewRequest(method, url, body) 199 if err != nil { 200 return err 201 } 202 203 req.Header.Set("Content-Type", "application/json; charset=utf-8") 204 205 resp, err := c.http.Do(req) 206 if err != nil { 207 return err 208 } 209 defer resp.Body.Close() 210 211 success := resp.StatusCode >= 200 && resp.StatusCode < 300 212 if !success { 213 return HandleHTTPError(resp) 214 } 215 216 if resp.StatusCode == http.StatusNoContent { 217 return nil 218 } 219 220 b, err := ioutil.ReadAll(resp.Body) 221 if err != nil { 222 return err 223 } 224 225 err = json.Unmarshal(b, &data) 226 if err != nil { 227 return err 228 } 229 230 return nil 231 } 232 233 func handleResponse(resp *http.Response, data interface{}) error { 234 success := resp.StatusCode >= 200 && resp.StatusCode < 300 235 236 if !success { 237 return HandleHTTPError(resp) 238 } 239 240 body, err := ioutil.ReadAll(resp.Body) 241 if err != nil { 242 return err 243 } 244 245 gr := &graphQLResponse{Data: data} 246 err = json.Unmarshal(body, &gr) 247 if err != nil { 248 return err 249 } 250 251 if len(gr.Errors) > 0 { 252 return &GraphQLErrorResponse{Errors: gr.Errors} 253 } 254 return nil 255 } 256 257 func HandleHTTPError(resp *http.Response) error { 258 httpError := HTTPError{ 259 StatusCode: resp.StatusCode, 260 RequestURL: resp.Request.URL, 261 OAuthScopes: resp.Header.Get("X-Oauth-Scopes"), 262 } 263 264 if !jsonTypeRE.MatchString(resp.Header.Get("Content-Type")) { 265 httpError.Message = resp.Status 266 return httpError 267 } 268 269 body, err := ioutil.ReadAll(resp.Body) 270 if err != nil { 271 httpError.Message = err.Error() 272 return httpError 273 } 274 275 var parsedBody struct { 276 Message string `json:"message"` 277 Errors []json.RawMessage 278 } 279 if err := json.Unmarshal(body, &parsedBody); err != nil { 280 return httpError 281 } 282 283 messages := []string{parsedBody.Message} 284 for _, raw := range parsedBody.Errors { 285 switch raw[0] { 286 case '"': 287 var errString string 288 _ = json.Unmarshal(raw, &errString) 289 messages = append(messages, errString) 290 httpError.Errors = append(httpError.Errors, HTTPErrorItem{Message: errString}) 291 case '{': 292 var errInfo HTTPErrorItem 293 _ = json.Unmarshal(raw, &errInfo) 294 msg := errInfo.Message 295 if errInfo.Code != "custom" { 296 msg = fmt.Sprintf("%s.%s %s", errInfo.Resource, errInfo.Field, errorCodeToMessage(errInfo.Code)) 297 } 298 if msg != "" { 299 messages = append(messages, msg) 300 } 301 httpError.Errors = append(httpError.Errors, errInfo) 302 } 303 } 304 httpError.Message = strings.Join(messages, "\n") 305 306 return httpError 307 } 308 309 func errorCodeToMessage(code string) string { 310 // https://docs.github.com/en/rest/overview/resources-in-the-rest-api#client-errors 311 switch code { 312 case "missing", "missing_field": 313 return "is missing" 314 case "invalid", "unprocessable": 315 return "is invalid" 316 case "already_exists": 317 return "already exists" 318 default: 319 return code 320 } 321 } 322 323 var jsonTypeRE = regexp.MustCompile(`[/+]json($|;)`) 324 325 func inspectableMIMEType(t string) bool { 326 return strings.HasPrefix(t, "text/") || jsonTypeRE.MatchString(t) 327 }