github.com/segmentio/encoding@v0.4.0/proto/proto.go (about) 1 package proto 2 3 import ( 4 "fmt" 5 "reflect" 6 "sync/atomic" 7 "unsafe" 8 ) 9 10 func Size(v interface{}) int { 11 t, p := inspect(v) 12 c := cachedCodecOf(t) 13 return c.size(p, inline|toplevel) 14 } 15 16 func Marshal(v interface{}) ([]byte, error) { 17 t, p := inspect(v) 18 c := cachedCodecOf(t) 19 b := make([]byte, c.size(p, inline|toplevel)) 20 _, err := c.encode(b, p, inline|toplevel) 21 if err != nil { 22 return nil, fmt.Errorf("proto.Marshal(%T): %w", v, err) 23 } 24 return b, nil 25 } 26 27 func MarshalTo(b []byte, v interface{}) (int, error) { 28 t, p := inspect(v) 29 c := cachedCodecOf(t) 30 n, err := c.encode(b, p, inline|toplevel) 31 if err != nil { 32 err = fmt.Errorf("proto.MarshalTo: %w", err) 33 } 34 return n, err 35 } 36 37 func Unmarshal(b []byte, v interface{}) error { 38 if len(b) == 0 { 39 // An empty input is a valid protobuf message with all fields set to the 40 // zero-value. 41 reflect.ValueOf(v).Elem().Set(reflect.Zero(reflect.TypeOf(v).Elem())) 42 return nil 43 } 44 45 t, p := inspect(v) 46 t = t.Elem() // Unmarshal must be passed a pointer 47 c := cachedCodecOf(t) 48 49 n, err := c.decode(b, p, toplevel) 50 if err != nil { 51 return err 52 } 53 if n < len(b) { 54 return fmt.Errorf("proto.Unmarshal(%T): read=%d < buffer=%d", v, n, len(b)) 55 } 56 return nil 57 } 58 59 type flags uintptr 60 61 const ( 62 noflags flags = 0 63 inline flags = 1 << 0 64 wantzero flags = 1 << 1 65 // Shared with structField.flags in struct.go: 66 // zigzag flags = 1 << 2 67 toplevel flags = 1 << 3 68 ) 69 70 func (f flags) has(x flags) bool { 71 return (f & x) != 0 72 } 73 74 func (f flags) with(x flags) flags { 75 return f | x 76 } 77 78 func (f flags) without(x flags) flags { 79 return f & ^x 80 } 81 82 func (f flags) uint64(i int64) uint64 { 83 if f.has(zigzag) { 84 return encodeZigZag64(i) 85 } else { 86 return uint64(i) 87 } 88 } 89 90 func (f flags) int64(u uint64) int64 { 91 if f.has(zigzag) { 92 return decodeZigZag64(u) 93 } else { 94 return int64(u) 95 } 96 } 97 98 type iface struct { 99 typ unsafe.Pointer 100 ptr unsafe.Pointer 101 } 102 103 func inspect(v interface{}) (reflect.Type, unsafe.Pointer) { 104 return reflect.TypeOf(v), pointer(v) 105 } 106 107 func pointer(v interface{}) unsafe.Pointer { 108 return (*iface)(unsafe.Pointer(&v)).ptr 109 } 110 111 func inlined(t reflect.Type) bool { 112 switch t.Kind() { 113 case reflect.Ptr: 114 return true 115 case reflect.Map: 116 return true 117 case reflect.Struct: 118 return t.NumField() == 1 && inlined(t.Field(0).Type) 119 default: 120 return false 121 } 122 } 123 124 type fieldNumber uint 125 126 type wireType uint 127 128 const ( 129 varint wireType = 0 130 fixed64 wireType = 1 131 varlen wireType = 2 132 fixed32 wireType = 5 133 ) 134 135 func (wt wireType) String() string { 136 switch wt { 137 case varint: 138 return "varint" 139 case varlen: 140 return "varlen" 141 case fixed32: 142 return "fixed32" 143 case fixed64: 144 return "fixed64" 145 default: 146 return "unknown" 147 } 148 } 149 150 type codec struct { 151 wire wireType 152 size sizeFunc 153 encode encodeFunc 154 decode decodeFunc 155 } 156 157 var codecCache atomic.Value // map[unsafe.Pointer]*codec 158 159 func loadCachedCodec(t reflect.Type) (*codec, map[unsafe.Pointer]*codec) { 160 cache, _ := codecCache.Load().(map[unsafe.Pointer]*codec) 161 return cache[pointer(t)], cache 162 } 163 164 func storeCachedCodec(newCache map[unsafe.Pointer]*codec) { 165 codecCache.Store(newCache) 166 } 167 168 func cachedCodecOf(t reflect.Type) *codec { 169 c, oldCache := loadCachedCodec(t) 170 if c != nil { 171 return c 172 } 173 174 var p reflect.Type 175 isPtr := t.Kind() == reflect.Ptr 176 if isPtr { 177 p = t 178 t = t.Elem() 179 } else { 180 p = reflect.PtrTo(t) 181 } 182 183 seen := make(map[reflect.Type]*codec) 184 c1 := codecOf(t, seen) 185 c2 := codecOf(p, seen) 186 187 newCache := make(map[unsafe.Pointer]*codec, len(oldCache)+2) 188 for p, c := range oldCache { 189 newCache[p] = c 190 } 191 192 newCache[pointer(t)] = c1 193 newCache[pointer(p)] = c2 194 storeCachedCodec(newCache) 195 196 if isPtr { 197 return c2 198 } else { 199 return c1 200 } 201 } 202 203 func codecOf(t reflect.Type, seen map[reflect.Type]*codec) *codec { 204 if c := seen[t]; c != nil { 205 return c 206 } 207 208 switch { 209 case implements(t, messageType): 210 return messageCodecOf(t) 211 case implements(t, customMessageType) && !implements(t, protoMessageType): 212 return customCodecOf(t) 213 } 214 215 switch t.Kind() { 216 case reflect.Bool: 217 return &boolCodec 218 case reflect.Int: 219 return &intCodec 220 case reflect.Int32: 221 return &int32Codec 222 case reflect.Int64: 223 return &int64Codec 224 case reflect.Uint: 225 return &uintCodec 226 case reflect.Uint32: 227 return &uint32Codec 228 case reflect.Uint64: 229 return &uint64Codec 230 case reflect.Float32: 231 return &float32Codec 232 case reflect.Float64: 233 return &float64Codec 234 case reflect.String: 235 return &stringCodec 236 case reflect.Array: 237 elem := t.Elem() 238 switch elem.Kind() { 239 case reflect.Uint8: 240 return byteArrayCodecOf(t, seen) 241 } 242 case reflect.Slice: 243 elem := t.Elem() 244 switch elem.Kind() { 245 case reflect.Uint8: 246 return &bytesCodec 247 } 248 case reflect.Struct: 249 return structCodecOf(t, seen) 250 case reflect.Ptr: 251 return pointerCodecOf(t, seen) 252 } 253 254 panic("unsupported type: " + t.String()) 255 } 256 257 // backward compatibility with gogoproto custom types. 258 type customMessage interface { 259 Size() int 260 MarshalTo([]byte) (int, error) 261 Unmarshal([]byte) error 262 } 263 264 type protoMessage interface { 265 ProtoMessage() 266 } 267 268 var ( 269 messageType = reflect.TypeOf((*Message)(nil)).Elem() 270 customMessageType = reflect.TypeOf((*customMessage)(nil)).Elem() 271 protoMessageType = reflect.TypeOf((*protoMessage)(nil)).Elem() 272 ) 273 274 func implements(t, iface reflect.Type) bool { 275 return t.Implements(iface) || reflect.PtrTo(t).Implements(iface) 276 }