github.com/hoveychen/protoreflect@v1.4.7-0.20221103114119-0b4b3385ec76/codec/encode_fields.go (about)

     1  package codec
     2  
     3  import (
     4  	"fmt"
     5  	"math"
     6  	"reflect"
     7  	"sort"
     8  
     9  	"github.com/golang/protobuf/proto"
    10  	"github.com/golang/protobuf/protoc-gen-go/descriptor"
    11  
    12  	"github.com/hoveychen/protoreflect/desc"
    13  )
    14  
    15  func (cb *Buffer) EncodeFieldValue(fd *desc.FieldDescriptor, val interface{}) error {
    16  	if fd.IsMap() {
    17  		mp := val.(map[interface{}]interface{})
    18  		entryType := fd.GetMessageType()
    19  		keyType := entryType.FindFieldByNumber(1)
    20  		valType := entryType.FindFieldByNumber(2)
    21  		var entryBuffer Buffer
    22  		if cb.deterministic {
    23  			keys := make([]interface{}, 0, len(mp))
    24  			for k := range mp {
    25  				keys = append(keys, k)
    26  			}
    27  			sort.Sort(sortable(keys))
    28  			for _, k := range keys {
    29  				v := mp[k]
    30  				entryBuffer.Reset()
    31  				if err := entryBuffer.encodeFieldElement(keyType, k); err != nil {
    32  					return err
    33  				}
    34  				if err := entryBuffer.encodeFieldElement(valType, v); err != nil {
    35  					return err
    36  				}
    37  				if err := cb.EncodeTagAndWireType(fd.GetNumber(), proto.WireBytes); err != nil {
    38  					return err
    39  				}
    40  				if err := cb.EncodeRawBytes(entryBuffer.Bytes()); err != nil {
    41  					return err
    42  				}
    43  			}
    44  		} else {
    45  			for k, v := range mp {
    46  				entryBuffer.Reset()
    47  				if err := entryBuffer.encodeFieldElement(keyType, k); err != nil {
    48  					return err
    49  				}
    50  				if err := entryBuffer.encodeFieldElement(valType, v); err != nil {
    51  					return err
    52  				}
    53  				if err := cb.EncodeTagAndWireType(fd.GetNumber(), proto.WireBytes); err != nil {
    54  					return err
    55  				}
    56  				if err := cb.EncodeRawBytes(entryBuffer.Bytes()); err != nil {
    57  					return err
    58  				}
    59  			}
    60  		}
    61  		return nil
    62  	} else if fd.IsRepeated() {
    63  		sl := val.([]interface{})
    64  		wt, err := getWireType(fd.GetType())
    65  		if err != nil {
    66  			return err
    67  		}
    68  		if isPacked(fd) && len(sl) > 1 &&
    69  			(wt == proto.WireVarint || wt == proto.WireFixed32 || wt == proto.WireFixed64) {
    70  			// packed repeated field
    71  			var packedBuffer Buffer
    72  			for _, v := range sl {
    73  				if err := packedBuffer.encodeFieldValue(fd, v); err != nil {
    74  					return err
    75  				}
    76  			}
    77  			if err := cb.EncodeTagAndWireType(fd.GetNumber(), proto.WireBytes); err != nil {
    78  				return err
    79  			}
    80  			return cb.EncodeRawBytes(packedBuffer.Bytes())
    81  		} else {
    82  			// non-packed repeated field
    83  			for _, v := range sl {
    84  				if err := cb.encodeFieldElement(fd, v); err != nil {
    85  					return err
    86  				}
    87  			}
    88  			return nil
    89  		}
    90  	} else {
    91  		return cb.encodeFieldElement(fd, val)
    92  	}
    93  }
    94  
    95  func isPacked(fd *desc.FieldDescriptor) bool {
    96  	opts := fd.AsFieldDescriptorProto().GetOptions()
    97  	// if set, use that value
    98  	if opts != nil && opts.Packed != nil {
    99  		return opts.GetPacked()
   100  	}
   101  	// if unset: proto2 defaults to false, proto3 to true
   102  	return fd.GetFile().IsProto3()
   103  }
   104  
   105  // sortable is used to sort map keys. Values will be integers (int32, int64, uint32, and uint64),
   106  // bools, or strings.
   107  type sortable []interface{}
   108  
   109  func (s sortable) Len() int {
   110  	return len(s)
   111  }
   112  
   113  func (s sortable) Less(i, j int) bool {
   114  	vi := s[i]
   115  	vj := s[j]
   116  	switch reflect.TypeOf(vi).Kind() {
   117  	case reflect.Int32:
   118  		return vi.(int32) < vj.(int32)
   119  	case reflect.Int64:
   120  		return vi.(int64) < vj.(int64)
   121  	case reflect.Uint32:
   122  		return vi.(uint32) < vj.(uint32)
   123  	case reflect.Uint64:
   124  		return vi.(uint64) < vj.(uint64)
   125  	case reflect.String:
   126  		return vi.(string) < vj.(string)
   127  	case reflect.Bool:
   128  		return !vi.(bool) && vj.(bool)
   129  	default:
   130  		panic(fmt.Sprintf("cannot compare keys of type %v", reflect.TypeOf(vi)))
   131  	}
   132  }
   133  
   134  func (s sortable) Swap(i, j int) {
   135  	s[i], s[j] = s[j], s[i]
   136  }
   137  
   138  func (b *Buffer) encodeFieldElement(fd *desc.FieldDescriptor, val interface{}) error {
   139  	wt, err := getWireType(fd.GetType())
   140  	if err != nil {
   141  		return err
   142  	}
   143  	if err := b.EncodeTagAndWireType(fd.GetNumber(), wt); err != nil {
   144  		return err
   145  	}
   146  	if err := b.encodeFieldValue(fd, val); err != nil {
   147  		return err
   148  	}
   149  	if wt == proto.WireStartGroup {
   150  		return b.EncodeTagAndWireType(fd.GetNumber(), proto.WireEndGroup)
   151  	}
   152  	return nil
   153  }
   154  
   155  func (b *Buffer) encodeFieldValue(fd *desc.FieldDescriptor, val interface{}) error {
   156  	switch fd.GetType() {
   157  	case descriptor.FieldDescriptorProto_TYPE_BOOL:
   158  		v := val.(bool)
   159  		if v {
   160  			return b.EncodeVarint(1)
   161  		}
   162  		return b.EncodeVarint(0)
   163  
   164  	case descriptor.FieldDescriptorProto_TYPE_ENUM,
   165  		descriptor.FieldDescriptorProto_TYPE_INT32:
   166  		v := val.(int32)
   167  		return b.EncodeVarint(uint64(v))
   168  
   169  	case descriptor.FieldDescriptorProto_TYPE_SFIXED32:
   170  		v := val.(int32)
   171  		return b.EncodeFixed32(uint64(v))
   172  
   173  	case descriptor.FieldDescriptorProto_TYPE_SINT32:
   174  		v := val.(int32)
   175  		return b.EncodeVarint(EncodeZigZag32(v))
   176  
   177  	case descriptor.FieldDescriptorProto_TYPE_UINT32:
   178  		v := val.(uint32)
   179  		return b.EncodeVarint(uint64(v))
   180  
   181  	case descriptor.FieldDescriptorProto_TYPE_FIXED32:
   182  		v := val.(uint32)
   183  		return b.EncodeFixed32(uint64(v))
   184  
   185  	case descriptor.FieldDescriptorProto_TYPE_INT64:
   186  		v := val.(int64)
   187  		return b.EncodeVarint(uint64(v))
   188  
   189  	case descriptor.FieldDescriptorProto_TYPE_SFIXED64:
   190  		v := val.(int64)
   191  		return b.EncodeFixed64(uint64(v))
   192  
   193  	case descriptor.FieldDescriptorProto_TYPE_SINT64:
   194  		v := val.(int64)
   195  		return b.EncodeVarint(EncodeZigZag64(v))
   196  
   197  	case descriptor.FieldDescriptorProto_TYPE_UINT64:
   198  		v := val.(uint64)
   199  		return b.EncodeVarint(v)
   200  
   201  	case descriptor.FieldDescriptorProto_TYPE_FIXED64:
   202  		v := val.(uint64)
   203  		return b.EncodeFixed64(v)
   204  
   205  	case descriptor.FieldDescriptorProto_TYPE_DOUBLE:
   206  		v := val.(float64)
   207  		return b.EncodeFixed64(math.Float64bits(v))
   208  
   209  	case descriptor.FieldDescriptorProto_TYPE_FLOAT:
   210  		v := val.(float32)
   211  		return b.EncodeFixed32(uint64(math.Float32bits(v)))
   212  
   213  	case descriptor.FieldDescriptorProto_TYPE_BYTES:
   214  		v := val.([]byte)
   215  		return b.EncodeRawBytes(v)
   216  
   217  	case descriptor.FieldDescriptorProto_TYPE_STRING:
   218  		v := val.(string)
   219  		return b.EncodeRawBytes(([]byte)(v))
   220  
   221  	case descriptor.FieldDescriptorProto_TYPE_MESSAGE:
   222  		return b.EncodeDelimitedMessage(val.(proto.Message))
   223  
   224  	case descriptor.FieldDescriptorProto_TYPE_GROUP:
   225  		// just append the nested message to this buffer
   226  		return b.EncodeMessage(val.(proto.Message))
   227  		// whosoever writeth start-group tag (e.g. caller) is responsible for writing end-group tag
   228  
   229  	default:
   230  		return fmt.Errorf("unrecognized field type: %v", fd.GetType())
   231  	}
   232  }
   233  
   234  func getWireType(t descriptor.FieldDescriptorProto_Type) (int8, error) {
   235  	switch t {
   236  	case descriptor.FieldDescriptorProto_TYPE_ENUM,
   237  		descriptor.FieldDescriptorProto_TYPE_BOOL,
   238  		descriptor.FieldDescriptorProto_TYPE_INT32,
   239  		descriptor.FieldDescriptorProto_TYPE_SINT32,
   240  		descriptor.FieldDescriptorProto_TYPE_UINT32,
   241  		descriptor.FieldDescriptorProto_TYPE_INT64,
   242  		descriptor.FieldDescriptorProto_TYPE_SINT64,
   243  		descriptor.FieldDescriptorProto_TYPE_UINT64:
   244  		return proto.WireVarint, nil
   245  
   246  	case descriptor.FieldDescriptorProto_TYPE_FIXED32,
   247  		descriptor.FieldDescriptorProto_TYPE_SFIXED32,
   248  		descriptor.FieldDescriptorProto_TYPE_FLOAT:
   249  		return proto.WireFixed32, nil
   250  
   251  	case descriptor.FieldDescriptorProto_TYPE_FIXED64,
   252  		descriptor.FieldDescriptorProto_TYPE_SFIXED64,
   253  		descriptor.FieldDescriptorProto_TYPE_DOUBLE:
   254  		return proto.WireFixed64, nil
   255  
   256  	case descriptor.FieldDescriptorProto_TYPE_BYTES,
   257  		descriptor.FieldDescriptorProto_TYPE_STRING,
   258  		descriptor.FieldDescriptorProto_TYPE_MESSAGE:
   259  		return proto.WireBytes, nil
   260  
   261  	case descriptor.FieldDescriptorProto_TYPE_GROUP:
   262  		return proto.WireStartGroup, nil
   263  
   264  	default:
   265  		return 0, ErrBadWireType
   266  	}
   267  }