github.com/hugelgupf/u-root@v0.0.0-20191023214958-4807c632154c/cmds/core/cp/cp.go (about)

     1  // Copyright 2016-2017 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Copy files.
     6  //
     7  // Synopsis:
     8  //     cp [-rRfivwP] FROM... TO
     9  //
    10  // Options:
    11  //     -w n: number of worker goroutines
    12  //     -R: copy file hierarchies
    13  //     -r: alias to -R recursive mode
    14  //     -i: prompt about overwriting file
    15  //     -f: force overwrite files
    16  //     -v: verbose copy mode
    17  //     -P: don't follow symlinks
    18  package main
    19  
    20  import (
    21  	"bufio"
    22  	"fmt"
    23  	"log"
    24  	"os"
    25  	"path/filepath"
    26  	"strings"
    27  
    28  	flag "github.com/spf13/pflag"
    29  	"github.com/u-root/u-root/pkg/cp"
    30  )
    31  
    32  var (
    33  	flags struct {
    34  		recursive        bool
    35  		ask              bool
    36  		force            bool
    37  		verbose          bool
    38  		noFollowSymlinks bool
    39  	}
    40  	input = bufio.NewReader(os.Stdin)
    41  )
    42  
    43  func init() {
    44  	defUsage := flag.Usage
    45  	flag.Usage = func() {
    46  		os.Args[0] = "cp [-wRrifvP] file[s] ... dest"
    47  		defUsage()
    48  	}
    49  	flag.BoolVarP(&flags.recursive, "RECURSIVE", "R", false, "copy file hierarchies")
    50  	flag.BoolVarP(&flags.recursive, "recursive", "r", false, "alias to -R recursive mode")
    51  	flag.BoolVarP(&flags.ask, "interactive", "i", false, "prompt about overwriting file")
    52  	flag.BoolVarP(&flags.force, "force", "f", false, "force overwrite files")
    53  	flag.BoolVarP(&flags.verbose, "verbose", "v", false, "verbose copy mode")
    54  	flag.BoolVarP(&flags.noFollowSymlinks, "no-dereference", "P", false, "don't follow symlinks")
    55  }
    56  
    57  // promptOverwrite ask if the user wants overwrite file
    58  func promptOverwrite(dst string) (bool, error) {
    59  	fmt.Printf("cp: overwrite %q? ", dst)
    60  	answer, err := input.ReadString('\n')
    61  	if err != nil {
    62  		return false, err
    63  	}
    64  
    65  	if strings.ToLower(answer)[0] != 'y' {
    66  		return false, nil
    67  	}
    68  
    69  	return true, nil
    70  }
    71  
    72  // cpArgs is a function whose eval the args
    73  // and make decisions for copyfiles
    74  func cpArgs(args []string) error {
    75  	todir := false
    76  	from, to := args[:len(args)-1], args[len(args)-1]
    77  	toStat, err := os.Stat(to)
    78  	if err == nil {
    79  		todir = toStat.IsDir()
    80  	}
    81  	if flag.NArg() > 2 && !todir {
    82  		log.Fatalf("is not a directory: %s\n", to)
    83  	}
    84  
    85  	opts := cp.Options{
    86  		NoFollowSymlinks: flags.noFollowSymlinks,
    87  
    88  		// cp the command makes sure that
    89  		//
    90  		// (1) the files it's copying aren't already the same,
    91  		// (2) the user is asked about overwriting an existing file if
    92  		//     one is already there.
    93  		PreCallback: func(src, dst string, srcfi os.FileInfo) error {
    94  			// check if src is dir
    95  			if !flags.recursive && srcfi.IsDir() {
    96  				log.Printf("cp: -r not specified, omitting directory %s", src)
    97  				return cp.ErrSkip
    98  			}
    99  
   100  			dstfi, err := os.Stat(dst)
   101  			if err != nil && !os.IsNotExist(err) {
   102  				log.Printf("cp: %q: can't handle error %v", dst, err)
   103  				return cp.ErrSkip
   104  			} else if err != nil {
   105  				// dst does not exist.
   106  				return nil
   107  			}
   108  
   109  			// dst does exist.
   110  
   111  			if os.SameFile(srcfi, dstfi) {
   112  				log.Printf("cp: %q and %q are the same file", src, dst)
   113  				return cp.ErrSkip
   114  			}
   115  			if flags.ask && !flags.force {
   116  				overwrite, err := promptOverwrite(dst)
   117  				if err != nil {
   118  					return err
   119  				}
   120  				if !overwrite {
   121  					return cp.ErrSkip
   122  				}
   123  			}
   124  			return nil
   125  		},
   126  
   127  		PostCallback: func(src, dst string) {
   128  			if flags.verbose {
   129  				fmt.Printf("%q -> %q\n", src, dst)
   130  			}
   131  		},
   132  	}
   133  
   134  	var lastErr error
   135  	for _, file := range from {
   136  		dst := to
   137  		if todir {
   138  			dst = filepath.Join(dst, filepath.Base(file))
   139  		}
   140  		if flags.recursive {
   141  			err = opts.CopyTree(file, dst)
   142  		} else {
   143  			err = opts.Copy(file, dst)
   144  		}
   145  		if err != nil {
   146  			log.Printf("cp: %v\n", err)
   147  			lastErr = err
   148  		}
   149  	}
   150  	return lastErr
   151  }
   152  
   153  func main() {
   154  	flag.Parse()
   155  	if flag.NArg() < 2 {
   156  		flag.Usage()
   157  		os.Exit(1)
   158  	}
   159  
   160  	if err := cpArgs(flag.Args()); err != nil {
   161  		os.Exit(1)
   162  	}
   163  }