github.com/ratrocket/u-root@v0.0.0-20180201221235-1cf9f48ee2cf/cmds/cp/cp.go (about)

     1  // Copyright 2016-2017 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Copy files.
     6  //
     7  // Synopsis:
     8  //     cp [-rRfivwP] FROM... TO
     9  //
    10  // Options:
    11  //     -w n: number of worker goroutines
    12  //     -R: copy file hierarchies
    13  //     -r: alias to -R recursive mode
    14  //     -i: prompt about overwriting file
    15  //     -f: force overwrite files
    16  //     -v: verbose copy mode
    17  //     -P: don't follow symlinks
    18  package main
    19  
    20  import (
    21  	"bufio"
    22  	"flag"
    23  	"fmt"
    24  	"io"
    25  	"io/ioutil"
    26  	"log"
    27  	"os"
    28  	"path/filepath"
    29  	"runtime"
    30  	"strings"
    31  )
    32  
    33  // buffSize is the length of buffer during
    34  // the parallel copy using worker function
    35  const buffSize = 8192
    36  
    37  var (
    38  	recursive bool
    39  	ask       bool
    40  	force     bool
    41  	verbose   bool
    42  	symlink   bool
    43  	nwork     int
    44  	input     = bufio.NewReader(os.Stdin)
    45  	// offchan is a channel used for indicate the nextbuffer to read with worker()
    46  	offchan = make(chan int64, 0)
    47  	// zerochan is a channel used for indicate the start of a new read file
    48  	zerochan = make(chan int, 0)
    49  )
    50  
    51  func init() {
    52  	defUsage := flag.Usage
    53  	flag.Usage = func() {
    54  		os.Args[0] = "cp [-wRrifvP] file[s] ... dest"
    55  		defUsage()
    56  	}
    57  	flag.IntVar(&nwork, "w", runtime.NumCPU(), "number of worker goroutines")
    58  	flag.BoolVar(&recursive, "R", false, "copy file hierarchies")
    59  	flag.BoolVar(&recursive, "r", false, "alias to -R recursive mode")
    60  	flag.BoolVar(&ask, "i", false, "prompt about overwriting file")
    61  	flag.BoolVar(&force, "f", false, "force overwrite files")
    62  	flag.BoolVar(&verbose, "v", false, "verbose copy mode")
    63  	flag.BoolVar(&symlink, "P", false, "don't follow symlinks")
    64  	flag.Parse()
    65  	go nextOff()
    66  }
    67  
    68  // promptOverwrite ask if the user wants overwrite file
    69  func promptOverwrite(dst string) (bool, error) {
    70  	fmt.Printf("cp: overwrite %q? ", dst)
    71  	answer, err := input.ReadString('\n')
    72  	if err != nil {
    73  		return false, err
    74  	}
    75  
    76  	if strings.ToLower(answer)[0] != 'y' {
    77  		return false, nil
    78  	}
    79  
    80  	return true, nil
    81  }
    82  
    83  // copyFile copies file between src (source) and dst (destination)
    84  // todir: if true insert src INTO dir dst
    85  func copyFile(src, dst string, todir bool) error {
    86  	if todir {
    87  		file := filepath.Base(src)
    88  		dst = filepath.Join(dst, file)
    89  	}
    90  
    91  	srcb, err := os.Lstat(src)
    92  	if err != nil {
    93  		return fmt.Errorf("can't stat %v: %v", src, err)
    94  	}
    95  
    96  	// don't follow symlinks, copy symlink
    97  	if L := os.ModeSymlink; symlink && srcb.Mode()&L == L {
    98  		linkPath, err := filepath.EvalSymlinks(src)
    99  		if err != nil {
   100  			return fmt.Errorf("can't eval symlink %v: %v", src, err)
   101  		}
   102  		return os.Symlink(linkPath, dst)
   103  	}
   104  
   105  	if srcb.IsDir() {
   106  		if recursive {
   107  			return copyDir(src, dst)
   108  		}
   109  		return fmt.Errorf("%q is a directory, try use recursive option", src)
   110  	}
   111  
   112  	dstb, err := os.Stat(dst)
   113  	if !os.IsNotExist(err) {
   114  		if sameFile(srcb.Sys(), dstb.Sys()) {
   115  			return fmt.Errorf("%q and %q are the same file", src, dst)
   116  		}
   117  		if ask && !force {
   118  			overwrite, err := promptOverwrite(dst)
   119  			if err != nil {
   120  				return err
   121  			}
   122  			if !overwrite {
   123  				return nil
   124  			}
   125  		}
   126  	}
   127  
   128  	mode := srcb.Mode() & 0777
   129  	s, err := os.Open(src)
   130  	if err != nil {
   131  		return fmt.Errorf("can't open %q: %v", src, err)
   132  	}
   133  	defer s.Close()
   134  
   135  	d, err := os.OpenFile(dst, os.O_RDWR|os.O_CREATE|os.O_TRUNC, mode)
   136  	if err != nil {
   137  		return fmt.Errorf("can't create %q: %v", dst, err)
   138  	}
   139  	defer d.Close()
   140  
   141  	return copyOneFile(s, d, src, dst)
   142  }
   143  
   144  // copyOneFile copy the content between two files
   145  func copyOneFile(s *os.File, d *os.File, src, dst string) error {
   146  	zerochan <- 0
   147  	fail := make(chan error, nwork)
   148  	for i := 0; i < nwork; i++ {
   149  		go worker(s, d, fail)
   150  	}
   151  
   152  	// iterate the errors from channel
   153  	for i := 0; i < nwork; i++ {
   154  		err := <-fail
   155  		if err != nil {
   156  			return err
   157  		}
   158  	}
   159  
   160  	if verbose {
   161  		fmt.Printf("%q -> %q\n", src, dst)
   162  	}
   163  
   164  	return nil
   165  }
   166  
   167  // createDir populate dir destination if not exists
   168  // if exists verify is not a dir: return error if is file
   169  // cannot overwrite: dir -> file
   170  func createDir(src, dst string) error {
   171  	dstInfo, err := os.Stat(dst)
   172  	if err != nil && !os.IsNotExist(err) {
   173  		return err
   174  	}
   175  
   176  	if err == nil {
   177  		if !dstInfo.IsDir() {
   178  			return fmt.Errorf("can't overwrite non-dir %q with dir %q", dst, src)
   179  		}
   180  		return nil
   181  	}
   182  
   183  	srcInfo, err := os.Stat(src)
   184  	if err != nil {
   185  		return err
   186  	}
   187  	if err := os.Mkdir(dst, srcInfo.Mode()); err != nil {
   188  		return err
   189  	}
   190  	if verbose {
   191  		fmt.Printf("%q -> %q\n", src, dst)
   192  	}
   193  
   194  	return nil
   195  }
   196  
   197  // copyDir copy the file hierarchies
   198  // used at cp when -r or -R flag is true
   199  func copyDir(src, dst string) error {
   200  	if err := createDir(src, dst); err != nil {
   201  		return err
   202  	}
   203  
   204  	// list files from destination
   205  	files, err := ioutil.ReadDir(src)
   206  	if err != nil {
   207  		return fmt.Errorf("can't list files from %q: %q", src, err)
   208  	}
   209  
   210  	// copy recursively the src -> dst
   211  	for _, file := range files {
   212  		fname := file.Name()
   213  		fpath := filepath.Join(src, fname)
   214  		newDst := filepath.Join(dst, fname)
   215  		copyFile(fpath, newDst, false)
   216  	}
   217  
   218  	return err
   219  }
   220  
   221  // worker is a concurrent copy, used to copy part of the files
   222  // in parallel
   223  func worker(s *os.File, d *os.File, fail chan error) {
   224  	var buf [buffSize]byte
   225  	var bp []byte
   226  
   227  	l := len(buf)
   228  	bp = buf[0:]
   229  	o := <-offchan
   230  	for {
   231  		n, err := s.ReadAt(bp, o)
   232  		if err != nil && err != io.EOF {
   233  			fail <- fmt.Errorf("reading %s at %v: %v", s.Name(), o, err)
   234  			return
   235  		}
   236  		if n == 0 {
   237  			break
   238  		}
   239  
   240  		nb := bp[0:n]
   241  		n, err = d.WriteAt(nb, o)
   242  		if err != nil {
   243  			fail <- fmt.Errorf("writing %s: %v", d.Name(), err)
   244  			return
   245  		}
   246  		bp = buf[n:]
   247  		o += int64(n)
   248  		l -= n
   249  		if l == 0 {
   250  			l = len(buf)
   251  			bp = buf[0:]
   252  			o = <-offchan
   253  		}
   254  	}
   255  	fail <- nil
   256  }
   257  
   258  // nextOff handler for next buffers and sync goroutines
   259  // zerochan imply the init of file
   260  // offchan is the next buffer part to read
   261  func nextOff() {
   262  	off := int64(0)
   263  	for {
   264  		select {
   265  		case <-zerochan:
   266  			off = 0
   267  		case offchan <- off:
   268  			off += buffSize
   269  		}
   270  	}
   271  }
   272  
   273  // cp is a function whose eval the args
   274  // and make decisions for copyfiles
   275  func cp(args []string) (lastErr error) {
   276  	todir := false
   277  	from, to := args[:len(args)-1], args[len(args)-1]
   278  	toStat, err := os.Stat(to)
   279  	if err == nil {
   280  		todir = toStat.IsDir()
   281  	}
   282  	if flag.NArg() > 2 && todir == false {
   283  		log.Fatalf("is not a directory: %s\n", to)
   284  	}
   285  
   286  	for _, file := range from {
   287  		if err := copyFile(file, to, todir); err != nil {
   288  			log.Printf("cp: %v\n", err)
   289  			lastErr = err
   290  		}
   291  	}
   292  
   293  	return err
   294  }
   295  
   296  func main() {
   297  	if flag.NArg() < 2 {
   298  		flag.Usage()
   299  		os.Exit(1)
   300  	}
   301  
   302  	if err := cp(flag.Args()); err != nil {
   303  		os.Exit(1)
   304  	}
   305  
   306  }