gopkg.in/hugelgupf/u-root.v4@v4.0.0-20180831060141-1d761fb73d50/cmds/cp/cp.go (about) 1 // Copyright 2016-2017 the u-root Authors. All rights reserved 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Copy files. 6 // 7 // Synopsis: 8 // cp [-rRfivwP] FROM... TO 9 // 10 // Options: 11 // -w n: number of worker goroutines 12 // -R: copy file hierarchies 13 // -r: alias to -R recursive mode 14 // -i: prompt about overwriting file 15 // -f: force overwrite files 16 // -v: verbose copy mode 17 // -P: don't follow symlinks 18 package main 19 20 import ( 21 "bufio" 22 "flag" 23 "fmt" 24 "io" 25 "io/ioutil" 26 "log" 27 "os" 28 "path/filepath" 29 "runtime" 30 "strings" 31 ) 32 33 // buffSize is the length of buffer during 34 // the parallel copy using worker function 35 const buffSize = 8192 36 37 var ( 38 recursive bool 39 ask bool 40 force bool 41 verbose bool 42 symlink bool 43 nwork int 44 input = bufio.NewReader(os.Stdin) 45 // offchan is a channel used for indicate the nextbuffer to read with worker() 46 offchan = make(chan int64, 0) 47 // zerochan is a channel used for indicate the start of a new read file 48 zerochan = make(chan int, 0) 49 ) 50 51 func init() { 52 defUsage := flag.Usage 53 flag.Usage = func() { 54 os.Args[0] = "cp [-wRrifvP] file[s] ... dest" 55 defUsage() 56 } 57 flag.IntVar(&nwork, "w", runtime.NumCPU(), "number of worker goroutines") 58 flag.BoolVar(&recursive, "R", false, "copy file hierarchies") 59 flag.BoolVar(&recursive, "r", false, "alias to -R recursive mode") 60 flag.BoolVar(&ask, "i", false, "prompt about overwriting file") 61 flag.BoolVar(&force, "f", false, "force overwrite files") 62 flag.BoolVar(&verbose, "v", false, "verbose copy mode") 63 flag.BoolVar(&symlink, "P", false, "don't follow symlinks") 64 flag.Parse() 65 go nextOff() 66 } 67 68 // promptOverwrite ask if the user wants overwrite file 69 func promptOverwrite(dst string) (bool, error) { 70 fmt.Printf("cp: overwrite %q? ", dst) 71 answer, err := input.ReadString('\n') 72 if err != nil { 73 return false, err 74 } 75 76 if strings.ToLower(answer)[0] != 'y' { 77 return false, nil 78 } 79 80 return true, nil 81 } 82 83 // copyFile copies file between src (source) and dst (destination) 84 // todir: if true insert src INTO dir dst 85 func copyFile(src, dst string, todir bool) error { 86 if todir { 87 file := filepath.Base(src) 88 dst = filepath.Join(dst, file) 89 } 90 91 srcb, err := os.Lstat(src) 92 if err != nil { 93 return fmt.Errorf("can't stat %v: %v", src, err) 94 } 95 96 // don't follow symlinks, copy symlink 97 if L := os.ModeSymlink; symlink && srcb.Mode()&L == L { 98 linkPath, err := filepath.EvalSymlinks(src) 99 if err != nil { 100 return fmt.Errorf("can't eval symlink %v: %v", src, err) 101 } 102 return os.Symlink(linkPath, dst) 103 } 104 105 if srcb.IsDir() { 106 if recursive { 107 return copyDir(src, dst) 108 } 109 return fmt.Errorf("%q is a directory, try use recursive option", src) 110 } 111 112 dstb, err := os.Stat(dst) 113 if err != nil && !os.IsNotExist(err) { 114 return fmt.Errorf("%q: can't handle error %v", dst, err) 115 } 116 117 if dstb != nil { 118 if sameFile(srcb.Sys(), dstb.Sys()) { 119 return fmt.Errorf("%q and %q are the same file", src, dst) 120 } 121 if ask && !force { 122 overwrite, err := promptOverwrite(dst) 123 if err != nil { 124 return err 125 } 126 if !overwrite { 127 return nil 128 } 129 } 130 } 131 132 mode := srcb.Mode() & 0777 133 s, err := os.Open(src) 134 if err != nil { 135 return fmt.Errorf("can't open %q: %v", src, err) 136 } 137 defer s.Close() 138 139 d, err := os.OpenFile(dst, os.O_RDWR|os.O_CREATE|os.O_TRUNC, mode) 140 if err != nil { 141 return fmt.Errorf("can't create %q: %v", dst, err) 142 } 143 defer d.Close() 144 145 return copyOneFile(s, d, src, dst) 146 } 147 148 // copyOneFile copy the content between two files 149 func copyOneFile(s *os.File, d *os.File, src, dst string) error { 150 zerochan <- 0 151 fail := make(chan error, nwork) 152 for i := 0; i < nwork; i++ { 153 go worker(s, d, fail) 154 } 155 156 // iterate the errors from channel 157 for i := 0; i < nwork; i++ { 158 err := <-fail 159 if err != nil { 160 return err 161 } 162 } 163 164 if verbose { 165 fmt.Printf("%q -> %q\n", src, dst) 166 } 167 168 return nil 169 } 170 171 // createDir populate dir destination if not exists 172 // if exists verify is not a dir: return error if is file 173 // cannot overwrite: dir -> file 174 func createDir(src, dst string) error { 175 dstInfo, err := os.Stat(dst) 176 if err != nil && !os.IsNotExist(err) { 177 return err 178 } 179 180 if err == nil { 181 if !dstInfo.IsDir() { 182 return fmt.Errorf("can't overwrite non-dir %q with dir %q", dst, src) 183 } 184 return nil 185 } 186 187 srcInfo, err := os.Stat(src) 188 if err != nil { 189 return err 190 } 191 if err := os.Mkdir(dst, srcInfo.Mode()); err != nil { 192 return err 193 } 194 if verbose { 195 fmt.Printf("%q -> %q\n", src, dst) 196 } 197 198 return nil 199 } 200 201 // copyDir copy the file hierarchies 202 // used at cp when -r or -R flag is true 203 func copyDir(src, dst string) error { 204 if err := createDir(src, dst); err != nil { 205 return err 206 } 207 208 // list files from destination 209 files, err := ioutil.ReadDir(src) 210 if err != nil { 211 return fmt.Errorf("can't list files from %q: %q", src, err) 212 } 213 214 // copy recursively the src -> dst 215 for _, file := range files { 216 fname := file.Name() 217 fpath := filepath.Join(src, fname) 218 newDst := filepath.Join(dst, fname) 219 copyFile(fpath, newDst, false) 220 } 221 222 return err 223 } 224 225 // worker is a concurrent copy, used to copy part of the files 226 // in parallel 227 func worker(s *os.File, d *os.File, fail chan error) { 228 var buf [buffSize]byte 229 var bp []byte 230 231 l := len(buf) 232 bp = buf[0:] 233 o := <-offchan 234 for { 235 n, err := s.ReadAt(bp, o) 236 if err != nil && err != io.EOF { 237 fail <- fmt.Errorf("reading %s at %v: %v", s.Name(), o, err) 238 return 239 } 240 if n == 0 { 241 break 242 } 243 244 nb := bp[0:n] 245 n, err = d.WriteAt(nb, o) 246 if err != nil { 247 fail <- fmt.Errorf("writing %s: %v", d.Name(), err) 248 return 249 } 250 bp = buf[n:] 251 o += int64(n) 252 l -= n 253 if l == 0 { 254 l = len(buf) 255 bp = buf[0:] 256 o = <-offchan 257 } 258 } 259 fail <- nil 260 } 261 262 // nextOff handler for next buffers and sync goroutines 263 // zerochan imply the init of file 264 // offchan is the next buffer part to read 265 func nextOff() { 266 off := int64(0) 267 for { 268 select { 269 case <-zerochan: 270 off = 0 271 case offchan <- off: 272 off += buffSize 273 } 274 } 275 } 276 277 // cp is a function whose eval the args 278 // and make decisions for copyfiles 279 func cp(args []string) (lastErr error) { 280 todir := false 281 from, to := args[:len(args)-1], args[len(args)-1] 282 toStat, err := os.Stat(to) 283 if err == nil { 284 todir = toStat.IsDir() 285 } 286 if flag.NArg() > 2 && todir == false { 287 log.Fatalf("is not a directory: %s\n", to) 288 } 289 290 for _, file := range from { 291 if err := copyFile(file, to, todir); err != nil { 292 log.Printf("cp: %v\n", err) 293 lastErr = err 294 } 295 } 296 297 return err 298 } 299 300 func main() { 301 if flag.NArg() < 2 { 302 flag.Usage() 303 os.Exit(1) 304 } 305 306 if err := cp(flag.Args()); err != nil { 307 os.Exit(1) 308 } 309 310 }