github.com/rck/u-root@v0.0.0-20180106144920-7eb602e381bb/cmds/dd/dd.go (about)

     1  // Copyright 2013-2017 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Convert and copy a file.
     6  //
     7  // Synopsis:
     8  //     dd [OPTIONS...] [-inName FILE] [-outName FILE]
     9  //
    10  // Description:
    11  //     dd is modeled after dd(1).
    12  //
    13  // Options:
    14  //     -ibs n:   input block size (default=1)
    15  //     -obs n:   output block size (default=1)
    16  //     -bs n:    input and output block size (default=0)
    17  //     -skip n:  skip n ibs-sized input blocks before reading (default=0)
    18  //     -seek n:  seek n obs-sized output blocks before writing (default=0)
    19  //     -conv s:  Convert the file on a specific way, like notrunc
    20  //     -count n: copy only n ibs-sized input blocks
    21  //     -inName:  defaults to stdin
    22  //     -outName: defaults to stdout
    23  //     -status:  print transfer stats to stderr, can be one of:
    24  //         none:     do not display
    25  //         xfer:     print on completion (default)
    26  //         progress: print throughout transfer (GNU)
    27  package main
    28  
    29  import (
    30  	"bytes"
    31  	"flag"
    32  	"fmt"
    33  	"io"
    34  	"io/ioutil"
    35  	"log"
    36  	"math"
    37  	"os"
    38  	"strings"
    39  	"sync"
    40  	"sync/atomic"
    41  	"time"
    42  
    43  	"github.com/rck/unit"
    44  )
    45  
    46  var (
    47  	ibs, obs, bs *unit.Value
    48  	skip         = flag.Int64("skip", 0, "skip N ibs-sized blocks before reading")
    49  	seek         = flag.Int64("seek", 0, "seek N obs-sized blocks before writing")
    50  	conv         = flag.String("conv", "none", "Convert the file on a specific way, like notrunc")
    51  	count        = flag.Int64("count", math.MaxInt64, "copy only N input blocks")
    52  	inName       = flag.String("if", "", "Input file")
    53  	outName      = flag.String("of", "", "Output file")
    54  	status       = flag.String("status", "xfer", "display status of transfer (none|xfer|progress)")
    55  
    56  	bytesWritten int64 // access atomically, must be global for correct alignedness
    57  )
    58  
    59  // intermediateBuffer is a buffer that one can write to and read from.
    60  type intermediateBuffer interface {
    61  	io.ReaderFrom
    62  	io.WriterTo
    63  }
    64  
    65  // chunkedBuffer is an intermediateBuffer with a specific size.
    66  type chunkedBuffer struct {
    67  	outChunk  int64
    68  	length    int64
    69  	data      []byte
    70  	transform func([]byte) []byte
    71  }
    72  
    73  func init() {
    74  	ddUnits := unit.DefaultUnits
    75  	ddUnits["c"] = 1
    76  	ddUnits["w"] = 2
    77  	ddUnits["b"] = 512
    78  	delete(ddUnits, "B")
    79  
    80  	ibs = unit.MustNewUnit(ddUnits).MustNewValue(512, unit.None)
    81  	obs = unit.MustNewUnit(ddUnits).MustNewValue(512, unit.None)
    82  	bs = unit.MustNewUnit(ddUnits).MustNewValue(512, unit.None)
    83  
    84  	flag.Var(ibs, "ibs", "Default input block size")
    85  	flag.Var(obs, "obs", "Default output block size")
    86  	flag.Var(bs, "bs", "Default input and output block size")
    87  }
    88  
    89  // newChunkedBuffer returns an intermediateBuffer that stores inChunkSize-sized
    90  // chunks of data and writes them to writers in outChunkSize-sized chunks.
    91  func newChunkedBuffer(inChunkSize int64, outChunkSize int64, transform func([]byte) []byte) intermediateBuffer {
    92  	return &chunkedBuffer{
    93  		outChunk:  outChunkSize,
    94  		length:    0,
    95  		data:      make([]byte, inChunkSize),
    96  		transform: transform,
    97  	}
    98  }
    99  
   100  // ReadFrom reads an inChunkSize-sized chunk from r into the buffer.
   101  func (cb *chunkedBuffer) ReadFrom(r io.Reader) (int64, error) {
   102  	n, err := r.Read(cb.data)
   103  	cb.length = int64(n)
   104  
   105  	// Convert to EOF explicitly.
   106  	if n == 0 && err == nil {
   107  		return 0, io.EOF
   108  	}
   109  	return int64(n), err
   110  }
   111  
   112  // WriteTo writes from the buffer to w in outChunkSize-sized chunks.
   113  func (cb *chunkedBuffer) WriteTo(w io.Writer) (int64, error) {
   114  	var i int64
   115  	for i = 0; i < int64(cb.length); {
   116  		chunk := cb.outChunk
   117  		if i+chunk > cb.length {
   118  			chunk = cb.length - i
   119  		}
   120  		got, err := w.Write(cb.transform(cb.data[i : i+chunk]))
   121  		// Ugh, Go cruft: io.Writer.Write returns (int, error).
   122  		// io.WriterTo.WriteTo returns (int64, error). So we have to
   123  		// cast.
   124  		i += int64(got)
   125  		if err != nil {
   126  			return i, err
   127  		}
   128  		if int64(got) != chunk {
   129  			return 0, io.ErrShortWrite
   130  		}
   131  	}
   132  	return i, nil
   133  }
   134  
   135  // bufferPool is a pool of intermediateBuffers.
   136  type bufferPool struct {
   137  	f func() intermediateBuffer
   138  	c chan intermediateBuffer
   139  }
   140  
   141  func newBufferPool(size int64, f func() intermediateBuffer) bufferPool {
   142  	return bufferPool{
   143  		f: f,
   144  		c: make(chan intermediateBuffer, size),
   145  	}
   146  }
   147  
   148  // Put returns a buffer to the pool for later use.
   149  func (bp bufferPool) Put(b intermediateBuffer) {
   150  	// Non-blocking write in case pool has filled up (too many buffers
   151  	// returned, none being used).
   152  	select {
   153  	case bp.c <- b:
   154  	default:
   155  	}
   156  }
   157  
   158  // Get returns a buffer from the pool or allocates a new buffer if none is
   159  // available.
   160  func (bp bufferPool) Get() intermediateBuffer {
   161  	select {
   162  	case buf := <-bp.c:
   163  		return buf
   164  	default:
   165  		return bp.f()
   166  	}
   167  }
   168  
   169  func (bp bufferPool) Destroy() {
   170  	close(bp.c)
   171  }
   172  
   173  func parallelChunkedCopy(r io.Reader, w io.Writer, inBufSize, outBufSize int64, transform func([]byte) []byte) error {
   174  	// Make the channels deep enough to hold a total of 1GiB of data.
   175  	depth := (1024 * 1024 * 1024) / inBufSize
   176  	// But keep it reasonable!
   177  	if depth > 8192 {
   178  		depth = 8192
   179  	}
   180  
   181  	readyBufs := make(chan intermediateBuffer, depth)
   182  	pool := newBufferPool(depth, func() intermediateBuffer {
   183  		return newChunkedBuffer(inBufSize, outBufSize, transform)
   184  	})
   185  	defer pool.Destroy()
   186  
   187  	// Closing quit makes both goroutines below exit.
   188  	quit := make(chan struct{})
   189  
   190  	// errs contains the error value to be returned.
   191  	errs := make(chan error, 1)
   192  	defer close(errs)
   193  
   194  	var wg sync.WaitGroup
   195  	wg.Add(1)
   196  	go func() {
   197  		// Closing this unblocks the writing for-loop below.
   198  		defer close(readyBufs)
   199  		defer wg.Done()
   200  
   201  		for {
   202  			select {
   203  			case <-quit:
   204  				return
   205  			default:
   206  				buf := pool.Get()
   207  				n, err := buf.ReadFrom(r)
   208  				if n > 0 {
   209  					readyBufs <- buf
   210  				}
   211  				if err == io.EOF {
   212  					return
   213  				}
   214  				if n == 0 || err != nil {
   215  					errs <- fmt.Errorf("input error: %v", err)
   216  					return
   217  				}
   218  			}
   219  		}
   220  	}()
   221  
   222  	var writeErr error
   223  	for buf := range readyBufs {
   224  		if n, err := buf.WriteTo(w); err != nil {
   225  			writeErr = fmt.Errorf("output error: %v", err)
   226  			break
   227  		} else {
   228  			atomic.AddInt64(&bytesWritten, n)
   229  		}
   230  		pool.Put(buf)
   231  	}
   232  
   233  	// This will force the goroutine to quit if an error occurred writing.
   234  	close(quit)
   235  
   236  	// Wait for goroutine to exit.
   237  	wg.Wait()
   238  
   239  	select {
   240  	case readErr := <-errs:
   241  		return readErr
   242  	default:
   243  		return writeErr
   244  	}
   245  }
   246  
   247  // sectionReader implements a SectionReader on an underlying implementation of
   248  // io.Reader (as opposed to io.SectionReader which uses io.ReaderAt).
   249  type sectionReader struct {
   250  	base   int64
   251  	offset int64
   252  	limit  int64
   253  	io.Reader
   254  }
   255  
   256  // newStreamSectionReader uses an io.Reader to implement an io.Reader that
   257  // seeks to offset and reads at most n bytes.
   258  //
   259  // This is useful if you want to use a NewSectionReader with stdin or other
   260  // types of pipes (things that can't be seek'd or pread from).
   261  func newStreamSectionReader(r io.Reader, offset int64, n int64) io.Reader {
   262  	limit := offset + n
   263  	if limit < 0 {
   264  		limit = math.MaxInt64
   265  	}
   266  	return &sectionReader{offset, 0, limit, r}
   267  }
   268  
   269  // Read implements io.Reader.
   270  func (s *sectionReader) Read(p []byte) (int, error) {
   271  	if s.offset == 0 && s.base != 0 {
   272  		if n, err := io.CopyN(ioutil.Discard, s.Reader, s.base); err != nil {
   273  			return 0, err
   274  		} else if n != s.base {
   275  			// Can't happen.
   276  			return 0, fmt.Errorf("error skipping input bytes, short write")
   277  		}
   278  		s.offset = s.base
   279  	}
   280  
   281  	if s.offset >= s.limit {
   282  		return 0, io.EOF
   283  	}
   284  
   285  	if max := s.limit - s.offset; int64(len(p)) > max {
   286  		p = p[0:max]
   287  	}
   288  
   289  	n, err := s.Reader.Read(p)
   290  	s.offset += int64(n)
   291  
   292  	// Convert to io.EOF explicitly.
   293  	if n == 0 && err == nil {
   294  		return 0, io.EOF
   295  	}
   296  	return n, err
   297  }
   298  
   299  // inFile opens the input file and seeks to the right position.
   300  func inFile(name string, inputBytes int64, skip int64, count int64) (io.Reader, error) {
   301  	maxRead := int64(math.MaxInt64)
   302  	if count != math.MaxInt64 {
   303  		maxRead = count * inputBytes
   304  	}
   305  
   306  	if name == "" {
   307  		// os.Stdin is an io.ReaderAt, but you can't actually call
   308  		// pread(2) on it, so use the copying section reader.
   309  		return newStreamSectionReader(os.Stdin, inputBytes*skip, maxRead), nil
   310  	}
   311  
   312  	in, err := os.Open(name)
   313  	if err != nil {
   314  		return nil, fmt.Errorf("error opening input file %q: %v", name, err)
   315  	}
   316  	return io.NewSectionReader(in, inputBytes*skip, maxRead), nil
   317  }
   318  
   319  // outFile opens the output file and seeks to the right position.
   320  func outFile(name string, outputBytes int64, seek int64) (io.Writer, error) {
   321  	var out io.WriteSeeker
   322  	var err error
   323  	if name == "" {
   324  		out = os.Stdout
   325  	} else {
   326  		if out, err = os.OpenFile(name, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0666); err != nil {
   327  			return nil, fmt.Errorf("error opening output file %q: %v", name, err)
   328  		}
   329  	}
   330  	if seek*outputBytes != 0 {
   331  		if _, err := out.Seek(seek*outputBytes, io.SeekCurrent); err != nil {
   332  			return nil, fmt.Errorf("error seeking output file: %v", err)
   333  		}
   334  	}
   335  	return out, nil
   336  }
   337  
   338  type progressData struct {
   339  	mode     string // one of: none, xfer, progress
   340  	start    time.Time
   341  	variable *int64 // must be aligned for atomic operations
   342  	quit     chan struct{}
   343  }
   344  
   345  func progressBegin(mode string, variable *int64) (ProgressData *progressData) {
   346  	p := &progressData{
   347  		mode:     mode,
   348  		start:    time.Now(),
   349  		variable: variable,
   350  	}
   351  	if p.mode == "progress" {
   352  		p.print()
   353  		// Print progress in a separate goroutine.
   354  		p.quit = make(chan struct{}, 1)
   355  		go func() {
   356  			ticker := time.NewTicker(1 * time.Second)
   357  			defer ticker.Stop()
   358  			for {
   359  				select {
   360  				case <-ticker.C:
   361  					p.print()
   362  				case <-p.quit:
   363  					return
   364  				}
   365  			}
   366  		}()
   367  	}
   368  	return p
   369  }
   370  
   371  func (p *progressData) end() {
   372  	if p.mode == "progress" {
   373  		// Properly synchronize goroutine.
   374  		p.quit <- struct{}{}
   375  		p.quit <- struct{}{}
   376  	}
   377  	if p.mode == "progress" || p.mode == "xfer" {
   378  		// Print grand total.
   379  		p.print()
   380  		fmt.Fprint(os.Stderr, "\n")
   381  	}
   382  }
   383  
   384  // With "status=progress", this is called from 3 places:
   385  // - Once at the beginning to appear responsive
   386  // - Every 1s afterwards
   387  // - Once at the end so the final value is accurate
   388  func (p *progressData) print() {
   389  	elapse := time.Since(p.start)
   390  	n := atomic.LoadInt64(p.variable)
   391  	d := float64(n)
   392  	const mib = 1024 * 1024
   393  	const mb = 1000 * 1000
   394  	// The ANSI escape may be undesirable to some eyes.
   395  	if p.mode == "progress" {
   396  		os.Stderr.Write([]byte("\033[2K\r"))
   397  	}
   398  	fmt.Fprintf(os.Stderr, "%d bytes (%.3f MB, %.3f MiB) copied, %.3f s, %.3f MB/s",
   399  		n, d/mb, d/mib, elapse.Seconds(), float64(d)/elapse.Seconds()/mb)
   400  }
   401  
   402  func usage() {
   403  	// If the conversions get more complex we can dump
   404  	// the convs map. For now, it's not really worth it.
   405  	log.Fatal(`Usage: dd [if=file] [of=file] [conv=lcase|ucase] [seek=#] [skip=#] [count=#] [bs=#] [ibs=#] [obs=#] [status=none|xfer|progress]
   406  		options may also be invoked Go-style as -opt value or -opt=value
   407  		bs, if specified, overrides ibs and obs`)
   408  }
   409  
   410  func convertArgs(osArgs []string) []string {
   411  	// EVERYTHING in dd follows x=y. So blindly split and convert.
   412  	var args []string
   413  	for _, v := range osArgs {
   414  		l := strings.SplitN(v, "=", 2)
   415  
   416  		// We only fix the exact case for x=y.
   417  		if len(l) == 2 {
   418  			l[0] = "-" + l[0]
   419  		}
   420  
   421  		args = append(args, l...)
   422  	}
   423  	return args
   424  }
   425  
   426  func main() {
   427  	// rather than, in essence, recreating all the apparatus of flag.xxxx
   428  	// with the if= bits, including dup checking, conversion, etc. we just
   429  	// convert the arguments and then run flag.Parse. Gross, but hey, it
   430  	// works.
   431  	os.Args = convertArgs(os.Args)
   432  	flag.Parse()
   433  
   434  	if len(flag.Args()) > 0 {
   435  		usage()
   436  	}
   437  
   438  	convs := map[string]func([]byte) []byte{
   439  		"none":  func(b []byte) []byte { return b },
   440  		"ucase": bytes.ToUpper,
   441  		"lcase": bytes.ToLower,
   442  	}
   443  	convert, ok := convs[*conv]
   444  	if !ok {
   445  		usage()
   446  	}
   447  
   448  	if *status != "none" && *status != "xfer" && *status != "progress" {
   449  		usage()
   450  	}
   451  	progress := progressBegin(*status, &bytesWritten)
   452  
   453  	// bs = both 'ibs' and 'obs' (IEEE Std 1003.1 - 2013)
   454  	if bs.IsSet {
   455  		ibs = bs
   456  		obs = bs
   457  	}
   458  
   459  	in, err := inFile(*inName, ibs.Value, *skip, *count)
   460  	if err != nil {
   461  		log.Fatal(err)
   462  	}
   463  	out, err := outFile(*outName, obs.Value, *seek)
   464  	if err != nil {
   465  		log.Fatal(err)
   466  	}
   467  	if err := parallelChunkedCopy(in, out, ibs.Value, obs.Value, convert); err != nil {
   468  		log.Fatal(err)
   469  	}
   470  
   471  	progress.end()
   472  }