github.com/cloudreve/Cloudreve/v3@v3.0.0-20240224133659-3edb00a6484c/pkg/filesystem/driver/onedrive/api.go (about)

     1  package onedrive
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"github.com/cloudreve/Cloudreve/v3/pkg/conf"
     8  	"io"
     9  	"net/http"
    10  	"net/url"
    11  	"path"
    12  	"strconv"
    13  	"strings"
    14  	"time"
    15  
    16  	model "github.com/cloudreve/Cloudreve/v3/models"
    17  	"github.com/cloudreve/Cloudreve/v3/pkg/cache"
    18  	"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk"
    19  	"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff"
    20  	"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
    21  	"github.com/cloudreve/Cloudreve/v3/pkg/mq"
    22  	"github.com/cloudreve/Cloudreve/v3/pkg/request"
    23  	"github.com/cloudreve/Cloudreve/v3/pkg/util"
    24  )
    25  
    26  const (
    27  	// SmallFileSize 单文件上传接口最大尺寸
    28  	SmallFileSize uint64 = 4 * 1024 * 1024
    29  	// ChunkSize 服务端中转分片上传分片大小
    30  	ChunkSize uint64 = 10 * 1024 * 1024
    31  	// ListRetry 列取请求重试次数
    32  	ListRetry       = 1
    33  	chunkRetrySleep = time.Second * 5
    34  
    35  	notFoundError = "itemNotFound"
    36  )
    37  
    38  // GetSourcePath 获取文件的绝对路径
    39  func (info *FileInfo) GetSourcePath() string {
    40  	res, err := url.PathUnescape(info.ParentReference.Path)
    41  	if err != nil {
    42  		return ""
    43  	}
    44  
    45  	return strings.TrimPrefix(
    46  		path.Join(
    47  			strings.TrimPrefix(res, "/drive/root:"),
    48  			info.Name,
    49  		),
    50  		"/",
    51  	)
    52  }
    53  
    54  func (client *Client) getRequestURL(api string, opts ...Option) string {
    55  	options := newDefaultOption()
    56  	for _, o := range opts {
    57  		o.apply(options)
    58  	}
    59  
    60  	base, _ := url.Parse(client.Endpoints.EndpointURL)
    61  	if base == nil {
    62  		return ""
    63  	}
    64  
    65  	if options.useDriverResource {
    66  		base.Path = path.Join(base.Path, client.Endpoints.DriverResource, api)
    67  	} else {
    68  		base.Path = path.Join(base.Path, api)
    69  	}
    70  
    71  	return base.String()
    72  }
    73  
    74  // ListChildren 根据路径列取子对象
    75  func (client *Client) ListChildren(ctx context.Context, path string) ([]FileInfo, error) {
    76  	var requestURL string
    77  	dst := strings.TrimPrefix(path, "/")
    78  	if dst == "" {
    79  		requestURL = client.getRequestURL("root/children")
    80  	} else {
    81  		requestURL = client.getRequestURL("root:/" + dst + ":/children")
    82  	}
    83  
    84  	res, err := client.requestWithStr(ctx, "GET", requestURL+"?$top=999999999", "", 200)
    85  	if err != nil {
    86  		retried := 0
    87  		if v, ok := ctx.Value(fsctx.RetryCtx).(int); ok {
    88  			retried = v
    89  		}
    90  		if retried < ListRetry {
    91  			retried++
    92  			util.Log().Debug("Failed to list path %q: %s, will retry in 5 seconds.", path, err)
    93  			time.Sleep(time.Duration(5) * time.Second)
    94  			return client.ListChildren(context.WithValue(ctx, fsctx.RetryCtx, retried), path)
    95  		}
    96  		return nil, err
    97  	}
    98  
    99  	var (
   100  		decodeErr error
   101  		fileInfo  ListResponse
   102  	)
   103  	decodeErr = json.Unmarshal([]byte(res), &fileInfo)
   104  	if decodeErr != nil {
   105  		return nil, decodeErr
   106  	}
   107  
   108  	return fileInfo.Value, nil
   109  }
   110  
   111  // Meta 根据资源ID或文件路径获取文件元信息
   112  func (client *Client) Meta(ctx context.Context, id string, path string) (*FileInfo, error) {
   113  	var requestURL string
   114  	if id != "" {
   115  		requestURL = client.getRequestURL("items/" + id)
   116  	} else {
   117  		dst := strings.TrimPrefix(path, "/")
   118  		requestURL = client.getRequestURL("root:/" + dst)
   119  	}
   120  
   121  	res, err := client.requestWithStr(ctx, "GET", requestURL+"?expand=thumbnails", "", 200)
   122  	if err != nil {
   123  		return nil, err
   124  	}
   125  
   126  	var (
   127  		decodeErr error
   128  		fileInfo  FileInfo
   129  	)
   130  	decodeErr = json.Unmarshal([]byte(res), &fileInfo)
   131  	if decodeErr != nil {
   132  		return nil, decodeErr
   133  	}
   134  
   135  	return &fileInfo, nil
   136  
   137  }
   138  
   139  // CreateUploadSession 创建分片上传会话
   140  func (client *Client) CreateUploadSession(ctx context.Context, dst string, opts ...Option) (string, error) {
   141  	options := newDefaultOption()
   142  	for _, o := range opts {
   143  		o.apply(options)
   144  	}
   145  
   146  	dst = strings.TrimPrefix(dst, "/")
   147  	requestURL := client.getRequestURL("root:/" + dst + ":/createUploadSession")
   148  	body := map[string]map[string]interface{}{
   149  		"item": {
   150  			"@microsoft.graph.conflictBehavior": options.conflictBehavior,
   151  		},
   152  	}
   153  	bodyBytes, _ := json.Marshal(body)
   154  
   155  	res, err := client.requestWithStr(ctx, "POST", requestURL, string(bodyBytes), 200)
   156  	if err != nil {
   157  		return "", err
   158  	}
   159  
   160  	var (
   161  		decodeErr     error
   162  		uploadSession UploadSessionResponse
   163  	)
   164  	decodeErr = json.Unmarshal([]byte(res), &uploadSession)
   165  	if decodeErr != nil {
   166  		return "", decodeErr
   167  	}
   168  
   169  	return uploadSession.UploadURL, nil
   170  }
   171  
   172  // GetSiteIDByURL 通过 SharePoint 站点 URL 获取站点ID
   173  func (client *Client) GetSiteIDByURL(ctx context.Context, siteUrl string) (string, error) {
   174  	siteUrlParsed, err := url.Parse(siteUrl)
   175  	if err != nil {
   176  		return "", err
   177  	}
   178  
   179  	hostName := siteUrlParsed.Hostname()
   180  	relativePath := strings.Trim(siteUrlParsed.Path, "/")
   181  	requestURL := client.getRequestURL(fmt.Sprintf("sites/%s:/%s", hostName, relativePath), WithDriverResource(false))
   182  	res, reqErr := client.requestWithStr(ctx, "GET", requestURL, "", 200)
   183  	if reqErr != nil {
   184  		return "", reqErr
   185  	}
   186  
   187  	var (
   188  		decodeErr error
   189  		siteInfo  Site
   190  	)
   191  	decodeErr = json.Unmarshal([]byte(res), &siteInfo)
   192  	if decodeErr != nil {
   193  		return "", decodeErr
   194  	}
   195  
   196  	return siteInfo.ID, nil
   197  }
   198  
   199  // GetUploadSessionStatus 查询上传会话状态
   200  func (client *Client) GetUploadSessionStatus(ctx context.Context, uploadURL string) (*UploadSessionResponse, error) {
   201  	res, err := client.requestWithStr(ctx, "GET", uploadURL, "", 200)
   202  	if err != nil {
   203  		return nil, err
   204  	}
   205  
   206  	var (
   207  		decodeErr     error
   208  		uploadSession UploadSessionResponse
   209  	)
   210  	decodeErr = json.Unmarshal([]byte(res), &uploadSession)
   211  	if decodeErr != nil {
   212  		return nil, decodeErr
   213  	}
   214  
   215  	return &uploadSession, nil
   216  }
   217  
   218  // UploadChunk 上传分片
   219  func (client *Client) UploadChunk(ctx context.Context, uploadURL string, content io.Reader, current *chunk.ChunkGroup) (*UploadSessionResponse, error) {
   220  	res, err := client.request(
   221  		ctx, "PUT", uploadURL, content,
   222  		request.WithContentLength(current.Length()),
   223  		request.WithHeader(http.Header{
   224  			"Content-Range": {current.RangeHeader()},
   225  		}),
   226  		request.WithoutHeader([]string{"Authorization", "Content-Type"}),
   227  		request.WithTimeout(0),
   228  	)
   229  	if err != nil {
   230  		return nil, fmt.Errorf("failed to upload OneDrive chunk #%d: %w", current.Index(), err)
   231  	}
   232  
   233  	if current.IsLast() {
   234  		return nil, nil
   235  	}
   236  
   237  	var (
   238  		decodeErr error
   239  		uploadRes UploadSessionResponse
   240  	)
   241  	decodeErr = json.Unmarshal([]byte(res), &uploadRes)
   242  	if decodeErr != nil {
   243  		return nil, decodeErr
   244  	}
   245  
   246  	return &uploadRes, nil
   247  }
   248  
   249  // Upload 上传文件
   250  func (client *Client) Upload(ctx context.Context, file fsctx.FileHeader) error {
   251  	fileInfo := file.Info()
   252  	// 决定是否覆盖文件
   253  	overwrite := "fail"
   254  	if fileInfo.Mode&fsctx.Overwrite == fsctx.Overwrite {
   255  		overwrite = "replace"
   256  	}
   257  
   258  	size := int(fileInfo.Size)
   259  	dst := fileInfo.SavePath
   260  
   261  	// 小文件,使用简单上传接口上传
   262  	if size <= int(SmallFileSize) {
   263  		_, err := client.SimpleUpload(ctx, dst, file, int64(size), WithConflictBehavior(overwrite))
   264  		return err
   265  	}
   266  
   267  	// 大文件,进行分片
   268  	// 创建上传会话
   269  	uploadURL, err := client.CreateUploadSession(ctx, dst, WithConflictBehavior(overwrite))
   270  	if err != nil {
   271  		return err
   272  	}
   273  
   274  	// Initial chunk groups
   275  	chunks := chunk.NewChunkGroup(file, client.Policy.OptionsSerialized.ChunkSize, &backoff.ConstantBackoff{
   276  		Max:   model.GetIntSetting("chunk_retries", 5),
   277  		Sleep: chunkRetrySleep,
   278  	}, model.IsTrueVal(model.GetSettingByName("use_temp_chunk_buffer")))
   279  
   280  	uploadFunc := func(current *chunk.ChunkGroup, content io.Reader) error {
   281  		_, err := client.UploadChunk(ctx, uploadURL, content, current)
   282  		return err
   283  	}
   284  
   285  	// upload chunks
   286  	for chunks.Next() {
   287  		if err := chunks.Process(uploadFunc); err != nil {
   288  			return fmt.Errorf("failed to upload chunk #%d: %w", chunks.Index(), err)
   289  		}
   290  	}
   291  
   292  	return nil
   293  }
   294  
   295  // DeleteUploadSession 删除上传会话
   296  func (client *Client) DeleteUploadSession(ctx context.Context, uploadURL string) error {
   297  	_, err := client.requestWithStr(ctx, "DELETE", uploadURL, "", 204)
   298  	if err != nil {
   299  		return err
   300  	}
   301  
   302  	return nil
   303  }
   304  
   305  // SimpleUpload 上传小文件到dst
   306  func (client *Client) SimpleUpload(ctx context.Context, dst string, body io.Reader, size int64, opts ...Option) (*UploadResult, error) {
   307  	options := newDefaultOption()
   308  	for _, o := range opts {
   309  		o.apply(options)
   310  	}
   311  
   312  	dst = strings.TrimPrefix(dst, "/")
   313  	requestURL := client.getRequestURL("root:/" + dst + ":/content")
   314  	requestURL += ("?@microsoft.graph.conflictBehavior=" + options.conflictBehavior)
   315  
   316  	res, err := client.request(ctx, "PUT", requestURL, body, request.WithContentLength(int64(size)),
   317  		request.WithTimeout(0),
   318  	)
   319  	if err != nil {
   320  		return nil, err
   321  	}
   322  
   323  	var (
   324  		decodeErr error
   325  		uploadRes UploadResult
   326  	)
   327  	decodeErr = json.Unmarshal([]byte(res), &uploadRes)
   328  	if decodeErr != nil {
   329  		return nil, decodeErr
   330  	}
   331  
   332  	return &uploadRes, nil
   333  }
   334  
   335  // BatchDelete 并行删除给出的文件,返回删除失败的文件,及第一个遇到的错误。此方法将文件分为
   336  // 20个一组,调用Delete并行删除
   337  // TODO 测试
   338  func (client *Client) BatchDelete(ctx context.Context, dst []string) ([]string, error) {
   339  	groupNum := len(dst)/20 + 1
   340  	finalRes := make([]string, 0, len(dst))
   341  	res := make([]string, 0, 20)
   342  	var err error
   343  
   344  	for i := 0; i < groupNum; i++ {
   345  		end := 20*i + 20
   346  		if i == groupNum-1 {
   347  			end = len(dst)
   348  		}
   349  		res, err = client.Delete(ctx, dst[20*i:end])
   350  		finalRes = append(finalRes, res...)
   351  	}
   352  
   353  	return finalRes, err
   354  }
   355  
   356  // Delete 并行删除文件,返回删除失败的文件,及第一个遇到的错误,
   357  // 由于API限制,最多删除20个
   358  func (client *Client) Delete(ctx context.Context, dst []string) ([]string, error) {
   359  	body := client.makeBatchDeleteRequestsBody(dst)
   360  	res, err := client.requestWithStr(ctx, "POST", client.getRequestURL("$batch",
   361  		WithDriverResource(false)), body, 200)
   362  	if err != nil {
   363  		return dst, err
   364  	}
   365  
   366  	var (
   367  		decodeErr error
   368  		deleteRes BatchResponses
   369  	)
   370  	decodeErr = json.Unmarshal([]byte(res), &deleteRes)
   371  	if decodeErr != nil {
   372  		return dst, decodeErr
   373  	}
   374  
   375  	// 取得删除失败的文件
   376  	failed := getDeleteFailed(&deleteRes)
   377  	if len(failed) != 0 {
   378  		return failed, ErrDeleteFile
   379  	}
   380  	return failed, nil
   381  }
   382  
   383  func getDeleteFailed(res *BatchResponses) []string {
   384  	var failed = make([]string, 0, len(res.Responses))
   385  	for _, v := range res.Responses {
   386  		if v.Status != 204 && v.Status != 404 {
   387  			failed = append(failed, v.ID)
   388  		}
   389  	}
   390  	return failed
   391  }
   392  
   393  // makeBatchDeleteRequestsBody 生成批量删除请求正文
   394  func (client *Client) makeBatchDeleteRequestsBody(files []string) string {
   395  	req := BatchRequests{
   396  		Requests: make([]BatchRequest, len(files)),
   397  	}
   398  	for i, v := range files {
   399  		v = strings.TrimPrefix(v, "/")
   400  		filePath, _ := url.Parse("/" + client.Endpoints.DriverResource + "/root:/")
   401  		filePath.Path = path.Join(filePath.Path, v)
   402  		req.Requests[i] = BatchRequest{
   403  			ID:     v,
   404  			Method: "DELETE",
   405  			URL:    filePath.EscapedPath(),
   406  		}
   407  	}
   408  
   409  	res, _ := json.Marshal(req)
   410  	return string(res)
   411  }
   412  
   413  // GetThumbURL 获取给定尺寸的缩略图URL
   414  func (client *Client) GetThumbURL(ctx context.Context, dst string, w, h uint) (string, error) {
   415  	dst = strings.TrimPrefix(dst, "/")
   416  	requestURL := client.getRequestURL("root:/"+dst+":/thumbnails/0") + "/large"
   417  
   418  	res, err := client.requestWithStr(ctx, "GET", requestURL, "", 200)
   419  	if err != nil {
   420  		return "", err
   421  	}
   422  
   423  	var (
   424  		decodeErr error
   425  		thumbRes  ThumbResponse
   426  	)
   427  	decodeErr = json.Unmarshal([]byte(res), &thumbRes)
   428  	if decodeErr != nil {
   429  		return "", decodeErr
   430  	}
   431  
   432  	if thumbRes.URL != "" {
   433  		return thumbRes.URL, nil
   434  	}
   435  
   436  	if len(thumbRes.Value) == 1 {
   437  		if res, ok := thumbRes.Value[0]["large"]; ok {
   438  			return res.(map[string]interface{})["url"].(string), nil
   439  		}
   440  	}
   441  
   442  	return "", ErrThumbSizeNotFound
   443  }
   444  
   445  // MonitorUpload 监控客户端分片上传进度
   446  func (client *Client) MonitorUpload(uploadURL, callbackKey, path string, size uint64, ttl int64) {
   447  	// 回调完成通知chan
   448  	callbackChan := mq.GlobalMQ.Subscribe(callbackKey, 1)
   449  	defer mq.GlobalMQ.Unsubscribe(callbackKey, callbackChan)
   450  
   451  	timeout := model.GetIntSetting("onedrive_monitor_timeout", 600)
   452  	interval := model.GetIntSetting("onedrive_callback_check", 20)
   453  
   454  	for {
   455  		select {
   456  		case <-callbackChan:
   457  			util.Log().Debug("Client finished OneDrive callback.")
   458  			return
   459  		case <-time.After(time.Duration(ttl) * time.Second):
   460  			// 上传会话到期,仍未完成上传,创建占位符
   461  			client.DeleteUploadSession(context.Background(), uploadURL)
   462  			_, err := client.SimpleUpload(context.Background(), path, strings.NewReader(""), 0, WithConflictBehavior("replace"))
   463  			if err != nil {
   464  				util.Log().Debug("Failed to create placeholder file: %s", err)
   465  			}
   466  			return
   467  		case <-time.After(time.Duration(timeout) * time.Second):
   468  			util.Log().Debug("Checking OneDrive upload status.")
   469  			status, err := client.GetUploadSessionStatus(context.Background(), uploadURL)
   470  
   471  			if err != nil {
   472  				if resErr, ok := err.(*RespError); ok {
   473  					if resErr.APIError.Code == notFoundError {
   474  						util.Log().Debug("Upload completed, will check upload callback later.")
   475  						select {
   476  						case <-time.After(time.Duration(interval) * time.Second):
   477  							util.Log().Warning("No callback is made, file will be deleted.")
   478  							cache.Deletes([]string{callbackKey}, "callback_")
   479  							_, err = client.Delete(context.Background(), []string{path})
   480  							if err != nil {
   481  								util.Log().Warning("Failed to delete file without callback: %s", err)
   482  							}
   483  						case <-callbackChan:
   484  							util.Log().Debug("Client finished callback.")
   485  						}
   486  						return
   487  					}
   488  				}
   489  				util.Log().Debug("Failed to get upload session status: %s, continue next iteration.", err.Error())
   490  				continue
   491  			}
   492  
   493  			// 成功获取分片上传状态,检查文件大小
   494  			if len(status.NextExpectedRanges) == 0 {
   495  				continue
   496  			}
   497  			sizeRange := strings.Split(
   498  				status.NextExpectedRanges[len(status.NextExpectedRanges)-1],
   499  				"-",
   500  			)
   501  			if len(sizeRange) != 2 {
   502  				continue
   503  			}
   504  			uploadFullSize, _ := strconv.ParseUint(sizeRange[1], 10, 64)
   505  			if (sizeRange[0] == "0" && sizeRange[1] == "") || uploadFullSize+1 != size {
   506  				util.Log().Debug("Upload has not started, or uploaded file size not match, canceling upload session...")
   507  				// 取消上传会话,实测OneDrive取消上传会话后,客户端还是可以上传,
   508  				// 所以上传一个空文件占位,阻止客户端上传
   509  				client.DeleteUploadSession(context.Background(), uploadURL)
   510  				_, err := client.SimpleUpload(context.Background(), path, strings.NewReader(""), 0, WithConflictBehavior("replace"))
   511  				if err != nil {
   512  					util.Log().Debug("无法创建占位文件,%s", err)
   513  				}
   514  				return
   515  			}
   516  
   517  		}
   518  	}
   519  }
   520  
   521  func sysError(err error) *RespError {
   522  	return &RespError{APIError: APIError{
   523  		Code:    "system",
   524  		Message: err.Error(),
   525  	}}
   526  }
   527  
   528  func (client *Client) request(ctx context.Context, method string, url string, body io.Reader, option ...request.Option) (string, error) {
   529  	// 获取凭证
   530  	err := client.UpdateCredential(ctx, conf.SystemConfig.Mode == "slave")
   531  	if err != nil {
   532  		return "", sysError(err)
   533  	}
   534  
   535  	option = append(option,
   536  		request.WithHeader(http.Header{
   537  			"Authorization": {"Bearer " + client.Credential.AccessToken},
   538  			"Content-Type":  {"application/json"},
   539  		}),
   540  		request.WithContext(ctx),
   541  		request.WithTPSLimit(
   542  			fmt.Sprintf("policy_%d", client.Policy.ID),
   543  			client.Policy.OptionsSerialized.TPSLimit,
   544  			client.Policy.OptionsSerialized.TPSLimitBurst,
   545  		),
   546  	)
   547  
   548  	// 发送请求
   549  	res := client.Request.Request(
   550  		method,
   551  		url,
   552  		body,
   553  		option...,
   554  	)
   555  
   556  	if res.Err != nil {
   557  		return "", sysError(res.Err)
   558  	}
   559  
   560  	respBody, err := res.GetResponse()
   561  	if err != nil {
   562  		return "", sysError(err)
   563  	}
   564  
   565  	// 解析请求响应
   566  	var (
   567  		errResp   RespError
   568  		decodeErr error
   569  	)
   570  	// 如果有错误
   571  	if res.Response.StatusCode < 200 || res.Response.StatusCode >= 300 {
   572  		decodeErr = json.Unmarshal([]byte(respBody), &errResp)
   573  		if decodeErr != nil {
   574  			util.Log().Debug("Onedrive returns unknown response: %s", respBody)
   575  			return "", sysError(decodeErr)
   576  		}
   577  
   578  		if res.Response.StatusCode == 429 {
   579  			util.Log().Warning("OneDrive request is throttled.")
   580  			return "", backoff.NewRetryableErrorFromHeader(&errResp, res.Response.Header)
   581  		}
   582  
   583  		return "", &errResp
   584  	}
   585  
   586  	return respBody, nil
   587  }
   588  
   589  func (client *Client) requestWithStr(ctx context.Context, method string, url string, body string, expectedCode int) (string, error) {
   590  	// 发送请求
   591  	bodyReader := io.NopCloser(strings.NewReader(body))
   592  	return client.request(ctx, method, url, bodyReader,
   593  		request.WithContentLength(int64(len(body))),
   594  	)
   595  }