github.com/oweisse/u-root@v0.0.0-20181109060735-d005ad25fef1/cmds/dd/dd.go (about)

     1  // Copyright 2013-2017 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Convert and copy a file.
     6  //
     7  // Synopsis:
     8  //     dd [OPTIONS...] [-inName FILE] [-outName FILE]
     9  //
    10  // Description:
    11  //     dd is modeled after dd(1).
    12  //
    13  // Options:
    14  //     -ibs n:   input block size (default=1)
    15  //     -obs n:   output block size (default=1)
    16  //     -bs n:    input and output block size (default=0)
    17  //     -skip n:  skip n ibs-sized input blocks before reading (default=0)
    18  //     -seek n:  seek n obs-sized output blocks before writing (default=0)
    19  //     -conv s:  comma separated list of conversions (none|notrunc)
    20  //     -count n: copy only n ibs-sized input blocks
    21  //     -if:      defaults to stdin
    22  //     -of:      defaults to stdout
    23  //     -oflag:   comma separated list of out flags (none|sync|dsync)
    24  //     -status:  print transfer stats to stderr, can be one of:
    25  //         none:     do not display
    26  //         xfer:     print on completion (default)
    27  //         progress: print throughout transfer (GNU)
    28  //
    29  // Notes:
    30  //     Because UTF-8 clashes with block-oriented copying, `conv=lcase` and
    31  //     `conv=ucase` will not be supported. Additionally, research showed these
    32  //     arguments are rarely useful. Use tr instead.
    33  package main
    34  
    35  import (
    36  	"flag"
    37  	"fmt"
    38  	"io"
    39  	"io/ioutil"
    40  	"log"
    41  	"math"
    42  	"os"
    43  	"strings"
    44  	"sync"
    45  	"sync/atomic"
    46  	"syscall"
    47  	"time"
    48  
    49  	"github.com/rck/unit"
    50  )
    51  
    52  var (
    53  	ibs, obs, bs *unit.Value
    54  	skip         = flag.Int64("skip", 0, "skip N ibs-sized blocks before reading")
    55  	seek         = flag.Int64("seek", 0, "seek N obs-sized blocks before writing")
    56  	conv         = flag.String("conv", "none", "comma separated list of conversions (none|notrunc)")
    57  	count        = flag.Int64("count", math.MaxInt64, "copy only N input blocks")
    58  	inName       = flag.String("if", "", "Input file")
    59  	outName      = flag.String("of", "", "Output file")
    60  	oFlag        = flag.String("oflag", "none", "comma separated list of out flags (none|sync|dsync)")
    61  	status       = flag.String("status", "xfer", "display status of transfer (none|xfer|progress)")
    62  
    63  	bytesWritten int64 // access atomically, must be global for correct alignedness
    64  )
    65  
    66  const (
    67  	convNoTrunc = 1 << iota
    68  	oFlagSync
    69  	oFlagDSync
    70  )
    71  
    72  var convMap = map[string]int{
    73  	"notrunc": convNoTrunc,
    74  }
    75  
    76  var flagMap = map[string]int{
    77  	"sync":  oFlagSync,
    78  	"dsync": oFlagDSync,
    79  }
    80  
    81  // intermediateBuffer is a buffer that one can write to and read from.
    82  type intermediateBuffer interface {
    83  	io.ReaderFrom
    84  	io.WriterTo
    85  }
    86  
    87  // chunkedBuffer is an intermediateBuffer with a specific size.
    88  type chunkedBuffer struct {
    89  	outChunk int64
    90  	length   int64
    91  	data     []byte
    92  	flags    int
    93  }
    94  
    95  func init() {
    96  	ddUnits := unit.DefaultUnits
    97  	ddUnits["c"] = 1
    98  	ddUnits["w"] = 2
    99  	ddUnits["b"] = 512
   100  	delete(ddUnits, "B")
   101  
   102  	ibs = unit.MustNewUnit(ddUnits).MustNewValue(512, unit.None)
   103  	obs = unit.MustNewUnit(ddUnits).MustNewValue(512, unit.None)
   104  	bs = unit.MustNewUnit(ddUnits).MustNewValue(512, unit.None)
   105  
   106  	flag.Var(ibs, "ibs", "Default input block size")
   107  	flag.Var(obs, "obs", "Default output block size")
   108  	flag.Var(bs, "bs", "Default input and output block size")
   109  }
   110  
   111  // newChunkedBuffer returns an intermediateBuffer that stores inChunkSize-sized
   112  // chunks of data and writes them to writers in outChunkSize-sized chunks.
   113  func newChunkedBuffer(inChunkSize int64, outChunkSize int64, flags int) intermediateBuffer {
   114  	return &chunkedBuffer{
   115  		outChunk: outChunkSize,
   116  		length:   0,
   117  		data:     make([]byte, inChunkSize),
   118  		flags:    flags,
   119  	}
   120  }
   121  
   122  // ReadFrom reads an inChunkSize-sized chunk from r into the buffer.
   123  func (cb *chunkedBuffer) ReadFrom(r io.Reader) (int64, error) {
   124  	n, err := r.Read(cb.data)
   125  	cb.length = int64(n)
   126  
   127  	// Convert to EOF explicitly.
   128  	if n == 0 && err == nil {
   129  		return 0, io.EOF
   130  	}
   131  	return int64(n), err
   132  }
   133  
   134  // WriteTo writes from the buffer to w in outChunkSize-sized chunks.
   135  func (cb *chunkedBuffer) WriteTo(w io.Writer) (int64, error) {
   136  	var i int64
   137  	for i = 0; i < int64(cb.length); {
   138  		chunk := cb.outChunk
   139  		if i+chunk > cb.length {
   140  			chunk = cb.length - i
   141  		}
   142  		block := cb.data[i : i+chunk]
   143  		got, err := w.Write(block)
   144  
   145  		// Ugh, Go cruft: io.Writer.Write returns (int, error).
   146  		// io.WriterTo.WriteTo returns (int64, error). So we have to
   147  		// cast.
   148  		i += int64(got)
   149  		if err != nil {
   150  			return i, err
   151  		}
   152  		if int64(got) != chunk {
   153  			return 0, io.ErrShortWrite
   154  		}
   155  	}
   156  	return i, nil
   157  }
   158  
   159  // bufferPool is a pool of intermediateBuffers.
   160  type bufferPool struct {
   161  	f func() intermediateBuffer
   162  	c chan intermediateBuffer
   163  }
   164  
   165  func newBufferPool(size int64, f func() intermediateBuffer) bufferPool {
   166  	return bufferPool{
   167  		f: f,
   168  		c: make(chan intermediateBuffer, size),
   169  	}
   170  }
   171  
   172  // Put returns a buffer to the pool for later use.
   173  func (bp bufferPool) Put(b intermediateBuffer) {
   174  	// Non-blocking write in case pool has filled up (too many buffers
   175  	// returned, none being used).
   176  	select {
   177  	case bp.c <- b:
   178  	default:
   179  	}
   180  }
   181  
   182  // Get returns a buffer from the pool or allocates a new buffer if none is
   183  // available.
   184  func (bp bufferPool) Get() intermediateBuffer {
   185  	select {
   186  	case buf := <-bp.c:
   187  		return buf
   188  	default:
   189  		return bp.f()
   190  	}
   191  }
   192  
   193  func (bp bufferPool) Destroy() {
   194  	close(bp.c)
   195  }
   196  
   197  func parallelChunkedCopy(r io.Reader, w io.Writer, inBufSize, outBufSize int64, flags int) error {
   198  	// Make the channels deep enough to hold a total of 1GiB of data.
   199  	depth := (1024 * 1024 * 1024) / inBufSize
   200  	// But keep it reasonable!
   201  	if depth > 8192 {
   202  		depth = 8192
   203  	}
   204  
   205  	readyBufs := make(chan intermediateBuffer, depth)
   206  	pool := newBufferPool(depth, func() intermediateBuffer {
   207  		return newChunkedBuffer(inBufSize, outBufSize, flags)
   208  	})
   209  	defer pool.Destroy()
   210  
   211  	// Closing quit makes both goroutines below exit.
   212  	quit := make(chan struct{})
   213  
   214  	// errs contains the error value to be returned.
   215  	errs := make(chan error, 1)
   216  	defer close(errs)
   217  
   218  	var wg sync.WaitGroup
   219  	wg.Add(1)
   220  	go func() {
   221  		// Closing this unblocks the writing for-loop below.
   222  		defer close(readyBufs)
   223  		defer wg.Done()
   224  
   225  		for {
   226  			select {
   227  			case <-quit:
   228  				return
   229  			default:
   230  				buf := pool.Get()
   231  				n, err := buf.ReadFrom(r)
   232  				if n > 0 {
   233  					readyBufs <- buf
   234  				}
   235  				if err == io.EOF {
   236  					return
   237  				}
   238  				if n == 0 || err != nil {
   239  					errs <- fmt.Errorf("input error: %v", err)
   240  					return
   241  				}
   242  			}
   243  		}
   244  	}()
   245  
   246  	var writeErr error
   247  	for buf := range readyBufs {
   248  		if n, err := buf.WriteTo(w); err != nil {
   249  			writeErr = fmt.Errorf("output error: %v", err)
   250  			break
   251  		} else {
   252  			atomic.AddInt64(&bytesWritten, n)
   253  		}
   254  		pool.Put(buf)
   255  	}
   256  
   257  	// This will force the goroutine to quit if an error occurred writing.
   258  	close(quit)
   259  
   260  	// Wait for goroutine to exit.
   261  	wg.Wait()
   262  
   263  	select {
   264  	case readErr := <-errs:
   265  		return readErr
   266  	default:
   267  		return writeErr
   268  	}
   269  }
   270  
   271  // sectionReader implements a SectionReader on an underlying implementation of
   272  // io.Reader (as opposed to io.SectionReader which uses io.ReaderAt).
   273  type sectionReader struct {
   274  	base   int64
   275  	offset int64
   276  	limit  int64
   277  	io.Reader
   278  }
   279  
   280  // newStreamSectionReader uses an io.Reader to implement an io.Reader that
   281  // seeks to offset and reads at most n bytes.
   282  //
   283  // This is useful if you want to use a NewSectionReader with stdin or other
   284  // types of pipes (things that can't be seek'd or pread from).
   285  func newStreamSectionReader(r io.Reader, offset int64, n int64) io.Reader {
   286  	limit := offset + n
   287  	if limit < 0 {
   288  		limit = math.MaxInt64
   289  	}
   290  	return &sectionReader{offset, 0, limit, r}
   291  }
   292  
   293  // Read implements io.Reader.
   294  func (s *sectionReader) Read(p []byte) (int, error) {
   295  	if s.offset == 0 && s.base != 0 {
   296  		if n, err := io.CopyN(ioutil.Discard, s.Reader, s.base); err != nil {
   297  			return 0, err
   298  		} else if n != s.base {
   299  			// Can't happen.
   300  			return 0, fmt.Errorf("error skipping input bytes, short write")
   301  		}
   302  		s.offset = s.base
   303  	}
   304  
   305  	if s.offset >= s.limit {
   306  		return 0, io.EOF
   307  	}
   308  
   309  	if max := s.limit - s.offset; int64(len(p)) > max {
   310  		p = p[0:max]
   311  	}
   312  
   313  	n, err := s.Reader.Read(p)
   314  	s.offset += int64(n)
   315  
   316  	// Convert to io.EOF explicitly.
   317  	if n == 0 && err == nil {
   318  		return 0, io.EOF
   319  	}
   320  	return n, err
   321  }
   322  
   323  // inFile opens the input file and seeks to the right position.
   324  func inFile(name string, inputBytes int64, skip int64, count int64) (io.Reader, error) {
   325  	maxRead := int64(math.MaxInt64)
   326  	if count != math.MaxInt64 {
   327  		maxRead = count * inputBytes
   328  	}
   329  
   330  	if name == "" {
   331  		// os.Stdin is an io.ReaderAt, but you can't actually call
   332  		// pread(2) on it, so use the copying section reader.
   333  		return newStreamSectionReader(os.Stdin, inputBytes*skip, maxRead), nil
   334  	}
   335  
   336  	in, err := os.Open(name)
   337  	if err != nil {
   338  		return nil, fmt.Errorf("error opening input file %q: %v", name, err)
   339  	}
   340  	return io.NewSectionReader(in, inputBytes*skip, maxRead), nil
   341  }
   342  
   343  // outFile opens the output file and seeks to the right position.
   344  func outFile(name string, outputBytes int64, seek int64, flags int) (io.Writer, error) {
   345  	var out io.WriteSeeker
   346  	var err error
   347  	if name == "" {
   348  		out = os.Stdout
   349  	} else {
   350  		perm := os.O_CREATE | os.O_WRONLY
   351  		if flags&convNoTrunc == 0 {
   352  			perm |= os.O_TRUNC
   353  		}
   354  		if flags&oFlagSync != 0 {
   355  			perm |= os.O_SYNC
   356  		}
   357  		if flags&oFlagDSync != 0 {
   358  			perm |= syscall.O_DSYNC
   359  		}
   360  		if out, err = os.OpenFile(name, perm, 0666); err != nil {
   361  			return nil, fmt.Errorf("error opening output file %q: %v", name, err)
   362  		}
   363  	}
   364  	if seek*outputBytes != 0 {
   365  		if _, err := out.Seek(seek*outputBytes, io.SeekCurrent); err != nil {
   366  			return nil, fmt.Errorf("error seeking output file: %v", err)
   367  		}
   368  	}
   369  	return out, nil
   370  }
   371  
   372  type progressData struct {
   373  	mode     string // one of: none, xfer, progress
   374  	start    time.Time
   375  	variable *int64 // must be aligned for atomic operations
   376  	quit     chan struct{}
   377  }
   378  
   379  func progressBegin(mode string, variable *int64) (ProgressData *progressData) {
   380  	p := &progressData{
   381  		mode:     mode,
   382  		start:    time.Now(),
   383  		variable: variable,
   384  	}
   385  	if p.mode == "progress" {
   386  		p.print()
   387  		// Print progress in a separate goroutine.
   388  		p.quit = make(chan struct{}, 1)
   389  		go func() {
   390  			ticker := time.NewTicker(1 * time.Second)
   391  			defer ticker.Stop()
   392  			for {
   393  				select {
   394  				case <-ticker.C:
   395  					p.print()
   396  				case <-p.quit:
   397  					return
   398  				}
   399  			}
   400  		}()
   401  	}
   402  	return p
   403  }
   404  
   405  func (p *progressData) end() {
   406  	if p.mode == "progress" {
   407  		// Properly synchronize goroutine.
   408  		p.quit <- struct{}{}
   409  		p.quit <- struct{}{}
   410  	}
   411  	if p.mode == "progress" || p.mode == "xfer" {
   412  		// Print grand total.
   413  		p.print()
   414  		fmt.Fprint(os.Stderr, "\n")
   415  	}
   416  }
   417  
   418  // With "status=progress", this is called from 3 places:
   419  // - Once at the beginning to appear responsive
   420  // - Every 1s afterwards
   421  // - Once at the end so the final value is accurate
   422  func (p *progressData) print() {
   423  	elapse := time.Since(p.start)
   424  	n := atomic.LoadInt64(p.variable)
   425  	d := float64(n)
   426  	const mib = 1024 * 1024
   427  	const mb = 1000 * 1000
   428  	// The ANSI escape may be undesirable to some eyes.
   429  	if p.mode == "progress" {
   430  		os.Stderr.Write([]byte("\033[2K\r"))
   431  	}
   432  	fmt.Fprintf(os.Stderr, "%d bytes (%.3f MB, %.3f MiB) copied, %.3f s, %.3f MB/s",
   433  		n, d/mb, d/mib, elapse.Seconds(), float64(d)/elapse.Seconds()/mb)
   434  }
   435  
   436  func usage() {
   437  	log.Fatal(`Usage: dd [if=file] [of=file] [conv=none|notrunc] [seek=#] [skip=#]
   438  			     [count=#] [bs=#] [ibs=#] [obs=#] [status=none|xfer|progress] [oflag=none|sync|dsync]
   439  		options may also be invoked Go-style as -opt value or -opt=value
   440  		bs, if specified, overrides ibs and obs`)
   441  }
   442  
   443  func convertArgs(osArgs []string) []string {
   444  	// EVERYTHING in dd follows x=y. So blindly split and convert.
   445  	var args []string
   446  	for _, v := range osArgs {
   447  		l := strings.SplitN(v, "=", 2)
   448  
   449  		// We only fix the exact case for x=y.
   450  		if len(l) == 2 {
   451  			l[0] = "-" + l[0]
   452  		}
   453  
   454  		args = append(args, l...)
   455  	}
   456  	return args
   457  }
   458  
   459  func main() {
   460  	// rather than, in essence, recreating all the apparatus of flag.xxxx
   461  	// with the if= bits, including dup checking, conversion, etc. we just
   462  	// convert the arguments and then run flag.Parse. Gross, but hey, it
   463  	// works.
   464  	os.Args = convertArgs(os.Args)
   465  	flag.Parse()
   466  
   467  	if len(flag.Args()) > 0 {
   468  		usage()
   469  	}
   470  
   471  	// Convert conv argument to bit set.
   472  	flags := 0
   473  	if *conv != "none" {
   474  		for _, c := range strings.Split(*conv, ",") {
   475  			if v, ok := convMap[c]; ok {
   476  				flags |= v
   477  			} else {
   478  				log.Printf("unknown argument conv=%s", c)
   479  				usage()
   480  			}
   481  		}
   482  	}
   483  
   484  	// Convert oflag argument to bit set.
   485  	if *oFlag != "none" {
   486  		for _, f := range strings.Split(*oFlag, ",") {
   487  			if v, ok := flagMap[f]; ok {
   488  				flags |= v
   489  			} else {
   490  				log.Printf("unknown argument oflag=%s", f)
   491  				usage()
   492  			}
   493  		}
   494  	}
   495  
   496  	if *status != "none" && *status != "xfer" && *status != "progress" {
   497  		usage()
   498  	}
   499  	progress := progressBegin(*status, &bytesWritten)
   500  
   501  	// bs = both 'ibs' and 'obs' (IEEE Std 1003.1 - 2013)
   502  	if bs.IsSet {
   503  		ibs = bs
   504  		obs = bs
   505  	}
   506  
   507  	in, err := inFile(*inName, ibs.Value, *skip, *count)
   508  	if err != nil {
   509  		log.Fatal(err)
   510  	}
   511  	out, err := outFile(*outName, obs.Value, *seek, flags)
   512  	if err != nil {
   513  		log.Fatal(err)
   514  	}
   515  	if err := parallelChunkedCopy(in, out, ibs.Value, obs.Value, flags); err != nil {
   516  		log.Fatal(err)
   517  	}
   518  
   519  	progress.end()
   520  }