github.com/ratrocket/u-root@v0.0.0-20180201221235-1cf9f48ee2cf/cmds/cp/cp.go (about) 1 // Copyright 2016-2017 the u-root Authors. All rights reserved 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Copy files. 6 // 7 // Synopsis: 8 // cp [-rRfivwP] FROM... TO 9 // 10 // Options: 11 // -w n: number of worker goroutines 12 // -R: copy file hierarchies 13 // -r: alias to -R recursive mode 14 // -i: prompt about overwriting file 15 // -f: force overwrite files 16 // -v: verbose copy mode 17 // -P: don't follow symlinks 18 package main 19 20 import ( 21 "bufio" 22 "flag" 23 "fmt" 24 "io" 25 "io/ioutil" 26 "log" 27 "os" 28 "path/filepath" 29 "runtime" 30 "strings" 31 ) 32 33 // buffSize is the length of buffer during 34 // the parallel copy using worker function 35 const buffSize = 8192 36 37 var ( 38 recursive bool 39 ask bool 40 force bool 41 verbose bool 42 symlink bool 43 nwork int 44 input = bufio.NewReader(os.Stdin) 45 // offchan is a channel used for indicate the nextbuffer to read with worker() 46 offchan = make(chan int64, 0) 47 // zerochan is a channel used for indicate the start of a new read file 48 zerochan = make(chan int, 0) 49 ) 50 51 func init() { 52 defUsage := flag.Usage 53 flag.Usage = func() { 54 os.Args[0] = "cp [-wRrifvP] file[s] ... dest" 55 defUsage() 56 } 57 flag.IntVar(&nwork, "w", runtime.NumCPU(), "number of worker goroutines") 58 flag.BoolVar(&recursive, "R", false, "copy file hierarchies") 59 flag.BoolVar(&recursive, "r", false, "alias to -R recursive mode") 60 flag.BoolVar(&ask, "i", false, "prompt about overwriting file") 61 flag.BoolVar(&force, "f", false, "force overwrite files") 62 flag.BoolVar(&verbose, "v", false, "verbose copy mode") 63 flag.BoolVar(&symlink, "P", false, "don't follow symlinks") 64 flag.Parse() 65 go nextOff() 66 } 67 68 // promptOverwrite ask if the user wants overwrite file 69 func promptOverwrite(dst string) (bool, error) { 70 fmt.Printf("cp: overwrite %q? ", dst) 71 answer, err := input.ReadString('\n') 72 if err != nil { 73 return false, err 74 } 75 76 if strings.ToLower(answer)[0] != 'y' { 77 return false, nil 78 } 79 80 return true, nil 81 } 82 83 // copyFile copies file between src (source) and dst (destination) 84 // todir: if true insert src INTO dir dst 85 func copyFile(src, dst string, todir bool) error { 86 if todir { 87 file := filepath.Base(src) 88 dst = filepath.Join(dst, file) 89 } 90 91 srcb, err := os.Lstat(src) 92 if err != nil { 93 return fmt.Errorf("can't stat %v: %v", src, err) 94 } 95 96 // don't follow symlinks, copy symlink 97 if L := os.ModeSymlink; symlink && srcb.Mode()&L == L { 98 linkPath, err := filepath.EvalSymlinks(src) 99 if err != nil { 100 return fmt.Errorf("can't eval symlink %v: %v", src, err) 101 } 102 return os.Symlink(linkPath, dst) 103 } 104 105 if srcb.IsDir() { 106 if recursive { 107 return copyDir(src, dst) 108 } 109 return fmt.Errorf("%q is a directory, try use recursive option", src) 110 } 111 112 dstb, err := os.Stat(dst) 113 if !os.IsNotExist(err) { 114 if sameFile(srcb.Sys(), dstb.Sys()) { 115 return fmt.Errorf("%q and %q are the same file", src, dst) 116 } 117 if ask && !force { 118 overwrite, err := promptOverwrite(dst) 119 if err != nil { 120 return err 121 } 122 if !overwrite { 123 return nil 124 } 125 } 126 } 127 128 mode := srcb.Mode() & 0777 129 s, err := os.Open(src) 130 if err != nil { 131 return fmt.Errorf("can't open %q: %v", src, err) 132 } 133 defer s.Close() 134 135 d, err := os.OpenFile(dst, os.O_RDWR|os.O_CREATE|os.O_TRUNC, mode) 136 if err != nil { 137 return fmt.Errorf("can't create %q: %v", dst, err) 138 } 139 defer d.Close() 140 141 return copyOneFile(s, d, src, dst) 142 } 143 144 // copyOneFile copy the content between two files 145 func copyOneFile(s *os.File, d *os.File, src, dst string) error { 146 zerochan <- 0 147 fail := make(chan error, nwork) 148 for i := 0; i < nwork; i++ { 149 go worker(s, d, fail) 150 } 151 152 // iterate the errors from channel 153 for i := 0; i < nwork; i++ { 154 err := <-fail 155 if err != nil { 156 return err 157 } 158 } 159 160 if verbose { 161 fmt.Printf("%q -> %q\n", src, dst) 162 } 163 164 return nil 165 } 166 167 // createDir populate dir destination if not exists 168 // if exists verify is not a dir: return error if is file 169 // cannot overwrite: dir -> file 170 func createDir(src, dst string) error { 171 dstInfo, err := os.Stat(dst) 172 if err != nil && !os.IsNotExist(err) { 173 return err 174 } 175 176 if err == nil { 177 if !dstInfo.IsDir() { 178 return fmt.Errorf("can't overwrite non-dir %q with dir %q", dst, src) 179 } 180 return nil 181 } 182 183 srcInfo, err := os.Stat(src) 184 if err != nil { 185 return err 186 } 187 if err := os.Mkdir(dst, srcInfo.Mode()); err != nil { 188 return err 189 } 190 if verbose { 191 fmt.Printf("%q -> %q\n", src, dst) 192 } 193 194 return nil 195 } 196 197 // copyDir copy the file hierarchies 198 // used at cp when -r or -R flag is true 199 func copyDir(src, dst string) error { 200 if err := createDir(src, dst); err != nil { 201 return err 202 } 203 204 // list files from destination 205 files, err := ioutil.ReadDir(src) 206 if err != nil { 207 return fmt.Errorf("can't list files from %q: %q", src, err) 208 } 209 210 // copy recursively the src -> dst 211 for _, file := range files { 212 fname := file.Name() 213 fpath := filepath.Join(src, fname) 214 newDst := filepath.Join(dst, fname) 215 copyFile(fpath, newDst, false) 216 } 217 218 return err 219 } 220 221 // worker is a concurrent copy, used to copy part of the files 222 // in parallel 223 func worker(s *os.File, d *os.File, fail chan error) { 224 var buf [buffSize]byte 225 var bp []byte 226 227 l := len(buf) 228 bp = buf[0:] 229 o := <-offchan 230 for { 231 n, err := s.ReadAt(bp, o) 232 if err != nil && err != io.EOF { 233 fail <- fmt.Errorf("reading %s at %v: %v", s.Name(), o, err) 234 return 235 } 236 if n == 0 { 237 break 238 } 239 240 nb := bp[0:n] 241 n, err = d.WriteAt(nb, o) 242 if err != nil { 243 fail <- fmt.Errorf("writing %s: %v", d.Name(), err) 244 return 245 } 246 bp = buf[n:] 247 o += int64(n) 248 l -= n 249 if l == 0 { 250 l = len(buf) 251 bp = buf[0:] 252 o = <-offchan 253 } 254 } 255 fail <- nil 256 } 257 258 // nextOff handler for next buffers and sync goroutines 259 // zerochan imply the init of file 260 // offchan is the next buffer part to read 261 func nextOff() { 262 off := int64(0) 263 for { 264 select { 265 case <-zerochan: 266 off = 0 267 case offchan <- off: 268 off += buffSize 269 } 270 } 271 } 272 273 // cp is a function whose eval the args 274 // and make decisions for copyfiles 275 func cp(args []string) (lastErr error) { 276 todir := false 277 from, to := args[:len(args)-1], args[len(args)-1] 278 toStat, err := os.Stat(to) 279 if err == nil { 280 todir = toStat.IsDir() 281 } 282 if flag.NArg() > 2 && todir == false { 283 log.Fatalf("is not a directory: %s\n", to) 284 } 285 286 for _, file := range from { 287 if err := copyFile(file, to, todir); err != nil { 288 log.Printf("cp: %v\n", err) 289 lastErr = err 290 } 291 } 292 293 return err 294 } 295 296 func main() { 297 if flag.NArg() < 2 { 298 flag.Usage() 299 os.Exit(1) 300 } 301 302 if err := cp(flag.Args()); err != nil { 303 os.Exit(1) 304 } 305 306 }