go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/common/iotools/countingwriter_test.go (about)

     1  // Copyright 2016 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package iotools
    16  
    17  import (
    18  	"bytes"
    19  	"errors"
    20  	"io"
    21  	"testing"
    22  
    23  	. "github.com/smartystreets/goconvey/convey"
    24  )
    25  
    26  // testWriter is an io.Writer and io.ByteWriter implementation that always
    27  // writes the full amount and returns the configured error.
    28  type testWriter struct {
    29  	buf             bytes.Buffer
    30  	writeByteCalled bool
    31  	err             error
    32  }
    33  
    34  func (w *testWriter) Write(buf []byte) (int, error) {
    35  	amt, _ := w.buf.Write(buf)
    36  	return amt, w.err
    37  }
    38  
    39  func (w *testWriter) WriteByte(b byte) error {
    40  	w.writeByteCalled = true
    41  
    42  	if err := w.err; err != nil {
    43  		return err
    44  	}
    45  	return w.buf.WriteByte(b)
    46  }
    47  
    48  type notAByteWriter struct {
    49  	inner io.Writer
    50  }
    51  
    52  func (w *notAByteWriter) Write(buf []byte) (int, error) {
    53  	return w.inner.Write(buf)
    54  }
    55  
    56  func TestCountingWriter(t *testing.T) {
    57  	t.Parallel()
    58  
    59  	Convey(`A CountingWriter backed by a test writer`, t, func() {
    60  		tw := testWriter{}
    61  		cw := CountingWriter{Writer: &tw}
    62  
    63  		Convey(`When writing 10 bytes of data, registers a count of 10.`, func() {
    64  			data := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
    65  
    66  			amount, err := cw.Write(data)
    67  			So(err, ShouldBeNil)
    68  			So(amount, ShouldEqual, 10)
    69  
    70  			So(tw.buf.Bytes(), ShouldResemble, data)
    71  			So(cw.Count, ShouldEqual, 10)
    72  		})
    73  
    74  		Convey(`When using 32 sequential WriteByte, uses underlying WriteByte and registers a count of 32.`, func() {
    75  			written := bytes.Buffer{}
    76  
    77  			for i := 0; i < 32; i++ {
    78  				So(cw.WriteByte(byte(i)), ShouldBeNil)
    79  				So(cw.Count, ShouldEqual, i+1)
    80  
    81  				// Record for bulk comparison.
    82  				written.WriteByte(byte(i))
    83  			}
    84  
    85  			So(tw.buf.Bytes(), ShouldResemble, written.Bytes())
    86  			So(cw.Count, ShouldEqual, 32)
    87  			So(tw.writeByteCalled, ShouldBeTrue)
    88  		})
    89  
    90  		Convey(`When an error is returned in Write, the error is propagated.`, func() {
    91  			tw.err = errors.New("test error")
    92  			data := []byte{0, 1, 2, 3}
    93  
    94  			amount, err := cw.Write(data)
    95  			So(amount, ShouldEqual, len(data))
    96  			So(err, ShouldEqual, tw.err)
    97  			So(tw.buf.Bytes(), ShouldResemble, data)
    98  			So(cw.Count, ShouldEqual, len(data))
    99  		})
   100  
   101  		Convey(`When an error is returned in WriteByte, the error is propagated.`, func() {
   102  			tw.err = errors.New("test error")
   103  
   104  			err := cw.WriteByte(0x55)
   105  			So(err, ShouldEqual, tw.err)
   106  			So(tw.buf.Bytes(), ShouldHaveLength, 0)
   107  			So(cw.Count, ShouldEqual, 0)
   108  			So(tw.writeByteCalled, ShouldBeTrue)
   109  		})
   110  
   111  		Convey(`When WriteByte is disabled`, func() {
   112  			cw.Writer = &notAByteWriter{&tw}
   113  
   114  			Convey(`WriteByte calls the underlying Write and propagates test error.`, func() {
   115  				tw.err = errors.New("test error")
   116  
   117  				err := cw.WriteByte(0x55)
   118  				So(err, ShouldEqual, tw.err)
   119  				So(tw.buf.Bytes(), ShouldResemble, []byte{0x55})
   120  				So(cw.Count, ShouldEqual, 1)
   121  				So(tw.writeByteCalled, ShouldBeFalse)
   122  			})
   123  		})
   124  	})
   125  }