github.com/oweisse/u-root@v0.0.0-20181109060735-d005ad25fef1/cmds/cp/cp.go (about) 1 // Copyright 2016-2017 the u-root Authors. All rights reserved 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Copy files. 6 // 7 // Synopsis: 8 // cp [-rRfivwP] FROM... TO 9 // 10 // Options: 11 // -w n: number of worker goroutines 12 // -R: copy file hierarchies 13 // -r: alias to -R recursive mode 14 // -i: prompt about overwriting file 15 // -f: force overwrite files 16 // -v: verbose copy mode 17 // -P: don't follow symlinks 18 package main 19 20 import ( 21 "bufio" 22 "fmt" 23 "io" 24 "io/ioutil" 25 "log" 26 "os" 27 "path/filepath" 28 "runtime" 29 "strings" 30 31 flag "github.com/spf13/pflag" 32 ) 33 34 // buffSize is the length of buffer during 35 // the parallel copy using worker function 36 const buffSize = 8192 37 38 var ( 39 flags struct { 40 recursive bool 41 ask bool 42 force bool 43 verbose bool 44 symlink bool 45 nwork int 46 } 47 input = bufio.NewReader(os.Stdin) 48 // offchan is a channel used for indicate the nextbuffer to read with worker() 49 offchan = make(chan int64, 0) 50 // zerochan is a channel used for indicate the start of a new read file 51 zerochan = make(chan int, 0) 52 ) 53 54 func init() { 55 defUsage := flag.Usage 56 flag.Usage = func() { 57 os.Args[0] = "cp [-wRrifvP] file[s] ... dest" 58 defUsage() 59 } 60 flag.IntVarP(&flags.nwork, "workers", "w", runtime.NumCPU(), "number of worker goroutines") 61 flag.BoolVarP(&flags.recursive, "RECURSIVE", "R", false, "copy file hierarchies") 62 flag.BoolVarP(&flags.recursive, "recursive", "r", false, "alias to -R recursive mode") 63 flag.BoolVarP(&flags.ask, "interactive", "i", false, "prompt about overwriting file") 64 flag.BoolVarP(&flags.force, "force", "f", false, "force overwrite files") 65 flag.BoolVarP(&flags.verbose, "verbose", "v", false, "verbose copy mode") 66 flag.BoolVarP(&flags.symlink, "no-dereference", "P", false, "don't follow symlinks") 67 go nextOff() 68 } 69 70 // promptOverwrite ask if the user wants overwrite file 71 func promptOverwrite(dst string) (bool, error) { 72 fmt.Printf("cp: overwrite %q? ", dst) 73 answer, err := input.ReadString('\n') 74 if err != nil { 75 return false, err 76 } 77 78 if strings.ToLower(answer)[0] != 'y' { 79 return false, nil 80 } 81 82 return true, nil 83 } 84 85 // copyFile copies file between src (source) and dst (destination) 86 // todir: if true insert src INTO dir dst 87 func copyFile(src, dst string, todir bool) error { 88 if todir { 89 file := filepath.Base(src) 90 dst = filepath.Join(dst, file) 91 } 92 93 srcb, err := os.Lstat(src) 94 if err != nil { 95 return fmt.Errorf("can't stat %v: %v", src, err) 96 } 97 98 // don't follow symlinks, copy symlink 99 if L := os.ModeSymlink; flags.symlink && srcb.Mode()&L == L { 100 linkPath, err := filepath.EvalSymlinks(src) 101 if err != nil { 102 return fmt.Errorf("can't eval symlink %v: %v", src, err) 103 } 104 return os.Symlink(linkPath, dst) 105 } 106 107 if srcb.IsDir() { 108 if flags.recursive { 109 return copyDir(src, dst) 110 } 111 return fmt.Errorf("%q is a directory, try use recursive option", src) 112 } 113 114 dstb, err := os.Stat(dst) 115 if err != nil && !os.IsNotExist(err) { 116 return fmt.Errorf("%q: can't handle error %v", dst, err) 117 } 118 119 if dstb != nil { 120 if sameFile(srcb.Sys(), dstb.Sys()) { 121 return fmt.Errorf("%q and %q are the same file", src, dst) 122 } 123 if flags.ask && !flags.force { 124 overwrite, err := promptOverwrite(dst) 125 if err != nil { 126 return err 127 } 128 if !overwrite { 129 return nil 130 } 131 } 132 } 133 134 mode := srcb.Mode() & 0777 135 s, err := os.Open(src) 136 if err != nil { 137 return fmt.Errorf("can't open %q: %v", src, err) 138 } 139 defer s.Close() 140 141 d, err := os.OpenFile(dst, os.O_RDWR|os.O_CREATE|os.O_TRUNC, mode) 142 if err != nil { 143 return fmt.Errorf("can't create %q: %v", dst, err) 144 } 145 defer d.Close() 146 147 return copyOneFile(s, d, src, dst) 148 } 149 150 // copyOneFile copy the content between two files 151 func copyOneFile(s *os.File, d *os.File, src, dst string) error { 152 zerochan <- 0 153 fail := make(chan error, flags.nwork) 154 for i := 0; i < flags.nwork; i++ { 155 go worker(s, d, fail) 156 } 157 158 // iterate the errors from channel 159 for i := 0; i < flags.nwork; i++ { 160 err := <-fail 161 if err != nil { 162 return err 163 } 164 } 165 166 if flags.verbose { 167 fmt.Printf("%q -> %q\n", src, dst) 168 } 169 170 return nil 171 } 172 173 // createDir populate dir destination if not exists 174 // if exists verify is not a dir: return error if is file 175 // cannot overwrite: dir -> file 176 func createDir(src, dst string) error { 177 dstInfo, err := os.Stat(dst) 178 if err != nil && !os.IsNotExist(err) { 179 return err 180 } 181 182 if err == nil { 183 if !dstInfo.IsDir() { 184 return fmt.Errorf("can't overwrite non-dir %q with dir %q", dst, src) 185 } 186 return nil 187 } 188 189 srcInfo, err := os.Stat(src) 190 if err != nil { 191 return err 192 } 193 if err := os.Mkdir(dst, srcInfo.Mode()); err != nil { 194 return err 195 } 196 if flags.verbose { 197 fmt.Printf("%q -> %q\n", src, dst) 198 } 199 200 return nil 201 } 202 203 // copyDir copy the file hierarchies 204 // used at cp when -r or -R flag is true 205 func copyDir(src, dst string) error { 206 if err := createDir(src, dst); err != nil { 207 return err 208 } 209 210 // list files from destination 211 files, err := ioutil.ReadDir(src) 212 if err != nil { 213 return fmt.Errorf("can't list files from %q: %q", src, err) 214 } 215 216 // copy recursively the src -> dst 217 for _, file := range files { 218 fname := file.Name() 219 fpath := filepath.Join(src, fname) 220 newDst := filepath.Join(dst, fname) 221 copyFile(fpath, newDst, false) 222 } 223 224 return err 225 } 226 227 // worker is a concurrent copy, used to copy part of the files 228 // in parallel 229 func worker(s *os.File, d *os.File, fail chan error) { 230 var buf [buffSize]byte 231 var bp []byte 232 233 l := len(buf) 234 bp = buf[0:] 235 o := <-offchan 236 for { 237 n, err := s.ReadAt(bp, o) 238 if err != nil && err != io.EOF { 239 fail <- fmt.Errorf("reading %s at %v: %v", s.Name(), o, err) 240 return 241 } 242 if n == 0 { 243 break 244 } 245 246 nb := bp[0:n] 247 n, err = d.WriteAt(nb, o) 248 if err != nil { 249 fail <- fmt.Errorf("writing %s: %v", d.Name(), err) 250 return 251 } 252 bp = buf[n:] 253 o += int64(n) 254 l -= n 255 if l == 0 { 256 l = len(buf) 257 bp = buf[0:] 258 o = <-offchan 259 } 260 } 261 fail <- nil 262 } 263 264 // nextOff handler for next buffers and sync goroutines 265 // zerochan imply the init of file 266 // offchan is the next buffer part to read 267 func nextOff() { 268 off := int64(0) 269 for { 270 select { 271 case <-zerochan: 272 off = 0 273 case offchan <- off: 274 off += buffSize 275 } 276 } 277 } 278 279 // cp is a function whose eval the args 280 // and make decisions for copyfiles 281 func cp(args []string) (lastErr error) { 282 todir := false 283 from, to := args[:len(args)-1], args[len(args)-1] 284 toStat, err := os.Stat(to) 285 if err == nil { 286 todir = toStat.IsDir() 287 } 288 if flag.NArg() > 2 && todir == false { 289 log.Fatalf("is not a directory: %s\n", to) 290 } 291 292 for _, file := range from { 293 if err := copyFile(file, to, todir); err != nil { 294 log.Printf("cp: %v\n", err) 295 lastErr = err 296 } 297 } 298 299 return err 300 } 301 302 func main() { 303 flag.Parse() 304 if flag.NArg() < 2 { 305 flag.Usage() 306 os.Exit(1) 307 } 308 309 if err := cp(flag.Args()); err != nil { 310 os.Exit(1) 311 } 312 313 }