github.com/Prakhar-Agarwal-byte/moby@v0.0.0-20231027092010-a14e3e8ab87e/pkg/stdcopy/stdcopy_test.go (about)

     1  package stdcopy // import "github.com/Prakhar-Agarwal-byte/moby/pkg/stdcopy"
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"io"
     7  	"strings"
     8  	"testing"
     9  )
    10  
    11  func TestNewStdWriter(t *testing.T) {
    12  	writer := NewStdWriter(io.Discard, Stdout)
    13  	if writer == nil {
    14  		t.Fatalf("NewStdWriter with an invalid StdType should not return nil.")
    15  	}
    16  }
    17  
    18  func TestWriteWithUninitializedStdWriter(t *testing.T) {
    19  	writer := stdWriter{
    20  		Writer: nil,
    21  		prefix: byte(Stdout),
    22  	}
    23  	n, err := writer.Write([]byte("Something here"))
    24  	if n != 0 || err == nil {
    25  		t.Fatalf("Should fail when given an incomplete or uninitialized StdWriter")
    26  	}
    27  }
    28  
    29  func TestWriteWithNilBytes(t *testing.T) {
    30  	writer := NewStdWriter(io.Discard, Stdout)
    31  	n, err := writer.Write(nil)
    32  	if err != nil {
    33  		t.Fatalf("Shouldn't have fail when given no data")
    34  	}
    35  	if n > 0 {
    36  		t.Fatalf("Write should have written 0 byte, but has written %d", n)
    37  	}
    38  }
    39  
    40  func TestWrite(t *testing.T) {
    41  	writer := NewStdWriter(io.Discard, Stdout)
    42  	data := []byte("Test StdWrite.Write")
    43  	n, err := writer.Write(data)
    44  	if err != nil {
    45  		t.Fatalf("Error while writing with StdWrite")
    46  	}
    47  	if n != len(data) {
    48  		t.Fatalf("Write should have written %d byte but wrote %d.", len(data), n)
    49  	}
    50  }
    51  
    52  type errWriter struct {
    53  	n   int
    54  	err error
    55  }
    56  
    57  func (f *errWriter) Write(buf []byte) (int, error) {
    58  	return f.n, f.err
    59  }
    60  
    61  func TestWriteWithWriterError(t *testing.T) {
    62  	expectedError := errors.New("expected")
    63  	expectedReturnedBytes := 10
    64  	writer := NewStdWriter(&errWriter{
    65  		n:   stdWriterPrefixLen + expectedReturnedBytes,
    66  		err: expectedError,
    67  	}, Stdout)
    68  	data := []byte("This won't get written, sigh")
    69  	n, err := writer.Write(data)
    70  	if err != expectedError {
    71  		t.Fatalf("Didn't get expected error.")
    72  	}
    73  	if n != expectedReturnedBytes {
    74  		t.Fatalf("Didn't get expected written bytes %d, got %d.",
    75  			expectedReturnedBytes, n)
    76  	}
    77  }
    78  
    79  func TestWriteDoesNotReturnNegativeWrittenBytes(t *testing.T) {
    80  	writer := NewStdWriter(&errWriter{n: -1}, Stdout)
    81  	data := []byte("This won't get written, sigh")
    82  	actual, _ := writer.Write(data)
    83  	if actual != 0 {
    84  		t.Fatalf("Expected returned written bytes equal to 0, got %d", actual)
    85  	}
    86  }
    87  
    88  func getSrcBuffer(stdOutBytes, stdErrBytes []byte) (buffer *bytes.Buffer, err error) {
    89  	buffer = new(bytes.Buffer)
    90  	dstOut := NewStdWriter(buffer, Stdout)
    91  	_, err = dstOut.Write(stdOutBytes)
    92  	if err != nil {
    93  		return
    94  	}
    95  	dstErr := NewStdWriter(buffer, Stderr)
    96  	_, err = dstErr.Write(stdErrBytes)
    97  	return
    98  }
    99  
   100  func TestStdCopyWriteAndRead(t *testing.T) {
   101  	stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
   102  	stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
   103  	buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
   104  	if err != nil {
   105  		t.Fatal(err)
   106  	}
   107  	written, err := StdCopy(io.Discard, io.Discard, buffer)
   108  	if err != nil {
   109  		t.Fatal(err)
   110  	}
   111  	expectedTotalWritten := len(stdOutBytes) + len(stdErrBytes)
   112  	if written != int64(expectedTotalWritten) {
   113  		t.Fatalf("Expected to have total of %d bytes written, got %d", expectedTotalWritten, written)
   114  	}
   115  }
   116  
   117  type customReader struct {
   118  	n            int
   119  	err          error
   120  	totalCalls   int
   121  	correctCalls int
   122  	src          *bytes.Buffer
   123  }
   124  
   125  func (f *customReader) Read(buf []byte) (int, error) {
   126  	f.totalCalls++
   127  	if f.totalCalls <= f.correctCalls {
   128  		return f.src.Read(buf)
   129  	}
   130  	return f.n, f.err
   131  }
   132  
   133  func TestStdCopyReturnsErrorReadingHeader(t *testing.T) {
   134  	expectedError := errors.New("error")
   135  	reader := &customReader{
   136  		err: expectedError,
   137  	}
   138  	written, err := StdCopy(io.Discard, io.Discard, reader)
   139  	if written != 0 {
   140  		t.Fatalf("Expected 0 bytes read, got %d", written)
   141  	}
   142  	if err != expectedError {
   143  		t.Fatalf("Didn't get expected error")
   144  	}
   145  }
   146  
   147  func TestStdCopyReturnsErrorReadingFrame(t *testing.T) {
   148  	expectedError := errors.New("error")
   149  	stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
   150  	stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
   151  	buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
   152  	if err != nil {
   153  		t.Fatal(err)
   154  	}
   155  	reader := &customReader{
   156  		correctCalls: 1,
   157  		n:            stdWriterPrefixLen + 1,
   158  		err:          expectedError,
   159  		src:          buffer,
   160  	}
   161  	written, err := StdCopy(io.Discard, io.Discard, reader)
   162  	if written != 0 {
   163  		t.Fatalf("Expected 0 bytes read, got %d", written)
   164  	}
   165  	if err != expectedError {
   166  		t.Fatalf("Didn't get expected error")
   167  	}
   168  }
   169  
   170  func TestStdCopyDetectsCorruptedFrame(t *testing.T) {
   171  	stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
   172  	stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
   173  	buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
   174  	if err != nil {
   175  		t.Fatal(err)
   176  	}
   177  	reader := &customReader{
   178  		correctCalls: 1,
   179  		n:            stdWriterPrefixLen + 1,
   180  		err:          io.EOF,
   181  		src:          buffer,
   182  	}
   183  	written, err := StdCopy(io.Discard, io.Discard, reader)
   184  	if written != startingBufLen {
   185  		t.Fatalf("Expected %d bytes read, got %d", startingBufLen, written)
   186  	}
   187  	if err != nil {
   188  		t.Fatal("Didn't get nil error")
   189  	}
   190  }
   191  
   192  func TestStdCopyWithInvalidInputHeader(t *testing.T) {
   193  	dstOut := NewStdWriter(io.Discard, Stdout)
   194  	dstErr := NewStdWriter(io.Discard, Stderr)
   195  	src := strings.NewReader("Invalid input")
   196  	_, err := StdCopy(dstOut, dstErr, src)
   197  	if err == nil {
   198  		t.Fatal("StdCopy with invalid input header should fail.")
   199  	}
   200  }
   201  
   202  func TestStdCopyWithCorruptedPrefix(t *testing.T) {
   203  	data := []byte{0x01, 0x02, 0x03}
   204  	src := bytes.NewReader(data)
   205  	written, err := StdCopy(nil, nil, src)
   206  	if err != nil {
   207  		t.Fatalf("StdCopy should not return an error with corrupted prefix.")
   208  	}
   209  	if written != 0 {
   210  		t.Fatalf("StdCopy should have written 0, but has written %d", written)
   211  	}
   212  }
   213  
   214  func TestStdCopyReturnsWriteErrors(t *testing.T) {
   215  	stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
   216  	stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
   217  	buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
   218  	if err != nil {
   219  		t.Fatal(err)
   220  	}
   221  	expectedError := errors.New("expected")
   222  
   223  	dstOut := &errWriter{err: expectedError}
   224  
   225  	written, err := StdCopy(dstOut, io.Discard, buffer)
   226  	if written != 0 {
   227  		t.Fatalf("StdCopy should have written 0, but has written %d", written)
   228  	}
   229  	if err != expectedError {
   230  		t.Fatalf("Didn't get expected error, got %v", err)
   231  	}
   232  }
   233  
   234  func TestStdCopyDetectsNotFullyWrittenFrames(t *testing.T) {
   235  	stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
   236  	stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
   237  	buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
   238  	if err != nil {
   239  		t.Fatal(err)
   240  	}
   241  	dstOut := &errWriter{n: startingBufLen - 10}
   242  
   243  	written, err := StdCopy(dstOut, io.Discard, buffer)
   244  	if written != 0 {
   245  		t.Fatalf("StdCopy should have return 0 written bytes, but returned %d", written)
   246  	}
   247  	if err != io.ErrShortWrite {
   248  		t.Fatalf("Didn't get expected io.ErrShortWrite error")
   249  	}
   250  }
   251  
   252  // TestStdCopyReturnsErrorFromSystem tests that StdCopy correctly returns an
   253  // error, when that error is muxed into the Systemerr stream.
   254  func TestStdCopyReturnsErrorFromSystem(t *testing.T) {
   255  	// write in the basic messages, just so there's some fluff in there
   256  	stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
   257  	stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
   258  	buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
   259  	if err != nil {
   260  		t.Fatal(err)
   261  	}
   262  	// add in an error message on the Systemerr stream
   263  	systemErrBytes := []byte(strings.Repeat("S", startingBufLen))
   264  	systemWriter := NewStdWriter(buffer, Systemerr)
   265  	_, err = systemWriter.Write(systemErrBytes)
   266  	if err != nil {
   267  		t.Fatal(err)
   268  	}
   269  
   270  	// now copy and demux. we should expect an error containing the string we
   271  	// wrote out
   272  	_, err = StdCopy(io.Discard, io.Discard, buffer)
   273  	if err == nil {
   274  		t.Fatal("expected error, got none")
   275  	}
   276  	if !strings.Contains(err.Error(), string(systemErrBytes)) {
   277  		t.Fatal("expected error to contain message")
   278  	}
   279  }
   280  
   281  func BenchmarkWrite(b *testing.B) {
   282  	w := NewStdWriter(io.Discard, Stdout)
   283  	data := []byte("Test line for testing stdwriter performance\n")
   284  	data = bytes.Repeat(data, 100)
   285  	b.SetBytes(int64(len(data)))
   286  	b.ResetTimer()
   287  	for i := 0; i < b.N; i++ {
   288  		if _, err := w.Write(data); err != nil {
   289  			b.Fatal(err)
   290  		}
   291  	}
   292  }