github.com/lheiskan/zebrapack@v4.1.1-0.20181107023619-e955d028f9bf+incompatible/msgp/extension.go (about)

     1  package msgp
     2  
     3  import (
     4  	"fmt"
     5  	"math"
     6  )
     7  
     8  const (
     9  	// Complex64Extension is the extension number used for complex64
    10  	Complex64Extension = 3
    11  
    12  	// Complex128Extension is the extension number used for complex128
    13  	Complex128Extension = 4
    14  
    15  	// TimeExtension is the extension number used for time.Time
    16  	TimeExtension = 5
    17  )
    18  
    19  // our extensions live here
    20  var extensionReg = make(map[int8]func() Extension)
    21  
    22  // RegisterExtension registers extensions so that they
    23  // can be initialized and returned by methods that
    24  // decode `interface{}` values. This should only
    25  // be called during initialization. f() should return
    26  // a newly-initialized zero value of the extension. Keep in
    27  // mind that extensions 3, 4, and 5 are reserved for
    28  // complex64, complex128, and time.Time, respectively,
    29  // and that MessagePack reserves extension types from -127 to -1.
    30  //
    31  // For example, if you wanted to register a user-defined struct:
    32  //
    33  //  msgp.RegisterExtension(10, func() msgp.Extension { &MyExtension{} })
    34  //
    35  // RegisterExtension will panic if you call it multiple times
    36  // with the same 'typ' argument, or if you use a reserved
    37  // type (3, 4, or 5).
    38  func RegisterExtension(typ int8, f func() Extension) {
    39  	switch typ {
    40  	case Complex64Extension, Complex128Extension, TimeExtension:
    41  		panic(fmt.Sprint("msgp: forbidden extension type:", typ))
    42  	}
    43  	if _, ok := extensionReg[typ]; ok {
    44  		panic(fmt.Sprint("msgp: RegisterExtension() called with typ", typ, "more than once"))
    45  	}
    46  	extensionReg[typ] = f
    47  }
    48  
    49  // ExtensionTypeError is an error type returned
    50  // when there is a mis-match between an extension type
    51  // and the type encoded on the wire
    52  type ExtensionTypeError struct {
    53  	Got  int8
    54  	Want int8
    55  }
    56  
    57  // Error implements the error interface
    58  func (e ExtensionTypeError) Error() string {
    59  	return fmt.Sprintf("msgp: error decoding extension: wanted type %d; got type %d", e.Want, e.Got)
    60  }
    61  
    62  // Resumable returns 'true' for ExtensionTypeErrors
    63  func (e ExtensionTypeError) Resumable() bool { return true }
    64  
    65  func errExt(got int8, wanted int8) error {
    66  	return ExtensionTypeError{Got: got, Want: wanted}
    67  }
    68  
    69  // Extension is the interface fulfilled
    70  // by types that want to define their
    71  // own binary encoding.
    72  type Extension interface {
    73  	// ExtensionType should return
    74  	// a int8 that identifies the concrete
    75  	// type of the extension. (Types <0 are
    76  	// officially reserved by the MessagePack
    77  	// specifications.)
    78  	ExtensionType() int8
    79  
    80  	// Len should return the length
    81  	// of the data to be encoded
    82  	Len() int
    83  
    84  	// MarshalBinaryTo should copy
    85  	// the data into the supplied slice,
    86  	// assuming that the slice has length Len()
    87  	MarshalBinaryTo([]byte) error
    88  
    89  	UnmarshalBinary([]byte) error
    90  }
    91  
    92  // RawExtension implements the Extension interface
    93  type RawExtension struct {
    94  	Data []byte
    95  	Type int8
    96  }
    97  
    98  // ExtensionType implements Extension.ExtensionType, and returns r.Type
    99  func (r *RawExtension) ExtensionType() int8 { return r.Type }
   100  
   101  // Len implements Extension.Len, and returns len(r.Data)
   102  func (r *RawExtension) Len() int { return len(r.Data) }
   103  
   104  // MarshalBinaryTo implements Extension.MarshalBinaryTo,
   105  // and returns a copy of r.Data
   106  func (r *RawExtension) MarshalBinaryTo(d []byte) error {
   107  	copy(d, r.Data)
   108  	return nil
   109  }
   110  
   111  // UnmarshalBinary implements Extension.UnmarshalBinary,
   112  // and sets r.Data to the contents of the provided slice
   113  func (r *RawExtension) UnmarshalBinary(b []byte) error {
   114  	if cap(r.Data) >= len(b) {
   115  		r.Data = r.Data[0:len(b)]
   116  	} else {
   117  		r.Data = make([]byte, len(b))
   118  	}
   119  	copy(r.Data, b)
   120  	return nil
   121  }
   122  
   123  // WriteExtension writes an extension type to the writer
   124  func (mw *Writer) WriteExtension(e Extension) error {
   125  	l := e.Len()
   126  	var err error
   127  	switch l {
   128  	case 0:
   129  		o, err := mw.require(3)
   130  		if err != nil {
   131  			return err
   132  		}
   133  		mw.buf[o] = mext8
   134  		mw.buf[o+1] = 0
   135  		mw.buf[o+2] = byte(e.ExtensionType())
   136  	case 1:
   137  		o, err := mw.require(2)
   138  		if err != nil {
   139  			return err
   140  		}
   141  		mw.buf[o] = mfixext1
   142  		mw.buf[o+1] = byte(e.ExtensionType())
   143  	case 2:
   144  		o, err := mw.require(2)
   145  		if err != nil {
   146  			return err
   147  		}
   148  		mw.buf[o] = mfixext2
   149  		mw.buf[o+1] = byte(e.ExtensionType())
   150  	case 4:
   151  		o, err := mw.require(2)
   152  		if err != nil {
   153  			return err
   154  		}
   155  		mw.buf[o] = mfixext4
   156  		mw.buf[o+1] = byte(e.ExtensionType())
   157  	case 8:
   158  		o, err := mw.require(2)
   159  		if err != nil {
   160  			return err
   161  		}
   162  		mw.buf[o] = mfixext8
   163  		mw.buf[o+1] = byte(e.ExtensionType())
   164  	case 16:
   165  		o, err := mw.require(2)
   166  		if err != nil {
   167  			return err
   168  		}
   169  		mw.buf[o] = mfixext16
   170  		mw.buf[o+1] = byte(e.ExtensionType())
   171  	default:
   172  		switch {
   173  		case l < math.MaxUint8:
   174  			o, err := mw.require(3)
   175  			if err != nil {
   176  				return err
   177  			}
   178  			mw.buf[o] = mext8
   179  			mw.buf[o+1] = byte(uint8(l))
   180  			mw.buf[o+2] = byte(e.ExtensionType())
   181  		case l < math.MaxUint16:
   182  			o, err := mw.require(4)
   183  			if err != nil {
   184  				return err
   185  			}
   186  			mw.buf[o] = mext16
   187  			big.PutUint16(mw.buf[o+1:], uint16(l))
   188  			mw.buf[o+3] = byte(e.ExtensionType())
   189  		default:
   190  			o, err := mw.require(6)
   191  			if err != nil {
   192  				return err
   193  			}
   194  			mw.buf[o] = mext32
   195  			big.PutUint32(mw.buf[o+1:], uint32(l))
   196  			mw.buf[o+5] = byte(e.ExtensionType())
   197  		}
   198  	}
   199  	// we can only write directly to the
   200  	// buffer if we're sure that it
   201  	// fits the object
   202  	if l <= mw.bufsize() {
   203  		o, err := mw.require(l)
   204  		if err != nil {
   205  			return err
   206  		}
   207  		return e.MarshalBinaryTo(mw.buf[o:])
   208  	}
   209  	// here we create a new buffer
   210  	// just large enough for the body
   211  	// and save it as the write buffer
   212  	err = mw.flush()
   213  	if err != nil {
   214  		return err
   215  	}
   216  	buf := make([]byte, l)
   217  	err = e.MarshalBinaryTo(buf)
   218  	if err != nil {
   219  		return err
   220  	}
   221  	mw.buf = buf
   222  	mw.wloc = l
   223  	return nil
   224  }
   225  
   226  // peek at the extension type, assuming the next
   227  // kind to be read is Extension
   228  func (m *Reader) peekExtensionType() (int8, error) {
   229  	p, err := m.R.Peek(2)
   230  	if err != nil {
   231  		return 0, err
   232  	}
   233  	spec := sizes[p[0]]
   234  	if spec.typ != ExtensionType {
   235  		return 0, badPrefix(ExtensionType, p[0])
   236  	}
   237  	if spec.extra == constsize {
   238  		return int8(p[1]), nil
   239  	}
   240  	size := spec.size
   241  	p, err = m.R.Peek(int(size))
   242  	if err != nil {
   243  		return 0, err
   244  	}
   245  	return int8(p[size-1]), nil
   246  }
   247  
   248  // peekExtension peeks at the extension encoding type
   249  // (must guarantee at least 1 byte in 'b')
   250  func peekExtension(b []byte) (int8, error) {
   251  	spec := sizes[b[0]]
   252  	size := spec.size
   253  	if spec.typ != ExtensionType {
   254  		return 0, badPrefix(ExtensionType, b[0])
   255  	}
   256  	if len(b) < int(size) {
   257  		return 0, ErrShortBytes
   258  	}
   259  	// for fixed extensions,
   260  	// the type information is in
   261  	// the second byte
   262  	if spec.extra == constsize {
   263  		return int8(b[1]), nil
   264  	}
   265  	// otherwise, it's in the last
   266  	// part of the prefix
   267  	return int8(b[size-1]), nil
   268  }
   269  
   270  // ReadExtension reads the next object from the reader
   271  // as an extension. ReadExtension will fail if the next
   272  // object in the stream is not an extension, or if
   273  // e.Type() is not the same as the wire type.
   274  func (m *Reader) ReadExtension(e Extension) (err error) {
   275  	var p []byte
   276  	p, err = m.R.Peek(2)
   277  	if err != nil {
   278  		return
   279  	}
   280  	lead := p[0]
   281  	var read int
   282  	var off int
   283  	switch lead {
   284  	case mfixext1:
   285  		if int8(p[1]) != e.ExtensionType() {
   286  			err = errExt(int8(p[1]), e.ExtensionType())
   287  			return
   288  		}
   289  		p, err = m.R.Peek(3)
   290  		if err != nil {
   291  			return
   292  		}
   293  		err = e.UnmarshalBinary(p[2:])
   294  		if err == nil {
   295  			_, err = m.R.Skip(3)
   296  		}
   297  		return
   298  
   299  	case mfixext2:
   300  		if int8(p[1]) != e.ExtensionType() {
   301  			err = errExt(int8(p[1]), e.ExtensionType())
   302  			return
   303  		}
   304  		p, err = m.R.Peek(4)
   305  		if err != nil {
   306  			return
   307  		}
   308  		err = e.UnmarshalBinary(p[2:])
   309  		if err == nil {
   310  			_, err = m.R.Skip(4)
   311  		}
   312  		return
   313  
   314  	case mfixext4:
   315  		if int8(p[1]) != e.ExtensionType() {
   316  			err = errExt(int8(p[1]), e.ExtensionType())
   317  			return
   318  		}
   319  		p, err = m.R.Peek(6)
   320  		if err != nil {
   321  			return
   322  		}
   323  		err = e.UnmarshalBinary(p[2:])
   324  		if err == nil {
   325  			_, err = m.R.Skip(6)
   326  		}
   327  		return
   328  
   329  	case mfixext8:
   330  		if int8(p[1]) != e.ExtensionType() {
   331  			err = errExt(int8(p[1]), e.ExtensionType())
   332  			return
   333  		}
   334  		p, err = m.R.Peek(10)
   335  		if err != nil {
   336  			return
   337  		}
   338  		err = e.UnmarshalBinary(p[2:])
   339  		if err == nil {
   340  			_, err = m.R.Skip(10)
   341  		}
   342  		return
   343  
   344  	case mfixext16:
   345  		if int8(p[1]) != e.ExtensionType() {
   346  			err = errExt(int8(p[1]), e.ExtensionType())
   347  			return
   348  		}
   349  		p, err = m.R.Peek(18)
   350  		if err != nil {
   351  			return
   352  		}
   353  		err = e.UnmarshalBinary(p[2:])
   354  		if err == nil {
   355  			_, err = m.R.Skip(18)
   356  		}
   357  		return
   358  
   359  	case mext8:
   360  		p, err = m.R.Peek(3)
   361  		if err != nil {
   362  			return
   363  		}
   364  		if int8(p[2]) != e.ExtensionType() {
   365  			err = errExt(int8(p[2]), e.ExtensionType())
   366  			return
   367  		}
   368  		read = int(uint8(p[1]))
   369  		off = 3
   370  
   371  	case mext16:
   372  		p, err = m.R.Peek(4)
   373  		if err != nil {
   374  			return
   375  		}
   376  		if int8(p[3]) != e.ExtensionType() {
   377  			err = errExt(int8(p[3]), e.ExtensionType())
   378  			return
   379  		}
   380  		read = int(big.Uint16(p[1:]))
   381  		off = 4
   382  
   383  	case mext32:
   384  		p, err = m.R.Peek(6)
   385  		if err != nil {
   386  			return
   387  		}
   388  		if int8(p[5]) != e.ExtensionType() {
   389  			err = errExt(int8(p[5]), e.ExtensionType())
   390  			return
   391  		}
   392  		read = int(big.Uint32(p[1:]))
   393  		off = 6
   394  
   395  	default:
   396  		err = badPrefix(ExtensionType, lead)
   397  		return
   398  	}
   399  
   400  	p, err = m.R.Peek(read + off)
   401  	if err != nil {
   402  		return
   403  	}
   404  	err = e.UnmarshalBinary(p[off:])
   405  	if err == nil {
   406  		_, err = m.R.Skip(read + off)
   407  	}
   408  	return
   409  }
   410  
   411  // AppendExtension appends a MessagePack extension to the provided slice
   412  func AppendExtension(b []byte, e Extension) ([]byte, error) {
   413  	l := e.Len()
   414  	var o []byte
   415  	var n int
   416  	switch l {
   417  	case 0:
   418  		o, n = ensure(b, 3)
   419  		o[n] = mext8
   420  		o[n+1] = 0
   421  		o[n+2] = byte(e.ExtensionType())
   422  		return o[:n+3], nil
   423  	case 1:
   424  		o, n = ensure(b, 3)
   425  		o[n] = mfixext1
   426  		o[n+1] = byte(e.ExtensionType())
   427  		n += 2
   428  	case 2:
   429  		o, n = ensure(b, 4)
   430  		o[n] = mfixext2
   431  		o[n+1] = byte(e.ExtensionType())
   432  		n += 2
   433  	case 4:
   434  		o, n = ensure(b, 6)
   435  		o[n] = mfixext4
   436  		o[n+1] = byte(e.ExtensionType())
   437  		n += 2
   438  	case 8:
   439  		o, n = ensure(b, 10)
   440  		o[n] = mfixext8
   441  		o[n+1] = byte(e.ExtensionType())
   442  		n += 2
   443  	case 16:
   444  		o, n = ensure(b, 18)
   445  		o[n] = mfixext16
   446  		o[n+1] = byte(e.ExtensionType())
   447  		n += 2
   448  	}
   449  	switch {
   450  	case l < math.MaxUint8:
   451  		o, n = ensure(b, l+3)
   452  		o[n] = mext8
   453  		o[n+1] = byte(uint8(l))
   454  		o[n+2] = byte(e.ExtensionType())
   455  		n += 3
   456  	case l < math.MaxUint16:
   457  		o, n = ensure(b, l+4)
   458  		o[n] = mext16
   459  		big.PutUint16(o[n+1:], uint16(l))
   460  		o[n+3] = byte(e.ExtensionType())
   461  		n += 4
   462  	default:
   463  		o, n = ensure(b, l+6)
   464  		o[n] = mext32
   465  		big.PutUint32(o[n+1:], uint32(l))
   466  		o[n+5] = byte(e.ExtensionType())
   467  		n += 6
   468  	}
   469  	return o, e.MarshalBinaryTo(o[n:])
   470  }
   471  
   472  // ReadExtensionBytes reads an extension from 'b' into 'e'
   473  // and returns any remaining bytes.
   474  // Possible errors:
   475  // - ErrShortBytes ('b' not long enough)
   476  // - ExtensionTypeErorr{} (wire type not the same as e.Type())
   477  // - TypeErorr{} (next object not an extension)
   478  // - InvalidPrefixError
   479  // - An umarshal error returned from e.UnmarshalBinary
   480  func (nbs *NilBitsStack) ReadExtensionBytes(b []byte, e Extension) ([]byte, error) {
   481  	if nbs != nil && nbs.AlwaysNil {
   482  		return b, nil
   483  	}
   484  
   485  	l := len(b)
   486  	if l < 3 {
   487  		return b, ErrShortBytes
   488  	}
   489  	lead := b[0]
   490  	var (
   491  		sz  int // size of 'data'
   492  		off int // offset of 'data'
   493  		typ int8
   494  	)
   495  	switch lead {
   496  	case mfixext1:
   497  		typ = int8(b[1])
   498  		sz = 1
   499  		off = 2
   500  	case mfixext2:
   501  		typ = int8(b[1])
   502  		sz = 2
   503  		off = 2
   504  	case mfixext4:
   505  		typ = int8(b[1])
   506  		sz = 4
   507  		off = 2
   508  	case mfixext8:
   509  		typ = int8(b[1])
   510  		sz = 8
   511  		off = 2
   512  	case mfixext16:
   513  		typ = int8(b[1])
   514  		sz = 16
   515  		off = 2
   516  	case mext8:
   517  		sz = int(uint8(b[1]))
   518  		typ = int8(b[2])
   519  		off = 3
   520  		if sz == 0 {
   521  			return b[3:], e.UnmarshalBinary(b[3:3])
   522  		}
   523  	case mext16:
   524  		if l < 4 {
   525  			return b, ErrShortBytes
   526  		}
   527  		sz = int(big.Uint16(b[1:]))
   528  		typ = int8(b[3])
   529  		off = 4
   530  	case mext32:
   531  		if l < 6 {
   532  			return b, ErrShortBytes
   533  		}
   534  		sz = int(big.Uint32(b[1:]))
   535  		typ = int8(b[5])
   536  		off = 6
   537  	default:
   538  		return b, badPrefix(ExtensionType, lead)
   539  	}
   540  
   541  	if typ != e.ExtensionType() {
   542  		return b, errExt(typ, e.ExtensionType())
   543  	}
   544  
   545  	// the data of the extension starts
   546  	// at 'off' and is 'sz' bytes long
   547  	if len(b[off:]) < sz {
   548  		return b, ErrShortBytes
   549  	}
   550  	tot := off + sz
   551  	return b[tot:], e.UnmarshalBinary(b[off:tot])
   552  }