github.com/Mrs4s/go-cqhttp@v1.2.0/internal/download/download.go (about) 1 // Package download provide download utility functions 2 package download 3 4 import ( 5 "bufio" 6 "compress/gzip" 7 "crypto/tls" 8 "fmt" 9 "io" 10 "net/http" 11 "net/url" 12 "os" 13 "strconv" 14 "strings" 15 "sync" 16 "time" 17 18 "github.com/RomiChan/syncx" 19 "github.com/pkg/errors" 20 "github.com/tidwall/gjson" 21 22 "github.com/Mrs4s/go-cqhttp/internal/base" 23 ) 24 25 var client = newClient(time.Second * 15) 26 var clients syncx.Map[time.Duration, *http.Client] 27 28 var clienth2 = &http.Client{ 29 Transport: &http.Transport{ 30 Proxy: func(request *http.Request) (*url.URL, error) { 31 if base.Proxy == "" { 32 return http.ProxyFromEnvironment(request) 33 } 34 return url.Parse(base.Proxy) 35 }, 36 ForceAttemptHTTP2: true, 37 MaxIdleConnsPerHost: 999, 38 }, 39 Timeout: time.Second * 15, 40 } 41 42 func newClient(t time.Duration) *http.Client { 43 return &http.Client{ 44 Transport: &http.Transport{ 45 Proxy: func(request *http.Request) (*url.URL, error) { 46 if base.Proxy == "" { 47 return http.ProxyFromEnvironment(request) 48 } 49 return url.Parse(base.Proxy) 50 }, 51 // Disable http2 52 TLSNextProto: map[string]func(authority string, c *tls.Conn) http.RoundTripper{}, 53 MaxIdleConnsPerHost: 999, 54 }, 55 Timeout: t, 56 } 57 } 58 59 // ErrOverSize 响应主体过大时返回此错误 60 var ErrOverSize = errors.New("oversize") 61 62 // UserAgent HTTP请求时使用的UA 63 const UserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36 Edg/87.0.664.66" 64 65 // WithTimeout get a download instance with timeout t 66 func (r Request) WithTimeout(t time.Duration) *Request { 67 if c, ok := clients.Load(t); ok { 68 r.custcli = c 69 } else { 70 c := newClient(t) 71 clients.Store(t, c) 72 r.custcli = c 73 } 74 return &r 75 } 76 77 // SetTimeout set internal/download client timeout 78 func SetTimeout(t time.Duration) { 79 if t == 0 { 80 t = time.Second * 10 81 } 82 client.Timeout = t 83 clienth2.Timeout = t 84 } 85 86 // Request is a file download request 87 type Request struct { 88 Method string 89 URL string 90 Header map[string]string 91 Limit int64 92 Body io.Reader 93 custcli *http.Client 94 } 95 96 func (r Request) client() *http.Client { 97 if r.custcli != nil { 98 return r.custcli 99 } 100 if strings.Contains(r.URL, "go-cqhttp.org") { 101 return clienth2 102 } 103 return client 104 } 105 106 func (r Request) do() (*http.Response, error) { 107 if r.Method == "" { 108 r.Method = http.MethodGet 109 } 110 req, err := http.NewRequest(r.Method, r.URL, r.Body) 111 if err != nil { 112 return nil, err 113 } 114 115 req.Header["User-Agent"] = []string{UserAgent} 116 for k, v := range r.Header { 117 req.Header.Set(k, v) 118 } 119 120 return r.client().Do(req) 121 } 122 123 func (r Request) body() (io.ReadCloser, error) { 124 resp, err := r.do() 125 if err != nil { 126 return nil, err 127 } 128 129 limit := r.Limit // check file size limit 130 if limit > 0 && resp.ContentLength > limit { 131 _ = resp.Body.Close() 132 return nil, ErrOverSize 133 } 134 135 if strings.Contains(resp.Header.Get("Content-Encoding"), "gzip") { 136 return gzipReadCloser(resp.Body) 137 } 138 return resp.Body, err 139 } 140 141 // Bytes 对给定URL发送请求,返回响应主体 142 func (r Request) Bytes() ([]byte, error) { 143 rd, err := r.body() 144 if err != nil { 145 return nil, err 146 } 147 defer rd.Close() 148 defer r.client().CloseIdleConnections() 149 return io.ReadAll(rd) 150 } 151 152 // JSON 发送请求, 并转换响应为JSON 153 func (r Request) JSON() (gjson.Result, error) { 154 rd, err := r.body() 155 if err != nil { 156 return gjson.Result{}, err 157 } 158 defer rd.Close() 159 defer r.client().CloseIdleConnections() 160 161 var sb strings.Builder 162 _, err = io.Copy(&sb, rd) 163 if err != nil { 164 return gjson.Result{}, err 165 } 166 167 return gjson.Parse(sb.String()), nil 168 } 169 170 func writeToFile(reader io.ReadCloser, path string) error { 171 file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE, 0o644) 172 if err != nil { 173 return err 174 } 175 defer func() { _ = file.Close() }() 176 _, err = file.ReadFrom(reader) 177 return err 178 } 179 180 // WriteToFile 下载到制定目录 181 func (r Request) WriteToFile(path string) error { 182 rd, err := r.body() 183 if err != nil { 184 return err 185 } 186 defer rd.Close() 187 defer r.client().CloseIdleConnections() 188 return writeToFile(rd, path) 189 } 190 191 // WriteToFileMultiThreading 多线程下载到制定目录 192 func (r Request) WriteToFileMultiThreading(path string, thread int) error { 193 if thread < 2 { 194 return r.WriteToFile(path) 195 } 196 197 defer r.client().CloseIdleConnections() 198 limit := r.Limit 199 type BlockMetaData struct { 200 BeginOffset int64 201 EndOffset int64 202 DownloadedSize int64 203 } 204 var blocks []*BlockMetaData 205 var contentLength int64 206 errUnsupportedMultiThreading := errors.New("unsupported multi-threading") 207 // 初始化分块或直接下载 208 initOrDownload := func() error { 209 header := make(map[string]string, len(r.Header)) 210 for k, v := range r.Header { // copy headers 211 header[k] = v 212 } 213 header["range"] = "bytes=0-" 214 req := Request{ 215 URL: r.URL, 216 Header: header, 217 } 218 resp, err := req.do() 219 if err != nil { 220 return err 221 } 222 defer resp.Body.Close() 223 if resp.StatusCode < 200 || resp.StatusCode >= 300 { 224 return errors.New("response status unsuccessful: " + strconv.FormatInt(int64(resp.StatusCode), 10)) 225 } 226 if resp.StatusCode == http.StatusOK { 227 if limit > 0 && resp.ContentLength > limit { 228 return ErrOverSize 229 } 230 if err = writeToFile(resp.Body, path); err != nil { 231 return err 232 } 233 return errUnsupportedMultiThreading 234 } 235 if resp.StatusCode == http.StatusPartialContent { 236 contentLength = resp.ContentLength 237 if limit > 0 && resp.ContentLength > limit { 238 return ErrOverSize 239 } 240 blockSize := contentLength 241 if contentLength > 1024*1024 { 242 blockSize = (contentLength / int64(thread)) - 10 243 } 244 if blockSize == contentLength { 245 return writeToFile(resp.Body, path) 246 } 247 var tmp int64 248 for tmp+blockSize < contentLength { 249 blocks = append(blocks, &BlockMetaData{ 250 BeginOffset: tmp, 251 EndOffset: tmp + blockSize - 1, 252 }) 253 tmp += blockSize 254 } 255 blocks = append(blocks, &BlockMetaData{ 256 BeginOffset: tmp, 257 EndOffset: contentLength - 1, 258 }) 259 return nil 260 } 261 return errors.New("unknown status code") 262 } 263 // 下载分块 264 downloadBlock := func(block *BlockMetaData) error { 265 file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE, 0o666) 266 if err != nil { 267 return err 268 } 269 defer file.Close() 270 _, _ = file.Seek(block.BeginOffset, io.SeekStart) 271 writer := bufio.NewWriter(file) 272 defer writer.Flush() 273 274 header := make(map[string]string, len(r.Header)) 275 for k, v := range r.Header { // copy headers 276 header[k] = v 277 } 278 header["range"] = fmt.Sprintf("bytes=%d-%d", block.BeginOffset, block.EndOffset) 279 req := Request{ 280 URL: r.URL, 281 Header: header, 282 } 283 resp, err := req.do() 284 if err != nil { 285 return err 286 } 287 defer resp.Body.Close() 288 if resp.StatusCode < 200 || resp.StatusCode >= 300 { 289 return errors.New("response status unsuccessful: " + strconv.FormatInt(int64(resp.StatusCode), 10)) 290 } 291 buffer := make([]byte, 1024) 292 i, err := resp.Body.Read(buffer) 293 for { 294 if err != nil && err != io.EOF { 295 return err 296 } 297 i64 := int64(len(buffer[:i])) 298 needSize := block.EndOffset + 1 - block.BeginOffset 299 if i64 > needSize { 300 i64 = needSize 301 err = io.EOF 302 } 303 _, e := writer.Write(buffer[:i64]) 304 if e != nil { 305 return e 306 } 307 block.BeginOffset += i64 308 block.DownloadedSize += i64 309 if err == io.EOF || block.BeginOffset > block.EndOffset { 310 break 311 } 312 i, err = resp.Body.Read(buffer) 313 } 314 return nil 315 } 316 317 if err := initOrDownload(); err != nil { 318 if err == errUnsupportedMultiThreading { 319 return nil 320 } 321 return err 322 } 323 wg := sync.WaitGroup{} 324 wg.Add(len(blocks)) 325 var lastErr error 326 for i := range blocks { 327 go func(b *BlockMetaData) { 328 defer wg.Done() 329 if err := downloadBlock(b); err != nil { 330 lastErr = err 331 } 332 }(blocks[i]) 333 } 334 wg.Wait() 335 return lastErr 336 } 337 338 type gzipCloser struct { 339 f io.Closer 340 r *gzip.Reader 341 } 342 343 // gzipReadCloser 从 io.ReadCloser 创建 gunzip io.ReadCloser 344 func gzipReadCloser(reader io.ReadCloser) (io.ReadCloser, error) { 345 gzipReader, err := gzip.NewReader(reader) 346 if err != nil { 347 return nil, err 348 } 349 return &gzipCloser{ 350 f: reader, 351 r: gzipReader, 352 }, nil 353 } 354 355 // Read impls io.Reader 356 func (g *gzipCloser) Read(p []byte) (n int, err error) { 357 return g.r.Read(p) 358 } 359 360 // Close impls io.Closer 361 func (g *gzipCloser) Close() error { 362 _ = g.f.Close() 363 return g.r.Close() 364 }