github.com/avfs/avfs@v0.33.1-0.20240303173310-c6ba67c33eb7/copy.go (about)

     1  //
     2  //  Copyright 2021 The AVFS authors
     3  //
     4  //  Licensed under the Apache License, Version 2.0 (the "License");
     5  //  you may not use this file except in compliance with the License.
     6  //  You may obtain a copy of the License at
     7  //
     8  //  	http://www.apache.org/licenses/LICENSE-2.0
     9  //
    10  //  Unless required by applicable law or agreed to in writing, software
    11  //  distributed under the License is distributed on an "AS IS" BASIS,
    12  //  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  //  See the License for the specific language governing permissions and
    14  //  limitations under the License.
    15  //
    16  
    17  package avfs
    18  
    19  import (
    20  	"hash"
    21  	"io"
    22  	"os"
    23  	"sync"
    24  )
    25  
    26  var copyPool = newCopyPool() //nolint:gochecknoglobals // copyPool is the buffer pool used to copy files.
    27  
    28  // newCopyPool initialize the copy buffer pool.
    29  func newCopyPool() *sync.Pool {
    30  	const bufSize = 32 * 1024
    31  
    32  	pool := &sync.Pool{New: func() any {
    33  		buf := make([]byte, bufSize)
    34  
    35  		return &buf
    36  	}}
    37  
    38  	return pool
    39  }
    40  
    41  // CopyFile copies a file between file systems and returns an error if any.
    42  func CopyFile(dstFs, srcFs VFSBase, dstPath, srcPath string) error {
    43  	_, err := CopyFileHash(dstFs, srcFs, dstPath, srcPath, nil)
    44  
    45  	return err
    46  }
    47  
    48  // CopyFileHash copies a file between file systems and returns the hash sum of the source file.
    49  func CopyFileHash(dstFs, srcFs VFSBase, dstPath, srcPath string, hasher hash.Hash) (sum []byte, err error) {
    50  	src, err := srcFs.OpenFile(srcPath, os.O_RDONLY, 0)
    51  	if err != nil {
    52  		return nil, err
    53  	}
    54  
    55  	defer src.Close()
    56  
    57  	dst, err := dstFs.Create(dstPath)
    58  	if err != nil {
    59  		return nil, err
    60  	}
    61  
    62  	defer func() {
    63  		cerr := dst.Close()
    64  		if cerr == nil {
    65  			err = cerr
    66  		}
    67  	}()
    68  
    69  	var out io.Writer
    70  
    71  	if hasher == nil {
    72  		out = dst
    73  	} else {
    74  		hasher.Reset()
    75  		out = io.MultiWriter(dst, hasher)
    76  	}
    77  
    78  	_, err = copyBufPool(out, src)
    79  	if err != nil {
    80  		return nil, err
    81  	}
    82  
    83  	err = dst.Sync()
    84  	if err != nil {
    85  		return nil, err
    86  	}
    87  
    88  	info, err := srcFs.Stat(srcPath)
    89  	if err != nil {
    90  		return nil, err
    91  	}
    92  
    93  	err = dstFs.Chmod(dstPath, info.Mode())
    94  	if err != nil {
    95  		return nil, err
    96  	}
    97  
    98  	if hasher == nil {
    99  		return nil, nil
   100  	}
   101  
   102  	return hasher.Sum(nil), nil
   103  }
   104  
   105  // HashFile hashes a file and returns the hash sum.
   106  func HashFile(vfs VFSBase, name string, hasher hash.Hash) (sum []byte, err error) {
   107  	f, err := vfs.OpenFile(name, os.O_RDONLY, 0)
   108  	if err != nil {
   109  		return nil, err
   110  	}
   111  
   112  	defer f.Close()
   113  
   114  	hasher.Reset()
   115  
   116  	_, err = copyBufPool(hasher, f)
   117  	if err != nil {
   118  		return nil, err
   119  	}
   120  
   121  	sum = hasher.Sum(nil)
   122  
   123  	return sum, nil
   124  }
   125  
   126  // copyBufPool copies a source reader to a writer using a buffer from the buffer pool.
   127  func copyBufPool(dst io.Writer, src io.Reader) (written int64, err error) { //nolint:unparam // unparam shoudn't check return values.
   128  	buf := copyPool.Get().(*[]byte) //nolint:forcetypeassert // Get() always returns a pointer to a byte slice.
   129  	defer copyPool.Put(buf)
   130  
   131  	written, err = io.CopyBuffer(dst, src, *buf)
   132  
   133  	return
   134  }