github.com/tickstep/library-go@v0.1.1/requester/fetch.go (about)

     1  package requester
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"fmt"
     7  	"github.com/tickstep/library-go/requester/rio"
     8  	"io"
     9  	"io/ioutil"
    10  	"net/http"
    11  	"net/url"
    12  	"strings"
    13  )
    14  
    15  // HttpGet 简单实现 http 访问 GET 请求
    16  func HttpGet(urlStr string) (body []byte, err error) {
    17  	resp, err := DefaultClient.Get(urlStr)
    18  	if resp != nil {
    19  		defer resp.Body.Close()
    20  	}
    21  	if err != nil {
    22  		return nil, err
    23  	}
    24  	return ioutil.ReadAll(resp.Body)
    25  }
    26  
    27  // HttpPost 简单的HTTP POST方法
    28  func HttpPost(urlStr string, postData interface{}) (body []byte, err error) {
    29  	return Fetch("POST", urlStr, postData, nil)
    30  }
    31  
    32  // Req 参见 *HTTPClient.Req, 使用默认 http 客户端
    33  func Req(method string, urlStr string, post interface{}, header map[string]string) (resp *http.Response, err error) {
    34  	return DefaultClient.Req(method, urlStr, post, header)
    35  }
    36  
    37  // Fetch 参见 *HTTPClient.Fetch, 使用默认 http 客户端
    38  func Fetch(method string, urlStr string, post interface{}, header map[string]string) (body []byte, err error) {
    39  	return DefaultClient.Fetch(method, urlStr, post, header)
    40  }
    41  
    42  // Req 实现 http/https 访问,
    43  // 根据给定的 method (GET, POST, HEAD, PUT 等等), urlStr (网址),
    44  // post (post 数据), header (header 请求头数据), 进行网站访问。
    45  // 返回值分别为 *http.Response, 错误信息
    46  func (h *HTTPClient) Req(method string, urlStr string, post interface{}, header map[string]string) (resp *http.Response, err error) {
    47  	h.lazyInit()
    48  	var (
    49  		req           *http.Request
    50  		obody         io.Reader
    51  		contentLength int64
    52  		contentType   string
    53  	)
    54  
    55  	if post != nil {
    56  		isJson := false
    57  		if header != nil {
    58  			if ct, ok := header["Content-Type"]; ok {
    59  				if strings.Contains(strings.ToLower(ct), "application/json") {
    60  					isJson = true
    61  				}
    62  			}
    63  			if ct, ok := header["content-type"]; ok {
    64  				if strings.Contains(strings.ToLower(ct), "application/json") {
    65  					isJson = true
    66  				}
    67  			}
    68  		}
    69  		if isJson {
    70  			switch value := post.(type) {
    71  			case io.Reader:
    72  				obody = value
    73  			case map[string]string:
    74  				paramJson, _ := json.Marshal(value)
    75  				obody = strings.NewReader(string(paramJson))
    76  			case map[string]interface{}:
    77  				paramJson, _ := json.Marshal(value)
    78  				obody = strings.NewReader(string(paramJson))
    79  			case map[interface{}]interface{}:
    80  				paramJson, _ := json.Marshal(value)
    81  				obody = strings.NewReader(string(paramJson))
    82  			case string:
    83  				obody = strings.NewReader(value)
    84  			case []byte:
    85  				obody = bytes.NewReader(value[:])
    86  			default:
    87  				paramJson, _ := json.Marshal(value)
    88  				obody = strings.NewReader(string(paramJson))
    89  			}
    90  		} else {
    91  			switch value := post.(type) {
    92  			case io.Reader:
    93  				obody = value
    94  			case map[string]string:
    95  				query := url.Values{}
    96  				for k := range value {
    97  					query.Set(k, value[k])
    98  				}
    99  				obody = strings.NewReader(query.Encode())
   100  			case map[string]interface{}:
   101  				query := url.Values{}
   102  				for k := range value {
   103  					query.Set(k, fmt.Sprint(value[k]))
   104  				}
   105  				obody = strings.NewReader(query.Encode())
   106  			case map[interface{}]interface{}:
   107  				query := url.Values{}
   108  				for k := range value {
   109  					query.Set(fmt.Sprint(k), fmt.Sprint(value[k]))
   110  				}
   111  				obody = strings.NewReader(query.Encode())
   112  			case string:
   113  				obody = strings.NewReader(value)
   114  			case []byte:
   115  				obody = bytes.NewReader(value[:])
   116  			default:
   117  				return nil, fmt.Errorf("requester.Req: unknown post type: %s", value)
   118  			}
   119  		}
   120  
   121  		switch value := post.(type) {
   122  		case ContentLengther:
   123  			contentLength = value.ContentLength()
   124  		case rio.Lener:
   125  			contentLength = int64(value.Len())
   126  		case rio.Lener64:
   127  			contentLength = value.Len()
   128  		}
   129  
   130  		switch value := post.(type) {
   131  		case ContentTyper:
   132  			contentType = value.ContentType()
   133  		}
   134  	}
   135  	req, err = http.NewRequest(method, urlStr, obody)
   136  	if err != nil {
   137  		return nil, err
   138  	}
   139  
   140  	if req.ContentLength <= 0 && contentLength != 0 {
   141  		req.ContentLength = contentLength
   142  	}
   143  
   144  	// 设置浏览器标识
   145  	req.Header.Set("User-Agent", h.UserAgent)
   146  
   147  	// 设置Content-Type
   148  	if contentType != "" {
   149  		req.Header.Set("Content-Type", contentType)
   150  	}
   151  
   152  	if header != nil {
   153  		// 处理Host
   154  		if host, ok := header["Host"]; ok {
   155  			req.Host = host
   156  		}
   157  
   158  		for key := range header {
   159  			req.Header.Set(key, header[key])
   160  		}
   161  	}
   162  
   163  	return h.Client.Do(req)
   164  }
   165  
   166  // Fetch 实现 http/https 访问,
   167  // 根据给定的 method (GET, POST, HEAD, PUT 等等), urlStr (网址),
   168  // post (post 数据), header (header 请求头数据), 进行网站访问。
   169  // 返回值分别为 网站主体, 错误信息
   170  func (h *HTTPClient) Fetch(method string, urlStr string, post interface{}, header map[string]string) (body []byte, err error) {
   171  	h.lazyInit()
   172  	resp, err := h.Req(method, urlStr, post, header)
   173  	if resp != nil {
   174  		defer resp.Body.Close()
   175  	}
   176  	if err != nil {
   177  		return nil, err
   178  	}
   179  
   180  	return ioutil.ReadAll(resp.Body)
   181  }
   182  
   183  func (h *HTTPClient) DoGet(urlStr string) (body []byte, err error) {
   184  	return h.Fetch("GET", urlStr, nil, nil)
   185  }
   186  
   187  func (h *HTTPClient) DoPost(urlStr string, post interface{}) (body []byte, err error) {
   188  	headers := map[string]string{
   189  		"Content-Type": "application/x-www-form-urlencoded; charset=UTF-8",
   190  	}
   191  	return h.Fetch("POST", urlStr, post, headers)
   192  }