github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/jsoni/reflect_marshaler.go (about)

     1  package jsoni
     2  
     3  import (
     4  	"context"
     5  	"encoding"
     6  	"encoding/json"
     7  	"unsafe"
     8  
     9  	"github.com/modern-go/reflect2"
    10  )
    11  
    12  var (
    13  	marshalerType       = PtrElem((*json.Marshaler)(nil))
    14  	unmarshalerType     = PtrElem((*json.Unmarshaler)(nil))
    15  	textMarshalerType   = PtrElem((*encoding.TextMarshaler)(nil))
    16  	textUnmarshalerType = PtrElem((*encoding.TextUnmarshaler)(nil))
    17  
    18  	marshalerContextType   = PtrElem((*MarshalerContext)(nil))
    19  	unmarshalerContextType = PtrElem((*UnmarshalerContext)(nil))
    20  )
    21  
    22  func createDecoderOfMarshaler(_ *ctx, typ reflect2.Type) ValDecoder {
    23  	ptrType := reflect2.PtrTo(typ)
    24  	if ptrType.Implements(unmarshalerType) {
    25  		return &referenceDecoder{decoder: &unmarshalerDecoder{valType: ptrType}}
    26  	}
    27  	if ptrType.Implements(unmarshalerContextType) {
    28  		return &referenceDecoder{decoder: &unmarshalerContextDecoder{valType: ptrType}}
    29  	}
    30  	if ptrType.Implements(textUnmarshalerType) {
    31  		return &referenceDecoder{decoder: &textUnmarshalerDecoder{valType: ptrType}}
    32  	}
    33  	return nil
    34  }
    35  
    36  func createEncoderOfMarshaler(ctx *ctx, typ reflect2.Type) ValEncoder {
    37  	if typ == marshalerType {
    38  		return &directMarshalerEncoder{checkIsEmpty: createCheckIsEmpty(ctx, typ)}
    39  	}
    40  	if typ.Implements(marshalerType) {
    41  		return &marshalerEncoder{valueType: typ, checkIsEmpty: createCheckIsEmpty(ctx, typ)}
    42  	}
    43  	if typ == marshalerContextType {
    44  		return &directMarshalerContextEncoder{checkIsEmpty: createCheckIsEmpty(ctx, typ)}
    45  	}
    46  	if typ.Implements(marshalerContextType) {
    47  		return &marshalerContextEncoder{valueType: typ, checkIsEmpty: createCheckIsEmpty(ctx, typ)}
    48  	}
    49  
    50  	ptrType := reflect2.PtrTo(typ)
    51  	if ptrType.Implements(marshalerType) {
    52  		encoder := &marshalerEncoder{valueType: ptrType, checkIsEmpty: createCheckIsEmpty(ctx, ptrType)}
    53  		return &referenceEncoder{encoder: encoder}
    54  	}
    55  	if ptrType.Implements(marshalerContextType) {
    56  		encoder := &marshalerContextEncoder{valueType: ptrType, checkIsEmpty: createCheckIsEmpty(ctx, ptrType)}
    57  		return &referenceEncoder{encoder: encoder}
    58  	}
    59  
    60  	if typ == textMarshalerType {
    61  		return &directTextMarshalerEncoder{checkIsEmpty: createCheckIsEmpty(ctx, typ), stringEncoder: ctx.EncoderOf(reflect2.TypeOf(""))}
    62  	}
    63  	if typ.Implements(textMarshalerType) {
    64  		return &textMarshalerEncoder{valType: typ, stringEncoder: ctx.EncoderOf(reflect2.TypeOf("")), checkIsEmpty: createCheckIsEmpty(ctx, typ)}
    65  	}
    66  	// if prefix is empty, the type is the root type
    67  	if ctx.prefix != "" && ptrType.Implements(textMarshalerType) {
    68  		encoder := &textMarshalerEncoder{valType: ptrType, stringEncoder: ctx.EncoderOf(reflect2.TypeOf("")), checkIsEmpty: createCheckIsEmpty(ctx, ptrType)}
    69  		return &referenceEncoder{encoder: encoder}
    70  	}
    71  	return nil
    72  }
    73  
    74  type marshalerEncoder struct {
    75  	checkIsEmpty checkIsEmpty
    76  	valueType    reflect2.Type
    77  }
    78  
    79  func (e *marshalerEncoder) Encode(_ context.Context, ptr unsafe.Pointer, stream *Stream) {
    80  	obj := e.valueType.UnsafeIndirect(ptr)
    81  	if e.valueType.IsNullable() && reflect2.IsNil(obj) {
    82  		stream.WriteNil()
    83  		return
    84  	}
    85  
    86  	bytes, err := obj.(json.Marshaler).MarshalJSON()
    87  	if err != nil {
    88  		stream.Error = err
    89  		return
    90  	}
    91  	// html escape was already done by jsoniter but the extra '\n' should be trimmed
    92  	if l := len(bytes); l > 0 && bytes[l-1] == '\n' {
    93  		bytes = bytes[:l-1]
    94  	}
    95  	stream.Write(bytes)
    96  }
    97  
    98  func (e *marshalerEncoder) IsEmpty(ctx context.Context, ptr unsafe.Pointer, checkZero bool) bool {
    99  	return e.checkIsEmpty.IsEmpty(ctx, ptr, checkZero)
   100  }
   101  
   102  type marshalerContextEncoder struct {
   103  	checkIsEmpty checkIsEmpty
   104  	valueType    reflect2.Type
   105  }
   106  
   107  func (e *marshalerContextEncoder) Encode(ctx context.Context, ptr unsafe.Pointer, stream *Stream) {
   108  	obj := e.valueType.UnsafeIndirect(ptr)
   109  	if e.valueType.IsNullable() && reflect2.IsNil(obj) {
   110  		stream.WriteNil()
   111  		return
   112  	}
   113  
   114  	bytes, err := obj.(MarshalerContext).MarshalJSONContext(ctx)
   115  	if err != nil {
   116  		stream.Error = err
   117  		return
   118  	}
   119  	// html escape was already done by jsoniter but the extra '\n' should be trimed
   120  	if l := len(bytes); l > 0 && bytes[l-1] == '\n' {
   121  		bytes = bytes[:l-1]
   122  	}
   123  	stream.Write(bytes)
   124  }
   125  
   126  func (e *marshalerContextEncoder) IsEmpty(ctx context.Context, ptr unsafe.Pointer, checkZero bool) bool {
   127  	return e.checkIsEmpty.IsEmpty(ctx, ptr, checkZero)
   128  }
   129  
   130  type directMarshalerEncoder struct {
   131  	checkIsEmpty checkIsEmpty
   132  }
   133  
   134  func (e *directMarshalerEncoder) Encode(_ context.Context, ptr unsafe.Pointer, stream *Stream) {
   135  	marshaler := *(*json.Marshaler)(ptr)
   136  	if marshaler == nil {
   137  		stream.WriteNil()
   138  		return
   139  	}
   140  	if bytes, err := marshaler.MarshalJSON(); err != nil {
   141  		stream.Error = err
   142  	} else {
   143  		stream.Write(bytes)
   144  	}
   145  }
   146  
   147  func (e *directMarshalerEncoder) IsEmpty(ctx context.Context, ptr unsafe.Pointer, checkZero bool) bool {
   148  	return e.checkIsEmpty.IsEmpty(ctx, ptr, checkZero)
   149  }
   150  
   151  type directMarshalerContextEncoder struct {
   152  	checkIsEmpty checkIsEmpty
   153  }
   154  
   155  func (e *directMarshalerContextEncoder) Encode(ctx context.Context, ptr unsafe.Pointer, stream *Stream) {
   156  	marshaler := *(*MarshalerContext)(ptr)
   157  	if marshaler == nil {
   158  		stream.WriteNil()
   159  		return
   160  	}
   161  	if bytes, err := marshaler.MarshalJSONContext(ctx); err != nil {
   162  		stream.Error = err
   163  	} else {
   164  		stream.Write(bytes)
   165  	}
   166  }
   167  
   168  func (e *directMarshalerContextEncoder) IsEmpty(ctx context.Context, ptr unsafe.Pointer, checkZero bool) bool {
   169  	return e.checkIsEmpty.IsEmpty(ctx, ptr, checkZero)
   170  }
   171  
   172  type textMarshalerEncoder struct {
   173  	valType       reflect2.Type
   174  	stringEncoder ValEncoder
   175  	checkIsEmpty  checkIsEmpty
   176  }
   177  
   178  func (e *textMarshalerEncoder) Encode(ctx context.Context, ptr unsafe.Pointer, stream *Stream) {
   179  	obj := e.valType.UnsafeIndirect(ptr)
   180  	if e.valType.IsNullable() && reflect2.IsNil(obj) {
   181  		stream.WriteNil()
   182  		return
   183  	}
   184  	marshaler := (obj).(encoding.TextMarshaler)
   185  	if bytes, err := marshaler.MarshalText(); err != nil {
   186  		stream.Error = err
   187  	} else {
   188  		str := string(bytes)
   189  		e.stringEncoder.Encode(ctx, unsafe.Pointer(&str), stream)
   190  	}
   191  }
   192  
   193  func (e *textMarshalerEncoder) IsEmpty(ctx context.Context, ptr unsafe.Pointer, checkZero bool) bool {
   194  	return e.checkIsEmpty.IsEmpty(ctx, ptr, checkZero)
   195  }
   196  
   197  type directTextMarshalerEncoder struct {
   198  	stringEncoder ValEncoder
   199  	checkIsEmpty  checkIsEmpty
   200  }
   201  
   202  func (e *directTextMarshalerEncoder) Encode(ctx context.Context, ptr unsafe.Pointer, stream *Stream) {
   203  	marshaler := *(*encoding.TextMarshaler)(ptr)
   204  	if marshaler == nil {
   205  		stream.WriteNil()
   206  		return
   207  	}
   208  	if bytes, err := marshaler.MarshalText(); err != nil {
   209  		stream.Error = err
   210  	} else {
   211  		str := string(bytes)
   212  		e.stringEncoder.Encode(ctx, unsafe.Pointer(&str), stream)
   213  	}
   214  }
   215  
   216  func (e *directTextMarshalerEncoder) IsEmpty(ctx context.Context, p unsafe.Pointer, checkZero bool) bool {
   217  	return e.checkIsEmpty.IsEmpty(ctx, p, checkZero)
   218  }
   219  
   220  type unmarshalerDecoder struct{ valType reflect2.Type }
   221  
   222  func (d *unmarshalerDecoder) Decode(_ context.Context, ptr unsafe.Pointer, iter *Iterator) {
   223  	obj := d.valType.UnsafeIndirect(ptr)
   224  	iter.nextToken()
   225  	iter.unreadByte() // skip spaces
   226  	bytes := iter.SkipAndReturnBytes()
   227  	if err := obj.(json.Unmarshaler).UnmarshalJSON(bytes); err != nil {
   228  		iter.ReportError("unmarshalerDecoder", err.Error())
   229  	}
   230  }
   231  
   232  type unmarshalerContextDecoder struct{ valType reflect2.Type }
   233  
   234  func (d *unmarshalerContextDecoder) Decode(ctx context.Context, ptr unsafe.Pointer, iter *Iterator) {
   235  	obj := d.valType.UnsafeIndirect(ptr)
   236  	iter.nextToken()
   237  	iter.unreadByte() // skip spaces
   238  	bytes := iter.SkipAndReturnBytes()
   239  	if err := obj.(UnmarshalerContext).UnmarshalJSONContext(ctx, bytes); err != nil {
   240  		iter.ReportError("unmarshalerDecoder", err.Error())
   241  	}
   242  }
   243  
   244  type textUnmarshalerDecoder struct {
   245  	valType reflect2.Type
   246  }
   247  
   248  func (d *textUnmarshalerDecoder) Decode(ctx context.Context, ptr unsafe.Pointer, iter *Iterator) {
   249  	valType := d.valType
   250  	obj := valType.UnsafeIndirect(ptr)
   251  	if reflect2.IsNil(obj) {
   252  		ptrType := valType.(*reflect2.UnsafePtrType)
   253  		elemType := ptrType.Elem()
   254  		elem := elemType.UnsafeNew()
   255  		ptrType.UnsafeSet(ptr, unsafe.Pointer(&elem))
   256  		obj = valType.UnsafeIndirect(ptr)
   257  	}
   258  	unmarshaler := (obj).(encoding.TextUnmarshaler)
   259  	str := iter.ReadString()
   260  	if err := unmarshaler.UnmarshalText([]byte(str)); err != nil {
   261  		iter.ReportError("textUnmarshalerDecoder", err.Error())
   262  	}
   263  }