github.com/sereiner/library@v0.0.0-20200518095232-1fa3e640cc5f/net/http/http.go (about)

     1  package http
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/rand"
     6  	"crypto/tls"
     7  	"crypto/x509"
     8  	"fmt"
     9  	"io"
    10  	"io/ioutil"
    11  	"mime/multipart"
    12  	"net"
    13  	"net/http"
    14  	"net/url"
    15  	"os"
    16  	"strings"
    17  
    18  	"time"
    19  
    20  	"github.com/sereiner/library/encoding"
    21  	"github.com/sereiner/library/envs"
    22  )
    23  
    24  type OptionConf struct {
    25  	ConnectionTimeout time.Duration
    26  	RequestTimeout    time.Duration
    27  	certFiles         []string
    28  	cafile            string
    29  	proxy             string
    30  	keepalive         bool
    31  }
    32  
    33  //Option 配置选项
    34  type Option func(*OptionConf)
    35  
    36  //WithConnTimeout 设置请求超时时长
    37  func WithConnTimeout(tm time.Duration) Option {
    38  	return func(o *OptionConf) {
    39  		o.ConnectionTimeout = tm
    40  	}
    41  }
    42  
    43  //WithRequestTimeout 设置请求超时时长
    44  func WithRequestTimeout(tm time.Duration) Option {
    45  	return func(o *OptionConf) {
    46  		o.RequestTimeout = tm
    47  	}
    48  }
    49  
    50  //WithCert 设置请求证书
    51  func WithCert(cerfile string, key string) Option {
    52  	return func(o *OptionConf) {
    53  		o.certFiles = []string{cerfile, key}
    54  	}
    55  }
    56  
    57  //WithCa 设置ca证书
    58  func WithCa(cafile string) Option {
    59  	return func(o *OptionConf) {
    60  		o.cafile = cafile
    61  	}
    62  }
    63  
    64  //WithProxy 使用代理地址
    65  func WithProxy(proxy string) Option {
    66  	return func(o *OptionConf) {
    67  		o.proxy = proxy
    68  	}
    69  }
    70  
    71  //WithKeepalive 设置keep alive
    72  func WithKeepalive(keepalive bool) Option {
    73  	return func(o *OptionConf) {
    74  		o.keepalive = keepalive
    75  	}
    76  }
    77  
    78  //HTTPClient HTTP客户端
    79  type HTTPClient struct {
    80  	*OptionConf
    81  	client   *http.Client
    82  	Response *http.Response
    83  }
    84  
    85  //HTTPClientRequest  http请求
    86  type HTTPClientRequest struct {
    87  	headers  map[string]string
    88  	client   *http.Client
    89  	method   string
    90  	url      string
    91  	params   string
    92  	encoding string
    93  }
    94  
    95  // NewHTTPClient 构建HTTP客户端,用于发送GET POST等请求
    96  func NewHTTPClient(opts ...Option) (client *HTTPClient, err error) {
    97  	client = &HTTPClient{}
    98  	client.OptionConf = &OptionConf{
    99  		ConnectionTimeout: time.Second * time.Duration(envs.GetInt("hydra_http_conn_timeout", 3)),
   100  		RequestTimeout:    time.Second * time.Duration(envs.GetInt("hydra_http_req_timeout", 10))}
   101  	for _, opt := range opts {
   102  		opt(client.OptionConf)
   103  	}
   104  	tlsConf, err := getCert(client.OptionConf)
   105  	if err != nil {
   106  		return nil, err
   107  	}
   108  	client.client = &http.Client{
   109  		Transport: &http.Transport{
   110  			DisableKeepAlives: client.OptionConf.keepalive,
   111  			TLSClientConfig:   tlsConf,
   112  			Proxy:             getProxy(client.OptionConf),
   113  			Dial: func(netw, addr string) (net.Conn, error) {
   114  				c, err := net.DialTimeout(netw, addr, client.OptionConf.ConnectionTimeout)
   115  				if err != nil {
   116  					return nil, err
   117  				}
   118  				c.SetDeadline(time.Now().Add(client.OptionConf.RequestTimeout))
   119  				return c, nil
   120  			},
   121  			MaxIdleConnsPerHost:   0,
   122  			ResponseHeaderTimeout: 0,
   123  		},
   124  	}
   125  	return
   126  }
   127  
   128  func getCert(c *OptionConf) (*tls.Config, error) {
   129  	ssl := &tls.Config{}
   130  	if len(c.certFiles) == 2 {
   131  		cert, err := tls.LoadX509KeyPair(c.certFiles[0], c.certFiles[1])
   132  		if err != nil {
   133  			return nil, fmt.Errorf("cert证书(pem:%s,key:%s),加载失败:%v", c.certFiles[0], c.certFiles[1], err)
   134  		}
   135  		ssl.Certificates = []tls.Certificate{cert}
   136  	}
   137  	if c.cafile != "" {
   138  		caData, err := ioutil.ReadFile(c.cafile)
   139  		if err != nil {
   140  			return nil, fmt.Errorf("ca证书(%s)读取错误:%v", c.cafile, err)
   141  		}
   142  		pool := x509.NewCertPool()
   143  		pool.AppendCertsFromPEM(caData)
   144  		ssl.RootCAs = pool
   145  	}
   146  	if len(ssl.Certificates) == 0 && ssl.RootCAs == nil {
   147  		return nil, nil
   148  	}
   149  	ssl.Rand = rand.Reader
   150  	return ssl, nil
   151  
   152  }
   153  func getProxy(c *OptionConf) func(*http.Request) (*url.URL, error) {
   154  	if c.proxy != "" {
   155  		return func(_ *http.Request) (*url.URL, error) {
   156  			return url.Parse(c.proxy) //根据定义Proxy func(*Request) (*url.URL, error)这里要返回url.URL
   157  		}
   158  	}
   159  	return nil
   160  }
   161  
   162  // Download 发送http请求, method:http请求方法包括:get,post,delete,put等 url: 请求的HTTP地址,不包括参数,params:请求参数,
   163  // header,http请求头多个用/n分隔,每个键值之前用=号连接
   164  func (c *HTTPClient) Download(method string, url string, params string, header map[string]string) (body []byte, status int, err error) {
   165  	req, err := http.NewRequest(strings.ToUpper(method), url, strings.NewReader(params))
   166  	if err != nil {
   167  		return
   168  	}
   169  	req.Close = true
   170  	for i, v := range header {
   171  		req.Header.Set(i, v)
   172  	}
   173  	c.Response, err = c.client.Do(req)
   174  	if c.Response != nil {
   175  		defer c.Response.Body.Close()
   176  	}
   177  	if err != nil {
   178  		return
   179  	}
   180  	status = c.Response.StatusCode
   181  	body, err = ioutil.ReadAll(c.Response.Body)
   182  	return
   183  }
   184  
   185  // Save 发送http请求, method:http请求方法包括:get,post,delete,put等 url: 请求的HTTP地址,不包括参数,params:请求参数,
   186  // header,http请求头多个用/n分隔,每个键值之前用=号连接
   187  func (c *HTTPClient) Save(method string, url string, params string, header map[string]string, path string) (status int, err error) {
   188  	body, status, err := c.Download(method, url, params, header)
   189  	if err != nil {
   190  		return
   191  	}
   192  	fl, err := os.OpenFile(path, os.O_RDWR|os.O_APPEND|os.O_CREATE, 0664)
   193  	if err != nil {
   194  		return
   195  	}
   196  	defer fl.Close()
   197  	n, err := fl.Write(body)
   198  	if err == nil && n < len(body) {
   199  		err = io.ErrShortWrite
   200  	}
   201  	return
   202  }
   203  
   204  // Request 发送http请求, method:http请求方法包括:get,post,delete,put等 url: 请求的HTTP地址,不包括参数,params:请求参数,
   205  // header,http请求头多个用/n分隔,每个键值之前用=号连接
   206  func (c *HTTPClient) Request(method string, url string, params string, charset string, header map[string]string) (content string, status int, err error) {
   207  	req, err := http.NewRequest(strings.ToUpper(method), url, strings.NewReader(params))
   208  	if err != nil {
   209  		return
   210  	}
   211  	req.Close = true
   212  	for i, v := range header {
   213  		req.Header.Set(i, v)
   214  	}
   215  	c.Response, err = c.client.Do(req)
   216  	if c.Response != nil {
   217  		defer c.Response.Body.Close()
   218  	}
   219  	if err != nil {
   220  		return
   221  	}
   222  	body, err := ioutil.ReadAll(c.Response.Body)
   223  	if err != nil {
   224  		return
   225  	}
   226  	status = c.Response.StatusCode
   227  	ct, err := encoding.DecodeBytes(body, charset)
   228  	content = string(ct)
   229  	return
   230  }
   231  
   232  // Get http get请求
   233  func (c *HTTPClient) Get(url string, args ...string) (content string, status int, err error) {
   234  	charset := getEncoding(args...)
   235  	c.Response, err = c.client.Get(url)
   236  	if c.Response != nil {
   237  		defer c.Response.Body.Close()
   238  	}
   239  	if err != nil {
   240  		return
   241  	}
   242  	body, err := ioutil.ReadAll(c.Response.Body)
   243  	if err != nil {
   244  		return
   245  	}
   246  	status = c.Response.StatusCode
   247  	ct, err := encoding.DecodeBytes(body, charset)
   248  	content = string(ct)
   249  	return
   250  }
   251  
   252  // Post http Post请求
   253  func (c *HTTPClient) Post(url string, params string, args ...string) (content string, status int, err error) {
   254  	charset := getEncoding(args...)
   255  	c.Response, err = c.client.Post(url, fmt.Sprintf("application/x-www-form-urlencoded;charset=%s", charset), encoding.GetEncodeReader([]byte(params), charset))
   256  	if c.Response != nil {
   257  		defer c.Response.Body.Close()
   258  	}
   259  	if err != nil {
   260  		return
   261  	}
   262  	body, err := ioutil.ReadAll(c.Response.Body)
   263  	if err != nil {
   264  		return
   265  	}
   266  	status = c.Response.StatusCode
   267  	rcontent, err := encoding.DecodeBytes(body, charset)
   268  	content = string(rcontent)
   269  	return
   270  }
   271  
   272  //Upload 文件上传
   273  func (c *HTTPClient) Upload(url string, params map[string]string, files map[string]string, args ...string) (content string, status int, err error) {
   274  	charset := getEncoding(args...)
   275  	bodyBuffer := &bytes.Buffer{}
   276  	bodyWriter := multipart.NewWriter(bodyBuffer)
   277  
   278  	//字段处理
   279  	for k, v := range params {
   280  		err = bodyWriter.WriteField(k, v)
   281  		if err != nil {
   282  			return "", 0, fmt.Errorf("设置字段失败:%s(%v)", k, v)
   283  		}
   284  	}
   285  
   286  	//文件流处理
   287  	for k, v := range files {
   288  		fw1, err := bodyWriter.CreateFormFile(k, v)
   289  		if err != nil {
   290  			return "", 0, fmt.Errorf("无法创建文件流:%v", v)
   291  		}
   292  		f1, err := os.Open(v)
   293  		if err != nil {
   294  			return "", 0, fmt.Errorf("无法读取文件:%s", v)
   295  		}
   296  		defer f1.Close()
   297  		io.Copy(fw1, f1)
   298  	}
   299  
   300  	contentType := bodyWriter.FormDataContentType()
   301  	bodyWriter.Close()
   302  
   303  	//发送POST请求
   304  	c.Response, err = c.client.Post(url, contentType, encoding.GetEncodeReader(bodyBuffer.Bytes(), charset))
   305  	if err != nil {
   306  		return
   307  	}
   308  	defer c.Response.Body.Close()
   309  
   310  	//处理响应包
   311  	body, err := ioutil.ReadAll(c.Response.Body)
   312  	if err != nil {
   313  		return
   314  	}
   315  	status = c.Response.StatusCode
   316  	rcontent, err := encoding.DecodeBytes(body, charset)
   317  	content = string(rcontent)
   318  	return
   319  }
   320  
   321  func getEncoding(params ...string) (encoding string) {
   322  	if len(params) > 0 {
   323  		encoding = strings.ToUpper(params[0])
   324  		return
   325  	}
   326  	return "UTF-8"
   327  }