github.com/cloudreve/Cloudreve/v3@v3.0.0-20240224133659-3edb00a6484c/pkg/request/request.go (about)

     1  package request
     2  
     3  import (
     4  	"encoding/json"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	"net/http"
    10  	"net/url"
    11  	"strings"
    12  	"sync"
    13  
    14  	model "github.com/cloudreve/Cloudreve/v3/models"
    15  	"github.com/cloudreve/Cloudreve/v3/pkg/auth"
    16  	"github.com/cloudreve/Cloudreve/v3/pkg/conf"
    17  	"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
    18  	"github.com/cloudreve/Cloudreve/v3/pkg/util"
    19  )
    20  
    21  // GeneralClient 通用 HTTP Client
    22  var GeneralClient Client = NewClient()
    23  
    24  // Response 请求的响应或错误信息
    25  type Response struct {
    26  	Err      error
    27  	Response *http.Response
    28  }
    29  
    30  // Client 请求客户端
    31  type Client interface {
    32  	Request(method, target string, body io.Reader, opts ...Option) *Response
    33  }
    34  
    35  // HTTPClient 实现 Client 接口
    36  type HTTPClient struct {
    37  	mu         sync.Mutex
    38  	options    *options
    39  	tpsLimiter TPSLimiter
    40  }
    41  
    42  func NewClient(opts ...Option) Client {
    43  	client := &HTTPClient{
    44  		options:    newDefaultOption(),
    45  		tpsLimiter: globalTPSLimiter,
    46  	}
    47  
    48  	for _, o := range opts {
    49  		o.apply(client.options)
    50  	}
    51  
    52  	return client
    53  }
    54  
    55  // Request 发送HTTP请求
    56  func (c *HTTPClient) Request(method, target string, body io.Reader, opts ...Option) *Response {
    57  	// 应用额外设置
    58  	c.mu.Lock()
    59  	options := c.options.clone()
    60  	c.mu.Unlock()
    61  	for _, o := range opts {
    62  		o.apply(&options)
    63  	}
    64  
    65  	// 创建请求客户端
    66  	client := &http.Client{Timeout: options.timeout}
    67  
    68  	// size为0时将body设为nil
    69  	if options.contentLength == 0 {
    70  		body = nil
    71  	}
    72  
    73  	// 确定请求URL
    74  	if options.endpoint != nil {
    75  		targetPath, err := url.Parse(target)
    76  		if err != nil {
    77  			return &Response{Err: err}
    78  		}
    79  
    80  		targetURL := *options.endpoint
    81  		target = targetURL.ResolveReference(targetPath).String()
    82  	}
    83  
    84  	// 创建请求
    85  	var (
    86  		req *http.Request
    87  		err error
    88  	)
    89  	if options.ctx != nil {
    90  		req, err = http.NewRequestWithContext(options.ctx, method, target, body)
    91  	} else {
    92  		req, err = http.NewRequest(method, target, body)
    93  	}
    94  	if err != nil {
    95  		return &Response{Err: err}
    96  	}
    97  
    98  	// 添加请求相关设置
    99  	if options.header != nil {
   100  		for k, v := range options.header {
   101  			req.Header.Add(k, strings.Join(v, " "))
   102  		}
   103  	}
   104  
   105  	if options.masterMeta && conf.SystemConfig.Mode == "master" {
   106  		req.Header.Add(auth.CrHeaderPrefix+"Site-Url", model.GetSiteURL().String())
   107  		req.Header.Add(auth.CrHeaderPrefix+"Site-Id", model.GetSettingByName("siteID"))
   108  		req.Header.Add(auth.CrHeaderPrefix+"Cloudreve-Version", conf.BackendVersion)
   109  	}
   110  
   111  	if options.slaveNodeID != "" && conf.SystemConfig.Mode == "slave" {
   112  		req.Header.Add(auth.CrHeaderPrefix+"Node-Id", options.slaveNodeID)
   113  	}
   114  
   115  	if options.contentLength != -1 {
   116  		req.ContentLength = options.contentLength
   117  	}
   118  
   119  	// 签名请求
   120  	if options.sign != nil {
   121  		switch method {
   122  		case "PUT", "POST", "PATCH":
   123  			auth.SignRequest(options.sign, req, options.signTTL)
   124  		default:
   125  			if resURL, err := auth.SignURI(options.sign, req.URL.String(), options.signTTL); err == nil {
   126  				req.URL = resURL
   127  			}
   128  		}
   129  	}
   130  
   131  	if options.tps > 0 {
   132  		c.tpsLimiter.Limit(options.ctx, options.tpsLimiterToken, options.tps, options.tpsBurst)
   133  	}
   134  
   135  	// 发送请求
   136  	resp, err := client.Do(req)
   137  	if err != nil {
   138  		return &Response{Err: err}
   139  	}
   140  
   141  	return &Response{Err: nil, Response: resp}
   142  }
   143  
   144  // GetResponse 检查响应并获取响应正文
   145  func (resp *Response) GetResponse() (string, error) {
   146  	if resp.Err != nil {
   147  		return "", resp.Err
   148  	}
   149  	respBody, err := ioutil.ReadAll(resp.Response.Body)
   150  	_ = resp.Response.Body.Close()
   151  
   152  	return string(respBody), err
   153  }
   154  
   155  // CheckHTTPResponse 检查请求响应HTTP状态码
   156  func (resp *Response) CheckHTTPResponse(status int) *Response {
   157  	if resp.Err != nil {
   158  		return resp
   159  	}
   160  
   161  	// 检查HTTP状态码
   162  	if resp.Response.StatusCode != status {
   163  		resp.Err = fmt.Errorf("服务器返回非正常HTTP状态%d", resp.Response.StatusCode)
   164  	}
   165  	return resp
   166  }
   167  
   168  // DecodeResponse 尝试解析为serializer.Response,并对状态码进行检查
   169  func (resp *Response) DecodeResponse() (*serializer.Response, error) {
   170  	if resp.Err != nil {
   171  		return nil, resp.Err
   172  	}
   173  
   174  	respString, err := resp.GetResponse()
   175  	if err != nil {
   176  		return nil, err
   177  	}
   178  
   179  	var res serializer.Response
   180  	err = json.Unmarshal([]byte(respString), &res)
   181  	if err != nil {
   182  		util.Log().Debug("Failed to parse response: %s", string(respString))
   183  		return nil, err
   184  	}
   185  	return &res, nil
   186  }
   187  
   188  // NopRSCloser 实现不完整seeker
   189  type NopRSCloser struct {
   190  	body   io.ReadCloser
   191  	status *rscStatus
   192  }
   193  
   194  type rscStatus struct {
   195  	// http.ServeContent 会读取一小块以决定内容类型,
   196  	// 但是响应body无法实现seek,所以此项为真时第一个read会返回假数据
   197  	IgnoreFirst bool
   198  
   199  	Size int64
   200  }
   201  
   202  // GetRSCloser 返回带有空seeker的RSCloser,供http.ServeContent使用
   203  func (resp *Response) GetRSCloser() (*NopRSCloser, error) {
   204  	if resp.Err != nil {
   205  		return nil, resp.Err
   206  	}
   207  
   208  	return &NopRSCloser{
   209  		body: resp.Response.Body,
   210  		status: &rscStatus{
   211  			Size: resp.Response.ContentLength,
   212  		},
   213  	}, resp.Err
   214  }
   215  
   216  // SetFirstFakeChunk 开启第一次read返回空数据
   217  // TODO 测试
   218  func (instance NopRSCloser) SetFirstFakeChunk() {
   219  	instance.status.IgnoreFirst = true
   220  }
   221  
   222  // SetContentLength 设置数据流大小
   223  func (instance NopRSCloser) SetContentLength(size int64) {
   224  	instance.status.Size = size
   225  }
   226  
   227  // Read 实现 NopRSCloser reader
   228  func (instance NopRSCloser) Read(p []byte) (n int, err error) {
   229  	if instance.status.IgnoreFirst && len(p) == 512 {
   230  		return 0, io.EOF
   231  	}
   232  	return instance.body.Read(p)
   233  }
   234  
   235  // Close 实现 NopRSCloser closer
   236  func (instance NopRSCloser) Close() error {
   237  	return instance.body.Close()
   238  }
   239  
   240  // Seek 实现 NopRSCloser seeker, 只实现seek开头/结尾以便http.ServeContent用于确定正文大小
   241  func (instance NopRSCloser) Seek(offset int64, whence int) (int64, error) {
   242  	// 进行第一次Seek操作后,取消忽略选项
   243  	if instance.status.IgnoreFirst {
   244  		instance.status.IgnoreFirst = false
   245  	}
   246  	if offset == 0 {
   247  		switch whence {
   248  		case io.SeekStart:
   249  			return 0, nil
   250  		case io.SeekEnd:
   251  			return instance.status.Size, nil
   252  		}
   253  	}
   254  	return 0, errors.New("not implemented")
   255  
   256  }
   257  
   258  // BlackHole 将客户端发来的数据放入黑洞
   259  func BlackHole(r io.Reader) {
   260  	if !model.IsTrueVal(model.GetSettingByName("reset_after_upload_failed")) {
   261  		io.Copy(ioutil.Discard, r)
   262  	}
   263  }