github.com/zerosnake0/jzon@v0.0.9-0.20230801092939-1b135cb83f7f/encoder_config.go (about)

     1  package jzon
     2  
     3  import (
     4  	"io"
     5  	"reflect"
     6  	"sync"
     7  	"sync/atomic"
     8  )
     9  
    10  var (
    11  	// DefaultEncoderConfig is compatible with standard lib
    12  	DefaultEncoderConfig = NewEncoderConfig(nil)
    13  )
    14  
    15  // EncoderOption can be used to customize the encoder config
    16  type EncoderOption struct {
    17  	ValEncoders   map[reflect.Type]ValEncoder
    18  	IfaceEncoders []IfaceValEncoderConfig
    19  
    20  	EscapeHTML      bool
    21  	Tag             string
    22  	OnlyTaggedField bool
    23  }
    24  
    25  type encoderCache map[rtype]ValEncoder
    26  
    27  func (cache encoderCache) has(rtype rtype) bool {
    28  	_, ok := cache[rtype]
    29  	return ok
    30  }
    31  
    32  // make sure that the pointer encoders has already been rebuilt
    33  // before calling, so it's safe to use it's internal encoder
    34  func (cache encoderCache) preferPtrEncoder(typ reflect.Type) ValEncoder {
    35  	ptrType := reflect.PtrTo(typ)
    36  	ptrEncoder := cache[rtypeOfType(ptrType)]
    37  	if pe, ok := ptrEncoder.(*pointerEncoder); ok {
    38  		return pe.encoder
    39  	}
    40  	// the element has a special pointer encoder
    41  	return &directEncoder{ptrEncoder}
    42  }
    43  
    44  // EncoderConfig is a frozen config for encoding
    45  type EncoderConfig struct {
    46  	cacheMu sync.Mutex
    47  	// the encoder cache, or root encoder cache
    48  	encoderCache atomic.Value
    49  	// the internal cache
    50  	internalCache encoderCache
    51  	// iface encoders
    52  	ifaceEncoderMap map[rtype]ValEncoder
    53  	ifaceEncoder    []IfaceValEncoderConfig
    54  
    55  	tag             string
    56  	onlyTaggedField bool
    57  
    58  	// can override during runtime
    59  	escapeHTML bool
    60  }
    61  
    62  func (encCfg *EncoderConfig) addIfaceEncoder(cfg IfaceValEncoderConfig) {
    63  	rt := rtypeOfType(cfg.Type)
    64  	if encCfg.ifaceEncoderMap[rt] != nil {
    65  		return
    66  	}
    67  
    68  	// get the pointer type
    69  	ptrRt := rtypeOfType(reflect.New(cfg.Type).Type())
    70  	encCfg.ifaceEncoderMap[rt] = &dynamicIfaceValEncoder{
    71  		rtype:   ptrRt,
    72  		encoder: cfg.Encoder,
    73  	}
    74  	encCfg.ifaceEncoder = append(encCfg.ifaceEncoder, cfg)
    75  }
    76  
    77  // NewEncoderConfig returns a new encoder config
    78  // If the input option is nil, the default option will be applied
    79  func NewEncoderConfig(opt *EncoderOption) *EncoderConfig {
    80  	encCfg := EncoderConfig{
    81  		tag:             "json",
    82  		escapeHTML:      true,
    83  		ifaceEncoderMap: map[rtype]ValEncoder{},
    84  	}
    85  	cache := encoderCache{}
    86  	internalCache := encoderCache{}
    87  	if opt != nil {
    88  		for typ, valEnc := range opt.ValEncoders {
    89  			rtype := rtypeOfType(typ)
    90  			cache[rtype] = valEnc
    91  			internalCache[rtype] = valEnc
    92  		}
    93  		encCfg.escapeHTML = opt.EscapeHTML
    94  		if opt.Tag != "" {
    95  			encCfg.tag = opt.Tag
    96  		}
    97  		encCfg.onlyTaggedField = opt.OnlyTaggedField
    98  
    99  		// iface
   100  		encCfg.ifaceEncoder = make([]IfaceValEncoderConfig, 0, len(opt.IfaceEncoders))
   101  		for _, enc := range opt.IfaceEncoders {
   102  			encCfg.addIfaceEncoder(enc)
   103  		}
   104  	}
   105  	encCfg.addIfaceEncoder(IfaceValEncoderConfig{
   106  		Type:    jsonMarshalerType,
   107  		Encoder: jsonMarshalerValEncoder{},
   108  	})
   109  	encCfg.addIfaceEncoder(IfaceValEncoderConfig{
   110  		Type:    textMarshalerType,
   111  		Encoder: textMarshalerValEncoder{},
   112  	})
   113  
   114  	encCfg.encoderCache.Store(cache)
   115  	encCfg.internalCache = internalCache
   116  	return &encCfg
   117  }
   118  
   119  // Marshal behave like json.Marshal
   120  func (encCfg *EncoderConfig) Marshal(obj interface{}) ([]byte, error) {
   121  	s := encCfg.NewStreamer()
   122  	defer s.Release()
   123  	s.Value(obj)
   124  	if s.Error != nil {
   125  		return nil, s.Error
   126  	}
   127  	// we make a new slice with explicit size,
   128  	//   1. the internal buffer may be much longer than the output one,
   129  	//      it can be used for longer output
   130  	//   2. avoid calling bytes buffer pool (sync.Pool)
   131  	b := make([]byte, len(s.buffer))
   132  	copy(b, s.buffer)
   133  	return b, nil
   134  }
   135  
   136  // NewEncoder returns a new encoder that writes to w.
   137  func (encCfg *EncoderConfig) NewEncoder(w io.Writer) *Encoder {
   138  	s := encCfg.NewStreamer()
   139  	s.Reset(w)
   140  	return &Encoder{
   141  		s: s,
   142  	}
   143  }
   144  
   145  func (encCfg *EncoderConfig) getEncoderFromCache(rtype rtype) ValEncoder {
   146  	return encCfg.encoderCache.Load().(encoderCache)[rtype]
   147  }
   148  
   149  func (encCfg *EncoderConfig) createEncoder(rtype rtype, typ reflect.Type) ValEncoder {
   150  	encCfg.cacheMu.Lock()
   151  	defer encCfg.cacheMu.Unlock()
   152  	cache := encCfg.encoderCache.Load().(encoderCache)
   153  	// double check
   154  	if ve := cache[rtype]; ve != nil {
   155  		return ve
   156  	}
   157  	newCache := encoderCache{}
   158  	for k, v := range cache {
   159  		newCache[k] = v
   160  	}
   161  	var q typeQueue
   162  	q.push(typ)
   163  	encCfg.createEncoderInternal(newCache, encCfg.internalCache, q)
   164  	encCfg.encoderCache.Store(newCache)
   165  	return newCache[rtype]
   166  }
   167  
   168  func (encCfg *EncoderConfig) createEncoderInternal(cache, internalCache encoderCache, typesToCreate typeQueue) {
   169  	rebuildMap := map[rtype]interface{}{}
   170  OuterLoop:
   171  	for typ := typesToCreate.pop(); typ != nil; typ = typesToCreate.pop() {
   172  		rType := rtypeOfType(typ)
   173  		if internalCache.has(rType) { // check if visited
   174  			continue
   175  		}
   176  
   177  		// check local encoders
   178  		if v, ok := encCfg.ifaceEncoderMap[rType]; ok {
   179  			internalCache[rType] = v
   180  			cache[rType] = v
   181  			continue
   182  		}
   183  
   184  		// check global encoders
   185  		if v, ok := globalValEncoders[rType]; ok {
   186  			internalCache[rType] = v
   187  			cache[rType] = v
   188  			continue
   189  		}
   190  
   191  		kind := typ.Kind()
   192  
   193  		for _, ienc := range encCfg.ifaceEncoder {
   194  			if typ.Implements(ienc.Type) {
   195  				if ifaceIndir(rType) {
   196  					v := &ifaceValEncoder{
   197  						isEmpty: isEmptyFunctions[kind],
   198  						encoder: ienc.Encoder,
   199  						rtype:   rType,
   200  					}
   201  					internalCache[rType] = v
   202  					cache[rType] = v
   203  					continue OuterLoop
   204  				}
   205  				if typ.Kind() == reflect.Ptr {
   206  					elemType := typ.Elem()
   207  					if elemType.Implements(ienc.Type) {
   208  						typesToCreate.push(elemType)
   209  						w := newPointerEncoder(elemType)
   210  						internalCache[rType] = w.encoder
   211  						rebuildMap[rType] = w
   212  					} else {
   213  						v := &pointerIfaceValEncoder{
   214  							encoder: ienc.Encoder,
   215  							rtype:   rType,
   216  						}
   217  						internalCache[rType] = v
   218  						cache[rType] = &directEncoder{v}
   219  					}
   220  					continue OuterLoop
   221  				}
   222  				v := &directIfaceValEncoder{
   223  					isEmpty: isEmptyFunctions[kind],
   224  					encoder: ienc.Encoder,
   225  					rtype:   rType,
   226  				}
   227  				internalCache[rType] = v
   228  				cache[rType] = &directEncoder{v}
   229  				continue OuterLoop
   230  			}
   231  		}
   232  
   233  		//// check json.Marshaler interface
   234  		//if typ.Implements(jsonMarshalerType) {
   235  		//	if ifaceIndir(rType) {
   236  		//		v := &jsonMarshalerEncoder{
   237  		//			isEmpty: isEmptyFunctions[kind],
   238  		//			rtype:   rType,
   239  		//		}
   240  		//		internalCache[rType] = v
   241  		//		cache[rType] = v
   242  		//		continue
   243  		//	}
   244  		//	if typ.Kind() == reflect.Ptr {
   245  		//		elemType := typ.Elem()
   246  		//		if elemType.Implements(jsonMarshalerType) {
   247  		//			// treat as a pointer encoder
   248  		//			typesToCreate.push(elemType)
   249  		//			w := newPointerEncoder(elemType)
   250  		//			internalCache[rType] = w.encoder
   251  		//			rebuildMap[rType] = w
   252  		//		} else {
   253  		//			v := pointerJSONMarshalerEncoder(rType)
   254  		//			internalCache[rType] = v
   255  		//			cache[rType] = &directEncoder{v}
   256  		//		}
   257  		//		continue
   258  		//	}
   259  		//	v := &directJSONMarshalerEncoder{
   260  		//		isEmpty: isEmptyFunctions[kind],
   261  		//		rtype:   rType,
   262  		//	}
   263  		//	internalCache[rType] = v
   264  		//	cache[rType] = &directEncoder{v}
   265  		//	continue
   266  		//}
   267  		//
   268  		//// check encoding.TextMarshaler interface
   269  		//if typ.Implements(textMarshalerType) {
   270  		//	if ifaceIndir(rType) {
   271  		//		v := &textMarshalerEncoder{
   272  		//			isEmpty: isEmptyFunctions[kind],
   273  		//			rtype:   rType,
   274  		//		}
   275  		//		internalCache[rType] = v
   276  		//		cache[rType] = v
   277  		//		continue
   278  		//	}
   279  		//	if typ.Kind() == reflect.Ptr {
   280  		//		elemType := typ.Elem()
   281  		//		if elemType.Implements(textMarshalerType) {
   282  		//			// treat as a pointer encoder
   283  		//			typesToCreate.push(elemType)
   284  		//			w := newPointerEncoder(elemType)
   285  		//			internalCache[rType] = w.encoder
   286  		//			rebuildMap[rType] = w
   287  		//		} else {
   288  		//			v := pointerTextMarshalerEncoder(rType)
   289  		//			internalCache[rType] = v
   290  		//			cache[rType] = &directEncoder{v}
   291  		//		}
   292  		//		continue
   293  		//	}
   294  		//	v := &directTextMarshalerEncoder{
   295  		//		isEmpty: isEmptyFunctions[kind],
   296  		//		rtype:   rType,
   297  		//	}
   298  		//	internalCache[rType] = v
   299  		//	cache[rType] = &directEncoder{v}
   300  		//	continue
   301  		//}
   302  
   303  		if kindRType := encoderKindMap[kind]; kindRType != 0 {
   304  			// TODO: shall we make this an option?
   305  			// TODO: so that only the native type is affected?
   306  			// check if the native type has a custom encoder
   307  			if v, ok := internalCache[kindRType]; ok {
   308  				internalCache[rType] = v
   309  				cache[rType] = v
   310  				continue
   311  			}
   312  
   313  			if v := kindEncoders[kind]; v != nil {
   314  				internalCache[rType] = v
   315  				cache[rType] = v
   316  				continue
   317  			}
   318  		}
   319  
   320  		switch kind {
   321  		case reflect.Ptr:
   322  			elemType := typ.Elem()
   323  			typesToCreate.push(elemType)
   324  			w := newPointerEncoder(elemType)
   325  			internalCache[rType] = w.encoder
   326  			rebuildMap[rType] = w
   327  		case reflect.Array:
   328  			elemType := typ.Elem()
   329  			typesToCreate.push(reflect.PtrTo(elemType))
   330  			if typ.Len() == 0 {
   331  				v := (*emptyArrayEncoder)(nil)
   332  				internalCache[rType] = v
   333  				cache[rType] = v
   334  			} else {
   335  				w := newArrayEncoder(typ)
   336  				internalCache[rType] = w.encoder
   337  				rebuildMap[rType] = w
   338  			}
   339  		case reflect.Interface:
   340  			var v ValEncoder
   341  			if typ.NumMethod() == 0 {
   342  				v = (*efaceEncoder)(nil)
   343  			} else {
   344  				v = (*ifaceEncoder)(nil)
   345  			}
   346  			internalCache[rType] = v
   347  			cache[rType] = v
   348  		case reflect.Map:
   349  			w := newMapEncoder(typ)
   350  			if w == nil {
   351  				v := notSupportedEncoder(typ.String())
   352  				internalCache[rType] = v
   353  				cache[rType] = v
   354  			} else {
   355  				typesToCreate.push(typ.Elem())
   356  				// pointer decoder is a reverse of direct encoder
   357  				internalCache[rType] = w.encoder
   358  				rebuildMap[rType] = w
   359  			}
   360  		case reflect.Slice:
   361  			w := newSliceEncoder(typ)
   362  			typesToCreate.push(reflect.PtrTo(typ.Elem()))
   363  			internalCache[rType] = w.encoder
   364  			rebuildMap[rType] = w
   365  		case reflect.Struct:
   366  			w := encCfg.newStructEncoder(typ)
   367  			if w == nil {
   368  				// no fields to marshal
   369  				v := (*emptyStructEncoder)(nil)
   370  				internalCache[rType] = v
   371  				cache[rType] = v
   372  			} else {
   373  				for i := range w.fields {
   374  					fi := &w.fields[i]
   375  					typesToCreate.push(fi.ptrType)
   376  				}
   377  				internalCache[rType] = w.encoder
   378  				rebuildMap[rType] = w
   379  			}
   380  		default:
   381  			v := notSupportedEncoder(typ.String())
   382  			internalCache[rType] = v
   383  			cache[rType] = v
   384  		}
   385  	}
   386  	// rebuild base64 encoders
   387  	for rType, builder := range rebuildMap {
   388  		switch x := builder.(type) {
   389  		case *sliceEncoderBuilder:
   390  			if x.elemType.Kind() != reflect.Uint8 {
   391  				continue
   392  			}
   393  			elemPtrType := reflect.PtrTo(x.elemType)
   394  			elemPtrEncoder := internalCache[rtypeOfType(elemPtrType)]
   395  			if _, ok := elemPtrEncoder.(*pointerEncoder); !ok {
   396  				// the element has a special pointer encoder
   397  				continue
   398  			}
   399  			// the pointer decoder has not been rebuilt yet
   400  			// we need to use the explicit element rtype
   401  			elemEncoder := internalCache[rtypeOfType(x.elemType)]
   402  			if elemEncoder != (*uint8Encoder)(nil) {
   403  				// the element has a special value encoder
   404  				continue
   405  			}
   406  			v := (*base64Encoder)(nil)
   407  			internalCache[rType] = v
   408  			cache[rType] = v
   409  			delete(rebuildMap, rType)
   410  		}
   411  	}
   412  	// rebuild ptr encoders
   413  	for rType, builder := range rebuildMap {
   414  		switch x := builder.(type) {
   415  		case *pointerEncoderBuilder:
   416  			v := internalCache[x.elemRType]
   417  			x.encoder.encoder = v
   418  			cache[rType] = v
   419  			delete(rebuildMap, rType)
   420  		}
   421  	}
   422  	// rebuild other encoders
   423  	for rType, builder := range rebuildMap {
   424  		switch x := builder.(type) {
   425  		case *arrayEncoderBuilder:
   426  			x.encoder.encoder = internalCache.preferPtrEncoder(x.elemType)
   427  			if ifaceIndir(rType) {
   428  				cache[rType] = x.encoder
   429  			} else {
   430  				// (see reflect.ArrayOf)
   431  				// when the array is stored in interface directly, it means:
   432  				// 1. the length of array is 1
   433  				// 2. the element of the array is also directly saved
   434  				cache[rType] = &directEncoder{x.encoder}
   435  			}
   436  		case *mapEncoderBuilder:
   437  			// TODO: key encoder
   438  			x.encoder.elemEncoder = internalCache[x.elemRType]
   439  			cache[rType] = &directEncoder{x.encoder}
   440  		case *sliceEncoderBuilder:
   441  			x.encoder.elemEncoder = internalCache.preferPtrEncoder(x.elemType)
   442  			cache[rType] = x.encoder
   443  		case *structEncoderBuilder:
   444  			x.encoder.fields.init(len(x.fields))
   445  			for i := range x.fields {
   446  				fi := &x.fields[i]
   447  				v := internalCache.preferPtrEncoder(fi.ptrType.Elem())
   448  				x.encoder.fields.add(fi, v)
   449  			}
   450  			if ifaceIndir(rType) {
   451  				cache[rType] = x.encoder
   452  			} else {
   453  				cache[rType] = &directEncoder{x.encoder}
   454  			}
   455  		}
   456  	}
   457  }