github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/jsoni/reflect_map.go (about)

     1  package jsoni
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"reflect"
     8  	"sort"
     9  	"unsafe"
    10  
    11  	"github.com/modern-go/reflect2"
    12  )
    13  
    14  type MapEntryEncoder interface {
    15  	EntryEncoder() (keyEncoder, elemEncoder ValEncoder)
    16  }
    17  
    18  type MapEntryDecoder interface {
    19  	EntryDecoder() (keyDecoder, elemDecoder ValDecoder)
    20  }
    21  
    22  func (e *sortKeysMapEncoder) EntryEncoder() (k, v ValEncoder) { return e.keyEncoder, e.elemEncoder }
    23  func (e *mapEncoder) EntryEncoder() (k, v ValEncoder)         { return e.keyEncoder, e.elemEncoder }
    24  func (d *mapDecoder) EntryDecoder() (k, v ValDecoder)         { return d.keyDecoder, d.elemDecoder }
    25  
    26  func decoderOfMap(ctx *ctx, typ reflect2.Type) ValDecoder {
    27  	mapType := typ.(*reflect2.UnsafeMapType)
    28  	return &mapDecoder{
    29  		mapType:     mapType,
    30  		keyType:     mapType.Key(),
    31  		elemType:    mapType.Elem(),
    32  		keyDecoder:  decoderOfMapKey(ctx.append("[mapKey]"), mapType.Key()),
    33  		elemDecoder: decoderOfType(ctx.append("[mapElem]"), mapType.Elem()),
    34  	}
    35  }
    36  
    37  func encoderOfMap(ctx *ctx, typ reflect2.Type) ValEncoder {
    38  	mapType := typ.(*reflect2.UnsafeMapType)
    39  	keyEncoder := encoderOfMapKey(ctx.append("[mapKey]"), mapType.Key())
    40  	elemEncoder := encoderOfType(ctx.append("[mapElem]"), mapType.Elem())
    41  	encoder := &mapEncoder{mapType: mapType, keyEncoder: keyEncoder, elemEncoder: elemEncoder, omitempty: ctx.omitEmptyMapKeys}
    42  	if ctx.sortMapKeys {
    43  		return &sortKeysMapEncoder{mapEncoder: encoder}
    44  	}
    45  	return encoder
    46  }
    47  
    48  func decoderOfMapKey(ctx *ctx, typ reflect2.Type) ValDecoder {
    49  	if decoder := ctx.decoderExtension.CreateMapKeyDecoder(typ); decoder != nil {
    50  		return decoder
    51  	}
    52  	if decoder := ctx.extensions.CreateMapKeyDecoder(typ); decoder != nil {
    53  		return decoder
    54  	}
    55  
    56  	ptrType := reflect2.PtrTo(typ)
    57  	if ptrType.Implements(unmarshalerType) {
    58  		return &referenceDecoder{decoder: &unmarshalerDecoder{valType: ptrType}}
    59  	}
    60  	if ptrType.Implements(unmarshalerContextType) {
    61  		return &referenceDecoder{decoder: &unmarshalerContextDecoder{valType: ptrType}}
    62  	}
    63  	if typ.Implements(unmarshalerType) {
    64  		return &unmarshalerDecoder{valType: typ}
    65  	}
    66  	if typ.Implements(unmarshalerContextType) {
    67  		return &unmarshalerContextDecoder{valType: typ}
    68  	}
    69  	if ptrType.Implements(textUnmarshalerType) {
    70  		return &referenceDecoder{decoder: &textUnmarshalerDecoder{valType: ptrType}}
    71  	}
    72  	if typ.Implements(textUnmarshalerType) {
    73  		return &textUnmarshalerDecoder{valType: typ}
    74  	}
    75  
    76  	switch typ.Kind() {
    77  	case reflect.String:
    78  		return decoderOfType(ctx, reflect2.DefaultTypeOfKind(reflect.String))
    79  	case reflect.Bool,
    80  		reflect.Uint8, reflect.Int8, reflect.Uint16, reflect.Int16, reflect.Uint32, reflect.Int32,
    81  		reflect.Uint64, reflect.Int64, reflect.Uint, reflect.Int,
    82  		reflect.Float32, reflect.Float64, reflect.Uintptr:
    83  		typ = reflect2.DefaultTypeOfKind(typ.Kind())
    84  		return &numericMapKeyDecoder{decoder: decoderOfType(ctx, typ)}
    85  	default:
    86  		return &lazyErrorDecoder{err: fmt.Errorf("unsupported map key type: %v", typ)}
    87  	}
    88  }
    89  
    90  func encoderOfMapKey(ctx *ctx, typ reflect2.Type) ValEncoder {
    91  	if encoder := ctx.encoderExtension.CreateMapKeyEncoder(typ); encoder != nil {
    92  		return encoder
    93  	}
    94  	if encoder := ctx.extensions.CreateMapKeyEncoder(typ); encoder != nil {
    95  		return encoder
    96  	}
    97  	if typ == textMarshalerType {
    98  		return &directTextMarshalerEncoder{stringEncoder: ctx.EncoderOf(reflect2.TypeOf(""))}
    99  	}
   100  	if typ.Implements(textMarshalerType) {
   101  		return &textMarshalerEncoder{valType: typ, stringEncoder: ctx.EncoderOf(reflect2.TypeOf(""))}
   102  	}
   103  
   104  	switch typ.Kind() {
   105  	case reflect.String:
   106  		return encoderOfType(ctx, reflect2.DefaultTypeOfKind(reflect.String))
   107  	case reflect.Bool,
   108  		reflect.Uint8, reflect.Int8, reflect.Uint16, reflect.Int16, reflect.Uint32, reflect.Int32,
   109  		reflect.Uint64, reflect.Int64, reflect.Uint, reflect.Int,
   110  		reflect.Float32, reflect.Float64, reflect.Uintptr:
   111  		typ = reflect2.DefaultTypeOfKind(typ.Kind())
   112  		return &numericMapKeyEncoder{encoder: encoderOfType(ctx, typ)}
   113  	default:
   114  		if typ.Kind() == reflect.Interface {
   115  			return &dynamicMapKeyEncoder{ctx: ctx, valType: typ}
   116  		}
   117  		return &lazyErrorEncoder{err: fmt.Errorf("unsupported map key type: %v", typ)}
   118  	}
   119  }
   120  
   121  type mapDecoder struct {
   122  	mapType     *reflect2.UnsafeMapType
   123  	keyType     reflect2.Type
   124  	elemType    reflect2.Type
   125  	keyDecoder  ValDecoder
   126  	elemDecoder ValDecoder
   127  }
   128  
   129  func (d *mapDecoder) Decode(ctx context.Context, ptr unsafe.Pointer, iter *Iterator) {
   130  	mapType := d.mapType
   131  	c := iter.nextToken()
   132  	if c == 'n' {
   133  		iter.skip3Bytes('u', 'l', 'l')
   134  		*(*unsafe.Pointer)(ptr) = nil
   135  		mapType.UnsafeSet(ptr, mapType.UnsafeNew())
   136  		return
   137  	}
   138  	if mapType.UnsafeIsNil(ptr) {
   139  		mapType.UnsafeSet(ptr, mapType.UnsafeMakeMap(0))
   140  	}
   141  	if c != '{' {
   142  		iter.ReportError("ReadMapCB", `expect { or n, but found `+string([]byte{c}))
   143  		return
   144  	}
   145  	c = iter.nextToken()
   146  	if c == '}' {
   147  		return
   148  	}
   149  	iter.unreadByte()
   150  	key := d.keyType.UnsafeNew()
   151  	d.keyDecoder.Decode(ctx, key, iter)
   152  	c = iter.nextToken()
   153  	if c != ':' {
   154  		iter.ReportError("ReadMapCB", "expect : after object field, but found "+string([]byte{c}))
   155  		return
   156  	}
   157  	elem := d.elemType.UnsafeNew()
   158  	d.elemDecoder.Decode(ctx, elem, iter)
   159  	d.mapType.UnsafeSetIndex(ptr, key, elem)
   160  	for c = iter.nextToken(); c == ','; c = iter.nextToken() {
   161  		key := d.keyType.UnsafeNew()
   162  		d.keyDecoder.Decode(ctx, key, iter)
   163  		c = iter.nextToken()
   164  		if c != ':' {
   165  			iter.ReportError("ReadMapCB", "expect : after object field, but found "+string([]byte{c}))
   166  			return
   167  		}
   168  		elem := d.elemType.UnsafeNew()
   169  		d.elemDecoder.Decode(ctx, elem, iter)
   170  		d.mapType.UnsafeSetIndex(ptr, key, elem)
   171  	}
   172  	if c != '}' {
   173  		iter.ReportError("ReadMapCB", `expect }, but found `+string([]byte{c}))
   174  	}
   175  }
   176  
   177  type numericMapKeyDecoder struct {
   178  	decoder ValDecoder
   179  }
   180  
   181  func (d *numericMapKeyDecoder) Decode(ctx context.Context, ptr unsafe.Pointer, iter *Iterator) {
   182  	c := iter.nextToken()
   183  	if c != '"' {
   184  		iter.ReportError("ReadMapCB", `expect ", but found `+string([]byte{c}))
   185  		return
   186  	}
   187  	d.decoder.Decode(ctx, ptr, iter)
   188  	c = iter.nextToken()
   189  	if c != '"' {
   190  		iter.ReportError("ReadMapCB", `expect ", but found `+string([]byte{c}))
   191  		return
   192  	}
   193  }
   194  
   195  type numericMapKeyEncoder struct {
   196  	encoder ValEncoder
   197  }
   198  
   199  func (u *numericMapKeyEncoder) Encode(ctx context.Context, ptr unsafe.Pointer, stream *Stream) {
   200  	stream.writeByte('"')
   201  	u.encoder.Encode(ctx, ptr, stream)
   202  	stream.writeByte('"')
   203  }
   204  
   205  func (u *numericMapKeyEncoder) IsEmpty(context.Context, unsafe.Pointer, bool) bool { return false }
   206  
   207  type dynamicMapKeyEncoder struct {
   208  	ctx     *ctx
   209  	valType reflect2.Type
   210  }
   211  
   212  func (e *dynamicMapKeyEncoder) Encode(ctx context.Context, ptr unsafe.Pointer, stream *Stream) {
   213  	obj := e.valType.UnsafeIndirect(ptr)
   214  	encoderOfMapKey(e.ctx, reflect2.TypeOf(obj)).Encode(ctx, reflect2.PtrOf(obj), stream)
   215  }
   216  
   217  func (e *dynamicMapKeyEncoder) IsEmpty(ctx context.Context, ptr unsafe.Pointer, checkZero bool) bool {
   218  	obj := e.valType.UnsafeIndirect(ptr)
   219  	return encoderOfMapKey(e.ctx, reflect2.TypeOf(obj)).IsEmpty(ctx, reflect2.PtrOf(obj), checkZero)
   220  }
   221  
   222  type mapEncoder struct {
   223  	mapType     *reflect2.UnsafeMapType
   224  	keyEncoder  ValEncoder
   225  	elemEncoder ValEncoder
   226  	omitempty   bool
   227  }
   228  
   229  type sortKeysMapEncoder struct {
   230  	*mapEncoder
   231  }
   232  
   233  func (e *mapEncoder) Encode(ctx context.Context, ptr unsafe.Pointer, stream *Stream) {
   234  	if *(*unsafe.Pointer)(ptr) == nil {
   235  		stream.WriteNil()
   236  		return
   237  	}
   238  	stream.WriteObjectStart()
   239  	iter := e.mapType.UnsafeIterate(ptr)
   240  	for i := 0; iter.HasNext(); {
   241  		if i > 0 {
   242  			stream.WriteMore()
   243  		}
   244  		key, elem := iter.UnsafeNext()
   245  		if e.omitempty && e.elemEncoder.IsEmpty(ctx, elem, true) {
   246  			continue
   247  		}
   248  
   249  		e.keyEncoder.Encode(ctx, key, stream)
   250  		if stream.indention > 0 {
   251  			stream.write2Bytes(':', ' ')
   252  		} else {
   253  			stream.writeByte(':')
   254  		}
   255  		e.elemEncoder.Encode(ctx, elem, stream)
   256  		i++
   257  	}
   258  	stream.WriteObjectEnd()
   259  }
   260  
   261  func (e *mapEncoder) IsEmpty(_ context.Context, ptr unsafe.Pointer, _ bool) bool {
   262  	iter := e.mapType.UnsafeIterate(ptr)
   263  	return !iter.HasNext()
   264  }
   265  
   266  func (e *sortKeysMapEncoder) Encode(ctx context.Context, ptr unsafe.Pointer, stream *Stream) {
   267  	if *(*unsafe.Pointer)(ptr) == nil {
   268  		stream.WriteNil()
   269  		return
   270  	}
   271  	stream.WriteObjectStart()
   272  	mapIter := e.mapType.UnsafeIterate(ptr)
   273  	subStream := stream.cfg.BorrowStream(nil)
   274  	subStream.Attachment = stream.Attachment
   275  	subIter := stream.cfg.BorrowIterator(nil)
   276  	keyValues := encodedKvs{}
   277  	for mapIter.HasNext() {
   278  		key, elem := mapIter.UnsafeNext()
   279  		if e.omitempty && e.elemEncoder.IsEmpty(ctx, elem, true) {
   280  			continue
   281  		}
   282  
   283  		subStreamIndex := subStream.Buffered()
   284  		e.keyEncoder.Encode(ctx, key, subStream)
   285  		if subStream.Error != nil && subStream.Error != io.EOF && stream.Error == nil {
   286  			stream.Error = subStream.Error
   287  		}
   288  		encodedKey := subStream.Buffer()[subStreamIndex:]
   289  		subIter.ResetBytes(encodedKey)
   290  		decodedKey := subIter.ReadString()
   291  		if stream.indention > 0 {
   292  			subStream.write2Bytes(byte(':'), byte(' '))
   293  		} else {
   294  			subStream.writeByte(':')
   295  		}
   296  		e.elemEncoder.Encode(ctx, elem, subStream)
   297  		keyValues = append(keyValues, encodedKv{
   298  			key: decodedKey,
   299  			val: subStream.Buffer()[subStreamIndex:],
   300  		})
   301  	}
   302  	sort.Sort(keyValues)
   303  	for i, keyValue := range keyValues {
   304  		if i > 0 {
   305  			stream.WriteMore()
   306  		}
   307  		_, _ = stream.Write(keyValue.val)
   308  	}
   309  	if subStream.Error != nil && stream.Error == nil {
   310  		stream.Error = subStream.Error
   311  	}
   312  	stream.WriteObjectEnd()
   313  	stream.cfg.ReturnStream(subStream)
   314  	stream.cfg.ReturnIterator(subIter)
   315  }
   316  
   317  func (e *sortKeysMapEncoder) IsEmpty(_ context.Context, ptr unsafe.Pointer, _ bool) bool {
   318  	iter := e.mapType.UnsafeIterate(ptr)
   319  	return !iter.HasNext()
   320  }
   321  
   322  type encodedKvs []encodedKv
   323  
   324  type encodedKv struct {
   325  	key string
   326  	val []byte
   327  }
   328  
   329  func (r encodedKvs) Len() int           { return len(r) }
   330  func (r encodedKvs) Swap(i, j int)      { r[i], r[j] = r[j], r[i] }
   331  func (r encodedKvs) Less(i, j int) bool { return r[i].key < r[j].key }