github.com/shaardie/u-root@v4.0.1-0.20190127173353-f24a1c26aa2e+incompatible/cmds/cp/cp.go (about)

     1  // Copyright 2016-2017 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Copy files.
     6  //
     7  // Synopsis:
     8  //     cp [-rRfivwP] FROM... TO
     9  //
    10  // Options:
    11  //     -w n: number of worker goroutines
    12  //     -R: copy file hierarchies
    13  //     -r: alias to -R recursive mode
    14  //     -i: prompt about overwriting file
    15  //     -f: force overwrite files
    16  //     -v: verbose copy mode
    17  //     -P: don't follow symlinks
    18  package main
    19  
    20  import (
    21  	"bufio"
    22  	"fmt"
    23  	"io"
    24  	"io/ioutil"
    25  	"log"
    26  	"os"
    27  	"path/filepath"
    28  	"runtime"
    29  	"strings"
    30  
    31  	flag "github.com/spf13/pflag"
    32  )
    33  
    34  // buffSize is the length of buffer during
    35  // the parallel copy using worker function
    36  const buffSize = 8192
    37  
    38  var (
    39  	flags struct {
    40  		recursive bool
    41  		ask       bool
    42  		force     bool
    43  		verbose   bool
    44  		symlink   bool
    45  		nwork     int
    46  	}
    47  	input = bufio.NewReader(os.Stdin)
    48  	// offchan is a channel used for indicate the nextbuffer to read with worker()
    49  	offchan = make(chan int64, 0)
    50  	// zerochan is a channel used for indicate the start of a new read file
    51  	zerochan = make(chan int, 0)
    52  )
    53  
    54  func init() {
    55  	defUsage := flag.Usage
    56  	flag.Usage = func() {
    57  		os.Args[0] = "cp [-wRrifvP] file[s] ... dest"
    58  		defUsage()
    59  	}
    60  	flag.IntVarP(&flags.nwork, "workers", "w", runtime.NumCPU(), "number of worker goroutines")
    61  	flag.BoolVarP(&flags.recursive, "RECURSIVE", "R", false, "copy file hierarchies")
    62  	flag.BoolVarP(&flags.recursive, "recursive", "r", false, "alias to -R recursive mode")
    63  	flag.BoolVarP(&flags.ask, "interactive", "i", false, "prompt about overwriting file")
    64  	flag.BoolVarP(&flags.force, "force", "f", false, "force overwrite files")
    65  	flag.BoolVarP(&flags.verbose, "verbose", "v", false, "verbose copy mode")
    66  	flag.BoolVarP(&flags.symlink, "no-dereference", "P", false, "don't follow symlinks")
    67  	go nextOff()
    68  }
    69  
    70  // promptOverwrite ask if the user wants overwrite file
    71  func promptOverwrite(dst string) (bool, error) {
    72  	fmt.Printf("cp: overwrite %q? ", dst)
    73  	answer, err := input.ReadString('\n')
    74  	if err != nil {
    75  		return false, err
    76  	}
    77  
    78  	if strings.ToLower(answer)[0] != 'y' {
    79  		return false, nil
    80  	}
    81  
    82  	return true, nil
    83  }
    84  
    85  // copyFile copies file between src (source) and dst (destination)
    86  // todir: if true insert src INTO dir dst
    87  func copyFile(src, dst string, todir bool) error {
    88  	if todir {
    89  		file := filepath.Base(src)
    90  		dst = filepath.Join(dst, file)
    91  	}
    92  
    93  	srcb, err := os.Lstat(src)
    94  	if err != nil {
    95  		return fmt.Errorf("can't stat %v: %v", src, err)
    96  	}
    97  
    98  	// don't follow symlinks, copy symlink
    99  	if L := os.ModeSymlink; flags.symlink && srcb.Mode()&L == L {
   100  		linkPath, err := filepath.EvalSymlinks(src)
   101  		if err != nil {
   102  			return fmt.Errorf("can't eval symlink %v: %v", src, err)
   103  		}
   104  		return os.Symlink(linkPath, dst)
   105  	}
   106  
   107  	if srcb.IsDir() {
   108  		if flags.recursive {
   109  			return copyDir(src, dst)
   110  		}
   111  		return fmt.Errorf("%q is a directory, try use recursive option", src)
   112  	}
   113  
   114  	dstb, err := os.Stat(dst)
   115  	if err != nil && !os.IsNotExist(err) {
   116  		return fmt.Errorf("%q: can't handle error %v", dst, err)
   117  	}
   118  
   119  	if dstb != nil {
   120  		if sameFile(srcb.Sys(), dstb.Sys()) {
   121  			return fmt.Errorf("%q and %q are the same file", src, dst)
   122  		}
   123  		if flags.ask && !flags.force {
   124  			overwrite, err := promptOverwrite(dst)
   125  			if err != nil {
   126  				return err
   127  			}
   128  			if !overwrite {
   129  				return nil
   130  			}
   131  		}
   132  	}
   133  
   134  	mode := srcb.Mode() & 0777
   135  	s, err := os.Open(src)
   136  	if err != nil {
   137  		return fmt.Errorf("can't open %q: %v", src, err)
   138  	}
   139  	defer s.Close()
   140  
   141  	d, err := os.OpenFile(dst, os.O_RDWR|os.O_CREATE|os.O_TRUNC, mode)
   142  	if err != nil {
   143  		return fmt.Errorf("can't create %q: %v", dst, err)
   144  	}
   145  	defer d.Close()
   146  
   147  	return copyOneFile(s, d, src, dst)
   148  }
   149  
   150  // copyOneFile copy the content between two files
   151  func copyOneFile(s *os.File, d *os.File, src, dst string) error {
   152  	zerochan <- 0
   153  	fail := make(chan error, flags.nwork)
   154  	for i := 0; i < flags.nwork; i++ {
   155  		go worker(s, d, fail)
   156  	}
   157  
   158  	// iterate the errors from channel
   159  	for i := 0; i < flags.nwork; i++ {
   160  		err := <-fail
   161  		if err != nil {
   162  			return err
   163  		}
   164  	}
   165  
   166  	if flags.verbose {
   167  		fmt.Printf("%q -> %q\n", src, dst)
   168  	}
   169  
   170  	return nil
   171  }
   172  
   173  // createDir populate dir destination if not exists
   174  // if exists verify is not a dir: return error if is file
   175  // cannot overwrite: dir -> file
   176  func createDir(src, dst string) error {
   177  	dstInfo, err := os.Stat(dst)
   178  	if err != nil && !os.IsNotExist(err) {
   179  		return err
   180  	}
   181  
   182  	if err == nil {
   183  		if !dstInfo.IsDir() {
   184  			return fmt.Errorf("can't overwrite non-dir %q with dir %q", dst, src)
   185  		}
   186  		return nil
   187  	}
   188  
   189  	srcInfo, err := os.Stat(src)
   190  	if err != nil {
   191  		return err
   192  	}
   193  	if err := os.Mkdir(dst, srcInfo.Mode()); err != nil {
   194  		return err
   195  	}
   196  	if flags.verbose {
   197  		fmt.Printf("%q -> %q\n", src, dst)
   198  	}
   199  
   200  	return nil
   201  }
   202  
   203  // copyDir copy the file hierarchies
   204  // used at cp when -r or -R flag is true
   205  func copyDir(src, dst string) error {
   206  	if err := createDir(src, dst); err != nil {
   207  		return err
   208  	}
   209  
   210  	// list files from destination
   211  	files, err := ioutil.ReadDir(src)
   212  	if err != nil {
   213  		return fmt.Errorf("can't list files from %q: %q", src, err)
   214  	}
   215  
   216  	// copy recursively the src -> dst
   217  	for _, file := range files {
   218  		fname := file.Name()
   219  		fpath := filepath.Join(src, fname)
   220  		newDst := filepath.Join(dst, fname)
   221  		copyFile(fpath, newDst, false)
   222  	}
   223  
   224  	return err
   225  }
   226  
   227  // worker is a concurrent copy, used to copy part of the files
   228  // in parallel
   229  func worker(s *os.File, d *os.File, fail chan error) {
   230  	var buf [buffSize]byte
   231  	var bp []byte
   232  
   233  	l := len(buf)
   234  	bp = buf[0:]
   235  	o := <-offchan
   236  	for {
   237  		n, err := s.ReadAt(bp, o)
   238  		if err != nil && err != io.EOF {
   239  			fail <- fmt.Errorf("reading %s at %v: %v", s.Name(), o, err)
   240  			return
   241  		}
   242  		if n == 0 {
   243  			break
   244  		}
   245  
   246  		nb := bp[0:n]
   247  		n, err = d.WriteAt(nb, o)
   248  		if err != nil {
   249  			fail <- fmt.Errorf("writing %s: %v", d.Name(), err)
   250  			return
   251  		}
   252  		bp = buf[n:]
   253  		o += int64(n)
   254  		l -= n
   255  		if l == 0 {
   256  			l = len(buf)
   257  			bp = buf[0:]
   258  			o = <-offchan
   259  		}
   260  	}
   261  	fail <- nil
   262  }
   263  
   264  // nextOff handler for next buffers and sync goroutines
   265  // zerochan imply the init of file
   266  // offchan is the next buffer part to read
   267  func nextOff() {
   268  	off := int64(0)
   269  	for {
   270  		select {
   271  		case <-zerochan:
   272  			off = 0
   273  		case offchan <- off:
   274  			off += buffSize
   275  		}
   276  	}
   277  }
   278  
   279  // cp is a function whose eval the args
   280  // and make decisions for copyfiles
   281  func cp(args []string) (lastErr error) {
   282  	todir := false
   283  	from, to := args[:len(args)-1], args[len(args)-1]
   284  	toStat, err := os.Stat(to)
   285  	if err == nil {
   286  		todir = toStat.IsDir()
   287  	}
   288  	if flag.NArg() > 2 && todir == false {
   289  		log.Fatalf("is not a directory: %s\n", to)
   290  	}
   291  
   292  	for _, file := range from {
   293  		if err := copyFile(file, to, todir); err != nil {
   294  			log.Printf("cp: %v\n", err)
   295  			lastErr = err
   296  		}
   297  	}
   298  
   299  	return err
   300  }
   301  
   302  func main() {
   303  	flag.Parse()
   304  	if flag.NArg() < 2 {
   305  		flag.Usage()
   306  		os.Exit(1)
   307  	}
   308  
   309  	if err := cp(flag.Args()); err != nil {
   310  		os.Exit(1)
   311  	}
   312  
   313  }