github.com/u-root/u-root@v7.0.1-0.20200915234505-ad7babab0a8e+incompatible/cmds/core/dd/dd.go (about)

     1  // Copyright 2013-2017 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // dd converts and copies a file.
     6  //
     7  // Synopsis:
     8  //     dd [OPTIONS...] [-inName FILE] [-outName FILE]
     9  //
    10  // Description:
    11  //     dd is modeled after dd(1).
    12  //
    13  // Options:
    14  //     -ibs n:   input block size (default=1)
    15  //     -obs n:   output block size (default=1)
    16  //     -bs n:    input and output block size (default=0)
    17  //     -skip n:  skip n ibs-sized input blocks before reading (default=0)
    18  //     -seek n:  seek n obs-sized output blocks before writing (default=0)
    19  //     -conv s:  comma separated list of conversions (none|notrunc)
    20  //     -count n: copy only n ibs-sized input blocks
    21  //     -if:      defaults to stdin
    22  //     -of:      defaults to stdout
    23  //     -oflag:   comma separated list of out flags (none|sync|dsync)
    24  //     -status:  print transfer stats to stderr, can be one of:
    25  //         none:     do not display
    26  //         xfer:     print on completion (default)
    27  //         progress: print throughout transfer (GNU)
    28  //
    29  // Notes:
    30  //     Because UTF-8 clashes with block-oriented copying, `conv=lcase` and
    31  //     `conv=ucase` will not be supported. Additionally, research showed these
    32  //     arguments are rarely useful. Use tr instead.
    33  package main
    34  
    35  import (
    36  	"flag"
    37  	"fmt"
    38  	"io"
    39  	"io/ioutil"
    40  	"log"
    41  	"math"
    42  	"os"
    43  	"strings"
    44  	"sync"
    45  	"sync/atomic"
    46  	"time"
    47  
    48  	"github.com/rck/unit"
    49  )
    50  
    51  var (
    52  	ibs, obs, bs *unit.Value
    53  	skip         = flag.Int64("skip", 0, "skip N ibs-sized blocks before reading")
    54  	seek         = flag.Int64("seek", 0, "seek N obs-sized blocks before writing")
    55  	conv         = flag.String("conv", "none", "comma separated list of conversions (none|notrunc)")
    56  	count        = flag.Int64("count", math.MaxInt64, "copy only N input blocks")
    57  	inName       = flag.String("if", "", "Input file")
    58  	outName      = flag.String("of", "", "Output file")
    59  	oFlag        = flag.String("oflag", "none", "comma separated list of out flags (none|sync|dsync)")
    60  	status       = flag.String("status", "xfer", "display status of transfer (none|xfer|progress)")
    61  
    62  	bytesWritten int64 // access atomically, must be global for correct alignedness
    63  )
    64  
    65  type bitClearAndSet struct {
    66  	clear int
    67  	set   int
    68  }
    69  
    70  // N.B. The flags in os derive from syscall.
    71  // They are, hence, mutually exclusive, on target
    72  // kernels.
    73  var convMap = map[string]bitClearAndSet{
    74  	"notrunc": {clear: os.O_TRUNC},
    75  }
    76  
    77  var flagMap = map[string]bitClearAndSet{
    78  	"sync": {set: os.O_SYNC},
    79  }
    80  
    81  var allowedFlags = os.O_TRUNC | os.O_SYNC
    82  
    83  // intermediateBuffer is a buffer that one can write to and read from.
    84  type intermediateBuffer interface {
    85  	io.ReaderFrom
    86  	io.WriterTo
    87  }
    88  
    89  // chunkedBuffer is an intermediateBuffer with a specific size.
    90  type chunkedBuffer struct {
    91  	outChunk int64
    92  	length   int64
    93  	data     []byte
    94  	flags    int
    95  }
    96  
    97  func init() {
    98  	ddUnits := unit.DefaultUnits
    99  	ddUnits["c"] = 1
   100  	ddUnits["w"] = 2
   101  	ddUnits["b"] = 512
   102  	delete(ddUnits, "B")
   103  
   104  	ibs = unit.MustNewUnit(ddUnits).MustNewValue(512, unit.None)
   105  	obs = unit.MustNewUnit(ddUnits).MustNewValue(512, unit.None)
   106  	bs = unit.MustNewUnit(ddUnits).MustNewValue(512, unit.None)
   107  
   108  	flag.Var(ibs, "ibs", "Default input block size")
   109  	flag.Var(obs, "obs", "Default output block size")
   110  	flag.Var(bs, "bs", "Default input and output block size")
   111  }
   112  
   113  // newChunkedBuffer returns an intermediateBuffer that stores inChunkSize-sized
   114  // chunks of data and writes them to writers in outChunkSize-sized chunks.
   115  func newChunkedBuffer(inChunkSize int64, outChunkSize int64, flags int) intermediateBuffer {
   116  	return &chunkedBuffer{
   117  		outChunk: outChunkSize,
   118  		length:   0,
   119  		data:     make([]byte, inChunkSize),
   120  		flags:    flags,
   121  	}
   122  }
   123  
   124  // ReadFrom reads an inChunkSize-sized chunk from r into the buffer.
   125  func (cb *chunkedBuffer) ReadFrom(r io.Reader) (int64, error) {
   126  	n, err := r.Read(cb.data)
   127  	cb.length = int64(n)
   128  
   129  	// Convert to EOF explicitly.
   130  	if n == 0 && err == nil {
   131  		return 0, io.EOF
   132  	}
   133  	return int64(n), err
   134  }
   135  
   136  // WriteTo writes from the buffer to w in outChunkSize-sized chunks.
   137  func (cb *chunkedBuffer) WriteTo(w io.Writer) (int64, error) {
   138  	var i int64
   139  	for i = 0; i < int64(cb.length); {
   140  		chunk := cb.outChunk
   141  		if i+chunk > cb.length {
   142  			chunk = cb.length - i
   143  		}
   144  		block := cb.data[i : i+chunk]
   145  		got, err := w.Write(block)
   146  
   147  		// Ugh, Go cruft: io.Writer.Write returns (int, error).
   148  		// io.WriterTo.WriteTo returns (int64, error). So we have to
   149  		// cast.
   150  		i += int64(got)
   151  		if err != nil {
   152  			return i, err
   153  		}
   154  		if int64(got) != chunk {
   155  			return 0, io.ErrShortWrite
   156  		}
   157  	}
   158  	return i, nil
   159  }
   160  
   161  // bufferPool is a pool of intermediateBuffers.
   162  type bufferPool struct {
   163  	f func() intermediateBuffer
   164  	c chan intermediateBuffer
   165  }
   166  
   167  func newBufferPool(size int64, f func() intermediateBuffer) bufferPool {
   168  	return bufferPool{
   169  		f: f,
   170  		c: make(chan intermediateBuffer, size),
   171  	}
   172  }
   173  
   174  // Put returns a buffer to the pool for later use.
   175  func (bp bufferPool) Put(b intermediateBuffer) {
   176  	// Non-blocking write in case pool has filled up (too many buffers
   177  	// returned, none being used).
   178  	select {
   179  	case bp.c <- b:
   180  	default:
   181  	}
   182  }
   183  
   184  // Get returns a buffer from the pool or allocates a new buffer if none is
   185  // available.
   186  func (bp bufferPool) Get() intermediateBuffer {
   187  	select {
   188  	case buf := <-bp.c:
   189  		return buf
   190  	default:
   191  		return bp.f()
   192  	}
   193  }
   194  
   195  func (bp bufferPool) Destroy() {
   196  	close(bp.c)
   197  }
   198  
   199  func parallelChunkedCopy(r io.Reader, w io.Writer, inBufSize, outBufSize int64, flags int) error {
   200  	// Make the channels deep enough to hold a total of 1GiB of data.
   201  	depth := (1024 * 1024 * 1024) / inBufSize
   202  	// But keep it reasonable!
   203  	if depth > 8192 {
   204  		depth = 8192
   205  	}
   206  
   207  	readyBufs := make(chan intermediateBuffer, depth)
   208  	pool := newBufferPool(depth, func() intermediateBuffer {
   209  		return newChunkedBuffer(inBufSize, outBufSize, flags)
   210  	})
   211  	defer pool.Destroy()
   212  
   213  	// Closing quit makes both goroutines below exit.
   214  	quit := make(chan struct{})
   215  
   216  	// errs contains the error value to be returned.
   217  	errs := make(chan error, 1)
   218  	defer close(errs)
   219  
   220  	var wg sync.WaitGroup
   221  	wg.Add(1)
   222  	go func() {
   223  		// Closing this unblocks the writing for-loop below.
   224  		defer close(readyBufs)
   225  		defer wg.Done()
   226  
   227  		for {
   228  			select {
   229  			case <-quit:
   230  				return
   231  			default:
   232  				buf := pool.Get()
   233  				n, err := buf.ReadFrom(r)
   234  				if n > 0 {
   235  					readyBufs <- buf
   236  				}
   237  				if err == io.EOF {
   238  					return
   239  				}
   240  				if n == 0 || err != nil {
   241  					errs <- fmt.Errorf("input error: %v", err)
   242  					return
   243  				}
   244  			}
   245  		}
   246  	}()
   247  
   248  	var writeErr error
   249  	for buf := range readyBufs {
   250  		if n, err := buf.WriteTo(w); err != nil {
   251  			writeErr = fmt.Errorf("output error: %v", err)
   252  			break
   253  		} else {
   254  			atomic.AddInt64(&bytesWritten, n)
   255  		}
   256  		pool.Put(buf)
   257  	}
   258  
   259  	// This will force the goroutine to quit if an error occurred writing.
   260  	close(quit)
   261  
   262  	// Wait for goroutine to exit.
   263  	wg.Wait()
   264  
   265  	select {
   266  	case readErr := <-errs:
   267  		return readErr
   268  	default:
   269  		return writeErr
   270  	}
   271  }
   272  
   273  // sectionReader implements a SectionReader on an underlying implementation of
   274  // io.Reader (as opposed to io.SectionReader which uses io.ReaderAt).
   275  type sectionReader struct {
   276  	base   int64
   277  	offset int64
   278  	limit  int64
   279  	io.Reader
   280  }
   281  
   282  // newStreamSectionReader uses an io.Reader to implement an io.Reader that
   283  // seeks to offset and reads at most n bytes.
   284  //
   285  // This is useful if you want to use a NewSectionReader with stdin or other
   286  // types of pipes (things that can't be seek'd or pread from).
   287  func newStreamSectionReader(r io.Reader, offset int64, n int64) io.Reader {
   288  	limit := offset + n
   289  	if limit < 0 {
   290  		limit = math.MaxInt64
   291  	}
   292  	return &sectionReader{offset, 0, limit, r}
   293  }
   294  
   295  // Read implements io.Reader.
   296  func (s *sectionReader) Read(p []byte) (int, error) {
   297  	if s.offset == 0 && s.base != 0 {
   298  		if n, err := io.CopyN(ioutil.Discard, s.Reader, s.base); err != nil {
   299  			return 0, err
   300  		} else if n != s.base {
   301  			// Can't happen.
   302  			return 0, fmt.Errorf("error skipping input bytes, short write")
   303  		}
   304  		s.offset = s.base
   305  	}
   306  
   307  	if s.offset >= s.limit {
   308  		return 0, io.EOF
   309  	}
   310  
   311  	if max := s.limit - s.offset; int64(len(p)) > max {
   312  		p = p[0:max]
   313  	}
   314  
   315  	n, err := s.Reader.Read(p)
   316  	s.offset += int64(n)
   317  
   318  	// Convert to io.EOF explicitly.
   319  	if n == 0 && err == nil {
   320  		return 0, io.EOF
   321  	}
   322  	return n, err
   323  }
   324  
   325  // inFile opens the input file and seeks to the right position.
   326  func inFile(name string, inputBytes int64, skip int64, count int64) (io.Reader, error) {
   327  	maxRead := int64(math.MaxInt64)
   328  	if count != math.MaxInt64 {
   329  		maxRead = count * inputBytes
   330  	}
   331  
   332  	if name == "" {
   333  		// os.Stdin is an io.ReaderAt, but you can't actually call
   334  		// pread(2) on it, so use the copying section reader.
   335  		return newStreamSectionReader(os.Stdin, inputBytes*skip, maxRead), nil
   336  	}
   337  
   338  	in, err := os.Open(name)
   339  	if err != nil {
   340  		return nil, fmt.Errorf("error opening input file %q: %v", name, err)
   341  	}
   342  	return io.NewSectionReader(in, inputBytes*skip, maxRead), nil
   343  }
   344  
   345  // outFile opens the output file and seeks to the right position.
   346  func outFile(name string, outputBytes int64, seek int64, flags int) (io.Writer, error) {
   347  	var out io.WriteSeeker
   348  	var err error
   349  	if name == "" {
   350  		out = os.Stdout
   351  	} else {
   352  		perm := os.O_CREATE | os.O_WRONLY | (flags & allowedFlags)
   353  		if out, err = os.OpenFile(name, perm, 0666); err != nil {
   354  			return nil, fmt.Errorf("error opening output file %q: %v", name, err)
   355  		}
   356  	}
   357  	if seek*outputBytes != 0 {
   358  		if _, err := out.Seek(seek*outputBytes, io.SeekCurrent); err != nil {
   359  			return nil, fmt.Errorf("error seeking output file: %v", err)
   360  		}
   361  	}
   362  	return out, nil
   363  }
   364  
   365  type progressData struct {
   366  	mode     string // one of: none, xfer, progress
   367  	start    time.Time
   368  	variable *int64 // must be aligned for atomic operations
   369  	quit     chan struct{}
   370  }
   371  
   372  func progressBegin(mode string, variable *int64) (ProgressData *progressData) {
   373  	p := &progressData{
   374  		mode:     mode,
   375  		start:    time.Now(),
   376  		variable: variable,
   377  	}
   378  	if p.mode == "progress" {
   379  		p.print()
   380  		// Print progress in a separate goroutine.
   381  		p.quit = make(chan struct{}, 1)
   382  		go func() {
   383  			ticker := time.NewTicker(1 * time.Second)
   384  			defer ticker.Stop()
   385  			for {
   386  				select {
   387  				case <-ticker.C:
   388  					p.print()
   389  				case <-p.quit:
   390  					return
   391  				}
   392  			}
   393  		}()
   394  	}
   395  	return p
   396  }
   397  
   398  func (p *progressData) end() {
   399  	if p.mode == "progress" {
   400  		// Properly synchronize goroutine.
   401  		p.quit <- struct{}{}
   402  		p.quit <- struct{}{}
   403  	}
   404  	if p.mode == "progress" || p.mode == "xfer" {
   405  		// Print grand total.
   406  		p.print()
   407  		fmt.Fprint(os.Stderr, "\n")
   408  	}
   409  }
   410  
   411  // With "status=progress", this is called from 3 places:
   412  // - Once at the beginning to appear responsive
   413  // - Every 1s afterwards
   414  // - Once at the end so the final value is accurate
   415  func (p *progressData) print() {
   416  	elapse := time.Since(p.start)
   417  	n := atomic.LoadInt64(p.variable)
   418  	d := float64(n)
   419  	const mib = 1024 * 1024
   420  	const mb = 1000 * 1000
   421  	// The ANSI escape may be undesirable to some eyes.
   422  	if p.mode == "progress" {
   423  		os.Stderr.Write([]byte("\033[2K\r"))
   424  	}
   425  	fmt.Fprintf(os.Stderr, "%d bytes (%.3f MB, %.3f MiB) copied, %.3f s, %.3f MB/s",
   426  		n, d/mb, d/mib, elapse.Seconds(), float64(d)/elapse.Seconds()/mb)
   427  }
   428  
   429  func usage() {
   430  	log.Fatal(`Usage: dd [if=file] [of=file] [conv=none|notrunc] [seek=#] [skip=#]
   431  			     [count=#] [bs=#] [ibs=#] [obs=#] [status=none|xfer|progress] [oflag=none|sync|dsync]
   432  		options may also be invoked Go-style as -opt value or -opt=value
   433  		bs, if specified, overrides ibs and obs`)
   434  }
   435  
   436  func convertArgs(osArgs []string) []string {
   437  	// EVERYTHING in dd follows x=y. So blindly split and convert.
   438  	var args []string
   439  	for _, v := range osArgs {
   440  		l := strings.SplitN(v, "=", 2)
   441  
   442  		// We only fix the exact case for x=y.
   443  		if len(l) == 2 {
   444  			l[0] = "-" + l[0]
   445  		}
   446  
   447  		args = append(args, l...)
   448  	}
   449  	return args
   450  }
   451  
   452  func main() {
   453  	// rather than, in essence, recreating all the apparatus of flag.xxxx
   454  	// with the if= bits, including dup checking, conversion, etc. we just
   455  	// convert the arguments and then run flag.Parse. Gross, but hey, it
   456  	// works.
   457  	os.Args = convertArgs(os.Args)
   458  	flag.Parse()
   459  
   460  	if len(flag.Args()) > 0 {
   461  		usage()
   462  	}
   463  
   464  	// Convert conv argument to bit set.
   465  	flags := os.O_TRUNC
   466  	if *conv != "none" {
   467  		for _, c := range strings.Split(*conv, ",") {
   468  			if v, ok := convMap[c]; ok {
   469  				flags &= ^v.clear
   470  				flags |= v.set
   471  			} else {
   472  				log.Printf("unknown argument conv=%s", c)
   473  				usage()
   474  			}
   475  		}
   476  	}
   477  
   478  	// Convert oflag argument to bit set.
   479  	if *oFlag != "none" {
   480  		for _, f := range strings.Split(*oFlag, ",") {
   481  			if v, ok := flagMap[f]; ok {
   482  				flags &= ^v.clear
   483  				flags |= v.set
   484  			} else {
   485  				log.Printf("unknown argument oflag=%s", f)
   486  				usage()
   487  			}
   488  		}
   489  	}
   490  
   491  	if *status != "none" && *status != "xfer" && *status != "progress" {
   492  		usage()
   493  	}
   494  	progress := progressBegin(*status, &bytesWritten)
   495  
   496  	// bs = both 'ibs' and 'obs' (IEEE Std 1003.1 - 2013)
   497  	if bs.IsSet {
   498  		ibs = bs
   499  		obs = bs
   500  	}
   501  
   502  	in, err := inFile(*inName, ibs.Value, *skip, *count)
   503  	if err != nil {
   504  		log.Fatal(err)
   505  	}
   506  	out, err := outFile(*outName, obs.Value, *seek, flags)
   507  	if err != nil {
   508  		log.Fatal(err)
   509  	}
   510  	if err := parallelChunkedCopy(in, out, ibs.Value, obs.Value, flags); err != nil {
   511  		log.Fatal(err)
   512  	}
   513  
   514  	progress.end()
   515  }