github.com/cloudreve/Cloudreve/v3@v3.0.0-20240224133659-3edb00a6484c/pkg/request/request.go (about) 1 package request 2 3 import ( 4 "encoding/json" 5 "errors" 6 "fmt" 7 "io" 8 "io/ioutil" 9 "net/http" 10 "net/url" 11 "strings" 12 "sync" 13 14 model "github.com/cloudreve/Cloudreve/v3/models" 15 "github.com/cloudreve/Cloudreve/v3/pkg/auth" 16 "github.com/cloudreve/Cloudreve/v3/pkg/conf" 17 "github.com/cloudreve/Cloudreve/v3/pkg/serializer" 18 "github.com/cloudreve/Cloudreve/v3/pkg/util" 19 ) 20 21 // GeneralClient 通用 HTTP Client 22 var GeneralClient Client = NewClient() 23 24 // Response 请求的响应或错误信息 25 type Response struct { 26 Err error 27 Response *http.Response 28 } 29 30 // Client 请求客户端 31 type Client interface { 32 Request(method, target string, body io.Reader, opts ...Option) *Response 33 } 34 35 // HTTPClient 实现 Client 接口 36 type HTTPClient struct { 37 mu sync.Mutex 38 options *options 39 tpsLimiter TPSLimiter 40 } 41 42 func NewClient(opts ...Option) Client { 43 client := &HTTPClient{ 44 options: newDefaultOption(), 45 tpsLimiter: globalTPSLimiter, 46 } 47 48 for _, o := range opts { 49 o.apply(client.options) 50 } 51 52 return client 53 } 54 55 // Request 发送HTTP请求 56 func (c *HTTPClient) Request(method, target string, body io.Reader, opts ...Option) *Response { 57 // 应用额外设置 58 c.mu.Lock() 59 options := c.options.clone() 60 c.mu.Unlock() 61 for _, o := range opts { 62 o.apply(&options) 63 } 64 65 // 创建请求客户端 66 client := &http.Client{Timeout: options.timeout} 67 68 // size为0时将body设为nil 69 if options.contentLength == 0 { 70 body = nil 71 } 72 73 // 确定请求URL 74 if options.endpoint != nil { 75 targetPath, err := url.Parse(target) 76 if err != nil { 77 return &Response{Err: err} 78 } 79 80 targetURL := *options.endpoint 81 target = targetURL.ResolveReference(targetPath).String() 82 } 83 84 // 创建请求 85 var ( 86 req *http.Request 87 err error 88 ) 89 if options.ctx != nil { 90 req, err = http.NewRequestWithContext(options.ctx, method, target, body) 91 } else { 92 req, err = http.NewRequest(method, target, body) 93 } 94 if err != nil { 95 return &Response{Err: err} 96 } 97 98 // 添加请求相关设置 99 if options.header != nil { 100 for k, v := range options.header { 101 req.Header.Add(k, strings.Join(v, " ")) 102 } 103 } 104 105 if options.masterMeta && conf.SystemConfig.Mode == "master" { 106 req.Header.Add(auth.CrHeaderPrefix+"Site-Url", model.GetSiteURL().String()) 107 req.Header.Add(auth.CrHeaderPrefix+"Site-Id", model.GetSettingByName("siteID")) 108 req.Header.Add(auth.CrHeaderPrefix+"Cloudreve-Version", conf.BackendVersion) 109 } 110 111 if options.slaveNodeID != "" && conf.SystemConfig.Mode == "slave" { 112 req.Header.Add(auth.CrHeaderPrefix+"Node-Id", options.slaveNodeID) 113 } 114 115 if options.contentLength != -1 { 116 req.ContentLength = options.contentLength 117 } 118 119 // 签名请求 120 if options.sign != nil { 121 switch method { 122 case "PUT", "POST", "PATCH": 123 auth.SignRequest(options.sign, req, options.signTTL) 124 default: 125 if resURL, err := auth.SignURI(options.sign, req.URL.String(), options.signTTL); err == nil { 126 req.URL = resURL 127 } 128 } 129 } 130 131 if options.tps > 0 { 132 c.tpsLimiter.Limit(options.ctx, options.tpsLimiterToken, options.tps, options.tpsBurst) 133 } 134 135 // 发送请求 136 resp, err := client.Do(req) 137 if err != nil { 138 return &Response{Err: err} 139 } 140 141 return &Response{Err: nil, Response: resp} 142 } 143 144 // GetResponse 检查响应并获取响应正文 145 func (resp *Response) GetResponse() (string, error) { 146 if resp.Err != nil { 147 return "", resp.Err 148 } 149 respBody, err := ioutil.ReadAll(resp.Response.Body) 150 _ = resp.Response.Body.Close() 151 152 return string(respBody), err 153 } 154 155 // CheckHTTPResponse 检查请求响应HTTP状态码 156 func (resp *Response) CheckHTTPResponse(status int) *Response { 157 if resp.Err != nil { 158 return resp 159 } 160 161 // 检查HTTP状态码 162 if resp.Response.StatusCode != status { 163 resp.Err = fmt.Errorf("服务器返回非正常HTTP状态%d", resp.Response.StatusCode) 164 } 165 return resp 166 } 167 168 // DecodeResponse 尝试解析为serializer.Response,并对状态码进行检查 169 func (resp *Response) DecodeResponse() (*serializer.Response, error) { 170 if resp.Err != nil { 171 return nil, resp.Err 172 } 173 174 respString, err := resp.GetResponse() 175 if err != nil { 176 return nil, err 177 } 178 179 var res serializer.Response 180 err = json.Unmarshal([]byte(respString), &res) 181 if err != nil { 182 util.Log().Debug("Failed to parse response: %s", string(respString)) 183 return nil, err 184 } 185 return &res, nil 186 } 187 188 // NopRSCloser 实现不完整seeker 189 type NopRSCloser struct { 190 body io.ReadCloser 191 status *rscStatus 192 } 193 194 type rscStatus struct { 195 // http.ServeContent 会读取一小块以决定内容类型, 196 // 但是响应body无法实现seek,所以此项为真时第一个read会返回假数据 197 IgnoreFirst bool 198 199 Size int64 200 } 201 202 // GetRSCloser 返回带有空seeker的RSCloser,供http.ServeContent使用 203 func (resp *Response) GetRSCloser() (*NopRSCloser, error) { 204 if resp.Err != nil { 205 return nil, resp.Err 206 } 207 208 return &NopRSCloser{ 209 body: resp.Response.Body, 210 status: &rscStatus{ 211 Size: resp.Response.ContentLength, 212 }, 213 }, resp.Err 214 } 215 216 // SetFirstFakeChunk 开启第一次read返回空数据 217 // TODO 测试 218 func (instance NopRSCloser) SetFirstFakeChunk() { 219 instance.status.IgnoreFirst = true 220 } 221 222 // SetContentLength 设置数据流大小 223 func (instance NopRSCloser) SetContentLength(size int64) { 224 instance.status.Size = size 225 } 226 227 // Read 实现 NopRSCloser reader 228 func (instance NopRSCloser) Read(p []byte) (n int, err error) { 229 if instance.status.IgnoreFirst && len(p) == 512 { 230 return 0, io.EOF 231 } 232 return instance.body.Read(p) 233 } 234 235 // Close 实现 NopRSCloser closer 236 func (instance NopRSCloser) Close() error { 237 return instance.body.Close() 238 } 239 240 // Seek 实现 NopRSCloser seeker, 只实现seek开头/结尾以便http.ServeContent用于确定正文大小 241 func (instance NopRSCloser) Seek(offset int64, whence int) (int64, error) { 242 // 进行第一次Seek操作后,取消忽略选项 243 if instance.status.IgnoreFirst { 244 instance.status.IgnoreFirst = false 245 } 246 if offset == 0 { 247 switch whence { 248 case io.SeekStart: 249 return 0, nil 250 case io.SeekEnd: 251 return instance.status.Size, nil 252 } 253 } 254 return 0, errors.New("not implemented") 255 256 } 257 258 // BlackHole 将客户端发来的数据放入黑洞 259 func BlackHole(r io.Reader) { 260 if !model.IsTrueVal(model.GetSettingByName("reset_after_upload_failed")) { 261 io.Copy(ioutil.Discard, r) 262 } 263 }