github.com/Mrs4s/go-cqhttp@v1.2.0/internal/download/download.go (about)

     1  // Package download provide download utility functions
     2  package download
     3  
     4  import (
     5  	"bufio"
     6  	"compress/gzip"
     7  	"crypto/tls"
     8  	"fmt"
     9  	"io"
    10  	"net/http"
    11  	"net/url"
    12  	"os"
    13  	"strconv"
    14  	"strings"
    15  	"sync"
    16  	"time"
    17  
    18  	"github.com/RomiChan/syncx"
    19  	"github.com/pkg/errors"
    20  	"github.com/tidwall/gjson"
    21  
    22  	"github.com/Mrs4s/go-cqhttp/internal/base"
    23  )
    24  
    25  var client = newClient(time.Second * 15)
    26  var clients syncx.Map[time.Duration, *http.Client]
    27  
    28  var clienth2 = &http.Client{
    29  	Transport: &http.Transport{
    30  		Proxy: func(request *http.Request) (*url.URL, error) {
    31  			if base.Proxy == "" {
    32  				return http.ProxyFromEnvironment(request)
    33  			}
    34  			return url.Parse(base.Proxy)
    35  		},
    36  		ForceAttemptHTTP2:   true,
    37  		MaxIdleConnsPerHost: 999,
    38  	},
    39  	Timeout: time.Second * 15,
    40  }
    41  
    42  func newClient(t time.Duration) *http.Client {
    43  	return &http.Client{
    44  		Transport: &http.Transport{
    45  			Proxy: func(request *http.Request) (*url.URL, error) {
    46  				if base.Proxy == "" {
    47  					return http.ProxyFromEnvironment(request)
    48  				}
    49  				return url.Parse(base.Proxy)
    50  			},
    51  			// Disable http2
    52  			TLSNextProto:        map[string]func(authority string, c *tls.Conn) http.RoundTripper{},
    53  			MaxIdleConnsPerHost: 999,
    54  		},
    55  		Timeout: t,
    56  	}
    57  }
    58  
    59  // ErrOverSize 响应主体过大时返回此错误
    60  var ErrOverSize = errors.New("oversize")
    61  
    62  // UserAgent HTTP请求时使用的UA
    63  const UserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36 Edg/87.0.664.66"
    64  
    65  // WithTimeout get a download instance with timeout t
    66  func (r Request) WithTimeout(t time.Duration) *Request {
    67  	if c, ok := clients.Load(t); ok {
    68  		r.custcli = c
    69  	} else {
    70  		c := newClient(t)
    71  		clients.Store(t, c)
    72  		r.custcli = c
    73  	}
    74  	return &r
    75  }
    76  
    77  // SetTimeout set internal/download client timeout
    78  func SetTimeout(t time.Duration) {
    79  	if t == 0 {
    80  		t = time.Second * 10
    81  	}
    82  	client.Timeout = t
    83  	clienth2.Timeout = t
    84  }
    85  
    86  // Request is a file download request
    87  type Request struct {
    88  	Method  string
    89  	URL     string
    90  	Header  map[string]string
    91  	Limit   int64
    92  	Body    io.Reader
    93  	custcli *http.Client
    94  }
    95  
    96  func (r Request) client() *http.Client {
    97  	if r.custcli != nil {
    98  		return r.custcli
    99  	}
   100  	if strings.Contains(r.URL, "go-cqhttp.org") {
   101  		return clienth2
   102  	}
   103  	return client
   104  }
   105  
   106  func (r Request) do() (*http.Response, error) {
   107  	if r.Method == "" {
   108  		r.Method = http.MethodGet
   109  	}
   110  	req, err := http.NewRequest(r.Method, r.URL, r.Body)
   111  	if err != nil {
   112  		return nil, err
   113  	}
   114  
   115  	req.Header["User-Agent"] = []string{UserAgent}
   116  	for k, v := range r.Header {
   117  		req.Header.Set(k, v)
   118  	}
   119  
   120  	return r.client().Do(req)
   121  }
   122  
   123  func (r Request) body() (io.ReadCloser, error) {
   124  	resp, err := r.do()
   125  	if err != nil {
   126  		return nil, err
   127  	}
   128  
   129  	limit := r.Limit // check file size limit
   130  	if limit > 0 && resp.ContentLength > limit {
   131  		_ = resp.Body.Close()
   132  		return nil, ErrOverSize
   133  	}
   134  
   135  	if strings.Contains(resp.Header.Get("Content-Encoding"), "gzip") {
   136  		return gzipReadCloser(resp.Body)
   137  	}
   138  	return resp.Body, err
   139  }
   140  
   141  // Bytes 对给定URL发送请求,返回响应主体
   142  func (r Request) Bytes() ([]byte, error) {
   143  	rd, err := r.body()
   144  	if err != nil {
   145  		return nil, err
   146  	}
   147  	defer rd.Close()
   148  	defer r.client().CloseIdleConnections()
   149  	return io.ReadAll(rd)
   150  }
   151  
   152  // JSON 发送请求, 并转换响应为JSON
   153  func (r Request) JSON() (gjson.Result, error) {
   154  	rd, err := r.body()
   155  	if err != nil {
   156  		return gjson.Result{}, err
   157  	}
   158  	defer rd.Close()
   159  	defer r.client().CloseIdleConnections()
   160  
   161  	var sb strings.Builder
   162  	_, err = io.Copy(&sb, rd)
   163  	if err != nil {
   164  		return gjson.Result{}, err
   165  	}
   166  
   167  	return gjson.Parse(sb.String()), nil
   168  }
   169  
   170  func writeToFile(reader io.ReadCloser, path string) error {
   171  	file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE, 0o644)
   172  	if err != nil {
   173  		return err
   174  	}
   175  	defer func() { _ = file.Close() }()
   176  	_, err = file.ReadFrom(reader)
   177  	return err
   178  }
   179  
   180  // WriteToFile 下载到制定目录
   181  func (r Request) WriteToFile(path string) error {
   182  	rd, err := r.body()
   183  	if err != nil {
   184  		return err
   185  	}
   186  	defer rd.Close()
   187  	defer r.client().CloseIdleConnections()
   188  	return writeToFile(rd, path)
   189  }
   190  
   191  // WriteToFileMultiThreading 多线程下载到制定目录
   192  func (r Request) WriteToFileMultiThreading(path string, thread int) error {
   193  	if thread < 2 {
   194  		return r.WriteToFile(path)
   195  	}
   196  
   197  	defer r.client().CloseIdleConnections()
   198  	limit := r.Limit
   199  	type BlockMetaData struct {
   200  		BeginOffset    int64
   201  		EndOffset      int64
   202  		DownloadedSize int64
   203  	}
   204  	var blocks []*BlockMetaData
   205  	var contentLength int64
   206  	errUnsupportedMultiThreading := errors.New("unsupported multi-threading")
   207  	// 初始化分块或直接下载
   208  	initOrDownload := func() error {
   209  		header := make(map[string]string, len(r.Header))
   210  		for k, v := range r.Header { // copy headers
   211  			header[k] = v
   212  		}
   213  		header["range"] = "bytes=0-"
   214  		req := Request{
   215  			URL:    r.URL,
   216  			Header: header,
   217  		}
   218  		resp, err := req.do()
   219  		if err != nil {
   220  			return err
   221  		}
   222  		defer resp.Body.Close()
   223  		if resp.StatusCode < 200 || resp.StatusCode >= 300 {
   224  			return errors.New("response status unsuccessful: " + strconv.FormatInt(int64(resp.StatusCode), 10))
   225  		}
   226  		if resp.StatusCode == http.StatusOK {
   227  			if limit > 0 && resp.ContentLength > limit {
   228  				return ErrOverSize
   229  			}
   230  			if err = writeToFile(resp.Body, path); err != nil {
   231  				return err
   232  			}
   233  			return errUnsupportedMultiThreading
   234  		}
   235  		if resp.StatusCode == http.StatusPartialContent {
   236  			contentLength = resp.ContentLength
   237  			if limit > 0 && resp.ContentLength > limit {
   238  				return ErrOverSize
   239  			}
   240  			blockSize := contentLength
   241  			if contentLength > 1024*1024 {
   242  				blockSize = (contentLength / int64(thread)) - 10
   243  			}
   244  			if blockSize == contentLength {
   245  				return writeToFile(resp.Body, path)
   246  			}
   247  			var tmp int64
   248  			for tmp+blockSize < contentLength {
   249  				blocks = append(blocks, &BlockMetaData{
   250  					BeginOffset: tmp,
   251  					EndOffset:   tmp + blockSize - 1,
   252  				})
   253  				tmp += blockSize
   254  			}
   255  			blocks = append(blocks, &BlockMetaData{
   256  				BeginOffset: tmp,
   257  				EndOffset:   contentLength - 1,
   258  			})
   259  			return nil
   260  		}
   261  		return errors.New("unknown status code")
   262  	}
   263  	// 下载分块
   264  	downloadBlock := func(block *BlockMetaData) error {
   265  		file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE, 0o666)
   266  		if err != nil {
   267  			return err
   268  		}
   269  		defer file.Close()
   270  		_, _ = file.Seek(block.BeginOffset, io.SeekStart)
   271  		writer := bufio.NewWriter(file)
   272  		defer writer.Flush()
   273  
   274  		header := make(map[string]string, len(r.Header))
   275  		for k, v := range r.Header { // copy headers
   276  			header[k] = v
   277  		}
   278  		header["range"] = fmt.Sprintf("bytes=%d-%d", block.BeginOffset, block.EndOffset)
   279  		req := Request{
   280  			URL:    r.URL,
   281  			Header: header,
   282  		}
   283  		resp, err := req.do()
   284  		if err != nil {
   285  			return err
   286  		}
   287  		defer resp.Body.Close()
   288  		if resp.StatusCode < 200 || resp.StatusCode >= 300 {
   289  			return errors.New("response status unsuccessful: " + strconv.FormatInt(int64(resp.StatusCode), 10))
   290  		}
   291  		buffer := make([]byte, 1024)
   292  		i, err := resp.Body.Read(buffer)
   293  		for {
   294  			if err != nil && err != io.EOF {
   295  				return err
   296  			}
   297  			i64 := int64(len(buffer[:i]))
   298  			needSize := block.EndOffset + 1 - block.BeginOffset
   299  			if i64 > needSize {
   300  				i64 = needSize
   301  				err = io.EOF
   302  			}
   303  			_, e := writer.Write(buffer[:i64])
   304  			if e != nil {
   305  				return e
   306  			}
   307  			block.BeginOffset += i64
   308  			block.DownloadedSize += i64
   309  			if err == io.EOF || block.BeginOffset > block.EndOffset {
   310  				break
   311  			}
   312  			i, err = resp.Body.Read(buffer)
   313  		}
   314  		return nil
   315  	}
   316  
   317  	if err := initOrDownload(); err != nil {
   318  		if err == errUnsupportedMultiThreading {
   319  			return nil
   320  		}
   321  		return err
   322  	}
   323  	wg := sync.WaitGroup{}
   324  	wg.Add(len(blocks))
   325  	var lastErr error
   326  	for i := range blocks {
   327  		go func(b *BlockMetaData) {
   328  			defer wg.Done()
   329  			if err := downloadBlock(b); err != nil {
   330  				lastErr = err
   331  			}
   332  		}(blocks[i])
   333  	}
   334  	wg.Wait()
   335  	return lastErr
   336  }
   337  
   338  type gzipCloser struct {
   339  	f io.Closer
   340  	r *gzip.Reader
   341  }
   342  
   343  // gzipReadCloser 从 io.ReadCloser 创建 gunzip io.ReadCloser
   344  func gzipReadCloser(reader io.ReadCloser) (io.ReadCloser, error) {
   345  	gzipReader, err := gzip.NewReader(reader)
   346  	if err != nil {
   347  		return nil, err
   348  	}
   349  	return &gzipCloser{
   350  		f: reader,
   351  		r: gzipReader,
   352  	}, nil
   353  }
   354  
   355  // Read impls io.Reader
   356  func (g *gzipCloser) Read(p []byte) (n int, err error) {
   357  	return g.r.Read(p)
   358  }
   359  
   360  // Close impls io.Closer
   361  func (g *gzipCloser) Close() error {
   362  	_ = g.f.Close()
   363  	return g.r.Close()
   364  }