github.com/cloudreve/Cloudreve/v3@v3.0.0-20240224133659-3edb00a6484c/pkg/filesystem/archive.go (about)

     1  package filesystem
     2  
     3  import (
     4  	"archive/zip"
     5  	"context"
     6  	"fmt"
     7  	"io"
     8  	"os"
     9  	"path"
    10  	"path/filepath"
    11  	"strings"
    12  	"sync"
    13  	"time"
    14  
    15  	model "github.com/cloudreve/Cloudreve/v3/models"
    16  	"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
    17  	"github.com/cloudreve/Cloudreve/v3/pkg/util"
    18  	"github.com/gin-gonic/gin"
    19  	"github.com/mholt/archiver/v4"
    20  )
    21  
    22  /* ===============
    23       压缩/解压缩
    24     ===============
    25  */
    26  
    27  // Compress 创建给定目录和文件的压缩文件
    28  func (fs *FileSystem) Compress(ctx context.Context, writer io.Writer, folderIDs, fileIDs []uint, isArchive bool) error {
    29  	// 查找待压缩目录
    30  	folders, err := model.GetFoldersByIDs(folderIDs, fs.User.ID)
    31  	if err != nil && len(folderIDs) != 0 {
    32  		return ErrDBListObjects
    33  	}
    34  
    35  	// 查找待压缩文件
    36  	files, err := model.GetFilesByIDs(fileIDs, fs.User.ID)
    37  	if err != nil && len(fileIDs) != 0 {
    38  		return ErrDBListObjects
    39  	}
    40  
    41  	// 如果上下文限制了父目录,则进行检查
    42  	if parent, ok := ctx.Value(fsctx.LimitParentCtx).(*model.Folder); ok {
    43  		// 检查目录
    44  		for _, folder := range folders {
    45  			if *folder.ParentID != parent.ID {
    46  				return ErrObjectNotExist
    47  			}
    48  		}
    49  
    50  		// 检查文件
    51  		for _, file := range files {
    52  			if file.FolderID != parent.ID {
    53  				return ErrObjectNotExist
    54  			}
    55  		}
    56  	}
    57  
    58  	// 尝试获取请求上下文,以便于后续检查用户取消任务
    59  	reqContext := ctx
    60  	ginCtx, ok := ctx.Value(fsctx.GinCtx).(*gin.Context)
    61  	if ok {
    62  		reqContext = ginCtx.Request.Context()
    63  	}
    64  
    65  	// 将顶级待处理对象的路径设为根路径
    66  	for i := 0; i < len(folders); i++ {
    67  		folders[i].Position = ""
    68  	}
    69  	for i := 0; i < len(files); i++ {
    70  		files[i].Position = ""
    71  	}
    72  
    73  	// 创建压缩文件Writer
    74  	zipWriter := zip.NewWriter(writer)
    75  	defer zipWriter.Close()
    76  
    77  	ctx = reqContext
    78  
    79  	// 压缩各个目录及文件
    80  	for i := 0; i < len(folders); i++ {
    81  		select {
    82  		case <-reqContext.Done():
    83  			// 取消压缩请求
    84  			return ErrClientCanceled
    85  		default:
    86  			fs.doCompress(reqContext, nil, &folders[i], zipWriter, isArchive)
    87  		}
    88  
    89  	}
    90  	for i := 0; i < len(files); i++ {
    91  		select {
    92  		case <-reqContext.Done():
    93  			// 取消压缩请求
    94  			return ErrClientCanceled
    95  		default:
    96  			fs.doCompress(reqContext, &files[i], nil, zipWriter, isArchive)
    97  		}
    98  	}
    99  
   100  	return nil
   101  }
   102  
   103  func (fs *FileSystem) doCompress(ctx context.Context, file *model.File, folder *model.Folder, zipWriter *zip.Writer, isArchive bool) {
   104  	// 如果对象是文件
   105  	if file != nil {
   106  		// 切换上传策略
   107  		fs.Policy = file.GetPolicy()
   108  		err := fs.DispatchHandler()
   109  		if err != nil {
   110  			util.Log().Warning("Failed to compress file %q: %s", file.Name, err)
   111  			return
   112  		}
   113  
   114  		// 获取文件内容
   115  		fileToZip, err := fs.Handler.Get(
   116  			context.WithValue(ctx, fsctx.FileModelCtx, *file),
   117  			file.SourceName,
   118  		)
   119  		if err != nil {
   120  			util.Log().Debug("Failed to open %q: %s", file.Name, err)
   121  			return
   122  		}
   123  		if closer, ok := fileToZip.(io.Closer); ok {
   124  			defer closer.Close()
   125  		}
   126  
   127  		// 创建压缩文件头
   128  		header := &zip.FileHeader{
   129  			Name:               filepath.FromSlash(path.Join(file.Position, file.Name)),
   130  			Modified:           file.UpdatedAt,
   131  			UncompressedSize64: file.Size,
   132  		}
   133  
   134  		// 指定是压缩还是归档
   135  		if isArchive {
   136  			header.Method = zip.Store
   137  		} else {
   138  			header.Method = zip.Deflate
   139  		}
   140  
   141  		writer, err := zipWriter.CreateHeader(header)
   142  		if err != nil {
   143  			return
   144  		}
   145  
   146  		_, err = io.Copy(writer, fileToZip)
   147  	} else if folder != nil {
   148  		// 对象是目录
   149  		// 获取子文件
   150  		subFiles, err := folder.GetChildFiles()
   151  		if err == nil && len(subFiles) > 0 {
   152  			for i := 0; i < len(subFiles); i++ {
   153  				fs.doCompress(ctx, &subFiles[i], nil, zipWriter, isArchive)
   154  			}
   155  
   156  		}
   157  		// 获取子目录,继续递归遍历
   158  		subFolders, err := folder.GetChildFolder()
   159  		if err == nil && len(subFolders) > 0 {
   160  			for i := 0; i < len(subFolders); i++ {
   161  				fs.doCompress(ctx, nil, &subFolders[i], zipWriter, isArchive)
   162  			}
   163  		}
   164  	}
   165  }
   166  
   167  // Decompress 解压缩给定压缩文件到dst目录
   168  func (fs *FileSystem) Decompress(ctx context.Context, src, dst, encoding string) error {
   169  	err := fs.ResetFileIfNotExist(ctx, src)
   170  	if err != nil {
   171  		return err
   172  	}
   173  
   174  	tempZipFilePath := ""
   175  	defer func() {
   176  		// 结束时删除临时压缩文件
   177  		if tempZipFilePath != "" {
   178  			if err := os.Remove(tempZipFilePath); err != nil {
   179  				util.Log().Warning("Failed to delete temp archive file %q: %s", tempZipFilePath, err)
   180  			}
   181  		}
   182  	}()
   183  
   184  	// 下载压缩文件到临时目录
   185  	fileStream, err := fs.Handler.Get(ctx, fs.FileTarget[0].SourceName)
   186  	if err != nil {
   187  		return err
   188  	}
   189  
   190  	defer fileStream.Close()
   191  
   192  	tempZipFilePath = filepath.Join(
   193  		util.RelativePath(model.GetSettingByName("temp_path")),
   194  		"decompress",
   195  		fmt.Sprintf("archive_%d.zip", time.Now().UnixNano()),
   196  	)
   197  
   198  	zipFile, err := util.CreatNestedFile(tempZipFilePath)
   199  	if err != nil {
   200  		util.Log().Warning("Failed to create temp archive file %q: %s", tempZipFilePath, err)
   201  		tempZipFilePath = ""
   202  		return err
   203  	}
   204  	defer zipFile.Close()
   205  
   206  	// 下载前先判断是否是可解压的格式
   207  	format, readStream, err := archiver.Identify(fs.FileTarget[0].SourceName, fileStream)
   208  	if err != nil {
   209  		util.Log().Warning("Failed to detect compressed format of file %q: %s", fs.FileTarget[0].SourceName, err)
   210  		return err
   211  	}
   212  
   213  	extractor, ok := format.(archiver.Extractor)
   214  	if !ok {
   215  		return fmt.Errorf("file not an extractor %s", fs.FileTarget[0].SourceName)
   216  	}
   217  
   218  	// 只有zip格式可以多个文件同时上传
   219  	var isZip bool
   220  	switch extractor.(type) {
   221  	case archiver.Zip:
   222  		extractor = archiver.Zip{TextEncoding: encoding}
   223  		isZip = true
   224  	}
   225  
   226  	// 除了zip必须下载到本地,其余的可以边下载边解压
   227  	reader := readStream
   228  	if isZip {
   229  		_, err = io.Copy(zipFile, readStream)
   230  		if err != nil {
   231  			util.Log().Warning("Failed to write temp archive file %q: %s", tempZipFilePath, err)
   232  			return err
   233  		}
   234  
   235  		fileStream.Close()
   236  
   237  		// 设置文件偏移量
   238  		zipFile.Seek(0, io.SeekStart)
   239  		reader = zipFile
   240  	}
   241  
   242  	// 重设存储策略
   243  	fs.Policy = &fs.User.Policy
   244  	err = fs.DispatchHandler()
   245  	if err != nil {
   246  		return err
   247  	}
   248  
   249  	var wg sync.WaitGroup
   250  	parallel := model.GetIntSetting("max_parallel_transfer", 4)
   251  	worker := make(chan int, parallel)
   252  	for i := 0; i < parallel; i++ {
   253  		worker <- i
   254  	}
   255  
   256  	// 上传文件函数
   257  	uploadFunc := func(fileStream io.ReadCloser, size int64, savePath, rawPath string) {
   258  		defer func() {
   259  			if isZip {
   260  				worker <- 1
   261  				wg.Done()
   262  			}
   263  			if err := recover(); err != nil {
   264  				util.Log().Warning("Error while uploading files inside of archive file.")
   265  				fmt.Println(err)
   266  			}
   267  		}()
   268  
   269  		err := fs.UploadFromStream(ctx, &fsctx.FileStream{
   270  			File:        fileStream,
   271  			Size:        uint64(size),
   272  			Name:        path.Base(savePath),
   273  			VirtualPath: path.Dir(savePath),
   274  		}, true)
   275  		fileStream.Close()
   276  		if err != nil {
   277  			util.Log().Debug("Failed to upload file %q in archive file: %s, skipping...", rawPath, err)
   278  		}
   279  	}
   280  
   281  	// 解压缩文件,回调函数如果出错会停止解压的下一步进行,全部return nil
   282  	err = extractor.Extract(ctx, reader, nil, func(ctx context.Context, f archiver.File) error {
   283  		rawPath := util.FormSlash(f.NameInArchive)
   284  		savePath := path.Join(dst, rawPath)
   285  		// 路径是否合法
   286  		if !strings.HasPrefix(savePath, util.FillSlash(path.Clean(dst))) {
   287  			util.Log().Warning("%s: illegal file path", f.NameInArchive)
   288  			return nil
   289  		}
   290  
   291  		// 如果是目录
   292  		if f.FileInfo.IsDir() {
   293  			fs.CreateDirectory(ctx, savePath)
   294  			return nil
   295  		}
   296  
   297  		// 上传文件
   298  		fileStream, err := f.Open()
   299  		if err != nil {
   300  			util.Log().Warning("Failed to open file %q in archive file: %s, skipping...", rawPath, err)
   301  			return nil
   302  		}
   303  
   304  		if !isZip {
   305  			uploadFunc(fileStream, f.FileInfo.Size(), savePath, rawPath)
   306  		} else {
   307  			<-worker
   308  			wg.Add(1)
   309  			go uploadFunc(fileStream, f.FileInfo.Size(), savePath, rawPath)
   310  		}
   311  		return nil
   312  	})
   313  	wg.Wait()
   314  	return err
   315  
   316  }