github.com/Synthesix/Sia@v1.3.3-0.20180413141344-f863baeed3ca/encoding/marshal.go (about)

     1  // Package encoding converts arbitrary objects into byte slices, and vis
     2  // versa. It also contains helper functions for reading and writing length-
     3  // prefixed data. See doc/Encoding.md for the full encoding specification.
     4  package encoding
     5  
     6  import (
     7  	"bytes"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"os"
    12  	"reflect"
    13  )
    14  
    15  const (
    16  	// MaxObjectSize refers to the maximum size an object could have.
    17  	// Limited to 12 MB.
    18  	MaxObjectSize = 12e6
    19  
    20  	// MaxSliceSize refers to the maximum size slice could have. Limited
    21  	// to 5 MB.
    22  	MaxSliceSize = 5e6 // 5 MB
    23  )
    24  
    25  var (
    26  	errBadPointer = errors.New("cannot decode into invalid pointer")
    27  )
    28  
    29  // ErrObjectTooLarge is an error when encoded object exceeds size limit.
    30  type ErrObjectTooLarge uint64
    31  
    32  // Error implements the error interface.
    33  func (e ErrObjectTooLarge) Error() string {
    34  	return fmt.Sprintf("encoded object (>= %v bytes) exceeds size limit (%v bytes)", uint64(e), uint64(MaxObjectSize))
    35  }
    36  
    37  // ErrSliceTooLarge is an error when encoded slice is too large.
    38  type ErrSliceTooLarge struct {
    39  	Len      uint64
    40  	ElemSize uint64
    41  }
    42  
    43  // Error implements the error interface.
    44  func (e ErrSliceTooLarge) Error() string {
    45  	return fmt.Sprintf("encoded slice (%v*%v bytes) exceeds size limit (%v bytes)", e.Len, e.ElemSize, uint64(MaxSliceSize))
    46  }
    47  
    48  type (
    49  	// A SiaMarshaler can encode and write itself to a stream.
    50  	SiaMarshaler interface {
    51  		MarshalSia(io.Writer) error
    52  	}
    53  
    54  	// A SiaUnmarshaler can read and decode itself from a stream.
    55  	SiaUnmarshaler interface {
    56  		UnmarshalSia(io.Reader) error
    57  	}
    58  
    59  	// An Encoder writes objects to an output stream.
    60  	Encoder struct {
    61  		w io.Writer
    62  	}
    63  )
    64  
    65  // Encode writes the encoding of v to the stream. For encoding details, see
    66  // the package docstring.
    67  func (e *Encoder) Encode(v interface{}) error {
    68  	return e.encode(reflect.ValueOf(v))
    69  }
    70  
    71  // EncodeAll encodes a variable number of arguments.
    72  func (e *Encoder) EncodeAll(vs ...interface{}) error {
    73  	for _, v := range vs {
    74  		if err := e.Encode(v); err != nil {
    75  			return err
    76  		}
    77  	}
    78  	return nil
    79  }
    80  
    81  // write catches instances where short writes do not return an error.
    82  func (e *Encoder) write(p []byte) error {
    83  	n, err := e.w.Write(p)
    84  	if n != len(p) && err == nil {
    85  		return io.ErrShortWrite
    86  	}
    87  	return err
    88  }
    89  
    90  // Encode writes the encoding of val to the stream. For encoding details, see
    91  // the package docstring.
    92  func (e *Encoder) encode(val reflect.Value) error {
    93  	// check for MarshalSia interface first
    94  	if val.CanInterface() {
    95  		if m, ok := val.Interface().(SiaMarshaler); ok {
    96  			return m.MarshalSia(e.w)
    97  		}
    98  	}
    99  
   100  	switch val.Kind() {
   101  	case reflect.Ptr:
   102  		// write either a 1 or 0
   103  		if err := e.Encode(!val.IsNil()); err != nil {
   104  			return err
   105  		}
   106  		if !val.IsNil() {
   107  			return e.encode(val.Elem())
   108  		}
   109  	case reflect.Bool:
   110  		if val.Bool() {
   111  			return e.write([]byte{1})
   112  		}
   113  
   114  		return e.write([]byte{0})
   115  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   116  		return e.write(EncInt64(val.Int()))
   117  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   118  		return WriteUint64(e.w, val.Uint())
   119  	case reflect.String:
   120  		return WritePrefix(e.w, []byte(val.String()))
   121  	case reflect.Slice:
   122  		// slices are variable length, so prepend the length and then fallthrough to array logic
   123  		if err := WriteInt(e.w, val.Len()); err != nil {
   124  			return err
   125  		}
   126  		if val.Len() == 0 {
   127  			return nil
   128  		}
   129  		fallthrough
   130  	case reflect.Array:
   131  		// special case for byte arrays
   132  		if val.Type().Elem().Kind() == reflect.Uint8 {
   133  			// if the array is addressable, we can optimize a bit here
   134  			if val.CanAddr() {
   135  				return e.write(val.Slice(0, val.Len()).Bytes())
   136  			}
   137  			// otherwise we have to copy into a newly allocated slice
   138  			slice := reflect.MakeSlice(reflect.SliceOf(val.Type().Elem()), val.Len(), val.Len())
   139  			reflect.Copy(slice, val)
   140  			return e.write(slice.Bytes())
   141  		}
   142  		// normal slices/arrays are encoded by sequentially encoding their elements
   143  		for i := 0; i < val.Len(); i++ {
   144  			if err := e.encode(val.Index(i)); err != nil {
   145  				return err
   146  			}
   147  		}
   148  		return nil
   149  	case reflect.Struct:
   150  		for i := 0; i < val.NumField(); i++ {
   151  			if err := e.encode(val.Field(i)); err != nil {
   152  				return err
   153  			}
   154  		}
   155  		return nil
   156  	}
   157  
   158  	// Marshalling should never fail. If it panics, you're doing something wrong,
   159  	// like trying to encode a map or an unexported struct field.
   160  	panic("could not marshal type " + val.Type().String())
   161  }
   162  
   163  // NewEncoder returns a new encoder that writes to w.
   164  func NewEncoder(w io.Writer) *Encoder {
   165  	return &Encoder{w}
   166  }
   167  
   168  // Marshal returns the encoding of v. For encoding details, see the package
   169  // docstring.
   170  func Marshal(v interface{}) []byte {
   171  	b := new(bytes.Buffer)
   172  	NewEncoder(b).Encode(v) // no error possible when using a bytes.Buffer
   173  	return b.Bytes()
   174  }
   175  
   176  // MarshalAll encodes all of its inputs and returns their concatenation.
   177  func MarshalAll(vs ...interface{}) []byte {
   178  	b := new(bytes.Buffer)
   179  	enc := NewEncoder(b)
   180  	// Error from EncodeAll is ignored as encoding cannot fail when writing
   181  	// to a bytes.Buffer.
   182  	_ = enc.EncodeAll(vs...)
   183  	return b.Bytes()
   184  }
   185  
   186  // WriteFile writes v to a file. The file will be created if it does not exist.
   187  func WriteFile(filename string, v interface{}) error {
   188  	file, err := os.Create(filename)
   189  	if err != nil {
   190  		return err
   191  	}
   192  	defer file.Close()
   193  	err = NewEncoder(file).Encode(v)
   194  	if err != nil {
   195  		return errors.New("error while writing " + filename + ": " + err.Error())
   196  	}
   197  	return nil
   198  }
   199  
   200  // A Decoder reads and decodes values from an input stream.
   201  type Decoder struct {
   202  	r io.Reader
   203  	n int
   204  }
   205  
   206  // Read implements the io.Reader interface. It also keeps track of the total
   207  // number of bytes decoded, and panics if that number exceeds a global
   208  // maximum.
   209  func (d *Decoder) Read(p []byte) (int, error) {
   210  	n, err := d.r.Read(p)
   211  	// enforce an absolute maximum size limit
   212  	if d.n += n; d.n > MaxObjectSize {
   213  		panic(ErrObjectTooLarge(d.n))
   214  	}
   215  	return n, err
   216  }
   217  
   218  // Decode reads the next encoded value from its input stream and stores it in
   219  // v, which must be a pointer. The decoding rules are the inverse of those
   220  // specified in the package docstring.
   221  func (d *Decoder) Decode(v interface{}) (err error) {
   222  	// v must be a pointer
   223  	pval := reflect.ValueOf(v)
   224  	if pval.Kind() != reflect.Ptr || pval.IsNil() {
   225  		return errBadPointer
   226  	}
   227  
   228  	// catch decoding panics and convert them to errors
   229  	// note that this allows us to skip boundary checks during decoding
   230  	defer func() {
   231  		if r := recover(); r != nil {
   232  			err = fmt.Errorf("could not decode type %s: %v", pval.Elem().Type().String(), r)
   233  		}
   234  	}()
   235  
   236  	// reset the read count
   237  	d.n = 0
   238  
   239  	d.decode(pval.Elem())
   240  	return
   241  }
   242  
   243  // DecodeAll decodes a variable number of arguments.
   244  func (d *Decoder) DecodeAll(vs ...interface{}) error {
   245  	for _, v := range vs {
   246  		if err := d.Decode(v); err != nil {
   247  			return err
   248  		}
   249  	}
   250  	return nil
   251  }
   252  
   253  // readN reads n bytes and panics if the read fails.
   254  func (d *Decoder) readN(n int) []byte {
   255  	if buf, ok := d.r.(*bytes.Buffer); ok {
   256  		b := buf.Next(n)
   257  		if len(b) != n {
   258  			panic(io.ErrUnexpectedEOF)
   259  		}
   260  		if d.n += n; d.n > MaxObjectSize {
   261  			panic(ErrObjectTooLarge(d.n))
   262  		}
   263  		return b
   264  	}
   265  	b := make([]byte, n)
   266  	_, err := io.ReadFull(d, b)
   267  	if err != nil {
   268  		panic(err)
   269  	}
   270  	return b
   271  }
   272  
   273  // decode reads the next encoded value from its input stream and stores it in
   274  // val. The decoding rules are the inverse of those specified in the package
   275  // docstring.
   276  func (d *Decoder) decode(val reflect.Value) {
   277  	// check for UnmarshalSia interface first
   278  	if val.CanAddr() && val.Addr().CanInterface() {
   279  		if u, ok := val.Addr().Interface().(SiaUnmarshaler); ok {
   280  			err := u.UnmarshalSia(d.r)
   281  			if err != nil {
   282  				panic(err)
   283  			}
   284  			return
   285  		}
   286  	}
   287  
   288  	switch val.Kind() {
   289  	case reflect.Ptr:
   290  		var valid bool
   291  		d.decode(reflect.ValueOf(&valid).Elem())
   292  		// nil pointer, nothing to decode
   293  		if !valid {
   294  			return
   295  		}
   296  		// make sure we aren't decoding into nil
   297  		if val.IsNil() {
   298  			val.Set(reflect.New(val.Type().Elem()))
   299  		}
   300  		d.decode(val.Elem())
   301  	case reflect.Bool:
   302  		b := d.readN(1)
   303  		if b[0] > 1 {
   304  			panic("boolean value was not 0 or 1")
   305  		}
   306  		val.SetBool(b[0] == 1)
   307  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   308  		val.SetInt(DecInt64(d.readN(8)))
   309  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   310  		val.SetUint(DecUint64(d.readN(8)))
   311  	case reflect.String:
   312  		strLen := DecUint64(d.readN(8))
   313  		if strLen > MaxSliceSize {
   314  			panic("string is too large")
   315  		}
   316  		val.SetString(string(d.readN(int(strLen))))
   317  	case reflect.Slice:
   318  		// slices are variable length, but otherwise the same as arrays.
   319  		// just have to allocate them first, then we can fallthrough to the array logic.
   320  		sliceLen := DecUint64(d.readN(8))
   321  		// sanity-check the sliceLen, otherwise you can crash a peer by making
   322  		// them allocate a massive slice
   323  		if sliceLen > 1<<31-1 || sliceLen*uint64(val.Type().Elem().Size()) > MaxSliceSize {
   324  			panic(ErrSliceTooLarge{Len: sliceLen, ElemSize: uint64(val.Type().Elem().Size())})
   325  		} else if sliceLen == 0 {
   326  			return
   327  		}
   328  		val.Set(reflect.MakeSlice(val.Type(), int(sliceLen), int(sliceLen)))
   329  		fallthrough
   330  	case reflect.Array:
   331  		// special case for byte arrays (e.g. hashes)
   332  		if val.Type().Elem().Kind() == reflect.Uint8 {
   333  			// convert val to a slice and read into it directly
   334  			b := val.Slice(0, val.Len())
   335  			_, err := io.ReadFull(d, b.Bytes())
   336  			if err != nil {
   337  				panic(err)
   338  			}
   339  			return
   340  		}
   341  		// arrays are unmarshalled by sequentially unmarshalling their elements
   342  		for i := 0; i < val.Len(); i++ {
   343  			d.decode(val.Index(i))
   344  		}
   345  		return
   346  	case reflect.Struct:
   347  		for i := 0; i < val.NumField(); i++ {
   348  			d.decode(val.Field(i))
   349  		}
   350  		return
   351  	default:
   352  		panic("unknown type")
   353  	}
   354  }
   355  
   356  // NewDecoder returns a new decoder that reads from r.
   357  func NewDecoder(r io.Reader) *Decoder {
   358  	return &Decoder{r, 0}
   359  }
   360  
   361  // Unmarshal decodes the encoded value b and stores it in v, which must be a
   362  // pointer. The decoding rules are the inverse of those specified in the
   363  // package docstring for marshaling.
   364  func Unmarshal(b []byte, v interface{}) error {
   365  	r := bytes.NewBuffer(b)
   366  	return NewDecoder(r).Decode(v)
   367  }
   368  
   369  // UnmarshalAll decodes the encoded values in b and stores them in vs, which
   370  // must be pointers.
   371  func UnmarshalAll(b []byte, vs ...interface{}) error {
   372  	dec := NewDecoder(bytes.NewBuffer(b))
   373  	return dec.DecodeAll(vs...)
   374  }
   375  
   376  // ReadFile reads the contents of a file and decodes them into v.
   377  func ReadFile(filename string, v interface{}) error {
   378  	file, err := os.Open(filename)
   379  	if err != nil {
   380  		return err
   381  	}
   382  	defer file.Close()
   383  	err = NewDecoder(file).Decode(v)
   384  	if err != nil {
   385  		return errors.New("error while reading " + filename + ": " + err.Error())
   386  	}
   387  	return nil
   388  }