github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/compress/rw_test.go (about)

     1  package compress_test
     2  
     3  import (
     4  	"bytes"
     5  	"compress/gzip"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	"math/rand"
    10  	"os"
    11  	"os/exec"
    12  	"strings"
    13  	"testing"
    14  
    15  	"github.com/Schaudge/grailbase/compress"
    16  	"github.com/grailbio/testutil/assert"
    17  	"github.com/klauspost/compress/zstd"
    18  )
    19  
    20  func testReader(t *testing.T, plaintext string, comp func(t *testing.T, in []byte) []byte) {
    21  	compressed := comp(t, []byte(plaintext))
    22  	cr := bytes.NewReader(compressed)
    23  	r, n := compress.NewReader(cr)
    24  	assert.True(t, n)
    25  	assert.NotNil(t, r)
    26  	got := bytes.Buffer{}
    27  	_, err := io.Copy(&got, r)
    28  	assert.NoError(t, err)
    29  	assert.NoError(t, r.Close())
    30  	assert.EQ(t, got.String(), plaintext)
    31  }
    32  
    33  // Generate a random ASCII text.
    34  func randomText(buf *strings.Builder, r *rand.Rand, n int) {
    35  	for i := 0; i < n; i++ {
    36  		buf.WriteByte(byte(r.Intn(96) + 32))
    37  	}
    38  }
    39  
    40  func gzipCompress(t *testing.T, in []byte) []byte {
    41  	buf := bytes.Buffer{}
    42  	w := gzip.NewWriter(&buf)
    43  	_, err := io.Copy(w, bytes.NewReader(in))
    44  	assert.NoError(t, err)
    45  	assert.NoError(t, w.Close())
    46  	return buf.Bytes()
    47  }
    48  
    49  func bzip2Compress(t *testing.T, in []byte) []byte {
    50  	temp, err := ioutil.TempFile("", "test")
    51  	assert.NoError(t, err)
    52  	_, err = temp.Write(in)
    53  	assert.NoError(t, err)
    54  	assert.NoError(t, temp.Close())
    55  	cmd := exec.Command("bzip2", temp.Name())
    56  	assert.NoError(t, cmd.Run())
    57  
    58  	compressed, err := ioutil.ReadFile(temp.Name() + ".bz2")
    59  	assert.NoError(t, err)
    60  	assert.NoError(t, os.Remove(temp.Name()+".bz2"))
    61  	return compressed
    62  }
    63  
    64  func zstdCompress(t *testing.T, in []byte) []byte {
    65  	buf := bytes.Buffer{}
    66  	// WithZeroFrames ensures that a zero-length input (like in TestReaderSmall) yields
    67  	// a non-empty output with a header that compress.NewReader can sniff.
    68  	w, err := zstd.NewWriter(&buf, zstd.WithZeroFrames(true))
    69  	assert.NoError(t, err)
    70  	_, err = io.Copy(w, bytes.NewReader(in))
    71  	assert.NoError(t, err)
    72  	assert.NoError(t, w.Close())
    73  	return buf.Bytes()
    74  }
    75  
    76  type compressor struct {
    77  	fn  func(t *testing.T, in []byte) []byte
    78  	ext string
    79  }
    80  
    81  var compressors = []compressor{
    82  	{zstdCompress, "zst"},
    83  	{gzipCompress, "gz"},
    84  	{bzip2Compress, "bz2"},
    85  }
    86  
    87  func TestReaderSmall(t *testing.T) {
    88  	for _, c := range compressors {
    89  		t.Run(c.ext, func(t *testing.T) {
    90  			testReader(t, "", c.fn)
    91  			testReader(t, "hello", c.fn)
    92  		})
    93  		n := 1
    94  		for i := 1; i < 25; i++ {
    95  			t.Run(fmt.Sprint("format=", c.ext, ",n=", n), func(t *testing.T) {
    96  				r := rand.New(rand.NewSource(int64(i)))
    97  				n = (n + 1) * 3 / 2
    98  				buf := strings.Builder{}
    99  				randomText(&buf, r, n)
   100  				testReader(t, buf.String(), c.fn)
   101  			})
   102  		}
   103  	}
   104  }
   105  
   106  func TestGzipReaderUncompressed(t *testing.T) {
   107  	data := make([]byte, 128<<10+1)
   108  	got := bytes.Buffer{}
   109  
   110  	runTest := func(t *testing.T, n int) {
   111  		for i := range data[:n] {
   112  			// gzip/bzip2 header contains at least one char > 128, so the plaintext should
   113  			// never be conflated with a gzip header.
   114  			data[i] = byte(n + i%128)
   115  		}
   116  		cr := bytes.NewReader(data[:n])
   117  		r, compressed := compress.NewReader(cr)
   118  		assert.False(t, compressed)
   119  		got.Reset()
   120  		nRead, err := io.Copy(&got, r)
   121  		assert.NoError(t, err)
   122  		assert.EQ(t, int(nRead), n)
   123  		assert.NoError(t, r.Close())
   124  		assert.EQ(t, got.Bytes(), data[:n])
   125  	}
   126  
   127  	dataSize := 1
   128  	for dataSize <= len(data) {
   129  		n := dataSize
   130  		t.Run(fmt.Sprint(n), func(t *testing.T) { runTest(t, n) })
   131  		t.Run(fmt.Sprint(n-1), func(t *testing.T) { runTest(t, n-1) })
   132  		t.Run(fmt.Sprint(n+1), func(t *testing.T) { runTest(t, n+1) })
   133  		dataSize *= 2
   134  	}
   135  }
   136  
   137  func TestReaderWriterPath(t *testing.T) {
   138  	for _, c := range compressors {
   139  		t.Run(c.ext, func(t *testing.T) {
   140  			if c.ext == "bz2" { // bz2 compression not yet supported
   141  				t.Skip("bz2")
   142  			}
   143  			buf := bytes.Buffer{}
   144  			w, compressed := compress.NewWriterPath(&buf, "foo."+c.ext)
   145  			assert.True(t, compressed)
   146  			_, err := io.WriteString(w, "hello")
   147  			assert.NoError(t, w.Close())
   148  			assert.NoError(t, err)
   149  
   150  			r, compressed := compress.NewReaderPath(&buf, "foo."+c.ext)
   151  			assert.True(t, compressed)
   152  			data, err := ioutil.ReadAll(r)
   153  			assert.NoError(t, err)
   154  			assert.EQ(t, string(data), "hello")
   155  			assert.NoError(t, r.Close())
   156  		})
   157  	}
   158  }
   159  
   160  // NewReaderPath and NewWriterPath for non-compressed extensions.
   161  func TestReaderWriterPathNop(t *testing.T) {
   162  	buf := bytes.Buffer{}
   163  	w, compressed := compress.NewWriterPath(&buf, "foo.txt")
   164  	assert.False(t, compressed)
   165  	_, err := io.WriteString(w, "hello")
   166  	assert.NoError(t, w.Close())
   167  	assert.NoError(t, err)
   168  
   169  	r, compressed := compress.NewReaderPath(&buf, "foo.txt")
   170  	assert.False(t, compressed)
   171  	data, err := ioutil.ReadAll(r)
   172  	assert.NoError(t, err)
   173  	assert.EQ(t, string(data), "hello")
   174  	assert.NoError(t, r.Close())
   175  }