github.com/lheiskan/zebrapack@v4.1.1-0.20181107023619-e955d028f9bf+incompatible/msgp/extension.go (about) 1 package msgp 2 3 import ( 4 "fmt" 5 "math" 6 ) 7 8 const ( 9 // Complex64Extension is the extension number used for complex64 10 Complex64Extension = 3 11 12 // Complex128Extension is the extension number used for complex128 13 Complex128Extension = 4 14 15 // TimeExtension is the extension number used for time.Time 16 TimeExtension = 5 17 ) 18 19 // our extensions live here 20 var extensionReg = make(map[int8]func() Extension) 21 22 // RegisterExtension registers extensions so that they 23 // can be initialized and returned by methods that 24 // decode `interface{}` values. This should only 25 // be called during initialization. f() should return 26 // a newly-initialized zero value of the extension. Keep in 27 // mind that extensions 3, 4, and 5 are reserved for 28 // complex64, complex128, and time.Time, respectively, 29 // and that MessagePack reserves extension types from -127 to -1. 30 // 31 // For example, if you wanted to register a user-defined struct: 32 // 33 // msgp.RegisterExtension(10, func() msgp.Extension { &MyExtension{} }) 34 // 35 // RegisterExtension will panic if you call it multiple times 36 // with the same 'typ' argument, or if you use a reserved 37 // type (3, 4, or 5). 38 func RegisterExtension(typ int8, f func() Extension) { 39 switch typ { 40 case Complex64Extension, Complex128Extension, TimeExtension: 41 panic(fmt.Sprint("msgp: forbidden extension type:", typ)) 42 } 43 if _, ok := extensionReg[typ]; ok { 44 panic(fmt.Sprint("msgp: RegisterExtension() called with typ", typ, "more than once")) 45 } 46 extensionReg[typ] = f 47 } 48 49 // ExtensionTypeError is an error type returned 50 // when there is a mis-match between an extension type 51 // and the type encoded on the wire 52 type ExtensionTypeError struct { 53 Got int8 54 Want int8 55 } 56 57 // Error implements the error interface 58 func (e ExtensionTypeError) Error() string { 59 return fmt.Sprintf("msgp: error decoding extension: wanted type %d; got type %d", e.Want, e.Got) 60 } 61 62 // Resumable returns 'true' for ExtensionTypeErrors 63 func (e ExtensionTypeError) Resumable() bool { return true } 64 65 func errExt(got int8, wanted int8) error { 66 return ExtensionTypeError{Got: got, Want: wanted} 67 } 68 69 // Extension is the interface fulfilled 70 // by types that want to define their 71 // own binary encoding. 72 type Extension interface { 73 // ExtensionType should return 74 // a int8 that identifies the concrete 75 // type of the extension. (Types <0 are 76 // officially reserved by the MessagePack 77 // specifications.) 78 ExtensionType() int8 79 80 // Len should return the length 81 // of the data to be encoded 82 Len() int 83 84 // MarshalBinaryTo should copy 85 // the data into the supplied slice, 86 // assuming that the slice has length Len() 87 MarshalBinaryTo([]byte) error 88 89 UnmarshalBinary([]byte) error 90 } 91 92 // RawExtension implements the Extension interface 93 type RawExtension struct { 94 Data []byte 95 Type int8 96 } 97 98 // ExtensionType implements Extension.ExtensionType, and returns r.Type 99 func (r *RawExtension) ExtensionType() int8 { return r.Type } 100 101 // Len implements Extension.Len, and returns len(r.Data) 102 func (r *RawExtension) Len() int { return len(r.Data) } 103 104 // MarshalBinaryTo implements Extension.MarshalBinaryTo, 105 // and returns a copy of r.Data 106 func (r *RawExtension) MarshalBinaryTo(d []byte) error { 107 copy(d, r.Data) 108 return nil 109 } 110 111 // UnmarshalBinary implements Extension.UnmarshalBinary, 112 // and sets r.Data to the contents of the provided slice 113 func (r *RawExtension) UnmarshalBinary(b []byte) error { 114 if cap(r.Data) >= len(b) { 115 r.Data = r.Data[0:len(b)] 116 } else { 117 r.Data = make([]byte, len(b)) 118 } 119 copy(r.Data, b) 120 return nil 121 } 122 123 // WriteExtension writes an extension type to the writer 124 func (mw *Writer) WriteExtension(e Extension) error { 125 l := e.Len() 126 var err error 127 switch l { 128 case 0: 129 o, err := mw.require(3) 130 if err != nil { 131 return err 132 } 133 mw.buf[o] = mext8 134 mw.buf[o+1] = 0 135 mw.buf[o+2] = byte(e.ExtensionType()) 136 case 1: 137 o, err := mw.require(2) 138 if err != nil { 139 return err 140 } 141 mw.buf[o] = mfixext1 142 mw.buf[o+1] = byte(e.ExtensionType()) 143 case 2: 144 o, err := mw.require(2) 145 if err != nil { 146 return err 147 } 148 mw.buf[o] = mfixext2 149 mw.buf[o+1] = byte(e.ExtensionType()) 150 case 4: 151 o, err := mw.require(2) 152 if err != nil { 153 return err 154 } 155 mw.buf[o] = mfixext4 156 mw.buf[o+1] = byte(e.ExtensionType()) 157 case 8: 158 o, err := mw.require(2) 159 if err != nil { 160 return err 161 } 162 mw.buf[o] = mfixext8 163 mw.buf[o+1] = byte(e.ExtensionType()) 164 case 16: 165 o, err := mw.require(2) 166 if err != nil { 167 return err 168 } 169 mw.buf[o] = mfixext16 170 mw.buf[o+1] = byte(e.ExtensionType()) 171 default: 172 switch { 173 case l < math.MaxUint8: 174 o, err := mw.require(3) 175 if err != nil { 176 return err 177 } 178 mw.buf[o] = mext8 179 mw.buf[o+1] = byte(uint8(l)) 180 mw.buf[o+2] = byte(e.ExtensionType()) 181 case l < math.MaxUint16: 182 o, err := mw.require(4) 183 if err != nil { 184 return err 185 } 186 mw.buf[o] = mext16 187 big.PutUint16(mw.buf[o+1:], uint16(l)) 188 mw.buf[o+3] = byte(e.ExtensionType()) 189 default: 190 o, err := mw.require(6) 191 if err != nil { 192 return err 193 } 194 mw.buf[o] = mext32 195 big.PutUint32(mw.buf[o+1:], uint32(l)) 196 mw.buf[o+5] = byte(e.ExtensionType()) 197 } 198 } 199 // we can only write directly to the 200 // buffer if we're sure that it 201 // fits the object 202 if l <= mw.bufsize() { 203 o, err := mw.require(l) 204 if err != nil { 205 return err 206 } 207 return e.MarshalBinaryTo(mw.buf[o:]) 208 } 209 // here we create a new buffer 210 // just large enough for the body 211 // and save it as the write buffer 212 err = mw.flush() 213 if err != nil { 214 return err 215 } 216 buf := make([]byte, l) 217 err = e.MarshalBinaryTo(buf) 218 if err != nil { 219 return err 220 } 221 mw.buf = buf 222 mw.wloc = l 223 return nil 224 } 225 226 // peek at the extension type, assuming the next 227 // kind to be read is Extension 228 func (m *Reader) peekExtensionType() (int8, error) { 229 p, err := m.R.Peek(2) 230 if err != nil { 231 return 0, err 232 } 233 spec := sizes[p[0]] 234 if spec.typ != ExtensionType { 235 return 0, badPrefix(ExtensionType, p[0]) 236 } 237 if spec.extra == constsize { 238 return int8(p[1]), nil 239 } 240 size := spec.size 241 p, err = m.R.Peek(int(size)) 242 if err != nil { 243 return 0, err 244 } 245 return int8(p[size-1]), nil 246 } 247 248 // peekExtension peeks at the extension encoding type 249 // (must guarantee at least 1 byte in 'b') 250 func peekExtension(b []byte) (int8, error) { 251 spec := sizes[b[0]] 252 size := spec.size 253 if spec.typ != ExtensionType { 254 return 0, badPrefix(ExtensionType, b[0]) 255 } 256 if len(b) < int(size) { 257 return 0, ErrShortBytes 258 } 259 // for fixed extensions, 260 // the type information is in 261 // the second byte 262 if spec.extra == constsize { 263 return int8(b[1]), nil 264 } 265 // otherwise, it's in the last 266 // part of the prefix 267 return int8(b[size-1]), nil 268 } 269 270 // ReadExtension reads the next object from the reader 271 // as an extension. ReadExtension will fail if the next 272 // object in the stream is not an extension, or if 273 // e.Type() is not the same as the wire type. 274 func (m *Reader) ReadExtension(e Extension) (err error) { 275 var p []byte 276 p, err = m.R.Peek(2) 277 if err != nil { 278 return 279 } 280 lead := p[0] 281 var read int 282 var off int 283 switch lead { 284 case mfixext1: 285 if int8(p[1]) != e.ExtensionType() { 286 err = errExt(int8(p[1]), e.ExtensionType()) 287 return 288 } 289 p, err = m.R.Peek(3) 290 if err != nil { 291 return 292 } 293 err = e.UnmarshalBinary(p[2:]) 294 if err == nil { 295 _, err = m.R.Skip(3) 296 } 297 return 298 299 case mfixext2: 300 if int8(p[1]) != e.ExtensionType() { 301 err = errExt(int8(p[1]), e.ExtensionType()) 302 return 303 } 304 p, err = m.R.Peek(4) 305 if err != nil { 306 return 307 } 308 err = e.UnmarshalBinary(p[2:]) 309 if err == nil { 310 _, err = m.R.Skip(4) 311 } 312 return 313 314 case mfixext4: 315 if int8(p[1]) != e.ExtensionType() { 316 err = errExt(int8(p[1]), e.ExtensionType()) 317 return 318 } 319 p, err = m.R.Peek(6) 320 if err != nil { 321 return 322 } 323 err = e.UnmarshalBinary(p[2:]) 324 if err == nil { 325 _, err = m.R.Skip(6) 326 } 327 return 328 329 case mfixext8: 330 if int8(p[1]) != e.ExtensionType() { 331 err = errExt(int8(p[1]), e.ExtensionType()) 332 return 333 } 334 p, err = m.R.Peek(10) 335 if err != nil { 336 return 337 } 338 err = e.UnmarshalBinary(p[2:]) 339 if err == nil { 340 _, err = m.R.Skip(10) 341 } 342 return 343 344 case mfixext16: 345 if int8(p[1]) != e.ExtensionType() { 346 err = errExt(int8(p[1]), e.ExtensionType()) 347 return 348 } 349 p, err = m.R.Peek(18) 350 if err != nil { 351 return 352 } 353 err = e.UnmarshalBinary(p[2:]) 354 if err == nil { 355 _, err = m.R.Skip(18) 356 } 357 return 358 359 case mext8: 360 p, err = m.R.Peek(3) 361 if err != nil { 362 return 363 } 364 if int8(p[2]) != e.ExtensionType() { 365 err = errExt(int8(p[2]), e.ExtensionType()) 366 return 367 } 368 read = int(uint8(p[1])) 369 off = 3 370 371 case mext16: 372 p, err = m.R.Peek(4) 373 if err != nil { 374 return 375 } 376 if int8(p[3]) != e.ExtensionType() { 377 err = errExt(int8(p[3]), e.ExtensionType()) 378 return 379 } 380 read = int(big.Uint16(p[1:])) 381 off = 4 382 383 case mext32: 384 p, err = m.R.Peek(6) 385 if err != nil { 386 return 387 } 388 if int8(p[5]) != e.ExtensionType() { 389 err = errExt(int8(p[5]), e.ExtensionType()) 390 return 391 } 392 read = int(big.Uint32(p[1:])) 393 off = 6 394 395 default: 396 err = badPrefix(ExtensionType, lead) 397 return 398 } 399 400 p, err = m.R.Peek(read + off) 401 if err != nil { 402 return 403 } 404 err = e.UnmarshalBinary(p[off:]) 405 if err == nil { 406 _, err = m.R.Skip(read + off) 407 } 408 return 409 } 410 411 // AppendExtension appends a MessagePack extension to the provided slice 412 func AppendExtension(b []byte, e Extension) ([]byte, error) { 413 l := e.Len() 414 var o []byte 415 var n int 416 switch l { 417 case 0: 418 o, n = ensure(b, 3) 419 o[n] = mext8 420 o[n+1] = 0 421 o[n+2] = byte(e.ExtensionType()) 422 return o[:n+3], nil 423 case 1: 424 o, n = ensure(b, 3) 425 o[n] = mfixext1 426 o[n+1] = byte(e.ExtensionType()) 427 n += 2 428 case 2: 429 o, n = ensure(b, 4) 430 o[n] = mfixext2 431 o[n+1] = byte(e.ExtensionType()) 432 n += 2 433 case 4: 434 o, n = ensure(b, 6) 435 o[n] = mfixext4 436 o[n+1] = byte(e.ExtensionType()) 437 n += 2 438 case 8: 439 o, n = ensure(b, 10) 440 o[n] = mfixext8 441 o[n+1] = byte(e.ExtensionType()) 442 n += 2 443 case 16: 444 o, n = ensure(b, 18) 445 o[n] = mfixext16 446 o[n+1] = byte(e.ExtensionType()) 447 n += 2 448 } 449 switch { 450 case l < math.MaxUint8: 451 o, n = ensure(b, l+3) 452 o[n] = mext8 453 o[n+1] = byte(uint8(l)) 454 o[n+2] = byte(e.ExtensionType()) 455 n += 3 456 case l < math.MaxUint16: 457 o, n = ensure(b, l+4) 458 o[n] = mext16 459 big.PutUint16(o[n+1:], uint16(l)) 460 o[n+3] = byte(e.ExtensionType()) 461 n += 4 462 default: 463 o, n = ensure(b, l+6) 464 o[n] = mext32 465 big.PutUint32(o[n+1:], uint32(l)) 466 o[n+5] = byte(e.ExtensionType()) 467 n += 6 468 } 469 return o, e.MarshalBinaryTo(o[n:]) 470 } 471 472 // ReadExtensionBytes reads an extension from 'b' into 'e' 473 // and returns any remaining bytes. 474 // Possible errors: 475 // - ErrShortBytes ('b' not long enough) 476 // - ExtensionTypeErorr{} (wire type not the same as e.Type()) 477 // - TypeErorr{} (next object not an extension) 478 // - InvalidPrefixError 479 // - An umarshal error returned from e.UnmarshalBinary 480 func (nbs *NilBitsStack) ReadExtensionBytes(b []byte, e Extension) ([]byte, error) { 481 if nbs != nil && nbs.AlwaysNil { 482 return b, nil 483 } 484 485 l := len(b) 486 if l < 3 { 487 return b, ErrShortBytes 488 } 489 lead := b[0] 490 var ( 491 sz int // size of 'data' 492 off int // offset of 'data' 493 typ int8 494 ) 495 switch lead { 496 case mfixext1: 497 typ = int8(b[1]) 498 sz = 1 499 off = 2 500 case mfixext2: 501 typ = int8(b[1]) 502 sz = 2 503 off = 2 504 case mfixext4: 505 typ = int8(b[1]) 506 sz = 4 507 off = 2 508 case mfixext8: 509 typ = int8(b[1]) 510 sz = 8 511 off = 2 512 case mfixext16: 513 typ = int8(b[1]) 514 sz = 16 515 off = 2 516 case mext8: 517 sz = int(uint8(b[1])) 518 typ = int8(b[2]) 519 off = 3 520 if sz == 0 { 521 return b[3:], e.UnmarshalBinary(b[3:3]) 522 } 523 case mext16: 524 if l < 4 { 525 return b, ErrShortBytes 526 } 527 sz = int(big.Uint16(b[1:])) 528 typ = int8(b[3]) 529 off = 4 530 case mext32: 531 if l < 6 { 532 return b, ErrShortBytes 533 } 534 sz = int(big.Uint32(b[1:])) 535 typ = int8(b[5]) 536 off = 6 537 default: 538 return b, badPrefix(ExtensionType, lead) 539 } 540 541 if typ != e.ExtensionType() { 542 return b, errExt(typ, e.ExtensionType()) 543 } 544 545 // the data of the extension starts 546 // at 'off' and is 'sz' bytes long 547 if len(b[off:]) < sz { 548 return b, ErrShortBytes 549 } 550 tot := off + sz 551 return b[tot:], e.UnmarshalBinary(b[off:tot]) 552 }