github.com/richardwilkes/toolbox@v1.121.0/xio/fs/tar/untar.go (about)

     1  // Copyright (c) 2016-2024 by Richard A. Wilkes. All rights reserved.
     2  //
     3  // This Source Code Form is subject to the terms of the Mozilla Public
     4  // License, version 2.0. If a copy of the MPL was not distributed with
     5  // this file, You can obtain one at http://mozilla.org/MPL/2.0/.
     6  //
     7  // This Source Code Form is "Incompatible With Secondary Licenses", as
     8  // defined by the Mozilla Public License, version 2.0.
     9  
    10  // Package tar provides simple tar extraction.
    11  package tar
    12  
    13  import (
    14  	"archive/tar"
    15  	"errors"
    16  	"fmt"
    17  	"io"
    18  	"os"
    19  	"path/filepath"
    20  	"strings"
    21  
    22  	"github.com/richardwilkes/toolbox/errs"
    23  	"github.com/richardwilkes/toolbox/xio"
    24  )
    25  
    26  // ExtractArchive extracts the contents of a tar archive at 'src' into the 'dst' directory.
    27  func ExtractArchive(src, dst string) error {
    28  	return ExtractArchiveWithMask(src, dst, 0o777)
    29  }
    30  
    31  // ExtractArchiveWithMask extracts the contents of a tar archive at 'src' into the 'dst' directory.
    32  func ExtractArchiveWithMask(src, dst string, mask os.FileMode) error {
    33  	f, err := os.Open(src)
    34  	if err != nil {
    35  		return errs.Wrap(err)
    36  	}
    37  	r := tar.NewReader(f)
    38  	defer xio.CloseIgnoringErrors(f)
    39  	return ExtractWithMask(r, dst, mask)
    40  }
    41  
    42  // Extract the contents of a tar reader into the 'dst' directory.
    43  func Extract(tr *tar.Reader, dst string) error {
    44  	return ExtractWithMask(tr, dst, 0o777)
    45  }
    46  
    47  // ExtractWithMask the contents of a tar reader into the 'dst' directory.
    48  func ExtractWithMask(tr *tar.Reader, dst string, mask os.FileMode) error {
    49  	root, err := filepath.Abs(dst)
    50  	if err != nil {
    51  		return errs.Wrap(err)
    52  	}
    53  	rootWithTrailingSep := fmt.Sprintf("%s%c", root, filepath.Separator)
    54  	for {
    55  		var hdr *tar.Header
    56  		if hdr, err = tr.Next(); errors.Is(err, io.EOF) {
    57  			return nil
    58  		}
    59  		if err != nil {
    60  			return errs.Wrap(err)
    61  		}
    62  		path := filepath.Join(root, hdr.Name) //nolint:gosec // We check for path traversal below
    63  		if !strings.HasPrefix(path, rootWithTrailingSep) {
    64  			return errs.Newf("Path outside of root is not permitted: %s", hdr.Name)
    65  		}
    66  		switch hdr.Typeflag {
    67  		case tar.TypeReg:
    68  			if err = extractFile(tr, path, hdr.FileInfo().Mode().Perm(), mask); err != nil {
    69  				return err
    70  			}
    71  		case tar.TypeLink:
    72  			if err = os.MkdirAll(filepath.Dir(path), 0o755&mask); err != nil {
    73  				return errs.Wrap(err)
    74  			}
    75  			if err = os.Link(hdr.Linkname, path); err != nil {
    76  				return errs.Wrap(err)
    77  			}
    78  		case tar.TypeSymlink:
    79  			if err = os.MkdirAll(filepath.Dir(path), 0o755&mask); err != nil {
    80  				return errs.Wrap(err)
    81  			}
    82  			if err = os.Symlink(hdr.Linkname, path); err != nil {
    83  				return errs.Wrap(err)
    84  			}
    85  		case tar.TypeDir:
    86  			if err = os.MkdirAll(path, hdr.FileInfo().Mode().Perm()&mask); err != nil {
    87  				return errs.Wrap(err)
    88  			}
    89  		}
    90  	}
    91  }
    92  
    93  func extractFile(r io.Reader, dst string, mode, mask os.FileMode) (err error) {
    94  	if err = os.MkdirAll(filepath.Dir(dst), 0o755&mask); err != nil {
    95  		return errs.Wrap(err)
    96  	}
    97  	var file *os.File
    98  	if file, err = os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, mode&mask); err != nil {
    99  		return errs.Wrap(err)
   100  	}
   101  	defer func() {
   102  		if closeErr := file.Close(); closeErr != nil && err == nil {
   103  			err = errs.Wrap(closeErr)
   104  		}
   105  	}()
   106  	if _, err = io.Copy(file, r); err != nil {
   107  		err = errs.Wrap(err)
   108  	}
   109  	return nil
   110  }