github.com/batchcorp/thrift-iterator@v0.0.0-20220918180557-4c4a158fc6e9/config.go (about)

     1  package thrifter
     2  
     3  import (
     4  	"encoding/json"
     5  	"errors"
     6  	"github.com/batchcorp/thrift-iterator/binding/codegen"
     7  	"github.com/batchcorp/thrift-iterator/binding/reflection"
     8  	"github.com/batchcorp/thrift-iterator/general"
     9  	"github.com/batchcorp/thrift-iterator/protocol"
    10  	"github.com/batchcorp/thrift-iterator/protocol/binary"
    11  	"github.com/batchcorp/thrift-iterator/protocol/compact"
    12  	"github.com/batchcorp/thrift-iterator/raw"
    13  	"github.com/batchcorp/thrift-iterator/spi"
    14  	"github.com/v2pro/wombat/generic"
    15  	"io"
    16  	"reflect"
    17  	"sync"
    18  )
    19  
    20  type frozenConfig struct {
    21  	extension     spi.Extension
    22  	protocol      Protocol
    23  	genDecoders   sync.Map
    24  	genEncoders   sync.Map
    25  	extDecoders   sync.Map
    26  	extEncoders   sync.Map
    27  	staticCodegen bool
    28  }
    29  
    30  func (cfg Config) AddExtension(extension spi.Extension) Config {
    31  	cfg.Extensions = append(cfg.Extensions, extension)
    32  	return cfg
    33  }
    34  
    35  func (cfg Config) Froze() API {
    36  	extensions := append(cfg.Extensions, &general.Extension{})
    37  	extensions = append(extensions, &raw.Extension{})
    38  	api := &frozenConfig{
    39  		extension:     extensions,
    40  		protocol:      cfg.Protocol,
    41  		staticCodegen: cfg.StaticCodegen,
    42  	}
    43  	api.extDecoders = sync.Map{}
    44  	api.genDecoders = sync.Map{}
    45  	api.extEncoders = sync.Map{}
    46  	api.genEncoders = sync.Map{}
    47  	return api
    48  }
    49  
    50  func (cfg *frozenConfig) addGenDecoder(cacheKey reflect.Type, decoder spi.ValDecoder) {
    51  	cfg.genDecoders.Store(cacheKey, decoder)
    52  }
    53  
    54  func (cfg *frozenConfig) addExtDecoder(cacheKey string, decoder spi.ValDecoder) {
    55  	cfg.extDecoders.Store(cacheKey, decoder)
    56  }
    57  
    58  func (cfg *frozenConfig) addGenEncoder(cacheKey reflect.Type, encoder spi.ValEncoder) {
    59  	cfg.genEncoders.Store(cacheKey, encoder)
    60  }
    61  
    62  func (cfg *frozenConfig) addExtEncoder(cacheKey string, encoder spi.ValEncoder) {
    63  	cfg.extEncoders.Store(cacheKey, encoder)
    64  }
    65  
    66  func (cfg *frozenConfig) PrepareDecoder(valType reflect.Type) {
    67  	cacheKey := valType.String()
    68  	if cfg.GetDecoder(cacheKey) != nil {
    69  		return
    70  	}
    71  	decoder := cfg.extension.DecoderOf(valType)
    72  	cfg.addExtDecoder(cacheKey, decoder)
    73  	cfg.addGenDecoder(valType, decoder)
    74  }
    75  
    76  func (cfg *frozenConfig) GetDecoder(cacheKey string) spi.ValDecoder {
    77  	decoder, found := cfg.extDecoders.Load(cacheKey)
    78  	if found {
    79  		return decoder.(spi.ValDecoder)
    80  	}
    81  	return nil
    82  }
    83  
    84  func (cfg *frozenConfig) getGenDecoder(cacheKey reflect.Type) spi.ValDecoder {
    85  	decoder, found := cfg.genDecoders.Load(cacheKey)
    86  	if found {
    87  		return decoder.(spi.ValDecoder)
    88  	}
    89  	return nil
    90  }
    91  
    92  func (cfg *frozenConfig) PrepareEncoder(valType reflect.Type) {
    93  	cacheKey := valType.String()
    94  	if cfg.GetEncoder(cacheKey) != nil {
    95  		return
    96  	}
    97  	encoder := cfg.extension.EncoderOf(valType)
    98  	cfg.addExtEncoder(cacheKey, encoder)
    99  	cfg.addGenEncoder(valType, encoder)
   100  }
   101  
   102  func (cfg *frozenConfig) GetEncoder(cacheKey string) spi.ValEncoder {
   103  	encoder, found := cfg.extEncoders.Load(cacheKey)
   104  	if found {
   105  		return encoder.(spi.ValEncoder)
   106  	}
   107  	return nil
   108  }
   109  
   110  func (cfg *frozenConfig) getGenEncoder(cacheKey reflect.Type) spi.ValEncoder {
   111  	encoder, found := cfg.genEncoders.Load(cacheKey)
   112  	if found {
   113  		return encoder.(spi.ValEncoder)
   114  	}
   115  	return nil
   116  }
   117  
   118  func (cfg *frozenConfig) NewStream(writer io.Writer, buf []byte) spi.Stream {
   119  	switch cfg.protocol {
   120  	case ProtocolBinary:
   121  		return binary.NewStream(cfg, writer, buf)
   122  	case ProtocolCompact:
   123  		return compact.NewStream(cfg, writer, buf)
   124  	}
   125  	panic("unsupported protocol")
   126  }
   127  
   128  func (cfg *frozenConfig) NewIterator(reader io.Reader, buf []byte) spi.Iterator {
   129  	switch cfg.protocol {
   130  	case ProtocolBinary:
   131  		return binary.NewIterator(cfg, reader, buf)
   132  	case ProtocolCompact:
   133  		return compact.NewIterator(cfg, reader, buf)
   134  	}
   135  	panic("unsupported protocol")
   136  }
   137  
   138  func (cfg *frozenConfig) WillDecodeFromBuffer(samples ...interface{}) {
   139  	if !cfg.staticCodegen {
   140  		panic("this config is using dynamic codegen, can not do static codegen")
   141  	}
   142  	for _, sample := range samples {
   143  		cfg.staticDecoderOf(reflect.TypeOf(sample))
   144  	}
   145  }
   146  
   147  func (cfg *frozenConfig) WillDecodeFromReader(samples ...interface{}) {
   148  	if !cfg.staticCodegen {
   149  		panic("this config is using dynamic codegen, can not do static codegen")
   150  	}
   151  	for _, sample := range samples {
   152  		cfg.staticDecoderOf(reflect.TypeOf(sample))
   153  	}
   154  }
   155  
   156  func (cfg *frozenConfig) WillEncode(samples ...interface{}) {
   157  	if !cfg.staticCodegen {
   158  		panic("this config is using dynamic codegen, can not do static codegen")
   159  	}
   160  	for _, sample := range samples {
   161  		cfg.staticEncoderOf(reflect.TypeOf(sample))
   162  	}
   163  }
   164  
   165  func (cfg *frozenConfig) decoderOf(valType reflect.Type) spi.ValDecoder {
   166  	if cfg.staticCodegen {
   167  		return cfg.staticDecoderOf(valType)
   168  	}
   169  	return reflection.DecoderOf(cfg.extension, valType)
   170  }
   171  
   172  func (cfg *frozenConfig) staticDecoderOf(valType reflect.Type) spi.ValDecoder {
   173  	iteratorType := reflect.TypeOf((*binary.Iterator)(nil))
   174  	if cfg.protocol == ProtocolCompact {
   175  		iteratorType = reflect.TypeOf((*compact.Iterator)(nil))
   176  	}
   177  	funcObj := generic.Expand(codegen.Decode,
   178  		"EXT", &codegen.Extension{Extension: cfg.extension},
   179  		"ST", iteratorType,
   180  		"DT", valType)
   181  	f := funcObj.(func(interface{}, interface{}))
   182  	return &funcDecoder{f}
   183  }
   184  
   185  func (cfg *frozenConfig) encoderOf(valType reflect.Type) spi.ValEncoder {
   186  	if cfg.staticCodegen {
   187  		return cfg.staticEncoderOf(valType)
   188  	}
   189  	return reflection.EncoderOf(cfg.extension, valType)
   190  }
   191  
   192  func (cfg *frozenConfig) staticEncoderOf(valType reflect.Type) spi.ValEncoder {
   193  	streamType := reflect.TypeOf((*binary.Stream)(nil))
   194  	if cfg.protocol == ProtocolCompact {
   195  		streamType = reflect.TypeOf((*compact.Stream)(nil))
   196  	}
   197  	funcObj := generic.Expand(codegen.Encode,
   198  		"EXT", &codegen.Extension{Extension: cfg.extension},
   199  		"ST", valType,
   200  		"DT", streamType)
   201  	f := funcObj.(func(interface{}, interface{}))
   202  	return &funcEncoder{f}
   203  }
   204  
   205  type funcDecoder struct {
   206  	f func(dst interface{}, src interface{})
   207  }
   208  
   209  func (decoder *funcDecoder) Decode(val interface{}, iter spi.Iterator) {
   210  	decoder.f(val, iter)
   211  }
   212  
   213  type funcEncoder struct {
   214  	f func(dst interface{}, src interface{})
   215  }
   216  
   217  func (encoder *funcEncoder) Encode(val interface{}, stream spi.Stream) {
   218  	encoder.f(stream, val)
   219  }
   220  
   221  func (encoder *funcEncoder) ThriftType() protocol.TType {
   222  	panic("funcEncoder is not composable")
   223  }
   224  
   225  func (cfg *frozenConfig) Unmarshal(buf []byte, val interface{}) error {
   226  	valType := reflect.TypeOf(val)
   227  	decoder := cfg.getGenDecoder(valType)
   228  	if decoder == nil {
   229  		decoder = cfg.decoderOf(valType)
   230  		cfg.addGenDecoder(valType, decoder)
   231  	}
   232  	if buf == nil {
   233  		return errors.New("empty input")
   234  	}
   235  	iter := cfg.NewIterator(nil, buf)
   236  	decoder.Decode(val, iter)
   237  	if iter.Error() != nil {
   238  		return iter.Error()
   239  	}
   240  	return nil
   241  }
   242  
   243  func (cfg *frozenConfig) Marshal(val interface{}) ([]byte, error) {
   244  	valType := reflect.TypeOf(val)
   245  	encoder := cfg.getGenEncoder(valType)
   246  	if encoder == nil {
   247  		encoder = cfg.encoderOf(valType)
   248  		cfg.addGenEncoder(valType, encoder)
   249  	}
   250  	stream := cfg.NewStream(nil, nil)
   251  	encoder.Encode(val, stream)
   252  	if stream.Error() != nil {
   253  		return nil, stream.Error()
   254  	}
   255  	buf := stream.Buffer()
   256  	return buf, nil
   257  }
   258  
   259  func (cfg *frozenConfig) NewDecoder(reader io.Reader, buf []byte) *Decoder {
   260  	return &Decoder{
   261  		cfg:  cfg,
   262  		iter: cfg.NewIterator(reader, buf),
   263  	}
   264  }
   265  
   266  func (cfg *frozenConfig) NewEncoder(writer io.Writer) *Encoder {
   267  	return &Encoder{
   268  		cfg:    cfg,
   269  		stream: cfg.NewStream(writer, nil),
   270  	}
   271  }
   272  
   273  func (cfg *frozenConfig) ToJSON(buf []byte) (string, error) {
   274  	msg, err := UnmarshalMessage(buf)
   275  	if err != nil {
   276  		return "", err
   277  	}
   278  	jsonEncoded, err := json.MarshalIndent(msg, "", "  ")
   279  	if err != nil {
   280  		return "", err
   281  	}
   282  	return string(jsonEncoded), nil
   283  }
   284  
   285  func (cfg *frozenConfig) MarshalMessage(msg general.Message) ([]byte, error) {
   286  	return cfg.Marshal(msg)
   287  }
   288  
   289  func (cfg *frozenConfig) UnmarshalMessage(buf []byte) (general.Message, error) {
   290  	var msg general.Message
   291  	err := cfg.Unmarshal(buf, &msg)
   292  	return msg, err
   293  }