github.com/mvdan/u-root-coreutils@v0.0.0-20230122170626-c2eef2898555/pkg/cp/cmd.go (about)

     1  // Copyright 2016-2017 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // cp copies files.
     6  //
     7  // Synopsis:
     8  //
     9  //	cp [-rRfivwP] FROM... TO
    10  //
    11  // Options:
    12  //
    13  //	-w n: number of worker goroutines
    14  //	-R: copy file hierarchies
    15  //	-r: alias to -R recursive mode
    16  //	-i: prompt about overwriting file
    17  //	-f: force overwrite files
    18  //	-v: verbose copy mode
    19  //	-P: don't follow symlinks
    20  package cp
    21  
    22  import (
    23  	"bufio"
    24  	"errors"
    25  	"fmt"
    26  	"io"
    27  	"io/fs"
    28  	"os"
    29  	"path/filepath"
    30  	"strings"
    31  
    32  	flag "github.com/spf13/pflag"
    33  )
    34  
    35  type flags struct {
    36  	recursive        bool
    37  	ask              bool
    38  	force            bool
    39  	verbose          bool
    40  	noFollowSymlinks bool
    41  }
    42  
    43  // promptOverwrite ask if the user wants overwrite file
    44  func promptOverwrite(dst string, out io.Writer, in *bufio.Reader) (bool, error) {
    45  	fmt.Fprintf(out, "cp: overwrite %q? ", dst)
    46  	answer, err := in.ReadString('\n')
    47  	if err != nil {
    48  		return false, err
    49  	}
    50  
    51  	if strings.ToLower(answer)[0] != 'y' {
    52  		return false, nil
    53  	}
    54  
    55  	return true, nil
    56  }
    57  
    58  func setupPreCallback(recursive, ask, force bool, writer io.Writer, reader bufio.Reader) func(string, string, os.FileInfo) error {
    59  	return func(src, dst string, srcfi os.FileInfo) error {
    60  		// check if src is dir
    61  		if !recursive && srcfi.IsDir() {
    62  			fmt.Fprintf(writer, "cp: -r not specified, omitting directory %s\n", src)
    63  			return ErrSkip
    64  		}
    65  
    66  		dstfi, err := os.Stat(dst)
    67  		if err != nil && !os.IsNotExist(err) {
    68  			fmt.Fprintf(writer, "cp: %q: can't handle error %v\n", dst, err)
    69  			return ErrSkip
    70  		} else if err != nil {
    71  			// dst does not exist.
    72  			return nil
    73  		}
    74  
    75  		// dst does exist.
    76  
    77  		if os.SameFile(srcfi, dstfi) {
    78  			fmt.Fprintf(writer, "cp: %q and %q are the same file\n", src, dst)
    79  			return ErrSkip
    80  		}
    81  		if ask && !force {
    82  			overwrite, err := promptOverwrite(dst, writer, &reader)
    83  			if err != nil {
    84  				return err
    85  			}
    86  			if !overwrite {
    87  				return ErrSkip
    88  			}
    89  		}
    90  		return nil
    91  	}
    92  }
    93  
    94  func setupPostCallback(verbose bool, w io.Writer) func(src, dst string) {
    95  	return func(src, dst string) {
    96  		if verbose {
    97  			fmt.Fprintf(w, "%q -> %q\n", src, dst)
    98  		}
    99  	}
   100  }
   101  
   102  // run evaluates the args and makes decisions for copyfiles
   103  func run(params RunParams, args []string, f flags) error {
   104  	todir := false
   105  	from, to := args[:len(args)-1], args[len(args)-1]
   106  	toStat, err := os.Stat(to)
   107  	if err == nil {
   108  		todir = toStat.IsDir()
   109  	}
   110  	if len(args) > 2 && !todir {
   111  		return eNotDir
   112  	}
   113  
   114  	i := bufio.NewReader(params.Stdin)
   115  	w := params.Stderr
   116  
   117  	opts := Options{
   118  		NoFollowSymlinks: f.noFollowSymlinks,
   119  
   120  		// cp the command makes sure that
   121  		//
   122  		// (1) the files it's copying aren't already the same,
   123  		// (2) the user is asked about overwriting an existing file if
   124  		//     one is already there.
   125  		PreCallback: setupPreCallback(f.recursive, f.ask, f.force, w, *i),
   126  
   127  		PostCallback: setupPostCallback(f.verbose, w),
   128  	}
   129  
   130  	var lastErr error
   131  	for _, file := range from {
   132  		dst := to
   133  		if todir {
   134  			dst = filepath.Join(dst, filepath.Base(file))
   135  		}
   136  		absFile := params.MkAbs(file)
   137  		absDst := params.MkAbs(dst)
   138  		if f.recursive {
   139  			lastErr = opts.CopyTree(absFile, absDst)
   140  		} else {
   141  			lastErr = opts.Copy(absFile, absDst)
   142  		}
   143  		var pathError *fs.PathError
   144  		// Use the original path in errors.
   145  		if errors.As(lastErr, &pathError) {
   146  			switch pathError.Path {
   147  			case absFile:
   148  				pathError.Path = file
   149  			case absDst:
   150  				pathError.Path = dst
   151  			}
   152  		}
   153  	}
   154  	return lastErr
   155  }
   156  
   157  type RunParams struct {
   158  	Dir string
   159  	Env []string
   160  
   161  	Stdin          io.Reader
   162  	Stdout, Stderr io.Writer
   163  }
   164  
   165  func (p RunParams) MkAbs(name string) string {
   166  	if filepath.IsAbs(name) {
   167  		return name
   168  	}
   169  	return filepath.Join(p.Dir, name)
   170  }
   171  
   172  func RunMain(params RunParams, args ...string) (exit int) {
   173  	flagSet := flag.NewFlagSet("cp", flag.ContinueOnError)
   174  	flagSet.Usage = func() {
   175  		fmt.Fprintln(params.Stderr, "Usage: cp [-wRrifvP] file[s] ... dest")
   176  		flagSet.PrintDefaults()
   177  	}
   178  	var f flags
   179  	flagSet.BoolVarP(&f.recursive, "RECURSIVE", "R", false, "copy file hierarchies")
   180  	flagSet.BoolVarP(&f.recursive, "recursive", "r", false, "alias to -R recursive mode")
   181  	flagSet.BoolVarP(&f.ask, "interactive", "i", false, "prompt about overwriting file")
   182  	flagSet.BoolVarP(&f.force, "force", "f", false, "force overwrite files")
   183  	flagSet.BoolVarP(&f.verbose, "verbose", "v", false, "verbose copy mode")
   184  	flagSet.BoolVarP(&f.noFollowSymlinks, "no-dereference", "P", false, "don't follow symlinks")
   185  
   186  	if err := flagSet.Parse(args); err != nil {
   187  		if err == flag.ErrHelp {
   188  			return 0
   189  		}
   190  		return 2
   191  	}
   192  	if flagSet.NArg() < 2 {
   193  		// TODO: print usage to the stderr parameter
   194  		flagSet.Usage()
   195  		return 0
   196  	}
   197  
   198  	if err := run(params, flagSet.Args(), f); err != nil {
   199  		fmt.Fprintln(params.Stderr, err)
   200  		return 1
   201  	}
   202  	return 0
   203  }