github.com/lingyao2333/mo-zero@v1.4.1/rest/httpc/requests.go (about)

     1  package httpc
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"fmt"
     8  	"io"
     9  	"net/http"
    10  	"net/http/httptrace"
    11  	nurl "net/url"
    12  	"strings"
    13  
    14  	"github.com/lingyao2333/mo-zero/core/lang"
    15  	"github.com/lingyao2333/mo-zero/core/mapping"
    16  	"github.com/lingyao2333/mo-zero/core/trace"
    17  	"github.com/lingyao2333/mo-zero/rest/httpc/internal"
    18  	"github.com/lingyao2333/mo-zero/rest/internal/header"
    19  	"go.opentelemetry.io/otel"
    20  	"go.opentelemetry.io/otel/codes"
    21  	"go.opentelemetry.io/otel/propagation"
    22  	semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
    23  	oteltrace "go.opentelemetry.io/otel/trace"
    24  )
    25  
    26  var interceptors = []internal.Interceptor{
    27  	internal.LogInterceptor,
    28  }
    29  
    30  // Do sends an HTTP request with the given arguments and returns an HTTP response.
    31  // data is automatically marshal into a *httpRequest, typically it's defined in an API file.
    32  func Do(ctx context.Context, method, url string, data interface{}) (*http.Response, error) {
    33  	req, err := buildRequest(ctx, method, url, data)
    34  	if err != nil {
    35  		return nil, err
    36  	}
    37  
    38  	return DoRequest(req)
    39  }
    40  
    41  // DoRequest sends an HTTP request and returns an HTTP response.
    42  func DoRequest(r *http.Request) (*http.Response, error) {
    43  	return request(r, defaultClient{})
    44  }
    45  
    46  type (
    47  	client interface {
    48  		do(r *http.Request) (*http.Response, error)
    49  	}
    50  
    51  	defaultClient struct{}
    52  )
    53  
    54  func (c defaultClient) do(r *http.Request) (*http.Response, error) {
    55  	return http.DefaultClient.Do(r)
    56  }
    57  
    58  func buildFormQuery(u *nurl.URL, val map[string]interface{}) string {
    59  	query := u.Query()
    60  	for k, v := range val {
    61  		query.Add(k, fmt.Sprint(v))
    62  	}
    63  
    64  	return query.Encode()
    65  }
    66  
    67  func buildRequest(ctx context.Context, method, url string, data interface{}) (*http.Request, error) {
    68  	u, err := nurl.Parse(url)
    69  	if err != nil {
    70  		return nil, err
    71  	}
    72  
    73  	var val map[string]map[string]interface{}
    74  	if data != nil {
    75  		val, err = mapping.Marshal(data)
    76  		if err != nil {
    77  			return nil, err
    78  		}
    79  	}
    80  
    81  	if err := fillPath(u, val[pathKey]); err != nil {
    82  		return nil, err
    83  	}
    84  
    85  	var reader io.Reader
    86  	jsonVars, hasJsonBody := val[jsonKey]
    87  	if hasJsonBody {
    88  		if method == http.MethodGet {
    89  			return nil, ErrGetWithBody
    90  		}
    91  
    92  		var buf bytes.Buffer
    93  		enc := json.NewEncoder(&buf)
    94  		if err := enc.Encode(jsonVars); err != nil {
    95  			return nil, err
    96  		}
    97  
    98  		reader = &buf
    99  	}
   100  
   101  	req, err := http.NewRequestWithContext(ctx, method, u.String(), reader)
   102  	if err != nil {
   103  		return nil, err
   104  	}
   105  
   106  	req.URL.RawQuery = buildFormQuery(u, val[formKey])
   107  	fillHeader(req, val[headerKey])
   108  	if hasJsonBody {
   109  		req.Header.Set(header.ContentType, header.JsonContentType)
   110  	}
   111  
   112  	return req, nil
   113  }
   114  
   115  func fillHeader(r *http.Request, val map[string]interface{}) {
   116  	for k, v := range val {
   117  		r.Header.Add(k, fmt.Sprint(v))
   118  	}
   119  }
   120  
   121  func fillPath(u *nurl.URL, val map[string]interface{}) error {
   122  	used := make(map[string]lang.PlaceholderType)
   123  	fields := strings.Split(u.Path, slash)
   124  
   125  	for i := range fields {
   126  		field := fields[i]
   127  		if len(field) > 0 && field[0] == colon {
   128  			name := field[1:]
   129  			ival, ok := val[name]
   130  			if !ok {
   131  				return fmt.Errorf("missing path variable %q", name)
   132  			}
   133  			value := fmt.Sprint(ival)
   134  			if len(value) == 0 {
   135  				return fmt.Errorf("empty path variable %q", name)
   136  			}
   137  			fields[i] = value
   138  			used[name] = lang.Placeholder
   139  		}
   140  	}
   141  
   142  	if len(val) != len(used) {
   143  		for key := range used {
   144  			delete(val, key)
   145  		}
   146  
   147  		var unused []string
   148  		for key := range val {
   149  			unused = append(unused, key)
   150  		}
   151  
   152  		return fmt.Errorf("more path variables are provided: %q", strings.Join(unused, ", "))
   153  	}
   154  
   155  	u.Path = strings.Join(fields, slash)
   156  	return nil
   157  }
   158  
   159  func request(r *http.Request, cli client) (*http.Response, error) {
   160  	tracer := otel.GetTracerProvider().Tracer(trace.TraceName)
   161  	propagator := otel.GetTextMapPropagator()
   162  
   163  	spanName := r.URL.Path
   164  	ctx, span := tracer.Start(
   165  		r.Context(),
   166  		spanName,
   167  		oteltrace.WithSpanKind(oteltrace.SpanKindClient),
   168  		oteltrace.WithAttributes(semconv.HTTPClientAttributesFromHTTPRequest(r)...),
   169  	)
   170  	defer span.End()
   171  
   172  	respHandlers := make([]internal.ResponseHandler, len(interceptors))
   173  	for i, interceptor := range interceptors {
   174  		var h internal.ResponseHandler
   175  		r, h = interceptor(r)
   176  		respHandlers[i] = h
   177  	}
   178  
   179  	clientTrace := httptrace.ContextClientTrace(ctx)
   180  	if clientTrace != nil {
   181  		ctx = httptrace.WithClientTrace(ctx, clientTrace)
   182  	}
   183  
   184  	r = r.WithContext(ctx)
   185  	propagator.Inject(ctx, propagation.HeaderCarrier(r.Header))
   186  
   187  	resp, err := cli.do(r)
   188  	for i := len(respHandlers) - 1; i >= 0; i-- {
   189  		respHandlers[i](resp, err)
   190  	}
   191  
   192  	if err != nil {
   193  		span.RecordError(err)
   194  		span.SetStatus(codes.Error, err.Error())
   195  		return resp, err
   196  	}
   197  
   198  	span.SetAttributes(semconv.HTTPAttributesFromHTTPStatusCode(resp.StatusCode)...)
   199  	span.SetStatus(semconv.SpanStatusFromHTTPStatusCode(resp.StatusCode))
   200  
   201  	return resp, err
   202  }