trpc.group/trpc-go/trpc-go@v1.0.3/restful/compressor_test.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package restful_test
    15  
    16  import (
    17  	"bytes"
    18  	"errors"
    19  	"io"
    20  	"reflect"
    21  	"testing"
    22  
    23  	"github.com/stretchr/testify/require"
    24  
    25  	"trpc.group/trpc-go/trpc-go/restful"
    26  )
    27  
    28  type anonymousCompressor struct {
    29  	restful.Compressor
    30  }
    31  
    32  func (anonymousCompressor) Name() string { return "" }
    33  
    34  type mockCompressor struct {
    35  	restful.Compressor
    36  }
    37  
    38  func (mockCompressor) Name() string            { return "mock" }
    39  func (mockCompressor) ContentEncoding() string { return "mock" }
    40  
    41  type reader struct {
    42  	io.Reader
    43  }
    44  
    45  func (reader) Read([]byte) (int, error) {
    46  	return 0, errors.New("mock error")
    47  }
    48  
    49  func TestRegisterCompressor(t *testing.T) {
    50  	for _, test := range []struct {
    51  		compressor  restful.Compressor
    52  		expectPanic bool
    53  		desc        string
    54  	}{
    55  		{
    56  			compressor:  nil,
    57  			expectPanic: true,
    58  			desc:        "register nil compressor test",
    59  		},
    60  		{
    61  			compressor:  anonymousCompressor{},
    62  			expectPanic: true,
    63  			desc:        "register anonymous compressor test",
    64  		},
    65  		{
    66  			compressor:  mockCompressor{},
    67  			expectPanic: false,
    68  			desc:        "register mock compressor test",
    69  		},
    70  	} {
    71  		register := func() { restful.RegisterCompressor(test.compressor) }
    72  		if test.expectPanic {
    73  			require.Panics(t, register, test.desc)
    74  		} else {
    75  			require.NotPanics(t, register, test.desc)
    76  		}
    77  		var c restful.Compressor
    78  		if !test.expectPanic {
    79  			c = restful.GetCompressor(test.compressor.Name())
    80  			require.True(t, reflect.DeepEqual(c, test.compressor), test.desc)
    81  		}
    82  	}
    83  }
    84  
    85  func TestGZIPCompressor(t *testing.T) {
    86  	g := &restful.GZIPCompressor{}
    87  
    88  	require.Equal(t, "gzip", g.Name())
    89  	require.Equal(t, "gzip", g.ContentEncoding())
    90  
    91  	input := []byte("foobar foo bar baz")
    92  	buf := new(bytes.Buffer)
    93  	w, err := g.Compress(buf)
    94  	require.Nil(t, err)
    95  	_, err = w.Write(input)
    96  	require.Nil(t, err)
    97  	err = w.Close()
    98  	require.Nil(t, err)
    99  	wrong := reader{}
   100  	_, err = g.Decompress(wrong)
   101  	require.NotNil(t, err)
   102  	r, err := g.Decompress(buf)
   103  	require.Nil(t, err)
   104  	out, err := io.ReadAll(r)
   105  	require.Nil(t, err)
   106  	require.Equal(t, input, out)
   107  }