github.com/mvdan/u-root-coreutils@v0.0.0-20230122170626-c2eef2898555/pkg/cp/cmd.go (about) 1 // Copyright 2016-2017 the u-root Authors. All rights reserved 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // cp copies files. 6 // 7 // Synopsis: 8 // 9 // cp [-rRfivwP] FROM... TO 10 // 11 // Options: 12 // 13 // -w n: number of worker goroutines 14 // -R: copy file hierarchies 15 // -r: alias to -R recursive mode 16 // -i: prompt about overwriting file 17 // -f: force overwrite files 18 // -v: verbose copy mode 19 // -P: don't follow symlinks 20 package cp 21 22 import ( 23 "bufio" 24 "errors" 25 "fmt" 26 "io" 27 "io/fs" 28 "os" 29 "path/filepath" 30 "strings" 31 32 flag "github.com/spf13/pflag" 33 ) 34 35 type flags struct { 36 recursive bool 37 ask bool 38 force bool 39 verbose bool 40 noFollowSymlinks bool 41 } 42 43 // promptOverwrite ask if the user wants overwrite file 44 func promptOverwrite(dst string, out io.Writer, in *bufio.Reader) (bool, error) { 45 fmt.Fprintf(out, "cp: overwrite %q? ", dst) 46 answer, err := in.ReadString('\n') 47 if err != nil { 48 return false, err 49 } 50 51 if strings.ToLower(answer)[0] != 'y' { 52 return false, nil 53 } 54 55 return true, nil 56 } 57 58 func setupPreCallback(recursive, ask, force bool, writer io.Writer, reader bufio.Reader) func(string, string, os.FileInfo) error { 59 return func(src, dst string, srcfi os.FileInfo) error { 60 // check if src is dir 61 if !recursive && srcfi.IsDir() { 62 fmt.Fprintf(writer, "cp: -r not specified, omitting directory %s\n", src) 63 return ErrSkip 64 } 65 66 dstfi, err := os.Stat(dst) 67 if err != nil && !os.IsNotExist(err) { 68 fmt.Fprintf(writer, "cp: %q: can't handle error %v\n", dst, err) 69 return ErrSkip 70 } else if err != nil { 71 // dst does not exist. 72 return nil 73 } 74 75 // dst does exist. 76 77 if os.SameFile(srcfi, dstfi) { 78 fmt.Fprintf(writer, "cp: %q and %q are the same file\n", src, dst) 79 return ErrSkip 80 } 81 if ask && !force { 82 overwrite, err := promptOverwrite(dst, writer, &reader) 83 if err != nil { 84 return err 85 } 86 if !overwrite { 87 return ErrSkip 88 } 89 } 90 return nil 91 } 92 } 93 94 func setupPostCallback(verbose bool, w io.Writer) func(src, dst string) { 95 return func(src, dst string) { 96 if verbose { 97 fmt.Fprintf(w, "%q -> %q\n", src, dst) 98 } 99 } 100 } 101 102 // run evaluates the args and makes decisions for copyfiles 103 func run(params RunParams, args []string, f flags) error { 104 todir := false 105 from, to := args[:len(args)-1], args[len(args)-1] 106 toStat, err := os.Stat(to) 107 if err == nil { 108 todir = toStat.IsDir() 109 } 110 if len(args) > 2 && !todir { 111 return eNotDir 112 } 113 114 i := bufio.NewReader(params.Stdin) 115 w := params.Stderr 116 117 opts := Options{ 118 NoFollowSymlinks: f.noFollowSymlinks, 119 120 // cp the command makes sure that 121 // 122 // (1) the files it's copying aren't already the same, 123 // (2) the user is asked about overwriting an existing file if 124 // one is already there. 125 PreCallback: setupPreCallback(f.recursive, f.ask, f.force, w, *i), 126 127 PostCallback: setupPostCallback(f.verbose, w), 128 } 129 130 var lastErr error 131 for _, file := range from { 132 dst := to 133 if todir { 134 dst = filepath.Join(dst, filepath.Base(file)) 135 } 136 absFile := params.MkAbs(file) 137 absDst := params.MkAbs(dst) 138 if f.recursive { 139 lastErr = opts.CopyTree(absFile, absDst) 140 } else { 141 lastErr = opts.Copy(absFile, absDst) 142 } 143 var pathError *fs.PathError 144 // Use the original path in errors. 145 if errors.As(lastErr, &pathError) { 146 switch pathError.Path { 147 case absFile: 148 pathError.Path = file 149 case absDst: 150 pathError.Path = dst 151 } 152 } 153 } 154 return lastErr 155 } 156 157 type RunParams struct { 158 Dir string 159 Env []string 160 161 Stdin io.Reader 162 Stdout, Stderr io.Writer 163 } 164 165 func (p RunParams) MkAbs(name string) string { 166 if filepath.IsAbs(name) { 167 return name 168 } 169 return filepath.Join(p.Dir, name) 170 } 171 172 func RunMain(params RunParams, args ...string) (exit int) { 173 flagSet := flag.NewFlagSet("cp", flag.ContinueOnError) 174 flagSet.Usage = func() { 175 fmt.Fprintln(params.Stderr, "Usage: cp [-wRrifvP] file[s] ... dest") 176 flagSet.PrintDefaults() 177 } 178 var f flags 179 flagSet.BoolVarP(&f.recursive, "RECURSIVE", "R", false, "copy file hierarchies") 180 flagSet.BoolVarP(&f.recursive, "recursive", "r", false, "alias to -R recursive mode") 181 flagSet.BoolVarP(&f.ask, "interactive", "i", false, "prompt about overwriting file") 182 flagSet.BoolVarP(&f.force, "force", "f", false, "force overwrite files") 183 flagSet.BoolVarP(&f.verbose, "verbose", "v", false, "verbose copy mode") 184 flagSet.BoolVarP(&f.noFollowSymlinks, "no-dereference", "P", false, "don't follow symlinks") 185 186 if err := flagSet.Parse(args); err != nil { 187 if err == flag.ErrHelp { 188 return 0 189 } 190 return 2 191 } 192 if flagSet.NArg() < 2 { 193 // TODO: print usage to the stderr parameter 194 flagSet.Usage() 195 return 0 196 } 197 198 if err := run(params, flagSet.Args(), f); err != nil { 199 fmt.Fprintln(params.Stderr, err) 200 return 1 201 } 202 return 0 203 }