github.com/cloudwego/kitex@v0.9.0/pkg/generic/thrift/read.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package thrift
    18  
    19  import (
    20  	"context"
    21  	"encoding/base64"
    22  	"fmt"
    23  
    24  	"github.com/apache/thrift/lib/go/thrift"
    25  	"github.com/jhump/protoreflect/desc"
    26  
    27  	"github.com/cloudwego/kitex/pkg/generic/descriptor"
    28  	"github.com/cloudwego/kitex/pkg/generic/proto"
    29  )
    30  
    31  var emptyPbDsc = &desc.MessageDescriptor{}
    32  
    33  type readerOption struct {
    34  	// result will be encode to json, so map[interface{}]interface{} will not be valid
    35  	// need use map[string]interface{} instead
    36  	forJSON bool
    37  	// return exception as error
    38  	throwException bool
    39  	// read http response
    40  	http                bool
    41  	binaryWithBase64    bool
    42  	binaryWithByteSlice bool
    43  	// describe struct of current level
    44  	pbDsc proto.MessageDescriptor
    45  }
    46  
    47  type reader func(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error)
    48  
    49  type fieldSetter func(field *descriptor.FieldDescriptor, val interface{}) error
    50  
    51  func getMapFieldSetter(st map[string]interface{}) fieldSetter {
    52  	return func(field *descriptor.FieldDescriptor, val interface{}) error {
    53  		st[field.FieldName()] = val
    54  		return nil
    55  	}
    56  }
    57  
    58  func getPbFieldSetter(st proto.Message) fieldSetter {
    59  	return func(field *descriptor.FieldDescriptor, val interface{}) error {
    60  		return st.TrySetFieldByNumber(int(field.ID), val)
    61  	}
    62  }
    63  
    64  func nextReader(tt descriptor.Type, t *descriptor.TypeDescriptor, opt *readerOption) (reader, error) {
    65  	if err := assertType(tt, t.Type); err != nil {
    66  		return nil, err
    67  	}
    68  	switch tt {
    69  	case descriptor.BOOL:
    70  		return readBool, nil
    71  	case descriptor.BYTE:
    72  		return readByte, nil
    73  	case descriptor.I16:
    74  		return readInt16, nil
    75  	case descriptor.I32:
    76  		return readInt32, nil
    77  	case descriptor.I64:
    78  		return readInt64, nil
    79  	case descriptor.STRING:
    80  		if t.Name == "binary" {
    81  			if opt.binaryWithByteSlice {
    82  				return readBinary, nil
    83  			} else if opt.binaryWithBase64 {
    84  				return readBase64Binary, nil
    85  			}
    86  		}
    87  		return readString, nil
    88  	case descriptor.DOUBLE:
    89  		return readDouble, nil
    90  	case descriptor.LIST:
    91  		return readList, nil
    92  	case descriptor.SET:
    93  		return readList, nil
    94  	case descriptor.MAP:
    95  		return readMap, nil
    96  	case descriptor.STRUCT:
    97  		return readStruct, nil
    98  	case descriptor.VOID:
    99  		return readVoid, nil
   100  	case descriptor.JSON:
   101  		return readStruct, nil
   102  	default:
   103  		return nil, fmt.Errorf("unsupported type: %d", tt)
   104  	}
   105  }
   106  
   107  func skipStructReader(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) {
   108  	structName, err := in.ReadStructBegin()
   109  	if err != nil {
   110  		return nil, err
   111  	}
   112  	var v interface{}
   113  	for {
   114  		fieldName, fieldType, fieldID, err := in.ReadFieldBegin()
   115  		if err != nil {
   116  			return nil, err
   117  		}
   118  		if fieldType == thrift.STOP {
   119  			break
   120  		}
   121  		field, ok := t.Struct.FieldsByID[int32(fieldID)]
   122  		if !ok {
   123  			// just ignore the missing field, maybe server update its idls
   124  			if err := in.Skip(fieldType); err != nil {
   125  				return nil, err
   126  			}
   127  		} else {
   128  			_fieldType := descriptor.FromThriftTType(fieldType)
   129  			reader, err := nextReader(_fieldType, field.Type, opt)
   130  			if err != nil {
   131  				return nil, fmt.Errorf("nextReader of %s/%s/%d error %w", structName, fieldName, fieldID, err)
   132  			}
   133  			if field.IsException && opt != nil && opt.throwException {
   134  				if v, err = reader(ctx, in, field.Type, opt); err != nil {
   135  					return nil, err
   136  				}
   137  				// return exception as error
   138  				return nil, fmt.Errorf("%#v", v)
   139  			}
   140  			if opt != nil && opt.http {
   141  				// use http response reader when http generic call
   142  				// only support struct response method, return error when use base type response
   143  				reader = readHTTPResponse
   144  			}
   145  			if v, err = reader(ctx, in, field.Type, opt); err != nil {
   146  				return nil, fmt.Errorf("reader of %s/%s/%d error %w", structName, fieldName, fieldID, err)
   147  			}
   148  		}
   149  		if err := in.ReadFieldEnd(); err != nil {
   150  			return nil, err
   151  		}
   152  	}
   153  
   154  	return v, in.ReadStructEnd()
   155  }
   156  
   157  func readVoid(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) {
   158  	_, err := readStruct(ctx, in, t, opt)
   159  	return descriptor.Void{}, err
   160  }
   161  
   162  func readDouble(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) {
   163  	return in.ReadDouble()
   164  }
   165  
   166  func readBool(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) {
   167  	return in.ReadBool()
   168  }
   169  
   170  func readByte(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) {
   171  	res, err := in.ReadByte()
   172  	if err != nil {
   173  		return nil, err
   174  	}
   175  	if opt.pbDsc != nil {
   176  		return int32(res), nil
   177  	}
   178  	return res, nil
   179  }
   180  
   181  func readInt16(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) {
   182  	res, err := in.ReadI16()
   183  	if err != nil {
   184  		return nil, err
   185  	}
   186  	if opt.pbDsc != nil {
   187  		return int32(res), nil
   188  	}
   189  	return res, nil
   190  }
   191  
   192  func readInt32(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) {
   193  	return in.ReadI32()
   194  }
   195  
   196  func readInt64(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) {
   197  	return in.ReadI64()
   198  }
   199  
   200  func readString(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) {
   201  	return in.ReadString()
   202  }
   203  
   204  func readBinary(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) {
   205  	bytes, err := in.ReadBinary()
   206  	if err != nil {
   207  		return "", err
   208  	}
   209  	return bytes, nil
   210  }
   211  
   212  func readBase64Binary(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) {
   213  	bytes, err := in.ReadBinary()
   214  	if err != nil {
   215  		return "", err
   216  	}
   217  	return base64.StdEncoding.EncodeToString(bytes), nil
   218  }
   219  
   220  func readList(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) {
   221  	elemType, length, err := in.ReadListBegin()
   222  	if err != nil {
   223  		return nil, err
   224  	}
   225  	_elemType := descriptor.FromThriftTType(elemType)
   226  	reader, err := nextReader(_elemType, t.Elem, opt)
   227  	if err != nil {
   228  		return nil, err
   229  	}
   230  	l := make([]interface{}, 0, length)
   231  	for i := 0; i < length; i++ {
   232  		item, err := reader(ctx, in, t.Elem, opt)
   233  		if err != nil {
   234  			return nil, err
   235  		}
   236  		l = append(l, item)
   237  	}
   238  	return l, in.ReadListEnd()
   239  }
   240  
   241  func readMap(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) {
   242  	if opt != nil && opt.forJSON {
   243  		return readStringMap(ctx, in, t, opt)
   244  	}
   245  	return readInterfaceMap(ctx, in, t, opt)
   246  }
   247  
   248  func readInterfaceMap(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) {
   249  	keyType, elemType, length, err := in.ReadMapBegin()
   250  	if err != nil {
   251  		return nil, err
   252  	}
   253  	m := make(map[interface{}]interface{}, length)
   254  	if length == 0 {
   255  		return m, nil
   256  	}
   257  	_keyType := descriptor.FromThriftTType(keyType)
   258  	keyReader, err := nextReader(_keyType, t.Key, opt)
   259  	if err != nil {
   260  		return nil, err
   261  	}
   262  	_elemType := descriptor.FromThriftTType(elemType)
   263  	elemReader, err := nextReader(_elemType, t.Elem, opt)
   264  	if err != nil {
   265  		return nil, err
   266  	}
   267  	for i := 0; i < length; i++ {
   268  		nest := unnestPb(opt, 1)
   269  		key, err := keyReader(ctx, in, t.Key, opt)
   270  		if err != nil {
   271  			return nil, err
   272  		}
   273  		nest()
   274  		nest = unnestPb(opt, 2)
   275  		elem, err := elemReader(ctx, in, t.Elem, opt)
   276  		if err != nil {
   277  			return nil, err
   278  		}
   279  		nest()
   280  		m[key] = elem
   281  	}
   282  	return m, in.ReadMapEnd()
   283  }
   284  
   285  func readStringMap(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) {
   286  	keyType, elemType, length, err := in.ReadMapBegin()
   287  	if err != nil {
   288  		return nil, err
   289  	}
   290  	m := make(map[string]interface{}, length)
   291  	if length == 0 {
   292  		return m, nil
   293  	}
   294  	_keyType := descriptor.FromThriftTType(keyType)
   295  	keyReader, err := nextReader(_keyType, t.Key, opt)
   296  	if err != nil {
   297  		return nil, err
   298  	}
   299  	_elemType := descriptor.FromThriftTType(elemType)
   300  	elemReader, err := nextReader(_elemType, t.Elem, opt)
   301  	if err != nil {
   302  		return nil, err
   303  	}
   304  	for i := 0; i < length; i++ {
   305  		key, err := keyReader(ctx, in, t.Key, opt)
   306  		if err != nil {
   307  			return nil, err
   308  		}
   309  		elem, err := elemReader(ctx, in, t.Elem, opt)
   310  		if err != nil {
   311  			return nil, err
   312  		}
   313  		m[buildinTypeIntoString(key)] = elem
   314  	}
   315  	return m, in.ReadMapEnd()
   316  }
   317  
   318  func readStruct(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) {
   319  	var fs fieldSetter
   320  	var st interface{}
   321  	if opt == nil || opt.pbDsc == nil {
   322  		if opt == nil {
   323  			opt = &readerOption{}
   324  		}
   325  		holder := map[string]interface{}{}
   326  		fs = getMapFieldSetter(holder)
   327  		st = holder
   328  	} else {
   329  		holder := proto.NewMessage(opt.pbDsc)
   330  		fs = getPbFieldSetter(holder)
   331  		st = holder
   332  	}
   333  
   334  	var err error
   335  	// set default value
   336  	// void is nil struct
   337  	// default value with struct NOT SUPPORT pb.
   338  	if t.Struct != nil {
   339  		for _, field := range t.Struct.DefaultFields {
   340  			val := field.DefaultValue
   341  			if field.ValueMapping != nil {
   342  				if val, err = field.ValueMapping.Response(ctx, val, field); err != nil {
   343  					return nil, err
   344  				}
   345  			}
   346  			if err := fs(field, val); err != nil {
   347  				return nil, err
   348  			}
   349  		}
   350  	}
   351  	_, err = in.ReadStructBegin()
   352  	if err != nil {
   353  		return nil, err
   354  	}
   355  	readFields := map[int32]struct{}{}
   356  	for {
   357  		_, fieldType, fieldID, err := in.ReadFieldBegin()
   358  		if err != nil {
   359  			return nil, err
   360  		}
   361  		if fieldType == thrift.STOP {
   362  			if err := in.ReadFieldEnd(); err != nil {
   363  				return nil, err
   364  			}
   365  			// check required
   366  			// void is nil struct
   367  			if t.Struct != nil {
   368  				if err := t.Struct.CheckRequired(readFields); err != nil {
   369  					return nil, err
   370  				}
   371  			}
   372  			return st, in.ReadStructEnd()
   373  		}
   374  		field, ok := t.Struct.FieldsByID[int32(fieldID)]
   375  		if !ok {
   376  			// just ignore the missing field, maybe server update its idls
   377  			if err := in.Skip(fieldType); err != nil {
   378  				return nil, err
   379  			}
   380  		} else {
   381  			nest := unnestPb(opt, field.ID)
   382  			_fieldType := descriptor.FromThriftTType(fieldType)
   383  			reader, err := nextReader(_fieldType, field.Type, opt)
   384  			if err != nil {
   385  				return nil, fmt.Errorf("nextReader of %s/%s/%d error %w", t.Name, field.Name, fieldID, err)
   386  			}
   387  			val, err := reader(ctx, in, field.Type, opt)
   388  			if err != nil {
   389  				return nil, fmt.Errorf("reader of %s/%s/%d error %w", t.Name, field.Name, fieldID, err)
   390  			}
   391  			if field.ValueMapping != nil {
   392  				if val, err = field.ValueMapping.Response(ctx, val, field); err != nil {
   393  					return nil, err
   394  				}
   395  			}
   396  			nest()
   397  
   398  			if err := fs(field, val); err != nil {
   399  				return nil, err
   400  			}
   401  		}
   402  		if err := in.ReadFieldEnd(); err != nil {
   403  			return nil, err
   404  		}
   405  		readFields[int32(fieldID)] = struct{}{}
   406  	}
   407  }
   408  
   409  func readHTTPResponse(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) {
   410  	var resp *descriptor.HTTPResponse
   411  	if opt == nil || opt.pbDsc == nil {
   412  		if opt == nil {
   413  			opt = &readerOption{}
   414  		}
   415  		resp = descriptor.NewHTTPResponse()
   416  	} else {
   417  		resp = descriptor.NewHTTPPbResponse(proto.NewMessage(opt.pbDsc))
   418  	}
   419  
   420  	var err error
   421  	// set default value
   422  	// default value with struct NOT SUPPORT pb.
   423  	for _, field := range t.Struct.DefaultFields {
   424  		val := field.DefaultValue
   425  		if field.ValueMapping != nil {
   426  			if val, err = field.ValueMapping.Response(ctx, val, field); err != nil {
   427  				return nil, err
   428  			}
   429  		}
   430  		if err = field.HTTPMapping.Response(ctx, resp, field, val); err != nil {
   431  			return nil, err
   432  		}
   433  	}
   434  	_, err = in.ReadStructBegin()
   435  	if err != nil {
   436  		return nil, err
   437  	}
   438  	readFields := map[int32]struct{}{}
   439  	for {
   440  		_, fieldType, fieldID, err := in.ReadFieldBegin()
   441  		if err != nil {
   442  			return nil, err
   443  		}
   444  		if fieldType == thrift.STOP {
   445  			if err := in.ReadFieldEnd(); err != nil {
   446  				return nil, err
   447  			}
   448  			// check required
   449  			if err := t.Struct.CheckRequired(readFields); err != nil {
   450  				return nil, err
   451  			}
   452  			return resp, in.ReadStructEnd()
   453  		}
   454  		field, ok := t.Struct.FieldsByID[int32(fieldID)]
   455  		if !ok {
   456  			// just ignore the missing field, maybe server update its idls
   457  			if err := in.Skip(fieldType); err != nil {
   458  				return nil, err
   459  			}
   460  		} else {
   461  			// Replace pb descriptor with field type
   462  			nest := unnestPb(opt, field.ID)
   463  
   464  			// check required
   465  			_fieldType := descriptor.FromThriftTType(fieldType)
   466  			reader, err := nextReader(_fieldType, field.Type, opt)
   467  			if err != nil {
   468  				return nil, fmt.Errorf("nextReader of %s/%s/%d error %w", t.Name, field.Name, fieldID, err)
   469  			}
   470  			val, err := reader(ctx, in, field.Type, opt)
   471  			if err != nil {
   472  				return nil, fmt.Errorf("reader of %s/%s/%d error %w", t.Name, field.Name, fieldID, err)
   473  			}
   474  			if field.ValueMapping != nil {
   475  				if val, err = field.ValueMapping.Response(ctx, val, field); err != nil {
   476  					return nil, err
   477  				}
   478  			}
   479  			nest()
   480  			if err = field.HTTPMapping.Response(ctx, resp, field, val); err != nil {
   481  				return nil, err
   482  			}
   483  		}
   484  		if err := in.ReadFieldEnd(); err != nil {
   485  			return nil, err
   486  		}
   487  		readFields[int32(fieldID)] = struct{}{}
   488  	}
   489  }
   490  
   491  func unnestPb(opt *readerOption, fieldId int32) func() {
   492  	pbDsc := opt.pbDsc
   493  	if pbDsc != nil {
   494  		fd := opt.pbDsc.FindFieldByNumber(fieldId)
   495  		if fd != nil && fd.GetMessageType() != nil {
   496  			opt.pbDsc = fd.GetMessageType()
   497  		} else {
   498  			opt.pbDsc = emptyPbDsc
   499  		}
   500  	}
   501  	return func() {
   502  		opt.pbDsc = pbDsc
   503  	}
   504  }