github.com/batchcorp/thrift-iterator@v0.0.0-20220918180557-4c4a158fc6e9/config.go (about) 1 package thrifter 2 3 import ( 4 "encoding/json" 5 "errors" 6 "github.com/batchcorp/thrift-iterator/binding/codegen" 7 "github.com/batchcorp/thrift-iterator/binding/reflection" 8 "github.com/batchcorp/thrift-iterator/general" 9 "github.com/batchcorp/thrift-iterator/protocol" 10 "github.com/batchcorp/thrift-iterator/protocol/binary" 11 "github.com/batchcorp/thrift-iterator/protocol/compact" 12 "github.com/batchcorp/thrift-iterator/raw" 13 "github.com/batchcorp/thrift-iterator/spi" 14 "github.com/v2pro/wombat/generic" 15 "io" 16 "reflect" 17 "sync" 18 ) 19 20 type frozenConfig struct { 21 extension spi.Extension 22 protocol Protocol 23 genDecoders sync.Map 24 genEncoders sync.Map 25 extDecoders sync.Map 26 extEncoders sync.Map 27 staticCodegen bool 28 } 29 30 func (cfg Config) AddExtension(extension spi.Extension) Config { 31 cfg.Extensions = append(cfg.Extensions, extension) 32 return cfg 33 } 34 35 func (cfg Config) Froze() API { 36 extensions := append(cfg.Extensions, &general.Extension{}) 37 extensions = append(extensions, &raw.Extension{}) 38 api := &frozenConfig{ 39 extension: extensions, 40 protocol: cfg.Protocol, 41 staticCodegen: cfg.StaticCodegen, 42 } 43 api.extDecoders = sync.Map{} 44 api.genDecoders = sync.Map{} 45 api.extEncoders = sync.Map{} 46 api.genEncoders = sync.Map{} 47 return api 48 } 49 50 func (cfg *frozenConfig) addGenDecoder(cacheKey reflect.Type, decoder spi.ValDecoder) { 51 cfg.genDecoders.Store(cacheKey, decoder) 52 } 53 54 func (cfg *frozenConfig) addExtDecoder(cacheKey string, decoder spi.ValDecoder) { 55 cfg.extDecoders.Store(cacheKey, decoder) 56 } 57 58 func (cfg *frozenConfig) addGenEncoder(cacheKey reflect.Type, encoder spi.ValEncoder) { 59 cfg.genEncoders.Store(cacheKey, encoder) 60 } 61 62 func (cfg *frozenConfig) addExtEncoder(cacheKey string, encoder spi.ValEncoder) { 63 cfg.extEncoders.Store(cacheKey, encoder) 64 } 65 66 func (cfg *frozenConfig) PrepareDecoder(valType reflect.Type) { 67 cacheKey := valType.String() 68 if cfg.GetDecoder(cacheKey) != nil { 69 return 70 } 71 decoder := cfg.extension.DecoderOf(valType) 72 cfg.addExtDecoder(cacheKey, decoder) 73 cfg.addGenDecoder(valType, decoder) 74 } 75 76 func (cfg *frozenConfig) GetDecoder(cacheKey string) spi.ValDecoder { 77 decoder, found := cfg.extDecoders.Load(cacheKey) 78 if found { 79 return decoder.(spi.ValDecoder) 80 } 81 return nil 82 } 83 84 func (cfg *frozenConfig) getGenDecoder(cacheKey reflect.Type) spi.ValDecoder { 85 decoder, found := cfg.genDecoders.Load(cacheKey) 86 if found { 87 return decoder.(spi.ValDecoder) 88 } 89 return nil 90 } 91 92 func (cfg *frozenConfig) PrepareEncoder(valType reflect.Type) { 93 cacheKey := valType.String() 94 if cfg.GetEncoder(cacheKey) != nil { 95 return 96 } 97 encoder := cfg.extension.EncoderOf(valType) 98 cfg.addExtEncoder(cacheKey, encoder) 99 cfg.addGenEncoder(valType, encoder) 100 } 101 102 func (cfg *frozenConfig) GetEncoder(cacheKey string) spi.ValEncoder { 103 encoder, found := cfg.extEncoders.Load(cacheKey) 104 if found { 105 return encoder.(spi.ValEncoder) 106 } 107 return nil 108 } 109 110 func (cfg *frozenConfig) getGenEncoder(cacheKey reflect.Type) spi.ValEncoder { 111 encoder, found := cfg.genEncoders.Load(cacheKey) 112 if found { 113 return encoder.(spi.ValEncoder) 114 } 115 return nil 116 } 117 118 func (cfg *frozenConfig) NewStream(writer io.Writer, buf []byte) spi.Stream { 119 switch cfg.protocol { 120 case ProtocolBinary: 121 return binary.NewStream(cfg, writer, buf) 122 case ProtocolCompact: 123 return compact.NewStream(cfg, writer, buf) 124 } 125 panic("unsupported protocol") 126 } 127 128 func (cfg *frozenConfig) NewIterator(reader io.Reader, buf []byte) spi.Iterator { 129 switch cfg.protocol { 130 case ProtocolBinary: 131 return binary.NewIterator(cfg, reader, buf) 132 case ProtocolCompact: 133 return compact.NewIterator(cfg, reader, buf) 134 } 135 panic("unsupported protocol") 136 } 137 138 func (cfg *frozenConfig) WillDecodeFromBuffer(samples ...interface{}) { 139 if !cfg.staticCodegen { 140 panic("this config is using dynamic codegen, can not do static codegen") 141 } 142 for _, sample := range samples { 143 cfg.staticDecoderOf(reflect.TypeOf(sample)) 144 } 145 } 146 147 func (cfg *frozenConfig) WillDecodeFromReader(samples ...interface{}) { 148 if !cfg.staticCodegen { 149 panic("this config is using dynamic codegen, can not do static codegen") 150 } 151 for _, sample := range samples { 152 cfg.staticDecoderOf(reflect.TypeOf(sample)) 153 } 154 } 155 156 func (cfg *frozenConfig) WillEncode(samples ...interface{}) { 157 if !cfg.staticCodegen { 158 panic("this config is using dynamic codegen, can not do static codegen") 159 } 160 for _, sample := range samples { 161 cfg.staticEncoderOf(reflect.TypeOf(sample)) 162 } 163 } 164 165 func (cfg *frozenConfig) decoderOf(valType reflect.Type) spi.ValDecoder { 166 if cfg.staticCodegen { 167 return cfg.staticDecoderOf(valType) 168 } 169 return reflection.DecoderOf(cfg.extension, valType) 170 } 171 172 func (cfg *frozenConfig) staticDecoderOf(valType reflect.Type) spi.ValDecoder { 173 iteratorType := reflect.TypeOf((*binary.Iterator)(nil)) 174 if cfg.protocol == ProtocolCompact { 175 iteratorType = reflect.TypeOf((*compact.Iterator)(nil)) 176 } 177 funcObj := generic.Expand(codegen.Decode, 178 "EXT", &codegen.Extension{Extension: cfg.extension}, 179 "ST", iteratorType, 180 "DT", valType) 181 f := funcObj.(func(interface{}, interface{})) 182 return &funcDecoder{f} 183 } 184 185 func (cfg *frozenConfig) encoderOf(valType reflect.Type) spi.ValEncoder { 186 if cfg.staticCodegen { 187 return cfg.staticEncoderOf(valType) 188 } 189 return reflection.EncoderOf(cfg.extension, valType) 190 } 191 192 func (cfg *frozenConfig) staticEncoderOf(valType reflect.Type) spi.ValEncoder { 193 streamType := reflect.TypeOf((*binary.Stream)(nil)) 194 if cfg.protocol == ProtocolCompact { 195 streamType = reflect.TypeOf((*compact.Stream)(nil)) 196 } 197 funcObj := generic.Expand(codegen.Encode, 198 "EXT", &codegen.Extension{Extension: cfg.extension}, 199 "ST", valType, 200 "DT", streamType) 201 f := funcObj.(func(interface{}, interface{})) 202 return &funcEncoder{f} 203 } 204 205 type funcDecoder struct { 206 f func(dst interface{}, src interface{}) 207 } 208 209 func (decoder *funcDecoder) Decode(val interface{}, iter spi.Iterator) { 210 decoder.f(val, iter) 211 } 212 213 type funcEncoder struct { 214 f func(dst interface{}, src interface{}) 215 } 216 217 func (encoder *funcEncoder) Encode(val interface{}, stream spi.Stream) { 218 encoder.f(stream, val) 219 } 220 221 func (encoder *funcEncoder) ThriftType() protocol.TType { 222 panic("funcEncoder is not composable") 223 } 224 225 func (cfg *frozenConfig) Unmarshal(buf []byte, val interface{}) error { 226 valType := reflect.TypeOf(val) 227 decoder := cfg.getGenDecoder(valType) 228 if decoder == nil { 229 decoder = cfg.decoderOf(valType) 230 cfg.addGenDecoder(valType, decoder) 231 } 232 if buf == nil { 233 return errors.New("empty input") 234 } 235 iter := cfg.NewIterator(nil, buf) 236 decoder.Decode(val, iter) 237 if iter.Error() != nil { 238 return iter.Error() 239 } 240 return nil 241 } 242 243 func (cfg *frozenConfig) Marshal(val interface{}) ([]byte, error) { 244 valType := reflect.TypeOf(val) 245 encoder := cfg.getGenEncoder(valType) 246 if encoder == nil { 247 encoder = cfg.encoderOf(valType) 248 cfg.addGenEncoder(valType, encoder) 249 } 250 stream := cfg.NewStream(nil, nil) 251 encoder.Encode(val, stream) 252 if stream.Error() != nil { 253 return nil, stream.Error() 254 } 255 buf := stream.Buffer() 256 return buf, nil 257 } 258 259 func (cfg *frozenConfig) NewDecoder(reader io.Reader, buf []byte) *Decoder { 260 return &Decoder{ 261 cfg: cfg, 262 iter: cfg.NewIterator(reader, buf), 263 } 264 } 265 266 func (cfg *frozenConfig) NewEncoder(writer io.Writer) *Encoder { 267 return &Encoder{ 268 cfg: cfg, 269 stream: cfg.NewStream(writer, nil), 270 } 271 } 272 273 func (cfg *frozenConfig) ToJSON(buf []byte) (string, error) { 274 msg, err := UnmarshalMessage(buf) 275 if err != nil { 276 return "", err 277 } 278 jsonEncoded, err := json.MarshalIndent(msg, "", " ") 279 if err != nil { 280 return "", err 281 } 282 return string(jsonEncoded), nil 283 } 284 285 func (cfg *frozenConfig) MarshalMessage(msg general.Message) ([]byte, error) { 286 return cfg.Marshal(msg) 287 } 288 289 func (cfg *frozenConfig) UnmarshalMessage(buf []byte) (general.Message, error) { 290 var msg general.Message 291 err := cfg.Unmarshal(buf, &msg) 292 return msg, err 293 }