github.com/zerosnake0/jzon@v0.0.8/encoder_config.go (about)

     1  package jzon
     2  
     3  import (
     4  	"io"
     5  	"reflect"
     6  	"sync"
     7  	"sync/atomic"
     8  )
     9  
    10  var (
    11  	// DefaultEncoderConfig is compatible with standard lib
    12  	DefaultEncoderConfig = NewEncoderConfig(nil)
    13  )
    14  
    15  // EncoderOption can be used to customize the encoder config
    16  type EncoderOption struct {
    17  	ValEncoders map[reflect.Type]ValEncoder
    18  
    19  	EscapeHTML      bool
    20  	Tag             string
    21  	OnlyTaggedField bool
    22  }
    23  
    24  type encoderCache map[rtype]ValEncoder
    25  
    26  func (cache encoderCache) has(rtype rtype) bool {
    27  	_, ok := cache[rtype]
    28  	return ok
    29  }
    30  
    31  // make sure that the pointer encoders has already been rebuilt
    32  // before calling, so it's safe to use it's internal encoder
    33  func (cache encoderCache) preferPtrEncoder(typ reflect.Type) ValEncoder {
    34  	ptrType := reflect.PtrTo(typ)
    35  	ptrEncoder := cache[rtypeOfType(ptrType)]
    36  	if pe, ok := ptrEncoder.(*pointerEncoder); ok {
    37  		return pe.encoder
    38  	}
    39  	// the element has a special pointer encoder
    40  	return &directEncoder{ptrEncoder}
    41  }
    42  
    43  // EncoderConfig is a frozen config for encoding
    44  type EncoderConfig struct {
    45  	cacheMu sync.Mutex
    46  	// the encoder cache, or root encoder cache
    47  	encoderCache atomic.Value
    48  	// the internal cache
    49  	internalCache encoderCache
    50  
    51  	tag             string
    52  	onlyTaggedField bool
    53  
    54  	// can override during runtime
    55  	escapeHTML bool
    56  }
    57  
    58  // NewEncoderConfig returns a new encoder config
    59  // If the input option is nil, the default option will be applied
    60  func NewEncoderConfig(opt *EncoderOption) *EncoderConfig {
    61  	encCfg := EncoderConfig{
    62  		tag:        "json",
    63  		escapeHTML: true,
    64  	}
    65  	cache := encoderCache{}
    66  	internalCache := encoderCache{}
    67  	if opt != nil {
    68  		for typ, valEnc := range opt.ValEncoders {
    69  			rtype := rtypeOfType(typ)
    70  			cache[rtype] = valEnc
    71  			internalCache[rtype] = valEnc
    72  		}
    73  		encCfg.escapeHTML = opt.EscapeHTML
    74  		if opt.Tag != "" {
    75  			encCfg.tag = opt.Tag
    76  		}
    77  		encCfg.onlyTaggedField = opt.OnlyTaggedField
    78  	}
    79  	encCfg.encoderCache.Store(cache)
    80  	encCfg.internalCache = internalCache
    81  	return &encCfg
    82  }
    83  
    84  // Marshal behave like json.Marshal
    85  func (encCfg *EncoderConfig) Marshal(obj interface{}) ([]byte, error) {
    86  	s := encCfg.NewStreamer()
    87  	defer s.Release()
    88  	s.Value(obj)
    89  	if s.Error != nil {
    90  		return nil, s.Error
    91  	}
    92  	// we make a new slice with explicit size,
    93  	//   1. the internal buffer may be much longer than the output one,
    94  	//      it can be used for longer output
    95  	//   2. avoid calling bytes buffer pool (sync.Pool)
    96  	b := make([]byte, len(s.buffer))
    97  	copy(b, s.buffer)
    98  	return b, nil
    99  }
   100  
   101  // NewEncoder returns a new encoder that writes to w.
   102  func (encCfg *EncoderConfig) NewEncoder(w io.Writer) *Encoder {
   103  	s := encCfg.NewStreamer()
   104  	s.Reset(w)
   105  	return &Encoder{
   106  		s: s,
   107  	}
   108  }
   109  
   110  func (encCfg *EncoderConfig) getEncoderFromCache(rtype rtype) ValEncoder {
   111  	return encCfg.encoderCache.Load().(encoderCache)[rtype]
   112  }
   113  
   114  func (encCfg *EncoderConfig) createEncoder(rtype rtype, typ reflect.Type) ValEncoder {
   115  	encCfg.cacheMu.Lock()
   116  	defer encCfg.cacheMu.Unlock()
   117  	cache := encCfg.encoderCache.Load().(encoderCache)
   118  	// double check
   119  	if ve := cache[rtype]; ve != nil {
   120  		return ve
   121  	}
   122  	newCache := encoderCache{}
   123  	for k, v := range cache {
   124  		newCache[k] = v
   125  	}
   126  	var q typeQueue
   127  	q.push(typ)
   128  	encCfg.createEncoderInternal(newCache, encCfg.internalCache, q)
   129  	encCfg.encoderCache.Store(newCache)
   130  	return newCache[rtype]
   131  }
   132  
   133  func (encCfg *EncoderConfig) createEncoderInternal(cache, internalCache encoderCache, typesToCreate typeQueue) {
   134  	rebuildMap := map[rtype]interface{}{}
   135  	for typ := typesToCreate.pop(); typ != nil; typ = typesToCreate.pop() {
   136  		rType := rtypeOfType(typ)
   137  		if internalCache.has(rType) { // check if visited
   138  			continue
   139  		}
   140  
   141  		// check global encoders
   142  		if v, ok := globalValEncoders[rType]; ok {
   143  			internalCache[rType] = v
   144  			cache[rType] = v
   145  			continue
   146  		}
   147  
   148  		kind := typ.Kind()
   149  
   150  		// check json.Marshaler interface
   151  		if typ.Implements(jsonMarshalerType) {
   152  			if ifaceIndir(rType) {
   153  				v := &jsonMarshalerEncoder{
   154  					isEmpty: isEmptyFunctions[kind],
   155  					rtype:   rType,
   156  				}
   157  				internalCache[rType] = v
   158  				cache[rType] = v
   159  				continue
   160  			}
   161  			if typ.Kind() == reflect.Ptr {
   162  				elemType := typ.Elem()
   163  				if elemType.Implements(jsonMarshalerType) {
   164  					// treat as a pointer encoder
   165  					typesToCreate.push(elemType)
   166  					w := newPointerEncoder(elemType)
   167  					internalCache[rType] = w.encoder
   168  					rebuildMap[rType] = w
   169  				} else {
   170  					v := pointerJSONMarshalerEncoder(rType)
   171  					internalCache[rType] = v
   172  					cache[rType] = &directEncoder{v}
   173  				}
   174  				continue
   175  			}
   176  			v := &directJSONMarshalerEncoder{
   177  				isEmpty: isEmptyFunctions[kind],
   178  				rtype:   rType,
   179  			}
   180  			internalCache[rType] = v
   181  			cache[rType] = &directEncoder{v}
   182  			continue
   183  		}
   184  
   185  		// check encoding.TextMarshaler interface
   186  		if typ.Implements(textMarshalerType) {
   187  			if ifaceIndir(rType) {
   188  				v := &textMarshalerEncoder{
   189  					isEmpty: isEmptyFunctions[kind],
   190  					rtype:   rType,
   191  				}
   192  				internalCache[rType] = v
   193  				cache[rType] = v
   194  				continue
   195  			}
   196  			if typ.Kind() == reflect.Ptr {
   197  				elemType := typ.Elem()
   198  				if elemType.Implements(textMarshalerType) {
   199  					// treat as a pointer encoder
   200  					typesToCreate.push(elemType)
   201  					w := newPointerEncoder(elemType)
   202  					internalCache[rType] = w.encoder
   203  					rebuildMap[rType] = w
   204  				} else {
   205  					v := pointerTextMarshalerEncoder(rType)
   206  					internalCache[rType] = v
   207  					cache[rType] = &directEncoder{v}
   208  				}
   209  				continue
   210  			}
   211  			v := &directTextMarshalerEncoder{
   212  				isEmpty: isEmptyFunctions[kind],
   213  				rtype:   rType,
   214  			}
   215  			internalCache[rType] = v
   216  			cache[rType] = &directEncoder{v}
   217  			continue
   218  		}
   219  
   220  		if kindRType := encoderKindMap[kind]; kindRType != 0 {
   221  			// TODO: shall we make this an option?
   222  			// TODO: so that only the native type is affected?
   223  			// check if the native type has a custom encoder
   224  			if v, ok := internalCache[kindRType]; ok {
   225  				internalCache[rType] = v
   226  				cache[rType] = v
   227  				continue
   228  			}
   229  
   230  			if v := kindEncoders[kind]; v != nil {
   231  				internalCache[rType] = v
   232  				cache[rType] = v
   233  				continue
   234  			}
   235  		}
   236  
   237  		switch kind {
   238  		case reflect.Ptr:
   239  			elemType := typ.Elem()
   240  			typesToCreate.push(elemType)
   241  			w := newPointerEncoder(elemType)
   242  			internalCache[rType] = w.encoder
   243  			rebuildMap[rType] = w
   244  		case reflect.Array:
   245  			elemType := typ.Elem()
   246  			typesToCreate.push(reflect.PtrTo(elemType))
   247  			if typ.Len() == 0 {
   248  				v := (*emptyArrayEncoder)(nil)
   249  				internalCache[rType] = v
   250  				cache[rType] = v
   251  			} else {
   252  				w := newArrayEncoder(typ)
   253  				internalCache[rType] = w.encoder
   254  				rebuildMap[rType] = w
   255  			}
   256  		case reflect.Interface:
   257  			var v ValEncoder
   258  			if typ.NumMethod() == 0 {
   259  				v = (*efaceEncoder)(nil)
   260  			} else {
   261  				v = (*ifaceEncoder)(nil)
   262  			}
   263  			internalCache[rType] = v
   264  			cache[rType] = v
   265  		case reflect.Map:
   266  			w := newMapEncoder(typ)
   267  			if w == nil {
   268  				v := notSupportedEncoder(typ.String())
   269  				internalCache[rType] = v
   270  				cache[rType] = v
   271  			} else {
   272  				typesToCreate.push(typ.Elem())
   273  				// pointer decoder is a reverse of direct encoder
   274  				internalCache[rType] = w.encoder
   275  				rebuildMap[rType] = w
   276  			}
   277  		case reflect.Slice:
   278  			w := newSliceEncoder(typ)
   279  			typesToCreate.push(reflect.PtrTo(typ.Elem()))
   280  			internalCache[rType] = w.encoder
   281  			rebuildMap[rType] = w
   282  		case reflect.Struct:
   283  			w := encCfg.newStructEncoder(typ)
   284  			if w == nil {
   285  				// no fields to marshal
   286  				v := (*emptyStructEncoder)(nil)
   287  				internalCache[rType] = v
   288  				cache[rType] = v
   289  			} else {
   290  				for i := range w.fields {
   291  					fi := &w.fields[i]
   292  					typesToCreate.push(fi.ptrType)
   293  				}
   294  				internalCache[rType] = w.encoder
   295  				rebuildMap[rType] = w
   296  			}
   297  		default:
   298  			v := notSupportedEncoder(typ.String())
   299  			internalCache[rType] = v
   300  			cache[rType] = v
   301  		}
   302  	}
   303  	// rebuild base64 encoders
   304  	for rType, builder := range rebuildMap {
   305  		switch x := builder.(type) {
   306  		case *sliceEncoderBuilder:
   307  			if x.elemType.Kind() != reflect.Uint8 {
   308  				continue
   309  			}
   310  			elemPtrType := reflect.PtrTo(x.elemType)
   311  			elemPtrEncoder := internalCache[rtypeOfType(elemPtrType)]
   312  			if _, ok := elemPtrEncoder.(*pointerEncoder); !ok {
   313  				// the element has a special pointer encoder
   314  				continue
   315  			}
   316  			// the pointer decoder has not been rebuilt yet
   317  			// we need to use the explicit element rtype
   318  			elemEncoder := internalCache[rtypeOfType(x.elemType)]
   319  			if elemEncoder != (*uint8Encoder)(nil) {
   320  				// the element has a special value encoder
   321  				continue
   322  			}
   323  			v := (*base64Encoder)(nil)
   324  			internalCache[rType] = v
   325  			cache[rType] = v
   326  			delete(rebuildMap, rType)
   327  		}
   328  	}
   329  	// rebuild ptr encoders
   330  	for rType, builder := range rebuildMap {
   331  		switch x := builder.(type) {
   332  		case *pointerEncoderBuilder:
   333  			v := internalCache[x.elemRType]
   334  			x.encoder.encoder = v
   335  			cache[rType] = v
   336  			delete(rebuildMap, rType)
   337  		}
   338  	}
   339  	// rebuild other encoders
   340  	for rType, builder := range rebuildMap {
   341  		switch x := builder.(type) {
   342  		case *arrayEncoderBuilder:
   343  			x.encoder.encoder = internalCache.preferPtrEncoder(x.elemType)
   344  			if ifaceIndir(rType) {
   345  				cache[rType] = x.encoder
   346  			} else {
   347  				// (see reflect.ArrayOf)
   348  				// when the array is stored in interface directly, it means:
   349  				// 1. the length of array is 1
   350  				// 2. the element of the array is also directly saved
   351  				cache[rType] = &directEncoder{x.encoder}
   352  			}
   353  		case *mapEncoderBuilder:
   354  			// TODO: key encoder
   355  			x.encoder.elemEncoder = internalCache[x.elemRType]
   356  			cache[rType] = &directEncoder{x.encoder}
   357  		case *sliceEncoderBuilder:
   358  			x.encoder.elemEncoder = internalCache.preferPtrEncoder(x.elemType)
   359  			cache[rType] = x.encoder
   360  		case *structEncoderBuilder:
   361  			x.encoder.fields.init(len(x.fields))
   362  			for i := range x.fields {
   363  				fi := &x.fields[i]
   364  				v := internalCache.preferPtrEncoder(fi.ptrType.Elem())
   365  				x.encoder.fields.add(fi, v)
   366  			}
   367  			if ifaceIndir(rType) {
   368  				cache[rType] = x.encoder
   369  			} else {
   370  				cache[rType] = &directEncoder{x.encoder}
   371  			}
   372  		}
   373  	}
   374  }