github.com/gogo/protobuf@v1.3.2/proto/extensions_gogo.go (about) 1 // Protocol Buffers for Go with Gadgets 2 // 3 // Copyright (c) 2013, The GoGo Authors. All rights reserved. 4 // http://github.com/gogo/protobuf 5 // 6 // Redistribution and use in source and binary forms, with or without 7 // modification, are permitted provided that the following conditions are 8 // met: 9 // 10 // * Redistributions of source code must retain the above copyright 11 // notice, this list of conditions and the following disclaimer. 12 // * Redistributions in binary form must reproduce the above 13 // copyright notice, this list of conditions and the following disclaimer 14 // in the documentation and/or other materials provided with the 15 // distribution. 16 // 17 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 18 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 19 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 20 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 21 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 22 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 23 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 24 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 25 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 29 package proto 30 31 import ( 32 "bytes" 33 "errors" 34 "fmt" 35 "io" 36 "reflect" 37 "sort" 38 "strings" 39 "sync" 40 ) 41 42 type extensionsBytes interface { 43 Message 44 ExtensionRangeArray() []ExtensionRange 45 GetExtensions() *[]byte 46 } 47 48 type slowExtensionAdapter struct { 49 extensionsBytes 50 } 51 52 func (s slowExtensionAdapter) extensionsWrite() map[int32]Extension { 53 panic("Please report a bug to github.com/gogo/protobuf if you see this message: Writing extensions is not supported for extensions stored in a byte slice field.") 54 } 55 56 func (s slowExtensionAdapter) extensionsRead() (map[int32]Extension, sync.Locker) { 57 b := s.GetExtensions() 58 m, err := BytesToExtensionsMap(*b) 59 if err != nil { 60 panic(err) 61 } 62 return m, notLocker{} 63 } 64 65 func GetBoolExtension(pb Message, extension *ExtensionDesc, ifnotset bool) bool { 66 if reflect.ValueOf(pb).IsNil() { 67 return ifnotset 68 } 69 value, err := GetExtension(pb, extension) 70 if err != nil { 71 return ifnotset 72 } 73 if value == nil { 74 return ifnotset 75 } 76 if value.(*bool) == nil { 77 return ifnotset 78 } 79 return *(value.(*bool)) 80 } 81 82 func (this *Extension) Equal(that *Extension) bool { 83 if err := this.Encode(); err != nil { 84 return false 85 } 86 if err := that.Encode(); err != nil { 87 return false 88 } 89 return bytes.Equal(this.enc, that.enc) 90 } 91 92 func (this *Extension) Compare(that *Extension) int { 93 if err := this.Encode(); err != nil { 94 return 1 95 } 96 if err := that.Encode(); err != nil { 97 return -1 98 } 99 return bytes.Compare(this.enc, that.enc) 100 } 101 102 func SizeOfInternalExtension(m extendableProto) (n int) { 103 info := getMarshalInfo(reflect.TypeOf(m)) 104 return info.sizeV1Extensions(m.extensionsWrite()) 105 } 106 107 type sortableMapElem struct { 108 field int32 109 ext Extension 110 } 111 112 func newSortableExtensionsFromMap(m map[int32]Extension) sortableExtensions { 113 s := make(sortableExtensions, 0, len(m)) 114 for k, v := range m { 115 s = append(s, &sortableMapElem{field: k, ext: v}) 116 } 117 return s 118 } 119 120 type sortableExtensions []*sortableMapElem 121 122 func (this sortableExtensions) Len() int { return len(this) } 123 124 func (this sortableExtensions) Swap(i, j int) { this[i], this[j] = this[j], this[i] } 125 126 func (this sortableExtensions) Less(i, j int) bool { return this[i].field < this[j].field } 127 128 func (this sortableExtensions) String() string { 129 sort.Sort(this) 130 ss := make([]string, len(this)) 131 for i := range this { 132 ss[i] = fmt.Sprintf("%d: %v", this[i].field, this[i].ext) 133 } 134 return "map[" + strings.Join(ss, ",") + "]" 135 } 136 137 func StringFromInternalExtension(m extendableProto) string { 138 return StringFromExtensionsMap(m.extensionsWrite()) 139 } 140 141 func StringFromExtensionsMap(m map[int32]Extension) string { 142 return newSortableExtensionsFromMap(m).String() 143 } 144 145 func StringFromExtensionsBytes(ext []byte) string { 146 m, err := BytesToExtensionsMap(ext) 147 if err != nil { 148 panic(err) 149 } 150 return StringFromExtensionsMap(m) 151 } 152 153 func EncodeInternalExtension(m extendableProto, data []byte) (n int, err error) { 154 return EncodeExtensionMap(m.extensionsWrite(), data) 155 } 156 157 func EncodeInternalExtensionBackwards(m extendableProto, data []byte) (n int, err error) { 158 return EncodeExtensionMapBackwards(m.extensionsWrite(), data) 159 } 160 161 func EncodeExtensionMap(m map[int32]Extension, data []byte) (n int, err error) { 162 o := 0 163 for _, e := range m { 164 if err := e.Encode(); err != nil { 165 return 0, err 166 } 167 n := copy(data[o:], e.enc) 168 if n != len(e.enc) { 169 return 0, io.ErrShortBuffer 170 } 171 o += n 172 } 173 return o, nil 174 } 175 176 func EncodeExtensionMapBackwards(m map[int32]Extension, data []byte) (n int, err error) { 177 o := 0 178 end := len(data) 179 for _, e := range m { 180 if err := e.Encode(); err != nil { 181 return 0, err 182 } 183 n := copy(data[end-len(e.enc):], e.enc) 184 if n != len(e.enc) { 185 return 0, io.ErrShortBuffer 186 } 187 end -= n 188 o += n 189 } 190 return o, nil 191 } 192 193 func GetRawExtension(m map[int32]Extension, id int32) ([]byte, error) { 194 e := m[id] 195 if err := e.Encode(); err != nil { 196 return nil, err 197 } 198 return e.enc, nil 199 } 200 201 func size(buf []byte, wire int) (int, error) { 202 switch wire { 203 case WireVarint: 204 _, n := DecodeVarint(buf) 205 return n, nil 206 case WireFixed64: 207 return 8, nil 208 case WireBytes: 209 v, n := DecodeVarint(buf) 210 return int(v) + n, nil 211 case WireFixed32: 212 return 4, nil 213 case WireStartGroup: 214 offset := 0 215 for { 216 u, n := DecodeVarint(buf[offset:]) 217 fwire := int(u & 0x7) 218 offset += n 219 if fwire == WireEndGroup { 220 return offset, nil 221 } 222 s, err := size(buf[offset:], wire) 223 if err != nil { 224 return 0, err 225 } 226 offset += s 227 } 228 } 229 return 0, fmt.Errorf("proto: can't get size for unknown wire type %d", wire) 230 } 231 232 func BytesToExtensionsMap(buf []byte) (map[int32]Extension, error) { 233 m := make(map[int32]Extension) 234 i := 0 235 for i < len(buf) { 236 tag, n := DecodeVarint(buf[i:]) 237 if n <= 0 { 238 return nil, fmt.Errorf("unable to decode varint") 239 } 240 fieldNum := int32(tag >> 3) 241 wireType := int(tag & 0x7) 242 l, err := size(buf[i+n:], wireType) 243 if err != nil { 244 return nil, err 245 } 246 end := i + int(l) + n 247 m[int32(fieldNum)] = Extension{enc: buf[i:end]} 248 i = end 249 } 250 return m, nil 251 } 252 253 func NewExtension(e []byte) Extension { 254 ee := Extension{enc: make([]byte, len(e))} 255 copy(ee.enc, e) 256 return ee 257 } 258 259 func AppendExtension(e Message, tag int32, buf []byte) { 260 if ee, eok := e.(extensionsBytes); eok { 261 ext := ee.GetExtensions() 262 *ext = append(*ext, buf...) 263 return 264 } 265 if ee, eok := e.(extendableProto); eok { 266 m := ee.extensionsWrite() 267 ext := m[int32(tag)] // may be missing 268 ext.enc = append(ext.enc, buf...) 269 m[int32(tag)] = ext 270 } 271 } 272 273 func encodeExtension(extension *ExtensionDesc, value interface{}) ([]byte, error) { 274 u := getMarshalInfo(reflect.TypeOf(extension.ExtendedType)) 275 ei := u.getExtElemInfo(extension) 276 v := value 277 p := toAddrPointer(&v, ei.isptr) 278 siz := ei.sizer(p, SizeVarint(ei.wiretag)) 279 buf := make([]byte, 0, siz) 280 return ei.marshaler(buf, p, ei.wiretag, false) 281 } 282 283 func decodeExtensionFromBytes(extension *ExtensionDesc, buf []byte) (interface{}, error) { 284 o := 0 285 for o < len(buf) { 286 tag, n := DecodeVarint((buf)[o:]) 287 fieldNum := int32(tag >> 3) 288 wireType := int(tag & 0x7) 289 if o+n > len(buf) { 290 return nil, fmt.Errorf("unable to decode extension") 291 } 292 l, err := size((buf)[o+n:], wireType) 293 if err != nil { 294 return nil, err 295 } 296 if int32(fieldNum) == extension.Field { 297 if o+n+l > len(buf) { 298 return nil, fmt.Errorf("unable to decode extension") 299 } 300 v, err := decodeExtension((buf)[o:o+n+l], extension) 301 if err != nil { 302 return nil, err 303 } 304 return v, nil 305 } 306 o += n + l 307 } 308 return defaultExtensionValue(extension) 309 } 310 311 func (this *Extension) Encode() error { 312 if this.enc == nil { 313 var err error 314 this.enc, err = encodeExtension(this.desc, this.value) 315 if err != nil { 316 return err 317 } 318 } 319 return nil 320 } 321 322 func (this Extension) GoString() string { 323 if err := this.Encode(); err != nil { 324 return fmt.Sprintf("error encoding extension: %v", err) 325 } 326 return fmt.Sprintf("proto.NewExtension(%#v)", this.enc) 327 } 328 329 func SetUnsafeExtension(pb Message, fieldNum int32, value interface{}) error { 330 typ := reflect.TypeOf(pb).Elem() 331 ext, ok := extensionMaps[typ] 332 if !ok { 333 return fmt.Errorf("proto: bad extended type; %s is not extendable", typ.String()) 334 } 335 desc, ok := ext[fieldNum] 336 if !ok { 337 return errors.New("proto: bad extension number; not in declared ranges") 338 } 339 return SetExtension(pb, desc, value) 340 } 341 342 func GetUnsafeExtension(pb Message, fieldNum int32) (interface{}, error) { 343 typ := reflect.TypeOf(pb).Elem() 344 ext, ok := extensionMaps[typ] 345 if !ok { 346 return nil, fmt.Errorf("proto: bad extended type; %s is not extendable", typ.String()) 347 } 348 desc, ok := ext[fieldNum] 349 if !ok { 350 return nil, fmt.Errorf("unregistered field number %d", fieldNum) 351 } 352 return GetExtension(pb, desc) 353 } 354 355 func NewUnsafeXXX_InternalExtensions(m map[int32]Extension) XXX_InternalExtensions { 356 x := &XXX_InternalExtensions{ 357 p: new(struct { 358 mu sync.Mutex 359 extensionMap map[int32]Extension 360 }), 361 } 362 x.p.extensionMap = m 363 return *x 364 } 365 366 func GetUnsafeExtensionsMap(extendable Message) map[int32]Extension { 367 pb := extendable.(extendableProto) 368 return pb.extensionsWrite() 369 } 370 371 func deleteExtension(pb extensionsBytes, theFieldNum int32, offset int) int { 372 ext := pb.GetExtensions() 373 for offset < len(*ext) { 374 tag, n1 := DecodeVarint((*ext)[offset:]) 375 fieldNum := int32(tag >> 3) 376 wireType := int(tag & 0x7) 377 n2, err := size((*ext)[offset+n1:], wireType) 378 if err != nil { 379 panic(err) 380 } 381 newOffset := offset + n1 + n2 382 if fieldNum == theFieldNum { 383 *ext = append((*ext)[:offset], (*ext)[newOffset:]...) 384 return offset 385 } 386 offset = newOffset 387 } 388 return -1 389 }