go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/server/auth/internal/fetch.go (about)

     1  // Copyright 2015 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package internal
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"encoding/json"
    21  	"fmt"
    22  	"io"
    23  	"net/http"
    24  
    25  	"golang.org/x/net/context/ctxhttp"
    26  
    27  	"go.chromium.org/luci/common/logging"
    28  	"go.chromium.org/luci/common/retry/transient"
    29  )
    30  
    31  // ClientFactory knows how to produce http.Client that attach proper OAuth
    32  // headers.
    33  //
    34  // If 'scopes' is empty, the factory should return a client that makes anonymous
    35  // requests.
    36  type ClientFactory func(ctx context.Context, scopes []string) (*http.Client, error)
    37  
    38  var clientFactory ClientFactory
    39  
    40  // RegisterClientFactory allows external module to provide implementation of
    41  // the ClientFactory.
    42  //
    43  // This is needed to resolve module dependency cycle between server/auth and
    44  // server/auth/internal.
    45  //
    46  // See init() in server/auth/client.go.
    47  //
    48  // If client factory is not set, Do(...) uses http.DefaultClient. This happens
    49  // in unit tests for various auth/* subpackages.
    50  func RegisterClientFactory(f ClientFactory) {
    51  	if clientFactory != nil {
    52  		panic("ClientFactory is already registered")
    53  	}
    54  	clientFactory = f
    55  }
    56  
    57  // Request represents one JSON REST API request.
    58  type Request struct {
    59  	Method  string            // HTTP method to use
    60  	URL     string            // URL to access
    61  	Scopes  []string          // OAuth2 scopes to authenticate with or anonymous call if empty
    62  	Headers map[string]string // optional map with request headers
    63  	Body    any               // object to convert to JSON and send as body or []byte with the body
    64  	Out     any               // where to deserialize the response to
    65  }
    66  
    67  // Do performs an HTTP request with retries on transient errors.
    68  //
    69  // It can be used to make GET or DELETE requests (if Body is nil) or POST or PUT
    70  // requests (if Body is not nil). In latter case the body will be serialized to
    71  // JSON.
    72  //
    73  // Respects context's deadline and cancellation.
    74  func (r *Request) Do(ctx context.Context) error {
    75  	// Grab a client first. Use same client for all retry attempts.
    76  	var client *http.Client
    77  	if clientFactory != nil {
    78  		var err error
    79  		if client, err = clientFactory(ctx, r.Scopes); err != nil {
    80  			return err
    81  		}
    82  	} else {
    83  		client = http.DefaultClient
    84  		if testTransport := ctx.Value(&testTransportKey); testTransport != nil {
    85  			client = &http.Client{Transport: testTransport.(http.RoundTripper)}
    86  		}
    87  	}
    88  
    89  	// Prepare a blob with the request body. Marshal it once, to avoid
    90  	// remarshaling on retries.
    91  	isJSON := false
    92  	var bodyBlob []byte
    93  	if blob, ok := r.Body.([]byte); ok {
    94  		bodyBlob = blob
    95  	} else if r.Body != nil {
    96  		var err error
    97  		if bodyBlob, err = json.Marshal(r.Body); err != nil {
    98  			return err
    99  		}
   100  		isJSON = true
   101  	}
   102  
   103  	return fetchJSON(ctx, client, r.Out, func() (*http.Request, error) {
   104  		req, err := http.NewRequest(r.Method, r.URL, bytes.NewReader(bodyBlob))
   105  		if err != nil {
   106  			return nil, err
   107  		}
   108  		for k, v := range r.Headers {
   109  			req.Header.Set(k, v)
   110  		}
   111  		if isJSON {
   112  			req.Header.Set("Content-Type", "application/json")
   113  		}
   114  		return req, nil
   115  	})
   116  }
   117  
   118  // TODO(vadimsh): Add retries on HTTP 500.
   119  
   120  // fetchJSON fetches JSON document by making a request using given client.
   121  func fetchJSON(ctx context.Context, client *http.Client, val any, f func() (*http.Request, error)) error {
   122  	r, err := f()
   123  	if err != nil {
   124  		logging.Errorf(ctx, "auth: URL fetch failed - %s", err)
   125  		return err
   126  	}
   127  	logging.Infof(ctx, "auth: %s %s", r.Method, r.URL)
   128  	resp, err := ctxhttp.Do(ctx, client, r)
   129  	if err != nil {
   130  		logging.Errorf(ctx, "auth: URL fetch failed, can't connect - %s", err)
   131  		return transient.Tag.Apply(err)
   132  	}
   133  	defer func() {
   134  		io.ReadAll(resp.Body)
   135  		resp.Body.Close()
   136  	}()
   137  	if resp.StatusCode >= 300 {
   138  		body, _ := io.ReadAll(resp.Body)
   139  		// Opportunistically try to unmarshal the response. Works with JSON APIs.
   140  		if val != nil {
   141  			json.Unmarshal(body, val)
   142  		}
   143  		logging.Errorf(ctx, "auth: URL fetch failed - HTTP %d - %s", resp.StatusCode, string(body))
   144  		err := fmt.Errorf("auth: HTTP code (%d) when fetching %s", resp.StatusCode, r.URL)
   145  		if resp.StatusCode >= 500 {
   146  			return transient.Tag.Apply(err)
   147  		}
   148  		return err
   149  	}
   150  	if val != nil {
   151  		if err = json.NewDecoder(resp.Body).Decode(val); err != nil {
   152  			logging.Errorf(ctx, "auth: URL fetch failed, bad JSON - %s", err)
   153  			return fmt.Errorf("auth: can't deserialize JSON at %q - %s", r.URL, err)
   154  		}
   155  	}
   156  	return nil
   157  }