github.com/richardwilkes/toolbox@v1.121.0/xio/fs/zip/unzip.go (about)

     1  // Copyright (c) 2016-2024 by Richard A. Wilkes. All rights reserved.
     2  //
     3  // This Source Code Form is subject to the terms of the Mozilla Public
     4  // License, version 2.0. If a copy of the MPL was not distributed with
     5  // this file, You can obtain one at http://mozilla.org/MPL/2.0/.
     6  //
     7  // This Source Code Form is "Incompatible With Secondary Licenses", as
     8  // defined by the Mozilla Public License, version 2.0.
     9  
    10  // Package zip provides simple zip extraction.
    11  package zip
    12  
    13  import (
    14  	"archive/zip"
    15  	"fmt"
    16  	"io"
    17  	"os"
    18  	"path/filepath"
    19  	"strings"
    20  
    21  	"github.com/richardwilkes/toolbox/errs"
    22  	"github.com/richardwilkes/toolbox/xio"
    23  )
    24  
    25  // ExtractArchive extracts the contents of a zip archive at 'src' into the 'dst' directory.
    26  func ExtractArchive(src, dst string) error {
    27  	return ExtractArchiveWithMask(src, dst, 0o777)
    28  }
    29  
    30  // ExtractArchiveWithMask extracts the contents of a zip archive at 'src' into the 'dst' directory.
    31  func ExtractArchiveWithMask(src, dst string, mask os.FileMode) error {
    32  	r, err := zip.OpenReader(src)
    33  	if err != nil {
    34  		return errs.Wrap(err)
    35  	}
    36  	defer xio.CloseIgnoringErrors(r)
    37  	return ExtractWithMask(&r.Reader, dst, mask)
    38  }
    39  
    40  // Extract the contents of a zip reader into the 'dst' directory.
    41  func Extract(zr *zip.Reader, dst string) error {
    42  	return ExtractWithMask(zr, dst, 0o777)
    43  }
    44  
    45  // ExtractWithMask the contents of a zip reader into the 'dst' directory.
    46  func ExtractWithMask(zr *zip.Reader, dst string, mask os.FileMode) error {
    47  	root, err := filepath.Abs(dst)
    48  	if err != nil {
    49  		return errs.Wrap(err)
    50  	}
    51  	rootWithTrailingSep := fmt.Sprintf("%s%c", root, filepath.Separator)
    52  	for _, f := range zr.File {
    53  		path := filepath.Join(root, f.Name) //nolint:gosec // We check for path traversal below
    54  		if !strings.HasPrefix(path, rootWithTrailingSep) {
    55  			return errs.Newf("Path outside of root is not permitted: %s", f.Name)
    56  		}
    57  		fi := f.FileInfo()
    58  		mode := fi.Mode()
    59  		switch {
    60  		case mode&os.ModeSymlink != 0:
    61  			if err = extractSymLink(f, path, mask); err != nil {
    62  				return err
    63  			}
    64  		case fi.IsDir():
    65  			if err = os.MkdirAll(path, mode.Perm()&mask); err != nil {
    66  				return errs.Wrap(err)
    67  			}
    68  		default:
    69  			if err = extractFile(f, path, mask); err != nil {
    70  				return err
    71  			}
    72  		}
    73  	}
    74  	return nil
    75  }
    76  
    77  func extractSymLink(f *zip.File, dst string, mask os.FileMode) error {
    78  	r, err := f.Open()
    79  	if err != nil {
    80  		return errs.Wrap(err)
    81  	}
    82  	defer xio.CloseIgnoringErrors(r)
    83  	var buffer []byte
    84  	if buffer, err = io.ReadAll(r); err != nil {
    85  		return errs.Wrap(err)
    86  	}
    87  	if err = os.MkdirAll(filepath.Dir(dst), 0o755&mask); err != nil {
    88  		return errs.Wrap(err)
    89  	}
    90  	if err = os.Symlink(string(buffer), dst); err != nil {
    91  		return errs.Wrap(err)
    92  	}
    93  	return nil
    94  }
    95  
    96  func extractFile(f *zip.File, dst string, mask os.FileMode) (err error) {
    97  	var r io.ReadCloser
    98  	if r, err = f.Open(); err != nil {
    99  		return errs.Wrap(err)
   100  	}
   101  	defer xio.CloseIgnoringErrors(r)
   102  	if err = os.MkdirAll(filepath.Dir(dst), 0o755&mask); err != nil {
   103  		return errs.Wrap(err)
   104  	}
   105  	var file *os.File
   106  	if file, err = os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, f.FileInfo().Mode().Perm()&mask); err != nil {
   107  		return errs.Wrap(err)
   108  	}
   109  	defer func() {
   110  		if closeErr := file.Close(); closeErr != nil && err == nil {
   111  			err = errs.Wrap(closeErr)
   112  		}
   113  	}()
   114  	if _, err = io.Copy(file, r); err != nil { //nolint:gosec // We'll take the risk
   115  		err = errs.Wrap(err)
   116  	}
   117  	return
   118  }