github.com/abemedia/go-don@v0.2.2-0.20240329015135-be88e32bb73b/encoding/encode_test.go (about)

     1  package encoding_test
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"testing"
     7  
     8  	"github.com/abemedia/go-don/encoding"
     9  	"github.com/abemedia/go-don/pkg/httptest"
    10  	"github.com/valyala/fasthttp"
    11  )
    12  
    13  func TestRegisterEncoder(t *testing.T) {
    14  	t.Run("Marshaler", func(t *testing.T) {
    15  		testRegisterEncoder(t, func(v any) ([]byte, error) {
    16  			b := v.([]byte)
    17  			if len(b) == 0 {
    18  				return nil, io.EOF
    19  			}
    20  			return b, nil
    21  		}, "unmarshaler", "marshaler-alias")
    22  	})
    23  
    24  	t.Run("ContextMarshaler", func(t *testing.T) {
    25  		testRegisterEncoder(t, func(ctx context.Context, v any) ([]byte, error) {
    26  			b := v.([]byte)
    27  			if len(b) == 0 {
    28  				return nil, io.EOF
    29  			}
    30  			return b, nil
    31  		}, "context-marshaler", "context-marshaler-alias")
    32  	})
    33  
    34  	t.Run("ResponseEncoder", func(t *testing.T) {
    35  		testRegisterEncoder(t, func(ctx *fasthttp.RequestCtx, v any) error {
    36  			b := v.([]byte)
    37  			if len(b) == 0 {
    38  				return io.EOF
    39  			}
    40  			ctx.Response.SetBodyRaw(b)
    41  			return nil
    42  		}, "response-encoder", "response-encoder-alias")
    43  	})
    44  }
    45  
    46  func testRegisterEncoder[T encoding.EncoderConstraint](t *testing.T, dec T, contentType, alias string) {
    47  	t.Helper()
    48  
    49  	encoding.RegisterEncoder(dec, contentType, alias)
    50  
    51  	for _, v := range []string{contentType, alias} {
    52  		encode := encoding.GetEncoder(v)
    53  		if encode == nil {
    54  			t.Error("encoder not found")
    55  			continue
    56  		}
    57  
    58  		req := httptest.NewRequest("", "", v, nil)
    59  
    60  		if err := encode(req, []byte(v)); err != nil {
    61  			t.Error(err)
    62  		} else if string(req.Response.Body()) != v {
    63  			t.Error("should encode response")
    64  		}
    65  
    66  		if err := encode(req, []byte{}); err == nil {
    67  			t.Error("should return error")
    68  		}
    69  	}
    70  }
    71  
    72  func TestGetEncoderMultipleContentTypes(t *testing.T) {
    73  	encFn := func(ctx *fasthttp.RequestCtx, v any) error {
    74  		return nil
    75  	}
    76  
    77  	encoding.RegisterEncoder(encFn, "application/xml")
    78  
    79  	enc := encoding.GetEncoder("text/html,application/xhtml+xml,application/xml")
    80  	if enc == nil {
    81  		t.Fatal("encoder not found")
    82  	}
    83  
    84  	enc = encoding.GetEncoder("application/xhtml+xml")
    85  	if enc != nil {
    86  		t.Fatal("encoder should not be found")
    87  	}
    88  }