gopkg.in/hugelgupf/u-root.v4@v4.0.0-20180831060141-1d761fb73d50/cmds/cp/cp.go (about)

     1  // Copyright 2016-2017 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Copy files.
     6  //
     7  // Synopsis:
     8  //     cp [-rRfivwP] FROM... TO
     9  //
    10  // Options:
    11  //     -w n: number of worker goroutines
    12  //     -R: copy file hierarchies
    13  //     -r: alias to -R recursive mode
    14  //     -i: prompt about overwriting file
    15  //     -f: force overwrite files
    16  //     -v: verbose copy mode
    17  //     -P: don't follow symlinks
    18  package main
    19  
    20  import (
    21  	"bufio"
    22  	"flag"
    23  	"fmt"
    24  	"io"
    25  	"io/ioutil"
    26  	"log"
    27  	"os"
    28  	"path/filepath"
    29  	"runtime"
    30  	"strings"
    31  )
    32  
    33  // buffSize is the length of buffer during
    34  // the parallel copy using worker function
    35  const buffSize = 8192
    36  
    37  var (
    38  	recursive bool
    39  	ask       bool
    40  	force     bool
    41  	verbose   bool
    42  	symlink   bool
    43  	nwork     int
    44  	input     = bufio.NewReader(os.Stdin)
    45  	// offchan is a channel used for indicate the nextbuffer to read with worker()
    46  	offchan = make(chan int64, 0)
    47  	// zerochan is a channel used for indicate the start of a new read file
    48  	zerochan = make(chan int, 0)
    49  )
    50  
    51  func init() {
    52  	defUsage := flag.Usage
    53  	flag.Usage = func() {
    54  		os.Args[0] = "cp [-wRrifvP] file[s] ... dest"
    55  		defUsage()
    56  	}
    57  	flag.IntVar(&nwork, "w", runtime.NumCPU(), "number of worker goroutines")
    58  	flag.BoolVar(&recursive, "R", false, "copy file hierarchies")
    59  	flag.BoolVar(&recursive, "r", false, "alias to -R recursive mode")
    60  	flag.BoolVar(&ask, "i", false, "prompt about overwriting file")
    61  	flag.BoolVar(&force, "f", false, "force overwrite files")
    62  	flag.BoolVar(&verbose, "v", false, "verbose copy mode")
    63  	flag.BoolVar(&symlink, "P", false, "don't follow symlinks")
    64  	flag.Parse()
    65  	go nextOff()
    66  }
    67  
    68  // promptOverwrite ask if the user wants overwrite file
    69  func promptOverwrite(dst string) (bool, error) {
    70  	fmt.Printf("cp: overwrite %q? ", dst)
    71  	answer, err := input.ReadString('\n')
    72  	if err != nil {
    73  		return false, err
    74  	}
    75  
    76  	if strings.ToLower(answer)[0] != 'y' {
    77  		return false, nil
    78  	}
    79  
    80  	return true, nil
    81  }
    82  
    83  // copyFile copies file between src (source) and dst (destination)
    84  // todir: if true insert src INTO dir dst
    85  func copyFile(src, dst string, todir bool) error {
    86  	if todir {
    87  		file := filepath.Base(src)
    88  		dst = filepath.Join(dst, file)
    89  	}
    90  
    91  	srcb, err := os.Lstat(src)
    92  	if err != nil {
    93  		return fmt.Errorf("can't stat %v: %v", src, err)
    94  	}
    95  
    96  	// don't follow symlinks, copy symlink
    97  	if L := os.ModeSymlink; symlink && srcb.Mode()&L == L {
    98  		linkPath, err := filepath.EvalSymlinks(src)
    99  		if err != nil {
   100  			return fmt.Errorf("can't eval symlink %v: %v", src, err)
   101  		}
   102  		return os.Symlink(linkPath, dst)
   103  	}
   104  
   105  	if srcb.IsDir() {
   106  		if recursive {
   107  			return copyDir(src, dst)
   108  		}
   109  		return fmt.Errorf("%q is a directory, try use recursive option", src)
   110  	}
   111  
   112  	dstb, err := os.Stat(dst)
   113  	if err != nil && !os.IsNotExist(err) {
   114  		return fmt.Errorf("%q: can't handle error %v", dst, err)
   115  	}
   116  
   117  	if dstb != nil {
   118  		if sameFile(srcb.Sys(), dstb.Sys()) {
   119  			return fmt.Errorf("%q and %q are the same file", src, dst)
   120  		}
   121  		if ask && !force {
   122  			overwrite, err := promptOverwrite(dst)
   123  			if err != nil {
   124  				return err
   125  			}
   126  			if !overwrite {
   127  				return nil
   128  			}
   129  		}
   130  	}
   131  
   132  	mode := srcb.Mode() & 0777
   133  	s, err := os.Open(src)
   134  	if err != nil {
   135  		return fmt.Errorf("can't open %q: %v", src, err)
   136  	}
   137  	defer s.Close()
   138  
   139  	d, err := os.OpenFile(dst, os.O_RDWR|os.O_CREATE|os.O_TRUNC, mode)
   140  	if err != nil {
   141  		return fmt.Errorf("can't create %q: %v", dst, err)
   142  	}
   143  	defer d.Close()
   144  
   145  	return copyOneFile(s, d, src, dst)
   146  }
   147  
   148  // copyOneFile copy the content between two files
   149  func copyOneFile(s *os.File, d *os.File, src, dst string) error {
   150  	zerochan <- 0
   151  	fail := make(chan error, nwork)
   152  	for i := 0; i < nwork; i++ {
   153  		go worker(s, d, fail)
   154  	}
   155  
   156  	// iterate the errors from channel
   157  	for i := 0; i < nwork; i++ {
   158  		err := <-fail
   159  		if err != nil {
   160  			return err
   161  		}
   162  	}
   163  
   164  	if verbose {
   165  		fmt.Printf("%q -> %q\n", src, dst)
   166  	}
   167  
   168  	return nil
   169  }
   170  
   171  // createDir populate dir destination if not exists
   172  // if exists verify is not a dir: return error if is file
   173  // cannot overwrite: dir -> file
   174  func createDir(src, dst string) error {
   175  	dstInfo, err := os.Stat(dst)
   176  	if err != nil && !os.IsNotExist(err) {
   177  		return err
   178  	}
   179  
   180  	if err == nil {
   181  		if !dstInfo.IsDir() {
   182  			return fmt.Errorf("can't overwrite non-dir %q with dir %q", dst, src)
   183  		}
   184  		return nil
   185  	}
   186  
   187  	srcInfo, err := os.Stat(src)
   188  	if err != nil {
   189  		return err
   190  	}
   191  	if err := os.Mkdir(dst, srcInfo.Mode()); err != nil {
   192  		return err
   193  	}
   194  	if verbose {
   195  		fmt.Printf("%q -> %q\n", src, dst)
   196  	}
   197  
   198  	return nil
   199  }
   200  
   201  // copyDir copy the file hierarchies
   202  // used at cp when -r or -R flag is true
   203  func copyDir(src, dst string) error {
   204  	if err := createDir(src, dst); err != nil {
   205  		return err
   206  	}
   207  
   208  	// list files from destination
   209  	files, err := ioutil.ReadDir(src)
   210  	if err != nil {
   211  		return fmt.Errorf("can't list files from %q: %q", src, err)
   212  	}
   213  
   214  	// copy recursively the src -> dst
   215  	for _, file := range files {
   216  		fname := file.Name()
   217  		fpath := filepath.Join(src, fname)
   218  		newDst := filepath.Join(dst, fname)
   219  		copyFile(fpath, newDst, false)
   220  	}
   221  
   222  	return err
   223  }
   224  
   225  // worker is a concurrent copy, used to copy part of the files
   226  // in parallel
   227  func worker(s *os.File, d *os.File, fail chan error) {
   228  	var buf [buffSize]byte
   229  	var bp []byte
   230  
   231  	l := len(buf)
   232  	bp = buf[0:]
   233  	o := <-offchan
   234  	for {
   235  		n, err := s.ReadAt(bp, o)
   236  		if err != nil && err != io.EOF {
   237  			fail <- fmt.Errorf("reading %s at %v: %v", s.Name(), o, err)
   238  			return
   239  		}
   240  		if n == 0 {
   241  			break
   242  		}
   243  
   244  		nb := bp[0:n]
   245  		n, err = d.WriteAt(nb, o)
   246  		if err != nil {
   247  			fail <- fmt.Errorf("writing %s: %v", d.Name(), err)
   248  			return
   249  		}
   250  		bp = buf[n:]
   251  		o += int64(n)
   252  		l -= n
   253  		if l == 0 {
   254  			l = len(buf)
   255  			bp = buf[0:]
   256  			o = <-offchan
   257  		}
   258  	}
   259  	fail <- nil
   260  }
   261  
   262  // nextOff handler for next buffers and sync goroutines
   263  // zerochan imply the init of file
   264  // offchan is the next buffer part to read
   265  func nextOff() {
   266  	off := int64(0)
   267  	for {
   268  		select {
   269  		case <-zerochan:
   270  			off = 0
   271  		case offchan <- off:
   272  			off += buffSize
   273  		}
   274  	}
   275  }
   276  
   277  // cp is a function whose eval the args
   278  // and make decisions for copyfiles
   279  func cp(args []string) (lastErr error) {
   280  	todir := false
   281  	from, to := args[:len(args)-1], args[len(args)-1]
   282  	toStat, err := os.Stat(to)
   283  	if err == nil {
   284  		todir = toStat.IsDir()
   285  	}
   286  	if flag.NArg() > 2 && todir == false {
   287  		log.Fatalf("is not a directory: %s\n", to)
   288  	}
   289  
   290  	for _, file := range from {
   291  		if err := copyFile(file, to, todir); err != nil {
   292  			log.Printf("cp: %v\n", err)
   293  			lastErr = err
   294  		}
   295  	}
   296  
   297  	return err
   298  }
   299  
   300  func main() {
   301  	if flag.NArg() < 2 {
   302  		flag.Usage()
   303  		os.Exit(1)
   304  	}
   305  
   306  	if err := cp(flag.Args()); err != nil {
   307  		os.Exit(1)
   308  	}
   309  
   310  }