github.com/mvdan/u-root-coreutils@v0.0.0-20230122170626-c2eef2898555/cmds/core/dd/dd_test.go (about)

     1  // Copyright 2017 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"bufio"
     9  	"bytes"
    10  	"fmt"
    11  	"io"
    12  	"os"
    13  	"path/filepath"
    14  	"reflect"
    15  	"strings"
    16  	"testing"
    17  
    18  	"github.com/mvdan/u-root-coreutils/pkg/testutil"
    19  )
    20  
    21  func TestNewChunkedBuffer(t *testing.T) {
    22  	tests := []struct {
    23  		name         string
    24  		BufferSize   int64
    25  		outChunkSize int64
    26  		flags        int
    27  	}{
    28  		{
    29  			name:         "Empty buffer with length zero",
    30  			BufferSize:   0,
    31  			outChunkSize: 2,
    32  			flags:        0,
    33  		},
    34  		{
    35  			name:         "Normal buffer",
    36  			BufferSize:   16,
    37  			outChunkSize: 2,
    38  			flags:        0,
    39  		},
    40  		{
    41  			name:         "non-zero flag",
    42  			BufferSize:   16,
    43  			outChunkSize: 2,
    44  			flags:        3,
    45  		},
    46  	}
    47  
    48  	for _, tt := range tests {
    49  		t.Run(tt.name, func(t *testing.T) {
    50  			newIntermediateBufferInterface := newChunkedBuffer(tt.BufferSize, tt.outChunkSize, tt.flags)
    51  			newIntermediateBuffer := newIntermediateBufferInterface.(*chunkedBuffer)
    52  
    53  			if (int64(len(newIntermediateBuffer.data)) != tt.BufferSize) || (newIntermediateBuffer.outChunk != tt.outChunkSize) ||
    54  				(newIntermediateBuffer.flags != tt.flags) {
    55  				t.Errorf("Test failed - got: {%v, %v, %v} want {%v, %v, %v}",
    56  					len(newIntermediateBuffer.data), newIntermediateBuffer.outChunk, newIntermediateBuffer.flags,
    57  					tt.BufferSize, tt.outChunkSize, tt.flags)
    58  			}
    59  		})
    60  	}
    61  }
    62  
    63  func TestReadFrom(t *testing.T) {
    64  	tests := []struct {
    65  		name        string
    66  		inputBuffer []byte
    67  		wantError   bool
    68  	}{
    69  		{
    70  			name:        "Read From",
    71  			inputBuffer: []byte("ABC"),
    72  		},
    73  		{
    74  			name:        "Empty Buffer",
    75  			inputBuffer: []byte{},
    76  		},
    77  	}
    78  
    79  	for _, tt := range tests {
    80  		t.Run(tt.name, func(t *testing.T) {
    81  			cBuffer := &chunkedBuffer{
    82  				outChunk: 1,
    83  				length:   0,
    84  				data:     make([]byte, len(tt.inputBuffer)),
    85  				flags:    0,
    86  			}
    87  			readFromBuffer := bytes.NewReader(tt.inputBuffer)
    88  			cBuffer.ReadFrom(readFromBuffer)
    89  
    90  			if !reflect.DeepEqual(cBuffer.data, tt.inputBuffer) {
    91  				t.Errorf("ReadFrom failed. Got: %v - want: %v", cBuffer.data, tt.inputBuffer)
    92  			}
    93  		})
    94  	}
    95  }
    96  
    97  func TestWriteTo(t *testing.T) {
    98  	tests := []struct {
    99  		name          string
   100  		initialBuffer []byte
   101  		wantError     bool
   102  	}{
   103  		{
   104  			name:          "WriteTo",
   105  			initialBuffer: []byte("ABC"),
   106  		},
   107  		{
   108  			name:          "Empty Buffer",
   109  			initialBuffer: []byte{},
   110  		},
   111  		{
   112  			name:          "Bigger Buffer",
   113  			initialBuffer: []byte("This is madness. We need to turn this into happiness."),
   114  		},
   115  	}
   116  
   117  	for _, tt := range tests {
   118  		t.Run(tt.name, func(t *testing.T) {
   119  			cBuffer := &chunkedBuffer{
   120  				outChunk: 16,
   121  				length:   int64(len(tt.initialBuffer)),
   122  				data:     tt.initialBuffer,
   123  				flags:    0,
   124  			}
   125  
   126  			p := make([]byte, 0)
   127  			b := bytes.NewBuffer(p)
   128  			buffer := bufio.NewWriter(b)
   129  
   130  			n, err := cBuffer.WriteTo(buffer)
   131  			if err != nil || n != int64(len(tt.initialBuffer)) {
   132  				t.Errorf("Unable to write to buffer: %v. Wrote %d bytes.", err, n)
   133  			}
   134  
   135  			err = buffer.Flush()
   136  			if err != nil {
   137  				t.Errorf("Unable to flush buffer: %v", err)
   138  			}
   139  
   140  			if !reflect.DeepEqual(b.Bytes(), tt.initialBuffer) {
   141  				t.Errorf("WriteTo failed. Got: %v - want: %v", b.Bytes(), tt.initialBuffer)
   142  			}
   143  		})
   144  	}
   145  }
   146  
   147  func TestParallelChunkedCopy(t *testing.T) {
   148  	tests := []struct {
   149  		name        string
   150  		inputBuffer []byte
   151  		outBufSize  int
   152  		wantError   bool
   153  	}{
   154  		{
   155  			name:        "Copy 8 bytes",
   156  			inputBuffer: []byte("ABCDEFGH"),
   157  			outBufSize:  16,
   158  		},
   159  		{
   160  			name:        "No bytes to copy",
   161  			inputBuffer: []byte{},
   162  			outBufSize:  16,
   163  			wantError:   true,
   164  		},
   165  	}
   166  
   167  	for _, tt := range tests {
   168  		t.Run(tt.name, func(t *testing.T) {
   169  			// We need to set up an output buffer
   170  			outBuf := make([]byte, 0)
   171  
   172  			// Make it a Writer
   173  			b := bytes.NewBuffer(outBuf)
   174  			writeBuf := bufio.NewWriter(b)
   175  
   176  			// Now we need a readbuffer
   177  			readBuf := bytes.NewReader(tt.inputBuffer)
   178  
   179  			err := parallelChunkedCopy(readBuf, writeBuf, int64(len(tt.inputBuffer)), 8, 0)
   180  
   181  			if err != nil && !tt.wantError {
   182  				t.Errorf("parallelChunkedCopy failed with %v", err)
   183  			}
   184  
   185  			err = writeBuf.Flush()
   186  			if err != nil {
   187  				t.Errorf("Unable to flush writeBuffer: %v", err)
   188  			}
   189  
   190  			if !reflect.DeepEqual(b.Bytes(), tt.inputBuffer) {
   191  				t.Errorf("ParallelChunkedCopies are not equal. Got: %v - want: %v", b.Bytes(), tt.inputBuffer)
   192  			}
   193  		})
   194  	}
   195  }
   196  
   197  func TestRead(t *testing.T) {
   198  	tests := []struct {
   199  		name      string
   200  		offset    int64
   201  		maxRead   int64
   202  		expected  []byte
   203  		wantError bool
   204  	}{
   205  		{
   206  			name:     "read one byte from offset 0",
   207  			offset:   0,
   208  			maxRead:  10,
   209  			expected: []byte("A"),
   210  		},
   211  		{
   212  			name:     "read one byte from offset 3",
   213  			offset:   3,
   214  			maxRead:  10,
   215  			expected: []byte("D"),
   216  		},
   217  		{
   218  			name:      "read out of bounds",
   219  			offset:    11,
   220  			maxRead:   10,
   221  			expected:  []byte{},
   222  			wantError: true,
   223  		},
   224  		{
   225  			name:      "Read EOF",
   226  			offset:    0,
   227  			maxRead:   0,
   228  			expected:  []byte{},
   229  			wantError: true,
   230  		},
   231  	}
   232  
   233  	p, cleanup := setupDatafile(t, "datafile")
   234  	defer cleanup()
   235  
   236  	for _, tt := range tests {
   237  		t.Run(tt.name, func(t *testing.T) {
   238  			buffer := make([]byte, len(tt.expected))
   239  
   240  			file, err := os.Open(p)
   241  			if err != nil {
   242  				t.Errorf("Unable to open mock file: %v", err)
   243  			}
   244  
   245  			defer file.Close()
   246  
   247  			reader := &sectionReader{tt.offset, 0, tt.maxRead, file}
   248  			_, err = reader.Read(buffer)
   249  			if err != nil && !tt.wantError {
   250  				t.Errorf("Unable to read from sectionReader: %v", err)
   251  			}
   252  
   253  			if !reflect.DeepEqual(buffer, tt.expected) {
   254  				t.Errorf("Got: %v - Want: %v", buffer, tt.expected)
   255  			}
   256  		})
   257  	}
   258  }
   259  
   260  func TestInFile(t *testing.T) {
   261  	tests := []struct {
   262  		name        string
   263  		filename    string
   264  		outputBytes int64
   265  		seek        int64
   266  		count       int64
   267  		wantErr     bool
   268  	}{
   269  		{
   270  			name:        "Seek to first byte",
   271  			filename:    "datafile",
   272  			outputBytes: 1,
   273  			seek:        0,
   274  			count:       1,
   275  			wantErr:     false,
   276  		},
   277  		{
   278  			name:        "Seek to second byte",
   279  			filename:    "datafile",
   280  			outputBytes: 1,
   281  			seek:        1,
   282  			count:       1,
   283  			wantErr:     false,
   284  		},
   285  		{
   286  			name:        "no filename",
   287  			filename:    "",
   288  			outputBytes: 1,
   289  			seek:        0,
   290  			count:       1,
   291  			wantErr:     false,
   292  		},
   293  		{
   294  			name:        "unknown file",
   295  			filename:    "/something/something",
   296  			outputBytes: 1,
   297  			seek:        0,
   298  			count:       1,
   299  			wantErr:     true,
   300  		},
   301  		{
   302  			name:        "no filename and seek to nowhere",
   303  			filename:    "",
   304  			outputBytes: 8,
   305  			seek:        8,
   306  			count:       1,
   307  			wantErr:     true,
   308  		},
   309  	}
   310  
   311  	p, cleanup := setupDatafile(t, "datafile")
   312  	defer cleanup()
   313  
   314  	for _, tt := range tests {
   315  		t.Run(tt.name, func(t *testing.T) {
   316  			_, err := inFile(p, tt.outputBytes, tt.seek, tt.count)
   317  			if err != nil && !tt.wantErr {
   318  				t.Errorf("outFile failed with %v", err)
   319  			}
   320  		})
   321  	}
   322  }
   323  
   324  func setupDatafile(t *testing.T, name string) (string, func()) {
   325  	t.Helper()
   326  
   327  	testDir := t.TempDir()
   328  	dataFilePath := filepath.Join(testDir, name)
   329  
   330  	if err := os.WriteFile(dataFilePath, []byte("ABCDEFG"), 0o644); err != nil {
   331  		t.Errorf("unable to mockup file: %v", err)
   332  	}
   333  
   334  	return dataFilePath, func() { os.Remove(dataFilePath) }
   335  }
   336  
   337  func TestOutFile(t *testing.T) {
   338  	tests := []struct {
   339  		name        string
   340  		filename    string
   341  		outputBytes int64
   342  		seek        int64
   343  		flags       int
   344  		wantErr     bool
   345  	}{
   346  		{
   347  			name:        "Seek to first byte",
   348  			filename:    "datafile",
   349  			outputBytes: 1,
   350  			seek:        0,
   351  			flags:       0,
   352  			wantErr:     false,
   353  		},
   354  		{
   355  			name:        "Seek to second byte",
   356  			filename:    "datafile",
   357  			outputBytes: 1,
   358  			seek:        1,
   359  			flags:       0,
   360  			wantErr:     false,
   361  		},
   362  		{
   363  			name:        "no filename",
   364  			filename:    "",
   365  			outputBytes: 1,
   366  			seek:        0,
   367  			flags:       0,
   368  			wantErr:     false,
   369  		},
   370  		{
   371  			name:        "unknown file",
   372  			filename:    "/something/something",
   373  			outputBytes: 1,
   374  			seek:        0,
   375  			flags:       0,
   376  			wantErr:     true,
   377  		},
   378  		{
   379  			name:        "no filename and seek to nowhere",
   380  			filename:    "",
   381  			outputBytes: 8,
   382  			seek:        8,
   383  			flags:       0,
   384  			wantErr:     true,
   385  		},
   386  	}
   387  
   388  	p, cleanup := setupDatafile(t, "datafile")
   389  	defer cleanup()
   390  
   391  	for _, tt := range tests {
   392  		t.Run(tt.name, func(t *testing.T) {
   393  			_, err := outFile(p, tt.outputBytes, tt.seek, tt.flags)
   394  			if err != nil && !tt.wantErr {
   395  				t.Errorf("outFile failed with %v", err)
   396  			}
   397  		})
   398  	}
   399  }
   400  
   401  func TestConvertArgs(t *testing.T) {
   402  	tests := []struct {
   403  		name         string
   404  		args         []string
   405  		expectedArgs []string
   406  	}{
   407  		{
   408  			name:         "Empty Args",
   409  			args:         []string{""},
   410  			expectedArgs: []string{""},
   411  		},
   412  		{
   413  			name:         "One Arg",
   414  			args:         []string{"if=somefile"},
   415  			expectedArgs: []string{"-if", "somefile"},
   416  		},
   417  		{
   418  			name:         "Two Args",
   419  			args:         []string{"if=somefile", "conv=none"},
   420  			expectedArgs: []string{"-if", "somefile", "-conv", "none"},
   421  		},
   422  	}
   423  
   424  	for _, tt := range tests {
   425  		t.Run(tt.name, func(t *testing.T) {
   426  			gotArgs := convertArgs(tt.args)
   427  
   428  			if !reflect.DeepEqual(gotArgs, tt.expectedArgs) {
   429  				t.Errorf("Args not equal. Got %v, want %v", gotArgs, tt.expectedArgs)
   430  			}
   431  		})
   432  	}
   433  }
   434  
   435  // TestDd implements a table-driven test.
   436  func TestDd(t *testing.T) {
   437  	tests := []struct {
   438  		name    string
   439  		flags   []string
   440  		stdin   string
   441  		stdout  []byte
   442  		count   int64
   443  		compare func(io.Reader, []byte, int64) error
   444  	}{
   445  		{
   446  			name:    "Simple copying from input to output",
   447  			flags:   []string{},
   448  			stdin:   "1: defaults",
   449  			stdout:  []byte("1: defaults"),
   450  			compare: stdoutEqual,
   451  		},
   452  		{
   453  			name:    "Copy from input to output on a non-aligned block size",
   454  			flags:   []string{"bs=8c"},
   455  			stdin:   "2: bs=8c 11b", // len=12 is not multiple of 8
   456  			stdout:  []byte("2: bs=8c 11b"),
   457  			compare: stdoutEqual,
   458  		},
   459  		{
   460  			name:    "Copy from input to output on an aligned block size",
   461  			flags:   []string{"bs=8"},
   462  			stdin:   "hello world.....", // len=16 is a multiple of 8
   463  			stdout:  []byte("hello world....."),
   464  			compare: stdoutEqual,
   465  		},
   466  		{
   467  			name:    "Create a 64KiB zeroed file in 1KiB blocks",
   468  			flags:   []string{"if=/dev/zero", "bs=1K", "count=64"},
   469  			stdin:   "",
   470  			stdout:  []byte("\x00"),
   471  			count:   64 * 1024,
   472  			compare: byteCount,
   473  		},
   474  		{
   475  			name:    "Create a 64KiB zeroed file in 1 byte blocks",
   476  			flags:   []string{"if=/dev/zero", "bs=1", "count=65536"},
   477  			stdin:   "",
   478  			stdout:  []byte("\x00"),
   479  			count:   64 * 1024,
   480  			compare: byteCount,
   481  		},
   482  		{
   483  			name:    "Create a 64KiB zeroed file in one 64KiB block",
   484  			flags:   []string{"if=/dev/zero", "bs=64K", "count=1"},
   485  			stdin:   "",
   486  			stdout:  []byte("\x00"),
   487  			count:   64 * 1024,
   488  			compare: byteCount,
   489  		},
   490  		{
   491  			name:    "Use skip and count",
   492  			flags:   []string{"skip=6", "bs=1", "count=5"},
   493  			stdin:   "hello world.....",
   494  			stdout:  []byte("world"),
   495  			compare: stdoutEqual,
   496  		},
   497  		{
   498  			name:    "Count clamps to end of stream",
   499  			flags:   []string{"bs=2", "skip=3", "count=100000"},
   500  			stdin:   "hello world.....",
   501  			stdout:  []byte("world....."),
   502  			compare: stdoutEqual,
   503  		},
   504  		{
   505  			name:    "512 MiB zeroed file in 1024 1KiB blocks",
   506  			flags:   []string{"bs=524288", "count=1024", "if=/dev/zero"},
   507  			stdin:   "",
   508  			stdout:  []byte("\x00"),
   509  			count:   1024 * 1024 * 512,
   510  			compare: byteCount,
   511  		},
   512  	}
   513  
   514  	for _, tt := range tests {
   515  		t.Run(tt.name, func(t *testing.T) {
   516  			cmd := testutil.Command(t, tt.flags...)
   517  			cmd.Stdin = strings.NewReader(tt.stdin)
   518  			out, err := cmd.StdoutPipe()
   519  			if err != nil {
   520  				t.Fatal(err)
   521  			}
   522  			if err := cmd.Start(); err != nil {
   523  				t.Error(err)
   524  			}
   525  			err = tt.compare(out, tt.stdout, tt.count)
   526  			if err != nil {
   527  				t.Errorf("Test compare function returned: %v", err)
   528  			}
   529  			if err := cmd.Wait(); err != nil {
   530  				t.Errorf("Test %v exited with error: %v", tt.flags, err)
   531  			}
   532  		})
   533  	}
   534  }
   535  
   536  // stdoutEqual creates a bufio Reader from io.Reader, then compares a byte at a time input []byte.
   537  // The third argument (int64) is ignored and only exists to make the function signature compatible
   538  // with func byteCount.
   539  // Returns an error if mismatch is found with offset.
   540  func stdoutEqual(i io.Reader, o []byte, _ int64) error {
   541  	var count int64
   542  	b := bufio.NewReader(i)
   543  
   544  	for {
   545  		z, err := b.ReadByte()
   546  		if err != nil {
   547  			break
   548  		}
   549  		if o[count] != z {
   550  			return fmt.Errorf("Found mismatch at offset %d, wanted %s, found %s", count, string(o[count]), string(z))
   551  		}
   552  		count++
   553  	}
   554  	return nil
   555  }
   556  
   557  // byteCount creates a bufio Reader from io.Reader, then counts the number of sequential bytes
   558  // that match the first byte in the input []byte. If the count matches input n int64, nil error
   559  // is returned. Otherwise an error is returned for a non-matching byte or if the count doesn't
   560  // match.
   561  func byteCount(i io.Reader, o []byte, n int64) error {
   562  	var count int64
   563  	buf := make([]byte, 4096)
   564  
   565  	for {
   566  		read, err := i.Read(buf)
   567  		if err != nil || read == 0 {
   568  			break
   569  		}
   570  		for z := 0; z < read; z++ {
   571  			if buf[z] == o[0] {
   572  				count++
   573  			} else {
   574  				return fmt.Errorf("Found non-matching byte: %v != %v, at offset: %d",
   575  					buf[z], o[0], count)
   576  			}
   577  		}
   578  
   579  		if count > n {
   580  			break
   581  		}
   582  	}
   583  
   584  	if count == n {
   585  		return nil
   586  	}
   587  	return fmt.Errorf("Found %d count of %#v bytes, wanted to find %d count", count, o[0], n)
   588  }
   589  
   590  // TestFiles uses `if` and `of` arguments.
   591  func TestFiles(t *testing.T) {
   592  	tests := []struct {
   593  		name     string
   594  		flags    []string
   595  		inFile   []byte
   596  		outFile  []byte
   597  		expected []byte
   598  	}{
   599  		{
   600  			name:     "Simple copying from input to output",
   601  			flags:    []string{},
   602  			inFile:   []byte("1: defaults"),
   603  			expected: []byte("1: defaults"),
   604  		},
   605  		{
   606  			name:     "Copy from input to output on a non-aligned block size",
   607  			flags:    []string{"bs=8c"},
   608  			inFile:   []byte("2: bs=8c 11b"), // len=12 is not multiple of 8
   609  			expected: []byte("2: bs=8c 11b"),
   610  		},
   611  		{
   612  			name:     "Copy from input to output on an aligned block size",
   613  			flags:    []string{"bs=8"},
   614  			inFile:   []byte("hello world....."), // len=16 is a multiple of 8
   615  			expected: []byte("hello world....."),
   616  		},
   617  		{
   618  			name:     "Use skip and count",
   619  			flags:    []string{"skip=6", "bs=1", "count=5"},
   620  			inFile:   []byte("hello world....."),
   621  			expected: []byte("world"),
   622  		},
   623  		{
   624  			name:     "truncate",
   625  			flags:    []string{"bs=1"},
   626  			inFile:   []byte("1234"),
   627  			outFile:  []byte("abcde"),
   628  			expected: []byte("1234"),
   629  		},
   630  		{
   631  			name:     "no truncate",
   632  			flags:    []string{"bs=1", "conv=notrunc"},
   633  			inFile:   []byte("1234"),
   634  			outFile:  []byte("abcde"),
   635  			expected: []byte("1234e"),
   636  		},
   637  		{
   638  			// Fully testing the file is synchronous would require something more.
   639  			name:     "sync",
   640  			flags:    []string{"oflag=sync"},
   641  			inFile:   []byte("x: defaults"),
   642  			expected: []byte("x: defaults"),
   643  		},
   644  		{
   645  			// Fully testing the file is synchronous would require something more.
   646  			name:     "dsync",
   647  			flags:    []string{"oflag=dsync"},
   648  			inFile:   []byte("y: defaults"),
   649  			expected: []byte("y: defaults"),
   650  		},
   651  	}
   652  
   653  	for _, tt := range tests {
   654  		t.Run(tt.name, func(t *testing.T) {
   655  			// Write in and out file to temporary dir.
   656  			tmpDir := t.TempDir()
   657  			inFile := filepath.Join(tmpDir, "inFile")
   658  			outFile := filepath.Join(tmpDir, "outFile")
   659  			if err := os.WriteFile(inFile, tt.inFile, 0o666); err != nil {
   660  				t.Error(err)
   661  			}
   662  			if err := os.WriteFile(outFile, tt.outFile, 0o666); err != nil {
   663  				t.Error(err)
   664  			}
   665  
   666  			args := append(tt.flags, "if="+inFile, "of="+outFile)
   667  			if err := testutil.Command(t, args...).Run(); err != nil {
   668  				t.Error(err)
   669  			}
   670  			got, err := os.ReadFile(filepath.Join(tmpDir, "outFile"))
   671  			if err != nil {
   672  				t.Error(err)
   673  			}
   674  			if !reflect.DeepEqual(tt.expected, got) {
   675  				t.Errorf("expected %q, got %q", tt.expected, got)
   676  			}
   677  		})
   678  	}
   679  }
   680  
   681  // BenchmarkDd benchmarks the dd command. Each "op" unit is a 1MiB block.
   682  func BenchmarkDd(b *testing.B) {
   683  	const bytesPerOp = 1024 * 1024
   684  	b.SetBytes(bytesPerOp)
   685  
   686  	args := []string{
   687  		"if=/dev/zero",
   688  		"of=/dev/null",
   689  		fmt.Sprintf("count=%d", b.N),
   690  		fmt.Sprintf("bs=%d", bytesPerOp),
   691  	}
   692  	b.ResetTimer()
   693  	if err := testutil.Command(b, args...).Run(); err != nil {
   694  		b.Fatal(err)
   695  	}
   696  }
   697  
   698  func TestMain(m *testing.M) {
   699  	testutil.Run(m, main)
   700  }