github.com/zooyer/miskit@v1.0.71/zrpc/http.go (about)

     1  package zrpc
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"io/ioutil"
    11  	"mime"
    12  	"net"
    13  	"net/http"
    14  	url2 "net/url"
    15  	"reflect"
    16  	"strings"
    17  	"time"
    18  
    19  	"github.com/zooyer/miskit/log"
    20  	"github.com/zooyer/miskit/metric"
    21  	"github.com/zooyer/miskit/trace"
    22  )
    23  
    24  // HTTP客户端
    25  type Client struct {
    26  	name    string
    27  	retry   int
    28  	timeout time.Duration
    29  	logger  *log.Logger
    30  	client  http.Client
    31  	option  []Option
    32  }
    33  
    34  // HTTP请求
    35  type Request http.Request
    36  
    37  type Response struct {
    38  	Errno   int             `json:"errno"`
    39  	Message string          `json:"message"`
    40  	Data    json.RawMessage `json:"data,omitempty"`
    41  }
    42  
    43  // HTTP请求选项
    44  type Option func(ctx context.Context, req *Request)
    45  
    46  // 请求表单(适用于GET参数、POST表单等)
    47  func NewForm(v interface{}) url2.Values {
    48  	var values = make(url2.Values)
    49  	switch val := v.(type) {
    50  	case url2.Values:
    51  		return val
    52  	case map[string]interface{}:
    53  		for k, v := range val {
    54  			values.Set(k, fmt.Sprint(v))
    55  		}
    56  		return values
    57  	case map[string]string:
    58  		for k, v := range val {
    59  			values.Set(k, v)
    60  		}
    61  		return values
    62  	}
    63  
    64  	val := reflect.ValueOf(v)
    65  	for val.Kind() == reflect.Ptr {
    66  		val = val.Elem()
    67  	}
    68  
    69  	switch val.Kind() {
    70  	case reflect.Map:
    71  		for it := val.MapRange(); it.Next(); {
    72  			key := fmt.Sprint(it.Key().Interface())
    73  			val := fmt.Sprint(it.Value().Interface())
    74  			values.Set(key, val)
    75  		}
    76  		return values
    77  	case reflect.Struct:
    78  		for i := 0; i < val.NumField(); i++ {
    79  			var field = val.Field(i)
    80  			if tag := val.Type().Field(i).Tag; tag.Get("binding") == "required" || !field.IsZero() {
    81  				var name = tag.Get("json")
    82  				if name == "" || name == "-" {
    83  					name = val.Type().Field(i).Name
    84  				}
    85  				values.Set(name, fmt.Sprint(field.Interface()))
    86  			}
    87  		}
    88  	}
    89  
    90  	return values
    91  }
    92  
    93  // 创建HTTP客户端
    94  func New(name string, retry int, timeout time.Duration, logger *log.Logger, opts ...Option) *Client {
    95  	var connTimeout = timeout / 5
    96  	if timeout != 0 {
    97  		if connTimeout.Milliseconds() < 5 {
    98  			connTimeout = 5 * time.Millisecond
    99  		}
   100  		timeout -= connTimeout
   101  
   102  		if timeout <= 0 {
   103  			panic("timeout must be greater than 5ms")
   104  		}
   105  	}
   106  
   107  	return &Client{
   108  		name:    name,
   109  		retry:   retry,
   110  		timeout: timeout,
   111  		logger:  logger,
   112  		option:  opts,
   113  		client: http.Client{
   114  			Transport: &http.Transport{
   115  				DialContext: (&net.Dialer{
   116  					Timeout:       connTimeout,
   117  					KeepAlive:     5 * time.Second,
   118  					FallbackDelay: 0,
   119  				}).DialContext,
   120  				MaxIdleConns:    50,
   121  				IdleConnTimeout: 5 * time.Second,
   122  			},
   123  			Timeout: timeout,
   124  		},
   125  	}
   126  }
   127  
   128  // GET请求
   129  func (c *Client) Get(ctx context.Context, url string, params url2.Values, response interface{}, opts ...Option) (data []byte, code int, err error) {
   130  	return c.do(ctx, "GET", "", url, params, response, opts...)
   131  }
   132  
   133  // POST请求
   134  func (c *Client) post(ctx context.Context, url, contentType string, request, response interface{}, opts ...Option) (data []byte, code int, err error) {
   135  	return c.do(ctx, "POST", contentType, url, request, response, opts...)
   136  }
   137  
   138  // POST表单请求
   139  func (c *Client) PostForm(ctx context.Context, url string, values url2.Values, response interface{}, opts ...Option) (data []byte, code int, err error) {
   140  	return c.post(ctx, url, "application/x-www-form-urlencoded", values, response, opts...)
   141  }
   142  
   143  // POST JSON请求
   144  func (c *Client) PostJSON(ctx context.Context, url string, request, response interface{}, opts ...Option) (data []byte, code int, err error) {
   145  	return c.post(ctx, url, "application/json", request, response, opts...)
   146  }
   147  
   148  // 创建HTTP请求
   149  func (c *Client) newRequest(ctx context.Context, method, contentType, url string, request interface{}, opts ...Option) (req *http.Request, err error) {
   150  	var body io.Reader
   151  
   152  	if method != "GET" {
   153  		content, _, _ := mime.ParseMediaType(contentType)
   154  		switch content {
   155  		case "application/json":
   156  			data, err := json.Marshal(request)
   157  			if err != nil {
   158  				return nil, err
   159  			}
   160  			body = bytes.NewReader(data)
   161  		case "application/x-www-form-urlencoded":
   162  			switch form := request.(type) {
   163  			case url2.Values:
   164  				body = strings.NewReader(form.Encode())
   165  			case map[string]string:
   166  				var values = make(url2.Values)
   167  				for key, val := range form {
   168  					values.Set(key, val)
   169  				}
   170  				body = strings.NewReader(values.Encode())
   171  			case map[string]interface{}:
   172  				var values = make(url2.Values)
   173  				for key, val := range form {
   174  					values.Set(key, fmt.Sprint(val))
   175  				}
   176  				body = strings.NewReader(values.Encode())
   177  			default:
   178  				return nil, errors.New("not support content type and request type:" + contentType)
   179  			}
   180  		default:
   181  			return nil, errors.New("not support content type:" + contentType)
   182  		}
   183  	}
   184  
   185  	if req, err = http.NewRequestWithContext(ctx, method, url, body); err != nil {
   186  		return
   187  	}
   188  
   189  	req.Header.Set("Content-Type", contentType)
   190  
   191  	for _, opt := range append(c.option, opts...) {
   192  		if opt != nil {
   193  			opt(ctx, (*Request)(req))
   194  		}
   195  	}
   196  
   197  	if method == "GET" && request != nil {
   198  		var query = req.URL.Query()
   199  		switch form := request.(type) {
   200  		case url2.Values:
   201  			for key, values := range form {
   202  				for _, val := range values {
   203  					query.Add(key, val)
   204  				}
   205  			}
   206  		case map[string]string:
   207  			for key, val := range form {
   208  				query.Add(key, val)
   209  			}
   210  		case map[string]interface{}:
   211  			for key, val := range form {
   212  				query.Add(key, fmt.Sprint(val))
   213  			}
   214  		default:
   215  			return nil, errors.New("not support request type")
   216  		}
   217  
   218  		req.URL.RawQuery = query.Encode()
   219  	}
   220  
   221  	return
   222  }
   223  
   224  func (c *Client) marshalJSON(v interface{}) string {
   225  	data, _ := json.Marshal(v)
   226  	return string(data)
   227  }
   228  
   229  // 执行HTTP请求
   230  func (c *Client) do(ctx context.Context, method, contentType, url string, request, response interface{}, opts ...Option) (body []byte, code int, err error) {
   231  	var (
   232  		start = time.Now()
   233  		retry int
   234  		req   *http.Request
   235  		resp  *http.Response
   236  		child = trace.Get(ctx).GenChild()
   237  	)
   238  
   239  	defer func() {
   240  		var (
   241  			code    = code
   242  			caller  = ""
   243  			callee  = url
   244  			latency = time.Since(start)
   245  		)
   246  
   247  		if req != nil {
   248  			callee = req.URL.Path
   249  		} else {
   250  			callee = strings.TrimPrefix(callee, "http://")
   251  			callee = strings.TrimPrefix(callee, "https://")
   252  			if index := strings.Index(callee, "/"); index >= 0 {
   253  				callee = callee[index:]
   254  			}
   255  		}
   256  
   257  		if err != nil && (code == 0 || code == http.StatusOK) {
   258  			code = 599
   259  		}
   260  
   261  		if child != nil && child.Request != nil {
   262  			caller = child.Request.URL.Path
   263  		}
   264  
   265  		metric.Rpc("zrpc", caller, callee, code, latency, map[string]interface{}{
   266  			"name": c.name,
   267  		})
   268  
   269  		if c.logger != nil {
   270  			output := c.logger.Info
   271  			if err != nil {
   272  				output = c.logger.Error
   273  			}
   274  
   275  			c.logger.Tag(
   276  				false,
   277  				"rpc", "http",
   278  				"name", c.name,
   279  				"method", method,
   280  				"latency", latency,
   281  				"retry", retry,
   282  			)
   283  			if child != nil {
   284  				c.logger.Tag(false, "cspan_id", child.SpanID)
   285  			}
   286  			if req != nil {
   287  				c.logger.Tag(false, "url", req.URL.String())
   288  			} else {
   289  				c.logger.Tag(false, "url", url)
   290  			}
   291  			if request != nil {
   292  				c.logger.Tag(false, "req", c.marshalJSON(request))
   293  			}
   294  			if bytes.ContainsAny(body, "\t\r\n") {
   295  				if data, err := json.Marshal(json.RawMessage(body)); err == nil {
   296  					body = data
   297  				} else {
   298  					body = bytes.ReplaceAll(body, []byte("\r"), nil)
   299  					body = bytes.ReplaceAll(body, []byte("\n"), nil)
   300  				}
   301  			}
   302  			if len(body) > 0 {
   303  				c.logger.Tag(false, "resp", string(body))
   304  			}
   305  			if err != nil {
   306  				c.logger.Tag(false, "error", err.Error())
   307  			}
   308  
   309  			output(ctx)
   310  		}
   311  	}()
   312  
   313  	// 创建请求
   314  	if req, err = c.newRequest(ctx, method, contentType, url, request, opts...); err != nil {
   315  		return
   316  	}
   317  
   318  	// 设置trace
   319  	if child != nil {
   320  		child.SetHeader(req.Header)
   321  	}
   322  
   323  	// 请求重试
   324  	for i := 0; i < c.retry+1; i++ {
   325  		if resp, err = c.client.Do(req); err == nil {
   326  			break
   327  		}
   328  		retry++
   329  	}
   330  	if err != nil {
   331  		return
   332  	}
   333  	defer resp.Body.Close()
   334  
   335  	// 读取响应
   336  	if body, err = ioutil.ReadAll(resp.Body); err != nil {
   337  		return
   338  	}
   339  
   340  	// 断言HTTP响应
   341  	if code = resp.StatusCode; code != http.StatusOK {
   342  		var res Response
   343  		if err = json.Unmarshal(body, &res); err == nil && res.Errno != 0 && res.Message != "" {
   344  			if res.Message != "" {
   345  				return body, code, errors.New(res.Message)
   346  			}
   347  			return body, code, fmt.Errorf("%s: http resonse code:%d, errno:%d, message:%s", c.name, code, res.Errno, res.Message)
   348  		}
   349  		return body, code, fmt.Errorf("%s: http response code:%d, status:%s", c.name, resp.StatusCode, resp.Status)
   350  	}
   351  
   352  	// 解析业务层响应
   353  	if response != nil {
   354  		// 解析body
   355  		var res Response
   356  		if err = json.Unmarshal(body, &res); err != nil {
   357  			return body, code, err
   358  		}
   359  
   360  		// 断言业务层errno
   361  		if errno := res.Errno; errno != 0 {
   362  			return body, code, fmt.Errorf("%s: http resonse code:%d, errno:%d, message:%s", c.name, code, res.Errno, res.Message)
   363  		}
   364  
   365  		if err = json.Unmarshal(res.Data, response); err != nil {
   366  			return body, code, err
   367  		}
   368  	}
   369  
   370  	return body, code, nil
   371  }