github.com/mvdan/u-root-coreutils@v0.0.0-20230122170626-c2eef2898555/cmds/core/dd/dd.go (about)

     1  // Copyright 2013-2017 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // dd converts and copies a file.
     6  //
     7  // Synopsis:
     8  //
     9  //	dd [OPTIONS...] [-inName FILE] [-outName FILE]
    10  //
    11  // Description:
    12  //
    13  //	dd is modeled after dd(1).
    14  //
    15  // Options:
    16  //
    17  //	-ibs n:   input block size (default=1)
    18  //	-obs n:   output block size (default=1)
    19  //	-bs n:    input and output block size (default=0)
    20  //	-skip n:  skip n ibs-sized input blocks before reading (default=0)
    21  //	-seek n:  seek n obs-sized output blocks before writing (default=0)
    22  //	-conv s:  comma separated list of conversions (none|notrunc)
    23  //	-count n: copy only n ibs-sized input blocks
    24  //	-if:      defaults to stdin
    25  //	-of:      defaults to stdout
    26  //	-oflag:   comma separated list of out flags (none|sync|dsync)
    27  //	-status:  print transfer stats to stderr, can be one of:
    28  //	    none:     do not display
    29  //	    xfer:     print on completion (default)
    30  //	    progress: print throughout transfer (GNU)
    31  //
    32  // Notes:
    33  //
    34  //	Because UTF-8 clashes with block-oriented copying, `conv=lcase` and
    35  //	`conv=ucase` will not be supported. Additionally, research showed these
    36  //	arguments are rarely useful. Use tr instead.
    37  package main
    38  
    39  import (
    40  	"flag"
    41  	"fmt"
    42  	"io"
    43  	"log"
    44  	"math"
    45  	"os"
    46  	"strings"
    47  	"sync"
    48  	"sync/atomic"
    49  
    50  	"github.com/rck/unit"
    51  	"github.com/mvdan/u-root-coreutils/pkg/progress"
    52  )
    53  
    54  var (
    55  	ibs, obs, bs *unit.Value
    56  	skip         = flag.Int64("skip", 0, "skip N ibs-sized blocks before reading")
    57  	seek         = flag.Int64("seek", 0, "seek N obs-sized blocks before writing")
    58  	conv         = flag.String("conv", "none", "comma separated list of conversions (none|notrunc)")
    59  	count        = flag.Int64("count", math.MaxInt64, "copy only N input blocks")
    60  	inName       = flag.String("if", "", "Input file")
    61  	outName      = flag.String("of", "", "Output file")
    62  	oFlag        = flag.String("oflag", "none", "comma separated list of out flags (none|sync|dsync)")
    63  	status       = flag.String("status", "xfer", "display status of transfer (none|xfer|progress)")
    64  
    65  	bytesWritten int64 // access atomically, must be global for correct alignedness
    66  )
    67  
    68  type bitClearAndSet struct {
    69  	clear int
    70  	set   int
    71  }
    72  
    73  // N.B. The flags in os derive from syscall.
    74  // They are, hence, mutually exclusive, on target
    75  // kernels.
    76  var convMap = map[string]bitClearAndSet{
    77  	"notrunc": {clear: os.O_TRUNC},
    78  }
    79  
    80  var flagMap = map[string]bitClearAndSet{
    81  	"sync": {set: os.O_SYNC},
    82  }
    83  
    84  var allowedFlags = os.O_TRUNC | os.O_SYNC
    85  
    86  // intermediateBuffer is a buffer that one can write to and read from.
    87  type intermediateBuffer interface {
    88  	io.ReaderFrom
    89  	io.WriterTo
    90  }
    91  
    92  // chunkedBuffer is an intermediateBuffer with a specific size.
    93  type chunkedBuffer struct {
    94  	outChunk int64
    95  	length   int64
    96  	data     []byte
    97  	flags    int
    98  }
    99  
   100  func init() {
   101  	ddUnits := unit.DefaultUnits
   102  	ddUnits["c"] = 1
   103  	ddUnits["w"] = 2
   104  	ddUnits["b"] = 512
   105  	delete(ddUnits, "B")
   106  
   107  	ibs = unit.MustNewUnit(ddUnits).MustNewValue(512, unit.None)
   108  	obs = unit.MustNewUnit(ddUnits).MustNewValue(512, unit.None)
   109  	bs = unit.MustNewUnit(ddUnits).MustNewValue(512, unit.None)
   110  
   111  	flag.Var(ibs, "ibs", "Default input block size")
   112  	flag.Var(obs, "obs", "Default output block size")
   113  	flag.Var(bs, "bs", "Default input and output block size")
   114  }
   115  
   116  // newChunkedBuffer returns an intermediateBuffer that stores inChunkSize-sized
   117  // chunks of data and writes them to writers in outChunkSize-sized chunks.
   118  func newChunkedBuffer(inChunkSize int64, outChunkSize int64, flags int) intermediateBuffer {
   119  	return &chunkedBuffer{
   120  		outChunk: outChunkSize,
   121  		length:   0,
   122  		data:     make([]byte, inChunkSize),
   123  		flags:    flags,
   124  	}
   125  }
   126  
   127  // ReadFrom reads an inChunkSize-sized chunk from r into the buffer.
   128  func (cb *chunkedBuffer) ReadFrom(r io.Reader) (int64, error) {
   129  	n, err := r.Read(cb.data)
   130  	cb.length = int64(n)
   131  
   132  	// Convert to EOF explicitly.
   133  	if n == 0 && err == nil {
   134  		return 0, io.EOF
   135  	}
   136  	return int64(n), err
   137  }
   138  
   139  // WriteTo writes from the buffer to w in outChunkSize-sized chunks.
   140  func (cb *chunkedBuffer) WriteTo(w io.Writer) (int64, error) {
   141  	var i int64
   142  	for i = 0; i < int64(cb.length); {
   143  		chunk := cb.outChunk
   144  		if i+chunk > cb.length {
   145  			chunk = cb.length - i
   146  		}
   147  		block := cb.data[i : i+chunk]
   148  		got, err := w.Write(block)
   149  
   150  		// Ugh, Go cruft: io.Writer.Write returns (int, error).
   151  		// io.WriterTo.WriteTo returns (int64, error). So we have to
   152  		// cast.
   153  		i += int64(got)
   154  		if err != nil {
   155  			return i, err
   156  		}
   157  		if int64(got) != chunk {
   158  			return 0, io.ErrShortWrite
   159  		}
   160  	}
   161  	return i, nil
   162  }
   163  
   164  // bufferPool is a pool of intermediateBuffers.
   165  type bufferPool struct {
   166  	f func() intermediateBuffer
   167  	c chan intermediateBuffer
   168  }
   169  
   170  func newBufferPool(size int64, f func() intermediateBuffer) bufferPool {
   171  	return bufferPool{
   172  		f: f,
   173  		c: make(chan intermediateBuffer, size),
   174  	}
   175  }
   176  
   177  // Put returns a buffer to the pool for later use.
   178  func (bp bufferPool) Put(b intermediateBuffer) {
   179  	// Non-blocking write in case pool has filled up (too many buffers
   180  	// returned, none being used).
   181  	select {
   182  	case bp.c <- b:
   183  	default:
   184  	}
   185  }
   186  
   187  // Get returns a buffer from the pool or allocates a new buffer if none is
   188  // available.
   189  func (bp bufferPool) Get() intermediateBuffer {
   190  	select {
   191  	case buf := <-bp.c:
   192  		return buf
   193  	default:
   194  		return bp.f()
   195  	}
   196  }
   197  
   198  func (bp bufferPool) Destroy() {
   199  	close(bp.c)
   200  }
   201  
   202  func parallelChunkedCopy(r io.Reader, w io.Writer, inBufSize, outBufSize int64, flags int) error {
   203  	if inBufSize == 0 {
   204  		return fmt.Errorf("inBufSize is not allowed to be zero")
   205  	}
   206  
   207  	// Make the channels deep enough to hold a total of 1GiB of data.
   208  	depth := (1024 * 1024 * 1024) / inBufSize
   209  	// But keep it reasonable!
   210  	if depth > 8192 {
   211  		depth = 8192
   212  	}
   213  
   214  	readyBufs := make(chan intermediateBuffer, depth)
   215  	pool := newBufferPool(depth, func() intermediateBuffer {
   216  		return newChunkedBuffer(inBufSize, outBufSize, flags)
   217  	})
   218  	defer pool.Destroy()
   219  
   220  	// Closing quit makes both goroutines below exit.
   221  	quit := make(chan struct{})
   222  
   223  	// errs contains the error value to be returned.
   224  	errs := make(chan error, 1)
   225  	defer close(errs)
   226  
   227  	var wg sync.WaitGroup
   228  	wg.Add(1)
   229  	go func() {
   230  		// Closing this unblocks the writing for-loop below.
   231  		defer close(readyBufs)
   232  		defer wg.Done()
   233  
   234  		for {
   235  			select {
   236  			case <-quit:
   237  				return
   238  			default:
   239  				buf := pool.Get()
   240  				n, err := buf.ReadFrom(r)
   241  				if n > 0 {
   242  					readyBufs <- buf
   243  				}
   244  				if err == io.EOF {
   245  					return
   246  				}
   247  				if n == 0 || err != nil {
   248  					errs <- fmt.Errorf("input error: %v", err)
   249  					return
   250  				}
   251  			}
   252  		}
   253  	}()
   254  
   255  	var writeErr error
   256  	for buf := range readyBufs {
   257  		if n, err := buf.WriteTo(w); err != nil {
   258  			writeErr = fmt.Errorf("output error: %v", err)
   259  			break
   260  		} else {
   261  			atomic.AddInt64(&bytesWritten, n)
   262  		}
   263  		pool.Put(buf)
   264  	}
   265  
   266  	// This will force the goroutine to quit if an error occurred writing.
   267  	close(quit)
   268  
   269  	// Wait for goroutine to exit.
   270  	wg.Wait()
   271  
   272  	select {
   273  	case readErr := <-errs:
   274  		return readErr
   275  	default:
   276  		return writeErr
   277  	}
   278  }
   279  
   280  // sectionReader implements a SectionReader on an underlying implementation of
   281  // io.Reader (as opposed to io.SectionReader which uses io.ReaderAt).
   282  type sectionReader struct {
   283  	base   int64
   284  	offset int64
   285  	limit  int64
   286  	io.Reader
   287  }
   288  
   289  // newStreamSectionReader uses an io.Reader to implement an io.Reader that
   290  // seeks to offset and reads at most n bytes.
   291  //
   292  // This is useful if you want to use a NewSectionReader with stdin or other
   293  // types of pipes (things that can't be seek'd or pread from).
   294  func newStreamSectionReader(r io.Reader, offset int64, n int64) io.Reader {
   295  	limit := offset + n
   296  	if limit < 0 {
   297  		limit = math.MaxInt64
   298  	}
   299  	return &sectionReader{offset, 0, limit, r}
   300  }
   301  
   302  // Read implements io.Reader.
   303  func (s *sectionReader) Read(p []byte) (int, error) {
   304  	if s.offset == 0 && s.base != 0 {
   305  		if n, err := io.CopyN(io.Discard, s.Reader, s.base); err != nil {
   306  			return 0, err
   307  		} else if n != s.base {
   308  			// Can't happen.
   309  			return 0, fmt.Errorf("error skipping input bytes, short write")
   310  		}
   311  		s.offset = s.base
   312  	}
   313  
   314  	if s.offset >= s.limit {
   315  		return 0, io.EOF
   316  	}
   317  
   318  	if max := s.limit - s.offset; int64(len(p)) > max {
   319  		p = p[0:max]
   320  	}
   321  
   322  	n, err := s.Reader.Read(p)
   323  	s.offset += int64(n)
   324  
   325  	// Convert to io.EOF explicitly.
   326  	if n == 0 && err == nil {
   327  		return 0, io.EOF
   328  	}
   329  	return n, err
   330  }
   331  
   332  // inFile opens the input file and seeks to the right position.
   333  func inFile(name string, inputBytes int64, skip int64, count int64) (io.Reader, error) {
   334  	maxRead := int64(math.MaxInt64)
   335  	if count != math.MaxInt64 {
   336  		maxRead = count * inputBytes
   337  	}
   338  
   339  	if name == "" {
   340  		// os.Stdin is an io.ReaderAt, but you can't actually call
   341  		// pread(2) on it, so use the copying section reader.
   342  		return newStreamSectionReader(os.Stdin, inputBytes*skip, maxRead), nil
   343  	}
   344  
   345  	in, err := os.Open(name)
   346  	if err != nil {
   347  		return nil, fmt.Errorf("error opening input file %q: %v", name, err)
   348  	}
   349  	return io.NewSectionReader(in, inputBytes*skip, maxRead), nil
   350  }
   351  
   352  // outFile opens the output file and seeks to the right position.
   353  func outFile(name string, outputBytes int64, seek int64, flags int) (io.Writer, error) {
   354  	var out io.WriteSeeker
   355  	var err error
   356  	if name == "" {
   357  		out = os.Stdout
   358  	} else {
   359  		perm := os.O_CREATE | os.O_WRONLY | (flags & allowedFlags)
   360  		if out, err = os.OpenFile(name, perm, 0o666); err != nil {
   361  			return nil, fmt.Errorf("error opening output file %q: %v", name, err)
   362  		}
   363  	}
   364  	if seek*outputBytes != 0 {
   365  		if _, err := out.Seek(seek*outputBytes, io.SeekCurrent); err != nil {
   366  			return nil, fmt.Errorf("error seeking output file: %v", err)
   367  		}
   368  	}
   369  	return out, nil
   370  }
   371  
   372  func usage() {
   373  	log.Fatal(`Usage: dd [if=file] [of=file] [conv=none|notrunc] [seek=#] [skip=#]
   374  			     [count=#] [bs=#] [ibs=#] [obs=#] [status=none|xfer|progress] [oflag=none|sync|dsync]
   375  		options may also be invoked Go-style as -opt value or -opt=value
   376  		bs, if specified, overrides ibs and obs`)
   377  }
   378  
   379  func convertArgs(osArgs []string) []string {
   380  	// EVERYTHING in dd follows x=y. So blindly split and convert.
   381  	var args []string
   382  	for _, v := range osArgs {
   383  		l := strings.SplitN(v, "=", 2)
   384  
   385  		// We only fix the exact case for x=y.
   386  		if len(l) == 2 {
   387  			l[0] = "-" + l[0]
   388  		}
   389  
   390  		args = append(args, l...)
   391  	}
   392  	return args
   393  }
   394  
   395  func main() {
   396  	// rather than, in essence, recreating all the apparatus of flag.xxxx
   397  	// with the if= bits, including dup checking, conversion, etc. we just
   398  	// convert the arguments and then run flag.Parse. Gross, but hey, it
   399  	// works.
   400  	os.Args = convertArgs(os.Args)
   401  	flag.Parse()
   402  
   403  	if len(flag.Args()) > 0 {
   404  		usage()
   405  	}
   406  
   407  	// Convert conv argument to bit set.
   408  	flags := os.O_TRUNC
   409  	if *conv != "none" {
   410  		for _, c := range strings.Split(*conv, ",") {
   411  			if v, ok := convMap[c]; ok {
   412  				flags &= ^v.clear
   413  				flags |= v.set
   414  			} else {
   415  				log.Printf("unknown argument conv=%s", c)
   416  				usage()
   417  			}
   418  		}
   419  	}
   420  
   421  	// Convert oflag argument to bit set.
   422  	if *oFlag != "none" {
   423  		for _, f := range strings.Split(*oFlag, ",") {
   424  			if v, ok := flagMap[f]; ok {
   425  				flags &= ^v.clear
   426  				flags |= v.set
   427  			} else {
   428  				log.Printf("unknown argument oflag=%s", f)
   429  				usage()
   430  			}
   431  		}
   432  	}
   433  
   434  	if *status != "none" && *status != "xfer" && *status != "progress" {
   435  		usage()
   436  	}
   437  	progress := progress.Begin(*status, &bytesWritten)
   438  
   439  	// bs = both 'ibs' and 'obs' (IEEE Std 1003.1 - 2013)
   440  	if bs.IsSet {
   441  		ibs = bs
   442  		obs = bs
   443  	}
   444  
   445  	in, err := inFile(*inName, ibs.Value, *skip, *count)
   446  	if err != nil {
   447  		log.Fatal(err)
   448  	}
   449  	out, err := outFile(*outName, obs.Value, *seek, flags)
   450  	if err != nil {
   451  		log.Fatal(err)
   452  	}
   453  	if err := parallelChunkedCopy(in, out, ibs.Value, obs.Value, flags); err != nil {
   454  		log.Fatal(err)
   455  	}
   456  
   457  	progress.End()
   458  }