github.com/cloudwego/kitex@v0.9.0/pkg/protocol/bthrift/unknown.go (about)

     1  /*
     2   * Copyright 2023 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package bthrift
    18  
    19  import (
    20  	"errors"
    21  	"fmt"
    22  	"reflect"
    23  
    24  	"github.com/apache/thrift/lib/go/thrift"
    25  	"github.com/cloudwego/thriftgo/generator/golang/extension/unknown"
    26  )
    27  
    28  // UnknownField is used to describe an unknown field.
    29  type UnknownField struct {
    30  	Name    string
    31  	ID      int16
    32  	Type    int
    33  	KeyType int
    34  	ValType int
    35  	Value   interface{}
    36  }
    37  
    38  // GetUnknownFields deserialize unknownFields stored in v to a list of *UnknownFields.
    39  func GetUnknownFields(v interface{}) (fields []UnknownField, err error) {
    40  	var buf []byte
    41  	rv := reflect.ValueOf(v)
    42  	if rv.Kind() == reflect.Ptr && !rv.IsNil() {
    43  		rv = rv.Elem()
    44  	}
    45  	if rv.Kind() != reflect.Struct {
    46  		return nil, fmt.Errorf("%T is not a struct type", v)
    47  	}
    48  	if unknownField := rv.FieldByName("_unknownFields"); !unknownField.IsValid() {
    49  		return nil, fmt.Errorf("%T has no field named '_unknownFields'", v)
    50  	} else {
    51  		buf = unknownField.Bytes()
    52  	}
    53  	return ConvertUnknownFields(buf)
    54  }
    55  
    56  // ConvertUnknownFields converts buf to deserialized unknown fields.
    57  func ConvertUnknownFields(buf unknown.Fields) (fields []UnknownField, err error) {
    58  	if len(buf) == 0 {
    59  		return nil, errors.New("_unknownFields is empty")
    60  	}
    61  	var offset int
    62  	var l int
    63  	var name string
    64  	var fieldTypeId thrift.TType
    65  	var fieldId int16
    66  	var f UnknownField
    67  	for {
    68  		if offset == len(buf) {
    69  			return
    70  		}
    71  		name, fieldTypeId, fieldId, l, err = Binary.ReadFieldBegin(buf[offset:])
    72  		offset += l
    73  		if err != nil {
    74  			return nil, fmt.Errorf("read field %d begin error: %v", fieldId, err)
    75  		}
    76  		l, err = readUnknownField(&f, buf[offset:], name, fieldTypeId, fieldId)
    77  		offset += l
    78  		if err != nil {
    79  			return nil, fmt.Errorf("read unknown field %d error: %v", fieldId, err)
    80  		}
    81  		fields = append(fields, f)
    82  	}
    83  }
    84  
    85  func readUnknownField(f *UnknownField, buf []byte, name string, fieldType thrift.TType, id int16) (length int, err error) {
    86  	var size int
    87  	var l int
    88  	f.Name = name
    89  	f.ID = id
    90  	f.Type = int(fieldType)
    91  	switch fieldType {
    92  	case thrift.BOOL:
    93  		f.Value, l, err = Binary.ReadBool(buf[length:])
    94  		length += l
    95  	case thrift.BYTE:
    96  		f.Value, l, err = Binary.ReadByte(buf[length:])
    97  		length += l
    98  	case thrift.I16:
    99  		f.Value, l, err = Binary.ReadI16(buf[length:])
   100  		length += l
   101  	case thrift.I32:
   102  		f.Value, l, err = Binary.ReadI32(buf[length:])
   103  		length += l
   104  	case thrift.I64:
   105  		f.Value, l, err = Binary.ReadI64(buf[length:])
   106  		length += l
   107  	case thrift.DOUBLE:
   108  		f.Value, l, err = Binary.ReadDouble(buf[length:])
   109  		length += l
   110  	case thrift.STRING:
   111  		f.Value, l, err = Binary.ReadString(buf[length:])
   112  		length += l
   113  	case thrift.SET:
   114  		var ttype thrift.TType
   115  		ttype, size, l, err = Binary.ReadSetBegin(buf[length:])
   116  		length += l
   117  		if err != nil {
   118  			return length, fmt.Errorf("read set begin error: %w", err)
   119  		}
   120  		f.ValType = int(ttype)
   121  		set := make([]UnknownField, size)
   122  		for i := 0; i < size; i++ {
   123  			l, err2 := readUnknownField(&set[i], buf[length:], "", thrift.TType(f.ValType), int16(i))
   124  			length += l
   125  			if err2 != nil {
   126  				return length, fmt.Errorf("read set elem error: %w", err2)
   127  			}
   128  		}
   129  		l, err = Binary.ReadSetEnd(buf[length:])
   130  		length += l
   131  		if err != nil {
   132  			return length, fmt.Errorf("read set end error: %w", err)
   133  		}
   134  		f.Value = set
   135  	case thrift.LIST:
   136  		var ttype thrift.TType
   137  		ttype, size, l, err = Binary.ReadListBegin(buf[length:])
   138  		length += l
   139  		if err != nil {
   140  			return length, fmt.Errorf("read list begin error: %w", err)
   141  		}
   142  		f.ValType = int(ttype)
   143  		list := make([]UnknownField, size)
   144  		for i := 0; i < size; i++ {
   145  			l, err2 := readUnknownField(&list[i], buf[length:], "", thrift.TType(f.ValType), int16(i))
   146  			length += l
   147  			if err2 != nil {
   148  				return length, fmt.Errorf("read list elem error: %w", err2)
   149  			}
   150  		}
   151  		l, err = Binary.ReadListEnd(buf[length:])
   152  		length += l
   153  		if err != nil {
   154  			return length, fmt.Errorf("read list end error: %w", err)
   155  		}
   156  		f.Value = list
   157  	case thrift.MAP:
   158  		var kttype, vttype thrift.TType
   159  		kttype, vttype, size, l, err = Binary.ReadMapBegin(buf[length:])
   160  		length += l
   161  		if err != nil {
   162  			return length, fmt.Errorf("read map begin error: %w", err)
   163  		}
   164  		f.KeyType = int(kttype)
   165  		f.ValType = int(vttype)
   166  		flatMap := make([]UnknownField, size*2)
   167  		for i := 0; i < size; i++ {
   168  			l, err2 := readUnknownField(&flatMap[2*i], buf[length:], "", thrift.TType(f.KeyType), int16(i))
   169  			length += l
   170  			if err2 != nil {
   171  				return length, fmt.Errorf("read map key error: %w", err2)
   172  			}
   173  			l, err2 = readUnknownField(&flatMap[2*i+1], buf[length:], "", thrift.TType(f.ValType), int16(i))
   174  			length += l
   175  			if err2 != nil {
   176  				return length, fmt.Errorf("read map value error: %w", err2)
   177  			}
   178  		}
   179  		l, err = Binary.ReadMapEnd(buf[length:])
   180  		length += l
   181  		if err != nil {
   182  			return length, fmt.Errorf("read map end error: %w", err)
   183  		}
   184  		f.Value = flatMap
   185  	case thrift.STRUCT:
   186  		_, l, err = Binary.ReadStructBegin(buf[length:])
   187  		length += l
   188  		if err != nil {
   189  			return length, fmt.Errorf("read struct begin error: %w", err)
   190  		}
   191  		var field UnknownField
   192  		var fields []UnknownField
   193  		for {
   194  			name, fieldTypeID, fieldID, l, err := Binary.ReadFieldBegin(buf[length:])
   195  			length += l
   196  			if err != nil {
   197  				return length, fmt.Errorf("read field begin error: %w", err)
   198  			}
   199  			if fieldTypeID == thrift.STOP {
   200  				break
   201  			}
   202  			l, err = readUnknownField(&field, buf[length:], name, fieldTypeID, fieldID)
   203  			length += l
   204  			if err != nil {
   205  				return length, fmt.Errorf("read struct field error: %w", err)
   206  			}
   207  			l, err = Binary.ReadFieldEnd(buf[length:])
   208  			length += l
   209  			if err != nil {
   210  				return length, fmt.Errorf("read field end error: %w", err)
   211  			}
   212  			fields = append(fields, field)
   213  		}
   214  		l, err = Binary.ReadStructEnd(buf[length:])
   215  		length += l
   216  		if err != nil {
   217  			return length, fmt.Errorf("read struct end error: %w", err)
   218  		}
   219  		f.Value = fields
   220  	default:
   221  		return length, fmt.Errorf("unknown data type %d", f.Type)
   222  	}
   223  	if err != nil {
   224  		return length, err
   225  	}
   226  	return
   227  }
   228  
   229  // UnknownFieldsLength returns the length of fs.
   230  func UnknownFieldsLength(fs []UnknownField) (int, error) {
   231  	l := 0
   232  	for _, f := range fs {
   233  		l += Binary.FieldBeginLength(f.Name, thrift.TType(f.Type), f.ID)
   234  		ll, err := unknownFieldLength(&f)
   235  		l += ll
   236  		if err != nil {
   237  			return l, err
   238  		}
   239  		l += Binary.FieldEndLength()
   240  	}
   241  	return l, nil
   242  }
   243  
   244  func unknownFieldLength(f *UnknownField) (length int, err error) {
   245  	// use constants to avoid some type assert
   246  	switch f.Type {
   247  	case unknown.TBool:
   248  		length += Binary.BoolLength(false)
   249  	case unknown.TByte:
   250  		length += Binary.ByteLength(0)
   251  	case unknown.TDouble:
   252  		length += Binary.DoubleLength(0)
   253  	case unknown.TI16:
   254  		length += Binary.I16Length(0)
   255  	case unknown.TI32:
   256  		length += Binary.I32Length(0)
   257  	case unknown.TI64:
   258  		length += Binary.I64Length(0)
   259  	case unknown.TString:
   260  		length += Binary.StringLength(f.Value.(string))
   261  	case unknown.TSet:
   262  		vs := f.Value.([]UnknownField)
   263  		length += Binary.SetBeginLength(thrift.TType(f.ValType), len(vs))
   264  		for _, v := range vs {
   265  			l, err := unknownFieldLength(&v)
   266  			length += l
   267  			if err != nil {
   268  				return length, err
   269  			}
   270  		}
   271  		length += Binary.SetEndLength()
   272  	case unknown.TList:
   273  		vs := f.Value.([]UnknownField)
   274  		length += Binary.ListBeginLength(thrift.TType(f.ValType), len(vs))
   275  		for _, v := range vs {
   276  			l, err := unknownFieldLength(&v)
   277  			length += l
   278  			if err != nil {
   279  				return length, err
   280  			}
   281  		}
   282  		length += Binary.ListEndLength()
   283  	case unknown.TMap:
   284  		kvs := f.Value.([]UnknownField)
   285  		length += Binary.MapBeginLength(thrift.TType(f.KeyType), thrift.TType(f.ValType), len(kvs)/2)
   286  		for i := 0; i < len(kvs); i += 2 {
   287  			l, err := unknownFieldLength(&kvs[i])
   288  			length += l
   289  			if err != nil {
   290  				return length, err
   291  			}
   292  			l, err = unknownFieldLength(&kvs[i+1])
   293  			length += l
   294  			if err != nil {
   295  				return length, err
   296  			}
   297  		}
   298  		length += Binary.MapEndLength()
   299  	case unknown.TStruct:
   300  		fs := f.Value.([]UnknownField)
   301  		length += Binary.StructBeginLength(f.Name)
   302  		l, err := UnknownFieldsLength(fs)
   303  		length += l
   304  		if err != nil {
   305  			return length, err
   306  		}
   307  		length += Binary.FieldStopLength()
   308  		length += Binary.StructEndLength()
   309  	default:
   310  		return length, fmt.Errorf("unknown data type %d", f.Type)
   311  	}
   312  	return
   313  }
   314  
   315  // WriteUnknownFields writes fs into buf, and return written offset of the buf.
   316  func WriteUnknownFields(buf []byte, fs []UnknownField) (offset int, err error) {
   317  	for _, f := range fs {
   318  		offset += Binary.WriteFieldBegin(buf[offset:], f.Name, thrift.TType(f.Type), f.ID)
   319  		l, err := writeUnknownField(buf[offset:], &f)
   320  		offset += l
   321  		if err != nil {
   322  			return offset, err
   323  		}
   324  		offset += Binary.WriteFieldEnd(buf[offset:])
   325  	}
   326  	return offset, nil
   327  }
   328  
   329  func writeUnknownField(buf []byte, f *UnknownField) (offset int, err error) {
   330  	switch f.Type {
   331  	case unknown.TBool:
   332  		offset += Binary.WriteBool(buf, f.Value.(bool))
   333  	case unknown.TByte:
   334  		offset += Binary.WriteByte(buf, f.Value.(int8))
   335  	case unknown.TDouble:
   336  		offset += Binary.WriteDouble(buf, f.Value.(float64))
   337  	case unknown.TI16:
   338  		offset += Binary.WriteI16(buf, f.Value.(int16))
   339  	case unknown.TI32:
   340  		offset += Binary.WriteI32(buf, f.Value.(int32))
   341  	case unknown.TI64:
   342  		offset += Binary.WriteI64(buf, f.Value.(int64))
   343  	case unknown.TString:
   344  		offset += Binary.WriteString(buf, f.Value.(string))
   345  	case unknown.TSet:
   346  		vs := f.Value.([]UnknownField)
   347  		offset += Binary.WriteSetBegin(buf, thrift.TType(f.ValType), len(vs))
   348  		for _, v := range vs {
   349  			l, err := writeUnknownField(buf[offset:], &v)
   350  			offset += l
   351  			if err != nil {
   352  				return offset, err
   353  			}
   354  		}
   355  		offset += Binary.WriteSetEnd(buf[offset:])
   356  	case unknown.TList:
   357  		vs := f.Value.([]UnknownField)
   358  		offset += Binary.WriteListBegin(buf, thrift.TType(f.ValType), len(vs))
   359  		for _, v := range vs {
   360  			l, err := writeUnknownField(buf[offset:], &v)
   361  			offset += l
   362  			if err != nil {
   363  				return offset, err
   364  			}
   365  		}
   366  		offset += Binary.WriteListEnd(buf[offset:])
   367  	case unknown.TMap:
   368  		kvs := f.Value.([]UnknownField)
   369  		offset += Binary.WriteMapBegin(buf, thrift.TType(f.KeyType), thrift.TType(f.ValType), len(kvs)/2)
   370  		for i := 0; i < len(kvs); i += 2 {
   371  			l, err := writeUnknownField(buf[offset:], &kvs[i])
   372  			offset += l
   373  			if err != nil {
   374  				return offset, err
   375  			}
   376  			l, err = writeUnknownField(buf[offset:], &kvs[i+1])
   377  			offset += l
   378  			if err != nil {
   379  				return offset, err
   380  			}
   381  		}
   382  		offset += Binary.WriteMapEnd(buf[offset:])
   383  	case unknown.TStruct:
   384  		fs := f.Value.([]UnknownField)
   385  		offset += Binary.WriteStructBegin(buf, f.Name)
   386  		l, err := WriteUnknownFields(buf[offset:], fs)
   387  		offset += l
   388  		if err != nil {
   389  			return offset, err
   390  		}
   391  		offset += Binary.WriteFieldStop(buf[offset:])
   392  		offset += Binary.WriteStructEnd(buf[offset:])
   393  	default:
   394  		return offset, fmt.Errorf("unknown data type %d", f.Type)
   395  	}
   396  	return
   397  }