github.com/cloudreve/Cloudreve/v3@v3.0.0-20240224133659-3edb00a6484c/pkg/filesystem/driver/s3/handler.go (about)

     1  package s3
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
     8  	"io"
     9  	"net/http"
    10  	"net/url"
    11  	"path"
    12  	"path/filepath"
    13  	"strings"
    14  	"time"
    15  
    16  	"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk"
    17  	"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff"
    18  	"github.com/cloudreve/Cloudreve/v3/pkg/util"
    19  
    20  	"github.com/aws/aws-sdk-go/aws"
    21  	"github.com/aws/aws-sdk-go/aws/credentials"
    22  	"github.com/aws/aws-sdk-go/aws/session"
    23  	"github.com/aws/aws-sdk-go/service/s3"
    24  	"github.com/aws/aws-sdk-go/service/s3/s3manager"
    25  	model "github.com/cloudreve/Cloudreve/v3/models"
    26  	"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
    27  	"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
    28  	"github.com/cloudreve/Cloudreve/v3/pkg/request"
    29  	"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
    30  )
    31  
    32  // Driver 适配器模板
    33  type Driver struct {
    34  	Policy *model.Policy
    35  	sess   *session.Session
    36  	svc    *s3.S3
    37  }
    38  
    39  // UploadPolicy S3上传策略
    40  type UploadPolicy struct {
    41  	Expiration string        `json:"expiration"`
    42  	Conditions []interface{} `json:"conditions"`
    43  }
    44  
    45  // MetaData 文件信息
    46  type MetaData struct {
    47  	Size uint64
    48  	Etag string
    49  }
    50  
    51  func NewDriver(policy *model.Policy) (*Driver, error) {
    52  	if policy.OptionsSerialized.ChunkSize == 0 {
    53  		policy.OptionsSerialized.ChunkSize = 25 << 20 // 25 MB
    54  	}
    55  
    56  	driver := &Driver{
    57  		Policy: policy,
    58  	}
    59  
    60  	return driver, driver.InitS3Client()
    61  }
    62  
    63  // InitS3Client 初始化S3会话
    64  func (handler *Driver) InitS3Client() error {
    65  	if handler.Policy == nil {
    66  		return errors.New("empty policy")
    67  	}
    68  
    69  	if handler.svc == nil {
    70  		// 初始化会话
    71  		sess, err := session.NewSession(&aws.Config{
    72  			Credentials:      credentials.NewStaticCredentials(handler.Policy.AccessKey, handler.Policy.SecretKey, ""),
    73  			Endpoint:         &handler.Policy.Server,
    74  			Region:           &handler.Policy.OptionsSerialized.Region,
    75  			S3ForcePathStyle: &handler.Policy.OptionsSerialized.S3ForcePathStyle,
    76  		})
    77  
    78  		if err != nil {
    79  			return err
    80  		}
    81  		handler.sess = sess
    82  		handler.svc = s3.New(sess)
    83  	}
    84  	return nil
    85  }
    86  
    87  // List 列出给定路径下的文件
    88  func (handler *Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) {
    89  	// 初始化列目录参数
    90  	base = strings.TrimPrefix(base, "/")
    91  	if base != "" {
    92  		base += "/"
    93  	}
    94  
    95  	opt := &s3.ListObjectsInput{
    96  		Bucket:  &handler.Policy.BucketName,
    97  		Prefix:  &base,
    98  		MaxKeys: aws.Int64(1000),
    99  	}
   100  
   101  	// 是否为递归列出
   102  	if !recursive {
   103  		opt.Delimiter = aws.String("/")
   104  	}
   105  
   106  	var (
   107  		objects []*s3.Object
   108  		commons []*s3.CommonPrefix
   109  	)
   110  
   111  	for {
   112  		res, err := handler.svc.ListObjectsWithContext(ctx, opt)
   113  		if err != nil {
   114  			return nil, err
   115  		}
   116  		objects = append(objects, res.Contents...)
   117  		commons = append(commons, res.CommonPrefixes...)
   118  
   119  		// 如果本次未列取完,则继续使用marker获取结果
   120  		if *res.IsTruncated {
   121  			opt.Marker = res.NextMarker
   122  		} else {
   123  			break
   124  		}
   125  	}
   126  
   127  	// 处理列取结果
   128  	res := make([]response.Object, 0, len(objects)+len(commons))
   129  
   130  	// 处理目录
   131  	for _, object := range commons {
   132  		rel, err := filepath.Rel(*opt.Prefix, *object.Prefix)
   133  		if err != nil {
   134  			continue
   135  		}
   136  		res = append(res, response.Object{
   137  			Name:         path.Base(*object.Prefix),
   138  			RelativePath: filepath.ToSlash(rel),
   139  			Size:         0,
   140  			IsDir:        true,
   141  			LastModify:   time.Now(),
   142  		})
   143  	}
   144  	// 处理文件
   145  	for _, object := range objects {
   146  		rel, err := filepath.Rel(*opt.Prefix, *object.Key)
   147  		if err != nil {
   148  			continue
   149  		}
   150  		res = append(res, response.Object{
   151  			Name:         path.Base(*object.Key),
   152  			Source:       *object.Key,
   153  			RelativePath: filepath.ToSlash(rel),
   154  			Size:         uint64(*object.Size),
   155  			IsDir:        false,
   156  			LastModify:   time.Now(),
   157  		})
   158  	}
   159  
   160  	return res, nil
   161  
   162  }
   163  
   164  // Get 获取文件
   165  func (handler *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
   166  	// 获取文件源地址
   167  	downloadURL, err := handler.Source(ctx, path, int64(model.GetIntSetting("preview_timeout", 60)), false, 0)
   168  	if err != nil {
   169  		return nil, err
   170  	}
   171  
   172  	// 获取文件数据流
   173  	client := request.NewClient()
   174  	resp, err := client.Request(
   175  		"GET",
   176  		downloadURL,
   177  		nil,
   178  		request.WithContext(ctx),
   179  		request.WithHeader(
   180  			http.Header{"Cache-Control": {"no-cache", "no-store", "must-revalidate"}},
   181  		),
   182  		request.WithTimeout(time.Duration(0)),
   183  	).CheckHTTPResponse(200).GetRSCloser()
   184  	if err != nil {
   185  		return nil, err
   186  	}
   187  
   188  	resp.SetFirstFakeChunk()
   189  
   190  	// 尝试自主获取文件大小
   191  	if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
   192  		resp.SetContentLength(int64(file.Size))
   193  	}
   194  
   195  	return resp, nil
   196  }
   197  
   198  // Put 将文件流保存到指定目录
   199  func (handler *Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
   200  	defer file.Close()
   201  
   202  	// 初始化客户端
   203  	if err := handler.InitS3Client(); err != nil {
   204  		return err
   205  	}
   206  
   207  	uploader := s3manager.NewUploader(handler.sess, func(u *s3manager.Uploader) {
   208  		u.PartSize = int64(handler.Policy.OptionsSerialized.ChunkSize)
   209  	})
   210  
   211  	dst := file.Info().SavePath
   212  	_, err := uploader.Upload(&s3manager.UploadInput{
   213  		Bucket: &handler.Policy.BucketName,
   214  		Key:    &dst,
   215  		Body:   io.LimitReader(file, int64(file.Info().Size)),
   216  	})
   217  
   218  	if err != nil {
   219  		return err
   220  	}
   221  
   222  	return nil
   223  }
   224  
   225  // Delete 删除一个或多个文件,
   226  // 返回未删除的文件,及遇到的最后一个错误
   227  func (handler *Driver) Delete(ctx context.Context, files []string) ([]string, error) {
   228  	failed := make([]string, 0, len(files))
   229  	deleted := make([]string, 0, len(files))
   230  
   231  	keys := make([]*s3.ObjectIdentifier, 0, len(files))
   232  	for _, file := range files {
   233  		filePath := file
   234  		keys = append(keys, &s3.ObjectIdentifier{Key: &filePath})
   235  	}
   236  
   237  	// 发送异步删除请求
   238  	res, err := handler.svc.DeleteObjects(
   239  		&s3.DeleteObjectsInput{
   240  			Bucket: &handler.Policy.BucketName,
   241  			Delete: &s3.Delete{
   242  				Objects: keys,
   243  			},
   244  		})
   245  
   246  	if err != nil {
   247  		return files, err
   248  	}
   249  
   250  	// 统计未删除的文件
   251  	for _, deleteRes := range res.Deleted {
   252  		deleted = append(deleted, *deleteRes.Key)
   253  	}
   254  	failed = util.SliceDifference(files, deleted)
   255  
   256  	return failed, nil
   257  
   258  }
   259  
   260  // Thumb 获取文件缩略图
   261  func (handler *Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) {
   262  	return nil, driver.ErrorThumbNotSupported
   263  }
   264  
   265  // Source 获取外链URL
   266  func (handler *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
   267  
   268  	// 尝试从上下文获取文件名
   269  	fileName := ""
   270  	if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
   271  		fileName = file.Name
   272  	}
   273  
   274  	// 初始化客户端
   275  	if err := handler.InitS3Client(); err != nil {
   276  		return "", err
   277  	}
   278  
   279  	contentDescription := aws.String("attachment; filename=\"" + url.PathEscape(fileName) + "\"")
   280  	if !isDownload {
   281  		contentDescription = nil
   282  	}
   283  	req, _ := handler.svc.GetObjectRequest(
   284  		&s3.GetObjectInput{
   285  			Bucket:                     &handler.Policy.BucketName,
   286  			Key:                        &path,
   287  			ResponseContentDisposition: contentDescription,
   288  		})
   289  
   290  	signedURL, err := req.Presign(time.Duration(ttl) * time.Second)
   291  	if err != nil {
   292  		return "", err
   293  	}
   294  
   295  	// 将最终生成的签名URL域名换成用户自定义的加速域名(如果有)
   296  	finalURL, err := url.Parse(signedURL)
   297  	if err != nil {
   298  		return "", err
   299  	}
   300  
   301  	// 公有空间替换掉Key及不支持的头
   302  	if !handler.Policy.IsPrivate {
   303  		finalURL.RawQuery = ""
   304  	}
   305  
   306  	if handler.Policy.BaseURL != "" {
   307  		cdnURL, err := url.Parse(handler.Policy.BaseURL)
   308  		if err != nil {
   309  			return "", err
   310  		}
   311  		finalURL.Host = cdnURL.Host
   312  		finalURL.Scheme = cdnURL.Scheme
   313  	}
   314  
   315  	return finalURL.String(), nil
   316  }
   317  
   318  // Token 获取上传策略和认证Token
   319  func (handler *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
   320  	// 检查文件是否存在
   321  	fileInfo := file.Info()
   322  	if _, err := handler.Meta(ctx, fileInfo.SavePath); err == nil {
   323  		return nil, fmt.Errorf("file already exist")
   324  	}
   325  
   326  	// 创建分片上传
   327  	expires := time.Now().Add(time.Duration(ttl) * time.Second)
   328  	res, err := handler.svc.CreateMultipartUpload(&s3.CreateMultipartUploadInput{
   329  		Bucket:      &handler.Policy.BucketName,
   330  		Key:         &fileInfo.SavePath,
   331  		Expires:     &expires,
   332  		ContentType: aws.String(fileInfo.DetectMimeType()),
   333  	})
   334  	if err != nil {
   335  		return nil, fmt.Errorf("failed to create multipart upload: %w", err)
   336  	}
   337  
   338  	uploadSession.UploadID = *res.UploadId
   339  
   340  	// 为每个分片签名上传 URL
   341  	chunks := chunk.NewChunkGroup(file, handler.Policy.OptionsSerialized.ChunkSize, &backoff.ConstantBackoff{}, false)
   342  	urls := make([]string, chunks.Num())
   343  	for chunks.Next() {
   344  		err := chunks.Process(func(c *chunk.ChunkGroup, chunk io.Reader) error {
   345  			signedReq, _ := handler.svc.UploadPartRequest(&s3.UploadPartInput{
   346  				Bucket:     &handler.Policy.BucketName,
   347  				Key:        &fileInfo.SavePath,
   348  				PartNumber: aws.Int64(int64(c.Index() + 1)),
   349  				UploadId:   res.UploadId,
   350  			})
   351  
   352  			signedURL, err := signedReq.Presign(time.Duration(ttl) * time.Second)
   353  			if err != nil {
   354  				return err
   355  			}
   356  
   357  			urls[c.Index()] = signedURL
   358  			return nil
   359  		})
   360  		if err != nil {
   361  			return nil, err
   362  		}
   363  	}
   364  
   365  	// 签名完成分片上传的请求URL
   366  	signedReq, _ := handler.svc.CompleteMultipartUploadRequest(&s3.CompleteMultipartUploadInput{
   367  		Bucket:   &handler.Policy.BucketName,
   368  		Key:      &fileInfo.SavePath,
   369  		UploadId: res.UploadId,
   370  	})
   371  
   372  	signedURL, err := signedReq.Presign(time.Duration(ttl) * time.Second)
   373  	if err != nil {
   374  		return nil, err
   375  	}
   376  
   377  	// 生成上传凭证
   378  	return &serializer.UploadCredential{
   379  		SessionID:   uploadSession.Key,
   380  		ChunkSize:   handler.Policy.OptionsSerialized.ChunkSize,
   381  		UploadID:    *res.UploadId,
   382  		UploadURLs:  urls,
   383  		CompleteURL: signedURL,
   384  	}, nil
   385  }
   386  
   387  // Meta 获取文件信息
   388  func (handler *Driver) Meta(ctx context.Context, path string) (*MetaData, error) {
   389  	res, err := handler.svc.HeadObject(
   390  		&s3.HeadObjectInput{
   391  			Bucket: &handler.Policy.BucketName,
   392  			Key:    &path,
   393  		})
   394  
   395  	if err != nil {
   396  		return nil, err
   397  	}
   398  
   399  	return &MetaData{
   400  		Size: uint64(*res.ContentLength),
   401  		Etag: *res.ETag,
   402  	}, nil
   403  
   404  }
   405  
   406  // CORS 创建跨域策略
   407  func (handler *Driver) CORS() error {
   408  	rule := s3.CORSRule{
   409  		AllowedMethods: aws.StringSlice([]string{
   410  			"GET",
   411  			"POST",
   412  			"PUT",
   413  			"DELETE",
   414  			"HEAD",
   415  		}),
   416  		AllowedOrigins: aws.StringSlice([]string{"*"}),
   417  		AllowedHeaders: aws.StringSlice([]string{"*"}),
   418  		ExposeHeaders:  aws.StringSlice([]string{"ETag"}),
   419  		MaxAgeSeconds:  aws.Int64(3600),
   420  	}
   421  
   422  	_, err := handler.svc.PutBucketCors(&s3.PutBucketCorsInput{
   423  		Bucket: &handler.Policy.BucketName,
   424  		CORSConfiguration: &s3.CORSConfiguration{
   425  			CORSRules: []*s3.CORSRule{&rule},
   426  		},
   427  	})
   428  
   429  	return err
   430  }
   431  
   432  // 取消上传凭证
   433  func (handler *Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
   434  	_, err := handler.svc.AbortMultipartUpload(&s3.AbortMultipartUploadInput{
   435  		UploadId: &uploadSession.UploadID,
   436  		Bucket:   &handler.Policy.BucketName,
   437  		Key:      &uploadSession.SavePath,
   438  	})
   439  	return err
   440  }