github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/jsoni/extension_tests/extension_test.go (about)

     1  package test
     2  
     3  import (
     4  	"context"
     5  	"reflect"
     6  	"strconv"
     7  	"testing"
     8  	"unsafe"
     9  
    10  	"github.com/bingoohuang/gg/pkg/jsoni"
    11  	"github.com/modern-go/reflect2"
    12  	"github.com/stretchr/testify/require"
    13  )
    14  
    15  type TestObject1 struct {
    16  	Field1 string
    17  }
    18  
    19  type testExtension struct {
    20  	jsoni.DummyExtension
    21  }
    22  
    23  func (extension *testExtension) UpdateStructDescriptor(structDescriptor *jsoni.StructDescriptor) {
    24  	if structDescriptor.Type.String() != "test.TestObject1" {
    25  		return
    26  	}
    27  	binding := structDescriptor.GetField("Field1")
    28  	binding.Encoder = &funcEncoder{fun: func(_ context.Context, ptr unsafe.Pointer, stream *jsoni.Stream) {
    29  		str := *((*string)(ptr))
    30  		val, _ := strconv.Atoi(str)
    31  		stream.WriteInt(val)
    32  	}}
    33  	binding.Decoder = &funcDecoder{func(_ context.Context, ptr unsafe.Pointer, iter *jsoni.Iterator) {
    34  		*((*string)(ptr)) = strconv.Itoa(iter.ReadInt())
    35  	}}
    36  	binding.ToNames = []string{"field-1"}
    37  	binding.FromNames = []string{"field-1"}
    38  }
    39  
    40  func Test_customize_field_by_extension(t *testing.T) {
    41  	should := require.New(t)
    42  	cfg := jsoni.Config{}.Froze()
    43  	cfg.RegisterExtension(&testExtension{})
    44  	obj := TestObject1{}
    45  	ctx := context.Background()
    46  	err := cfg.UnmarshalFromString(ctx, `{"field-1": 100}`, &obj)
    47  	should.Nil(err)
    48  	should.Equal("100", obj.Field1)
    49  	str, err := cfg.MarshalToString(ctx, obj)
    50  	should.Nil(err)
    51  	should.Equal(`{"field-1":100}`, str)
    52  }
    53  
    54  func Test_customize_map_key_encoder(t *testing.T) {
    55  	should := require.New(t)
    56  	cfg := jsoni.Config{}.Froze()
    57  	cfg.RegisterExtension(&testMapKeyExtension{})
    58  	m := map[int]int{1: 2}
    59  	ctx := context.Background()
    60  	output, err := cfg.MarshalToString(ctx, m)
    61  	should.NoError(err)
    62  	should.Equal(`{"2":2}`, output)
    63  	m = map[int]int{}
    64  	should.NoError(cfg.UnmarshalFromString(ctx, output, &m))
    65  	should.Equal(map[int]int{1: 2}, m)
    66  }
    67  
    68  type testMapKeyExtension struct {
    69  	jsoni.DummyExtension
    70  }
    71  
    72  func (extension *testMapKeyExtension) CreateMapKeyEncoder(typ reflect2.Type) jsoni.ValEncoder {
    73  	if typ.Kind() == reflect.Int {
    74  		return &funcEncoder{
    75  			fun: func(_ context.Context, ptr unsafe.Pointer, stream *jsoni.Stream) {
    76  				stream.WriteRaw(`"`)
    77  				stream.WriteInt(*(*int)(ptr) + 1)
    78  				stream.WriteRaw(`"`)
    79  			},
    80  		}
    81  	}
    82  	return nil
    83  }
    84  
    85  func (extension *testMapKeyExtension) CreateMapKeyDecoder(typ reflect2.Type) jsoni.ValDecoder {
    86  	if typ.Kind() == reflect.Int {
    87  		return &funcDecoder{
    88  			fun: func(_ context.Context, ptr unsafe.Pointer, iter *jsoni.Iterator) {
    89  				i, err := strconv.Atoi(iter.ReadString())
    90  				if err != nil {
    91  					iter.ReportError("read map key", err.Error())
    92  					return
    93  				}
    94  				i--
    95  				*(*int)(ptr) = i
    96  			},
    97  		}
    98  	}
    99  	return nil
   100  }
   101  
   102  type funcDecoder struct {
   103  	fun jsoni.DecoderFunc
   104  }
   105  
   106  func (decoder *funcDecoder) Decode(ctx context.Context, ptr unsafe.Pointer, iter *jsoni.Iterator) {
   107  	decoder.fun(ctx, ptr, iter)
   108  }
   109  
   110  type funcEncoder struct {
   111  	fun         jsoni.EncoderFunc
   112  	isEmptyFunc func(ptr unsafe.Pointer) bool
   113  }
   114  
   115  func (encoder *funcEncoder) Encode(ctx context.Context, ptr unsafe.Pointer, stream *jsoni.Stream) {
   116  	encoder.fun(ctx, ptr, stream)
   117  }
   118  
   119  func (encoder *funcEncoder) IsEmpty(_ context.Context, ptr unsafe.Pointer, _ bool) bool {
   120  	if encoder.isEmptyFunc == nil {
   121  		return false
   122  	}
   123  	return encoder.isEmptyFunc(ptr)
   124  }