github.com/nebulouslabs/sia@v1.3.7/encoding/marshal.go (about)

     1  // Package encoding converts arbitrary objects into byte slices, and vis
     2  // versa. It also contains helper functions for reading and writing length-
     3  // prefixed data. See doc/Encoding.md for the full encoding specification.
     4  package encoding
     5  
     6  import (
     7  	"bytes"
     8  	"encoding/binary"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"os"
    13  	"reflect"
    14  )
    15  
    16  const (
    17  	// MaxObjectSize refers to the maximum size an object could have.
    18  	// Limited to 12 MB.
    19  	MaxObjectSize = 12e6
    20  
    21  	// MaxSliceSize refers to the maximum size slice could have. Limited
    22  	// to 5 MB.
    23  	MaxSliceSize = 5e6 // 5 MB
    24  )
    25  
    26  var (
    27  	errBadPointer = errors.New("cannot decode into invalid pointer")
    28  )
    29  
    30  // ErrObjectTooLarge is an error when encoded object exceeds size limit.
    31  type ErrObjectTooLarge uint64
    32  
    33  // Error implements the error interface.
    34  func (e ErrObjectTooLarge) Error() string {
    35  	return fmt.Sprintf("encoded object (>= %v bytes) exceeds size limit (%v bytes)", uint64(e), uint64(MaxObjectSize))
    36  }
    37  
    38  // ErrSliceTooLarge is an error when encoded slice is too large.
    39  type ErrSliceTooLarge struct {
    40  	Len      uint64
    41  	ElemSize uint64
    42  }
    43  
    44  // Error implements the error interface.
    45  func (e ErrSliceTooLarge) Error() string {
    46  	return fmt.Sprintf("encoded slice (%v*%v bytes) exceeds size limit (%v bytes)", e.Len, e.ElemSize, uint64(MaxSliceSize))
    47  }
    48  
    49  type (
    50  	// A SiaMarshaler can encode and write itself to a stream.
    51  	SiaMarshaler interface {
    52  		MarshalSia(io.Writer) error
    53  	}
    54  
    55  	// A SiaUnmarshaler can read and decode itself from a stream.
    56  	SiaUnmarshaler interface {
    57  		UnmarshalSia(io.Reader) error
    58  	}
    59  )
    60  
    61  // An Encoder writes objects to an output stream. It also provides helper
    62  // methods for writing custom SiaMarshaler implementations. All of its methods
    63  // become no-ops after the Encoder encounters a Write error.
    64  type Encoder struct {
    65  	w   io.Writer
    66  	buf [8]byte
    67  	err error
    68  }
    69  
    70  // Write implements the io.Writer interface.
    71  func (e *Encoder) Write(p []byte) (int, error) {
    72  	if e.err != nil {
    73  		return 0, e.err
    74  	}
    75  	var n int
    76  	n, e.err = e.w.Write(p)
    77  	if n != len(p) && e.err == nil {
    78  		e.err = io.ErrShortWrite
    79  	}
    80  	return n, e.err
    81  }
    82  
    83  // WriteByte implements the io.ByteWriter interface.
    84  func (e *Encoder) WriteByte(b byte) error {
    85  	if e.err != nil {
    86  		return e.err
    87  	}
    88  	e.buf[0] = b
    89  	e.Write(e.buf[:1])
    90  	return e.err
    91  }
    92  
    93  // WriteBool writes b to the underlying io.Writer.
    94  func (e *Encoder) WriteBool(b bool) error {
    95  	if b {
    96  		return e.WriteByte(1)
    97  	}
    98  	return e.WriteByte(0)
    99  }
   100  
   101  // WriteUint64 writes a uint64 value to the underlying io.Writer.
   102  func (e *Encoder) WriteUint64(u uint64) error {
   103  	if e.err != nil {
   104  		return e.err
   105  	}
   106  	binary.LittleEndian.PutUint64(e.buf[:8], u)
   107  	e.Write(e.buf[:8])
   108  	return e.err
   109  }
   110  
   111  // WriteInt writes an int value to the underlying io.Writer.
   112  func (e *Encoder) WriteInt(i int) error {
   113  	return e.WriteUint64(uint64(i))
   114  }
   115  
   116  // WritePrefixedBytes writes p to the underlying io.Writer, prefixed by its length.
   117  func (e *Encoder) WritePrefixedBytes(p []byte) error {
   118  	e.WriteInt(len(p))
   119  	e.Write(p)
   120  	return e.err
   121  }
   122  
   123  // Err returns the first non-nil error encountered by e.
   124  func (e *Encoder) Err() error {
   125  	return e.err
   126  }
   127  
   128  // Encode writes the encoding of v to the stream. For encoding details, see
   129  // the package docstring.
   130  func (e *Encoder) Encode(v interface{}) error {
   131  	return e.encode(reflect.ValueOf(v))
   132  }
   133  
   134  // EncodeAll encodes a variable number of arguments.
   135  func (e *Encoder) EncodeAll(vs ...interface{}) error {
   136  	for _, v := range vs {
   137  		if err := e.Encode(v); err != nil {
   138  			return err
   139  		}
   140  	}
   141  	return nil
   142  }
   143  
   144  // Encode writes the encoding of val to the stream. For encoding details, see
   145  // the package docstring.
   146  func (e *Encoder) encode(val reflect.Value) error {
   147  	if e.err != nil {
   148  		return e.err
   149  	}
   150  	// check for MarshalSia interface first
   151  	if val.CanInterface() {
   152  		if m, ok := val.Interface().(SiaMarshaler); ok {
   153  			return m.MarshalSia(e.w)
   154  		}
   155  	}
   156  
   157  	switch val.Kind() {
   158  	case reflect.Ptr:
   159  		// write either a 1 or 0
   160  		if err := e.Encode(!val.IsNil()); err != nil {
   161  			return err
   162  		}
   163  		if !val.IsNil() {
   164  			return e.encode(val.Elem())
   165  		}
   166  	case reflect.Bool:
   167  		return e.WriteBool(val.Bool())
   168  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   169  		return e.WriteUint64(uint64(val.Int()))
   170  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   171  		return e.WriteUint64(val.Uint())
   172  	case reflect.String:
   173  		return e.WritePrefixedBytes([]byte(val.String()))
   174  	case reflect.Slice:
   175  		// slices are variable length, so prepend the length and then fallthrough to array logic
   176  		if err := e.WriteInt(val.Len()); err != nil {
   177  			return err
   178  		}
   179  		if val.Len() == 0 {
   180  			return nil
   181  		}
   182  		fallthrough
   183  	case reflect.Array:
   184  		// special case for byte arrays
   185  		if val.Type().Elem().Kind() == reflect.Uint8 {
   186  			// if the array is addressable, we can optimize a bit here
   187  			if val.CanAddr() {
   188  				_, err := e.Write(val.Slice(0, val.Len()).Bytes())
   189  				return err
   190  			}
   191  			// otherwise we have to copy into a newly allocated slice
   192  			slice := reflect.MakeSlice(reflect.SliceOf(val.Type().Elem()), val.Len(), val.Len())
   193  			reflect.Copy(slice, val)
   194  			_, err := e.Write(slice.Bytes())
   195  			return err
   196  		}
   197  		// normal slices/arrays are encoded by sequentially encoding their elements
   198  		for i := 0; i < val.Len(); i++ {
   199  			if err := e.encode(val.Index(i)); err != nil {
   200  				return err
   201  			}
   202  		}
   203  		return nil
   204  	case reflect.Struct:
   205  		for i := 0; i < val.NumField(); i++ {
   206  			if err := e.encode(val.Field(i)); err != nil {
   207  				return err
   208  			}
   209  		}
   210  		return nil
   211  	}
   212  
   213  	// Marshalling should never fail. If it panics, you're doing something wrong,
   214  	// like trying to encode a map or an unexported struct field.
   215  	panic("could not marshal type " + val.Type().String())
   216  }
   217  
   218  // NewEncoder converts w to an Encoder.
   219  func NewEncoder(w io.Writer) *Encoder {
   220  	if e, ok := w.(*Encoder); ok {
   221  		return e
   222  	}
   223  	return &Encoder{w: w}
   224  }
   225  
   226  // Marshal returns the encoding of v. For encoding details, see the package
   227  // docstring.
   228  func Marshal(v interface{}) []byte {
   229  	b := new(bytes.Buffer)
   230  	NewEncoder(b).Encode(v) // no error possible when using a bytes.Buffer
   231  	return b.Bytes()
   232  }
   233  
   234  // MarshalAll encodes all of its inputs and returns their concatenation.
   235  func MarshalAll(vs ...interface{}) []byte {
   236  	b := new(bytes.Buffer)
   237  	enc := NewEncoder(b)
   238  	// Error from EncodeAll is ignored as encoding cannot fail when writing
   239  	// to a bytes.Buffer.
   240  	_ = enc.EncodeAll(vs...)
   241  	return b.Bytes()
   242  }
   243  
   244  // WriteFile writes v to a file. The file will be created if it does not exist.
   245  func WriteFile(filename string, v interface{}) error {
   246  	file, err := os.Create(filename)
   247  	if err != nil {
   248  		return err
   249  	}
   250  	defer file.Close()
   251  	err = NewEncoder(file).Encode(v)
   252  	if err != nil {
   253  		return errors.New("error while writing " + filename + ": " + err.Error())
   254  	}
   255  	return nil
   256  }
   257  
   258  // A Decoder reads and decodes values from an input stream. It also provides
   259  // helper methods for writing custom SiaUnmarshaler implementations. These
   260  // methods do not return errors, but instead set the value of d.Err(). Once
   261  // d.Err() is set, future operations become no-ops.
   262  type Decoder struct {
   263  	r   io.Reader
   264  	buf [8]byte
   265  	err error
   266  	n   int // total number of bytes read
   267  }
   268  
   269  // Read implements the io.Reader interface.
   270  func (d *Decoder) Read(p []byte) (int, error) {
   271  	if d.err != nil {
   272  		return 0, d.err
   273  	}
   274  	var n int
   275  	n, d.err = d.r.Read(p)
   276  	d.n += n
   277  	if d.n > MaxObjectSize {
   278  		d.err = ErrObjectTooLarge(d.n)
   279  	}
   280  	return n, d.err
   281  }
   282  
   283  // ReadFull is shorthand for io.ReadFull(d, p).
   284  func (d *Decoder) ReadFull(p []byte) {
   285  	if d.err != nil {
   286  		return
   287  	}
   288  	n, err := io.ReadFull(d.r, p)
   289  	if err != nil {
   290  		d.err = err
   291  	}
   292  	d.n += n
   293  	if d.n > MaxObjectSize {
   294  		d.err = ErrObjectTooLarge(d.n)
   295  	}
   296  }
   297  
   298  // ReadPrefixedBytes reads a length-prefix, allocates a byte slice with that length,
   299  // reads into the byte slice, and returns it. If the length prefix exceeds
   300  // encoding.MaxSliceSize, ReadPrefixedBytes returns nil and sets d.Err().
   301  func (d *Decoder) ReadPrefixedBytes() []byte {
   302  	n := d.NextPrefix(1) // if too large, n == 0
   303  	if buf, ok := d.r.(*bytes.Buffer); ok {
   304  		b := buf.Next(int(n))
   305  		d.n += len(b)
   306  		if len(b) < int(n) && d.err == nil {
   307  			d.err = io.ErrUnexpectedEOF
   308  		}
   309  		return b
   310  	}
   311  
   312  	b := make([]byte, n)
   313  	d.ReadFull(b)
   314  	if d.err != nil {
   315  		return nil
   316  	}
   317  	return b
   318  }
   319  
   320  // NextUint64 reads the next 8 bytes and returns them as a uint64.
   321  func (d *Decoder) NextUint64() uint64 {
   322  	d.ReadFull(d.buf[:8])
   323  	if d.err != nil {
   324  		return 0
   325  	}
   326  	return DecUint64(d.buf[:])
   327  }
   328  
   329  // NextBool reads the next byte and returns it as a bool.
   330  func (d *Decoder) NextBool() bool {
   331  	d.ReadFull(d.buf[:1])
   332  	if d.buf[0] > 1 && d.err == nil {
   333  		d.err = errors.New("boolean value was not 0 or 1")
   334  	}
   335  	return d.buf[0] == 1
   336  }
   337  
   338  // NextPrefix is like NextUint64, but performs sanity checks on the prefix.
   339  // Specifically, if the prefix multiplied by elemSize exceeds MaxSliceSize,
   340  // NextPrefix returns 0 and sets d.Err().
   341  func (d *Decoder) NextPrefix(elemSize uintptr) uint64 {
   342  	n := d.NextUint64()
   343  	if n > 1<<31-1 || n*uint64(elemSize) > MaxSliceSize {
   344  		d.err = ErrSliceTooLarge{Len: n, ElemSize: uint64(elemSize)}
   345  		return 0
   346  	}
   347  	return n
   348  }
   349  
   350  // Err returns the first non-nil error encountered by d.
   351  func (d *Decoder) Err() error {
   352  	return d.err
   353  }
   354  
   355  // Decode reads the next encoded value from its input stream and stores it in
   356  // v, which must be a pointer. The decoding rules are the inverse of those
   357  // specified in the package docstring.
   358  func (d *Decoder) Decode(v interface{}) (err error) {
   359  	// v must be a pointer
   360  	pval := reflect.ValueOf(v)
   361  	if pval.Kind() != reflect.Ptr || pval.IsNil() {
   362  		return errBadPointer
   363  	}
   364  
   365  	// catch decoding panics and convert them to errors
   366  	// note that this allows us to skip boundary checks during decoding
   367  	defer func() {
   368  		if r := recover(); r != nil {
   369  			err = fmt.Errorf("could not decode type %s: %v", pval.Elem().Type().String(), r)
   370  		}
   371  	}()
   372  
   373  	// reset the read count
   374  	d.n = 0
   375  
   376  	d.decode(pval.Elem())
   377  	return
   378  }
   379  
   380  // DecodeAll decodes a variable number of arguments.
   381  func (d *Decoder) DecodeAll(vs ...interface{}) error {
   382  	for _, v := range vs {
   383  		if err := d.Decode(v); err != nil {
   384  			return err
   385  		}
   386  	}
   387  	return nil
   388  }
   389  
   390  // decode reads the next encoded value from its input stream and stores it in
   391  // val. The decoding rules are the inverse of those specified in the package
   392  // docstring.
   393  func (d *Decoder) decode(val reflect.Value) {
   394  	// check for UnmarshalSia interface first
   395  	if val.CanAddr() && val.Addr().CanInterface() {
   396  		if u, ok := val.Addr().Interface().(SiaUnmarshaler); ok {
   397  			err := u.UnmarshalSia(d.r)
   398  			if err != nil {
   399  				panic(err)
   400  			}
   401  			return
   402  		}
   403  	}
   404  
   405  	switch val.Kind() {
   406  	case reflect.Ptr:
   407  		valid := d.NextBool()
   408  		if !valid {
   409  			// nil pointer, nothing to decode
   410  			break
   411  		}
   412  		// make sure we aren't decoding into nil
   413  		if val.IsNil() {
   414  			val.Set(reflect.New(val.Type().Elem()))
   415  		}
   416  		d.decode(val.Elem())
   417  	case reflect.Bool:
   418  		val.SetBool(d.NextBool())
   419  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   420  		val.SetInt(int64(d.NextUint64()))
   421  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   422  		val.SetUint(d.NextUint64())
   423  	case reflect.String:
   424  		val.SetString(string(d.ReadPrefixedBytes()))
   425  	case reflect.Slice:
   426  		// slices are variable length, but otherwise the same as arrays.
   427  		// just have to allocate them first, then we can fallthrough to the array logic.
   428  		sliceLen := d.NextPrefix(val.Type().Elem().Size())
   429  		if sliceLen == 0 {
   430  			break
   431  		}
   432  		val.Set(reflect.MakeSlice(val.Type(), int(sliceLen), int(sliceLen)))
   433  		fallthrough
   434  	case reflect.Array:
   435  		// special case for byte arrays (e.g. hashes)
   436  		if val.Type().Elem().Kind() == reflect.Uint8 {
   437  			// convert val to a slice and read into it directly
   438  			d.ReadFull(val.Slice(0, val.Len()).Bytes())
   439  			break
   440  		}
   441  		// arrays are unmarshalled by sequentially unmarshalling their elements
   442  		for i := 0; i < val.Len(); i++ {
   443  			d.decode(val.Index(i))
   444  		}
   445  	case reflect.Struct:
   446  		for i := 0; i < val.NumField(); i++ {
   447  			d.decode(val.Field(i))
   448  		}
   449  	default:
   450  		panic("unknown type")
   451  	}
   452  
   453  	if d.err != nil {
   454  		panic(d.err)
   455  	}
   456  }
   457  
   458  // NewDecoder converts r to a Decoder.
   459  func NewDecoder(r io.Reader) *Decoder {
   460  	if d, ok := r.(*Decoder); ok {
   461  		return d
   462  	}
   463  	return &Decoder{r: r}
   464  }
   465  
   466  // Unmarshal decodes the encoded value b and stores it in v, which must be a
   467  // pointer. The decoding rules are the inverse of those specified in the
   468  // package docstring for marshaling.
   469  func Unmarshal(b []byte, v interface{}) error {
   470  	r := bytes.NewBuffer(b)
   471  	return NewDecoder(r).Decode(v)
   472  }
   473  
   474  // UnmarshalAll decodes the encoded values in b and stores them in vs, which
   475  // must be pointers.
   476  func UnmarshalAll(b []byte, vs ...interface{}) error {
   477  	dec := NewDecoder(bytes.NewBuffer(b))
   478  	return dec.DecodeAll(vs...)
   479  }
   480  
   481  // ReadFile reads the contents of a file and decodes them into v.
   482  func ReadFile(filename string, v interface{}) error {
   483  	file, err := os.Open(filename)
   484  	if err != nil {
   485  		return err
   486  	}
   487  	defer file.Close()
   488  	err = NewDecoder(file).Decode(v)
   489  	if err != nil {
   490  		return errors.New("error while reading " + filename + ": " + err.Error())
   491  	}
   492  	return nil
   493  }