go.mway.dev/x@v0.0.0-20240520034138-950aede9a3fb/archive/extract/extract.go (about)

     1  // Copyright (c) 2024 Matt Way
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to
     5  // deal in the Software without restriction, including without limitation the
     6  // rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
     7  // sell copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
    18  // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
    19  // IN THE THE SOFTWARE.
    20  
    21  // Package extract provides archive extraction helpers.
    22  package extract
    23  
    24  import (
    25  	"bytes"
    26  	"context"
    27  	"fmt"
    28  	"io"
    29  	"os"
    30  	"path/filepath"
    31  	"strings"
    32  
    33  	"github.com/mholt/archiver/v4"
    34  	"go.mway.dev/color"
    35  	"go.mway.dev/errors"
    36  	xos "go.mway.dev/x/os"
    37  )
    38  
    39  const (
    40  	// ExtractToTempDir is a sentinel destination value that will cause
    41  	// [Extract] to generate a temporary directory for extraction, and remove
    42  	// the temporary directory once extraction (and callbacks) have completed.
    43  	// ExtractToTempDir is only useful if a [Callback] is passed to Extract.
    44  	ExtractToTempDir = ""
    45  
    46  	_sep = string(os.PathSeparator)
    47  )
    48  
    49  // Extract extracts the given archive to dst using any provided [Option]s.
    50  //
    51  //nolint:gocyclo
    52  func Extract(
    53  	ctx context.Context,
    54  	dst string,
    55  	archive string,
    56  	opts ...Option,
    57  ) (err error) {
    58  	var options Options
    59  	for _, opt := range opts {
    60  		opt.apply(&options)
    61  	}
    62  
    63  	if options.Output == nil {
    64  		options.Output = io.Discard
    65  	}
    66  
    67  	if dst == ExtractToTempDir {
    68  		if dst, err = os.MkdirTemp("", "extract"); err != nil {
    69  			return errors.Wrap(err, "failed to create temporary directory")
    70  		}
    71  
    72  		defer func() {
    73  			err = errors.Join(err, errors.Wrapf(
    74  				os.RemoveAll(dst),
    75  				"failed to remove temporary destination %q",
    76  				dst,
    77  			))
    78  		}()
    79  	}
    80  
    81  	dirs := make(map[string]struct{})
    82  	handler := archiver.FileHandler(func(
    83  		_ context.Context,
    84  		f archiver.File,
    85  	) (err error) {
    86  		// Ignore dirs; they are created lazily below.
    87  		if f.IsDir() {
    88  			return nil
    89  		}
    90  
    91  		fpath, stripped, stripErr := stripPrefix(
    92  			f.NameInArchive,
    93  			options.StripPrefix,
    94  		)
    95  		switch {
    96  		case stripErr != nil:
    97  			return errors.Wrap(stripErr, "failed to strip prefix")
    98  		case !stripped:
    99  			return nil
   100  		default:
   101  			// passthrough
   102  		}
   103  
   104  		for _, exclude := range options.ExcludePaths {
   105  			matched, matchErr := filepath.Match(exclude, fpath)
   106  			if matchErr != nil {
   107  				return errors.Wrapf(matchErr, "bad match pattern %q", exclude)
   108  			}
   109  			if matched {
   110  				return nil
   111  			}
   112  		}
   113  
   114  		dstpath := filepath.Join(dst, fpath)
   115  		if len(options.IncludePaths) > 0 {
   116  			var matched bool
   117  			for include, explicitDst := range options.IncludePaths {
   118  				if matched, err = filepath.Match(include, fpath); err != nil {
   119  					return errors.Wrapf(err, "bad match pattern %q", include)
   120  				}
   121  
   122  				switch {
   123  				case !matched:
   124  					continue
   125  				case len(explicitDst) == 0:
   126  					explicitDst = fpath
   127  				case !filepath.IsAbs(explicitDst):
   128  					explicitDst = filepath.Join(dst, explicitDst)
   129  				}
   130  
   131  				dstpath = explicitDst
   132  				break
   133  			}
   134  
   135  			if !matched {
   136  				return nil
   137  			}
   138  		}
   139  
   140  		parent := filepath.Dir(fpath)
   141  		if _, parentCreated := dirs[parent]; !parentCreated {
   142  			if options.Delete {
   143  				err = os.RemoveAll(parent)
   144  				if err != nil && !errors.Is(err, os.ErrNotExist) {
   145  					return errors.Wrapf(
   146  						err,
   147  						"failed to remove existing destination directory %q",
   148  						parent,
   149  					)
   150  				}
   151  			}
   152  
   153  			if err = xos.MkdirAllInherit(parent); err != nil {
   154  				return errors.Wrapf(
   155  					err,
   156  					"failed to create destination parent(s) %q",
   157  					parent,
   158  				)
   159  			}
   160  			dirs[fpath] = struct{}{}
   161  		}
   162  
   163  		_, err = xos.WriteReaderToFileWithFlags(
   164  			dstpath,
   165  			f.Open,
   166  			os.O_CREATE|os.O_TRUNC|os.O_WRONLY,
   167  			f.Mode(),
   168  		)
   169  		if err != nil {
   170  			return errors.Wrapf(
   171  				err,
   172  				"failed to extract %q to %q",
   173  				fpath,
   174  				dstpath,
   175  			)
   176  		}
   177  
   178  		_, err = color.FgHiGreen.Fprint(options.Output, "Extracting:")
   179  		if err != nil {
   180  			return errors.Wrap(err, "failed to write to output")
   181  		}
   182  
   183  		_, err = fmt.Fprintln(options.Output, "", fpath, "->", dstpath)
   184  		return errors.Wrap(err, "failed to write to output")
   185  	})
   186  
   187  	var src io.ReadCloser
   188  	if src, err = os.Open(archive); err != nil {
   189  		return errors.Wrapf(err, "failed to open source file %q", archive)
   190  	}
   191  	defer func() {
   192  		err = errors.Join(err, errors.Wrapf(
   193  			src.Close(),
   194  			"failed to close source file %q",
   195  			archive,
   196  		))
   197  	}()
   198  
   199  	return xos.WithCwd(dst, func() (err error) {
   200  		var (
   201  			format archiver.Format
   202  			reader io.Reader
   203  		)
   204  		if format, reader, err = archiver.Identify(archive, src); err != nil {
   205  			return errors.Wrapf(err, "failed to detect format of %q", archive)
   206  		}
   207  
   208  		ex, ok := format.(archiver.Extractor)
   209  		if !ok {
   210  			return errors.Newf(
   211  				"bug: identified format (%T) is not extractable",
   212  				format,
   213  			)
   214  		}
   215  
   216  		if _, ok := ex.(archiver.Zip); ok {
   217  			raw, readErr := io.ReadAll(reader)
   218  			if readErr != nil && !errors.Is(readErr, io.EOF) {
   219  				return errors.Wrap(readErr, "failed to read data buffer")
   220  			}
   221  			reader = bytes.NewReader(raw)
   222  		}
   223  
   224  		var extractAll []string
   225  		if err = ex.Extract(ctx, reader, extractAll, handler); err != nil {
   226  			return errors.Wrapf(err, "failed to extract %q", archive)
   227  		}
   228  
   229  		if options.Callback != nil {
   230  			return errors.Wrap(
   231  				options.Callback(ctx, dst),
   232  				"error during callback",
   233  			)
   234  		}
   235  
   236  		return nil
   237  	})
   238  }
   239  
   240  func stripPrefix(path string, prefix string) (string, bool, error) {
   241  	prefix = strings.Trim(prefix, _sep)
   242  	if len(prefix) == 0 {
   243  		return path, true, nil
   244  	}
   245  
   246  	path = filepath.Clean(path)
   247  	idx := strings.IndexByte(path, _sep[0])
   248  	for idx > 0 {
   249  		matched, err := filepath.Match(prefix, path[:idx])
   250  		if err != nil {
   251  			return "", false, errors.Wrapf(err, "bad prefix %q", prefix)
   252  		}
   253  
   254  		if matched {
   255  			return path[idx+1:], true, nil
   256  		}
   257  
   258  		offset := strings.IndexByte(path[idx+1:], _sep[0])
   259  		if offset < 0 {
   260  			break
   261  		}
   262  
   263  		idx += offset + 1
   264  	}
   265  
   266  	return path, false, nil
   267  }