github.com/avahowell/sia@v0.5.1-beta.0.20160524050156-83dcc3d37c94/encoding/marshal.go (about)

     1  // Package encoding converts arbitrary objects into byte slices, and vis
     2  // versa. It also contains helper functions for reading and writing length-
     3  // prefixed data. See doc/Encoding.md for the full encoding specification.
     4  package encoding
     5  
     6  import (
     7  	"bytes"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"os"
    12  	"reflect"
    13  )
    14  
    15  const (
    16  	maxDecodeLen = 12e6 // 12 MB
    17  	maxSliceLen  = 5e6  // 5 MB
    18  )
    19  
    20  var (
    21  	errBadPointer = errors.New("cannot decode into invalid pointer")
    22  )
    23  
    24  type (
    25  	// GenericMarshaler marshals objects into byte slices and unmarshals byte
    26  	// slices into objects.
    27  	GenericMarshaler interface {
    28  		Marshal(interface{}) []byte
    29  		Unmarshal([]byte, interface{}) error
    30  	}
    31  
    32  	// A SiaMarshaler can encode and write itself to a stream.
    33  	SiaMarshaler interface {
    34  		MarshalSia(io.Writer) error
    35  	}
    36  
    37  	// A SiaUnmarshaler can read and decode itself from a stream.
    38  	SiaUnmarshaler interface {
    39  		UnmarshalSia(io.Reader) error
    40  	}
    41  
    42  	// StdGenericMarshaler is an implementation of GenericMarshaler that uses
    43  	// the encoding.Marshal and encoding.Unmarshal functions to perform
    44  	// its marshaling/unmarshaling.
    45  	StdGenericMarshaler struct{}
    46  
    47  	// An Encoder writes objects to an output stream.
    48  	Encoder struct {
    49  		w io.Writer
    50  	}
    51  )
    52  
    53  // Encode writes the encoding of v to the stream. For encoding details, see
    54  // the package docstring.
    55  func (e *Encoder) Encode(v interface{}) error {
    56  	return e.encode(reflect.ValueOf(v))
    57  }
    58  
    59  // EncodeAll encodes a variable number of arguments.
    60  func (e *Encoder) EncodeAll(vs ...interface{}) error {
    61  	for _, v := range vs {
    62  		if err := e.Encode(v); err != nil {
    63  			return err
    64  		}
    65  	}
    66  	return nil
    67  }
    68  
    69  // write catches instances where short writes do not return an error.
    70  func (e *Encoder) write(p []byte) error {
    71  	n, err := e.w.Write(p)
    72  	if n != len(p) && err == nil {
    73  		return io.ErrShortWrite
    74  	}
    75  	return err
    76  }
    77  
    78  // Encode writes the encoding of val to the stream. For encoding details, see
    79  // the package docstring.
    80  func (e *Encoder) encode(val reflect.Value) error {
    81  	// check for MarshalSia interface first
    82  	if val.CanInterface() {
    83  		if m, ok := val.Interface().(SiaMarshaler); ok {
    84  			return m.MarshalSia(e.w)
    85  		}
    86  	}
    87  
    88  	switch val.Kind() {
    89  	case reflect.Ptr:
    90  		// write either a 1 or 0
    91  		if err := e.Encode(!val.IsNil()); err != nil {
    92  			return err
    93  		}
    94  		if !val.IsNil() {
    95  			return e.encode(val.Elem())
    96  		}
    97  	case reflect.Bool:
    98  		if val.Bool() {
    99  			return e.write([]byte{1})
   100  		} else {
   101  			return e.write([]byte{0})
   102  		}
   103  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   104  		return e.write(EncInt64(val.Int()))
   105  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   106  		return e.write(EncUint64(val.Uint()))
   107  	case reflect.String:
   108  		return WritePrefix(e.w, []byte(val.String()))
   109  	case reflect.Slice:
   110  		// slices are variable length, so prepend the length and then fallthrough to array logic
   111  		if err := e.write(EncUint64(uint64(val.Len()))); err != nil {
   112  			return err
   113  		}
   114  		if val.Len() == 0 {
   115  			return nil
   116  		}
   117  		fallthrough
   118  	case reflect.Array:
   119  		// special case for byte arrays
   120  		if val.Type().Elem().Kind() == reflect.Uint8 {
   121  			// if the array is addressable, we can optimize a bit here
   122  			if val.CanAddr() {
   123  				return e.write(val.Slice(0, val.Len()).Bytes())
   124  			}
   125  			// otherwise we have to copy into a newly allocated slice
   126  			slice := reflect.MakeSlice(reflect.SliceOf(val.Type().Elem()), val.Len(), val.Len())
   127  			reflect.Copy(slice, val)
   128  			return e.write(slice.Bytes())
   129  		}
   130  		// normal slices/arrays are encoded by sequentially encoding their elements
   131  		for i := 0; i < val.Len(); i++ {
   132  			if err := e.encode(val.Index(i)); err != nil {
   133  				return err
   134  			}
   135  		}
   136  		return nil
   137  	case reflect.Struct:
   138  		for i := 0; i < val.NumField(); i++ {
   139  			if err := e.encode(val.Field(i)); err != nil {
   140  				return err
   141  			}
   142  		}
   143  		return nil
   144  	}
   145  
   146  	// Marshalling should never fail. If it panics, you're doing something wrong,
   147  	// like trying to encode a map or an unexported struct field.
   148  	panic("could not marshal type " + val.Type().String())
   149  }
   150  
   151  // NewEncoder returns a new encoder that writes to w.
   152  func NewEncoder(w io.Writer) *Encoder {
   153  	return &Encoder{w}
   154  }
   155  
   156  // Marshal returns the encoding of v. For encoding details, see the package
   157  // docstring.
   158  func Marshal(v interface{}) []byte {
   159  	b := new(bytes.Buffer)
   160  	NewEncoder(b).Encode(v) // no error possible when using a bytes.Buffer
   161  	return b.Bytes()
   162  }
   163  
   164  // MarshalAll encodes all of its inputs and returns their concatenation.
   165  func MarshalAll(vs ...interface{}) []byte {
   166  	b := new(bytes.Buffer)
   167  	enc := NewEncoder(b)
   168  	// can't encode with EncodeAll (type information is lost)
   169  	for _, v := range vs {
   170  		enc.Encode(v)
   171  	}
   172  	return b.Bytes()
   173  }
   174  
   175  // WriteFile writes v to a file. The file will be created if it does not exist.
   176  func WriteFile(filename string, v interface{}) error {
   177  	file, err := os.Create(filename)
   178  	if err != nil {
   179  		return err
   180  	}
   181  	defer file.Close()
   182  	err = NewEncoder(file).Encode(v)
   183  	if err != nil {
   184  		return errors.New("error while writing " + filename + ": " + err.Error())
   185  	}
   186  	return nil
   187  }
   188  
   189  // A Decoder reads and decodes values from an input stream.
   190  type Decoder struct {
   191  	r io.Reader
   192  	n int
   193  }
   194  
   195  // Read implements the io.Reader interface. It also keeps track of the total
   196  // number of bytes decoded, and panics if that number exceeds a global
   197  // maximum.
   198  func (d *Decoder) Read(p []byte) (int, error) {
   199  	n, err := d.r.Read(p)
   200  	// enforce an absolute maximum size limit
   201  	if d.n += n; d.n > maxDecodeLen {
   202  		panic("encoded type exceeds size limit")
   203  	}
   204  	return n, err
   205  }
   206  
   207  // Decode reads the next encoded value from its input stream and stores it in
   208  // v, which must be a pointer. The decoding rules are the inverse of those
   209  // specified in the package docstring.
   210  func (d *Decoder) Decode(v interface{}) (err error) {
   211  	// v must be a pointer
   212  	pval := reflect.ValueOf(v)
   213  	if pval.Kind() != reflect.Ptr || pval.IsNil() {
   214  		return errBadPointer
   215  	}
   216  
   217  	// catch decoding panics and convert them to errors
   218  	// note that this allows us to skip boundary checks during decoding
   219  	defer func() {
   220  		if r := recover(); r != nil {
   221  			err = fmt.Errorf("could not decode type %s: %v", pval.Elem().Type().String(), r)
   222  		}
   223  	}()
   224  
   225  	// reset the read count
   226  	d.n = 0
   227  
   228  	d.decode(pval.Elem())
   229  	return
   230  }
   231  
   232  // DecodeAll decodes a variable number of arguments.
   233  func (d *Decoder) DecodeAll(vs ...interface{}) error {
   234  	for _, v := range vs {
   235  		if err := d.Decode(v); err != nil {
   236  			return err
   237  		}
   238  	}
   239  	return nil
   240  }
   241  
   242  // readN reads n bytes and panics if the read fails.
   243  func (d *Decoder) readN(n int) []byte {
   244  	b := make([]byte, n)
   245  	_, err := io.ReadFull(d, b)
   246  	if err != nil {
   247  		panic(err)
   248  	}
   249  	return b
   250  }
   251  
   252  // readPrefix reads a length-prefixed byte slice and panics if the read fails.
   253  func (d *Decoder) readPrefix() []byte {
   254  	b, err := ReadPrefix(d, maxSliceLen)
   255  	if err != nil {
   256  		panic(err)
   257  	}
   258  	return b
   259  }
   260  
   261  // decode reads the next encoded value from its input stream and stores it in
   262  // val. The decoding rules are the inverse of those specified in the package
   263  // docstring.
   264  func (d *Decoder) decode(val reflect.Value) {
   265  	// check for UnmarshalSia interface first
   266  	if val.CanAddr() && val.Addr().CanInterface() {
   267  		if u, ok := val.Addr().Interface().(SiaUnmarshaler); ok {
   268  			err := u.UnmarshalSia(d)
   269  			if err != nil {
   270  				panic(err)
   271  			}
   272  			return
   273  		}
   274  	}
   275  
   276  	switch val.Kind() {
   277  	case reflect.Ptr:
   278  		var valid bool
   279  		d.decode(reflect.ValueOf(&valid).Elem())
   280  		// nil pointer, nothing to decode
   281  		if !valid {
   282  			return
   283  		}
   284  		// make sure we aren't decoding into nil
   285  		if val.IsNil() {
   286  			val.Set(reflect.New(val.Type().Elem()))
   287  		}
   288  		d.decode(val.Elem())
   289  	case reflect.Bool:
   290  		b := d.readN(1)
   291  		if b[0] > 1 {
   292  			panic("boolean value was not 0 or 1")
   293  		}
   294  		val.SetBool(b[0] == 1)
   295  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   296  		val.SetInt(DecInt64(d.readN(8)))
   297  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   298  		val.SetUint(DecUint64(d.readN(8)))
   299  	case reflect.String:
   300  		val.SetString(string(d.readPrefix()))
   301  	case reflect.Slice:
   302  		// slices are variable length, but otherwise the same as arrays.
   303  		// just have to allocate them first, then we can fallthrough to the array logic.
   304  		sliceLen := DecUint64(d.readN(8))
   305  		// sanity-check the sliceLen, otherwise you can crash a peer by making
   306  		// them allocate a massive slice
   307  		if sliceLen > 1<<31-1 || sliceLen*uint64(val.Type().Elem().Size()) > maxSliceLen {
   308  			panic("slice is too large")
   309  		} else if sliceLen == 0 {
   310  			return
   311  		}
   312  		val.Set(reflect.MakeSlice(val.Type(), int(sliceLen), int(sliceLen)))
   313  		fallthrough
   314  	case reflect.Array:
   315  		// special case for byte arrays (e.g. hashes)
   316  		if val.Type().Elem().Kind() == reflect.Uint8 {
   317  			// convert val to a slice and read into it directly
   318  			b := val.Slice(0, val.Len())
   319  			_, err := io.ReadFull(d, b.Bytes())
   320  			if err != nil {
   321  				panic(err)
   322  			}
   323  			return
   324  		}
   325  		// arrays are unmarshalled by sequentially unmarshalling their elements
   326  		for i := 0; i < val.Len(); i++ {
   327  			d.decode(val.Index(i))
   328  		}
   329  		return
   330  	case reflect.Struct:
   331  		for i := 0; i < val.NumField(); i++ {
   332  			d.decode(val.Field(i))
   333  		}
   334  		return
   335  	default:
   336  		panic("unknown type")
   337  	}
   338  }
   339  
   340  // NewDecoder returns a new decoder that reads from r.
   341  func NewDecoder(r io.Reader) *Decoder {
   342  	return &Decoder{r, 0}
   343  }
   344  
   345  // Unmarshal decodes the encoded value b and stores it in v, which must be a
   346  // pointer. The decoding rules are the inverse of those specified in the
   347  // package docstring for marshaling.
   348  func Unmarshal(b []byte, v interface{}) error {
   349  	r := bytes.NewReader(b)
   350  	return NewDecoder(r).Decode(v)
   351  }
   352  
   353  // UnmarshalAll decodes the encoded values in b and stores them in vs, which
   354  // must be pointers.
   355  func UnmarshalAll(b []byte, vs ...interface{}) error {
   356  	dec := NewDecoder(bytes.NewReader(b))
   357  	// can't use DecodeAll (type information is lost)
   358  	for _, v := range vs {
   359  		if err := dec.Decode(v); err != nil {
   360  			return err
   361  		}
   362  	}
   363  	return nil
   364  }
   365  
   366  // ReadFile reads the contents of a file and decodes them into v.
   367  func ReadFile(filename string, v interface{}) error {
   368  	file, err := os.Open(filename)
   369  	if err != nil {
   370  		return err
   371  	}
   372  	defer file.Close()
   373  	err = NewDecoder(file).Decode(v)
   374  	if err != nil {
   375  		return errors.New("error while reading " + filename + ": " + err.Error())
   376  	}
   377  	return nil
   378  }
   379  
   380  // Marshal returns the encoding of v. For encoding details, see the package
   381  // docstring.
   382  func (m StdGenericMarshaler) Marshal(v interface{}) []byte {
   383  	return Marshal(v)
   384  }
   385  
   386  // Unmarshal decodes the encoded value b and stores it in v, which must be a
   387  // pointer. The decoding rules are the inverse of those specified in the
   388  // package docstring for marshaling.
   389  func (m StdGenericMarshaler) Unmarshal(b []byte, v interface{}) error {
   390  	return Unmarshal(b, v)
   391  }