github.com/avahowell/sia@v0.5.1-beta.0.20160524050156-83dcc3d37c94/encoding/marshal.go (about) 1 // Package encoding converts arbitrary objects into byte slices, and vis 2 // versa. It also contains helper functions for reading and writing length- 3 // prefixed data. See doc/Encoding.md for the full encoding specification. 4 package encoding 5 6 import ( 7 "bytes" 8 "errors" 9 "fmt" 10 "io" 11 "os" 12 "reflect" 13 ) 14 15 const ( 16 maxDecodeLen = 12e6 // 12 MB 17 maxSliceLen = 5e6 // 5 MB 18 ) 19 20 var ( 21 errBadPointer = errors.New("cannot decode into invalid pointer") 22 ) 23 24 type ( 25 // GenericMarshaler marshals objects into byte slices and unmarshals byte 26 // slices into objects. 27 GenericMarshaler interface { 28 Marshal(interface{}) []byte 29 Unmarshal([]byte, interface{}) error 30 } 31 32 // A SiaMarshaler can encode and write itself to a stream. 33 SiaMarshaler interface { 34 MarshalSia(io.Writer) error 35 } 36 37 // A SiaUnmarshaler can read and decode itself from a stream. 38 SiaUnmarshaler interface { 39 UnmarshalSia(io.Reader) error 40 } 41 42 // StdGenericMarshaler is an implementation of GenericMarshaler that uses 43 // the encoding.Marshal and encoding.Unmarshal functions to perform 44 // its marshaling/unmarshaling. 45 StdGenericMarshaler struct{} 46 47 // An Encoder writes objects to an output stream. 48 Encoder struct { 49 w io.Writer 50 } 51 ) 52 53 // Encode writes the encoding of v to the stream. For encoding details, see 54 // the package docstring. 55 func (e *Encoder) Encode(v interface{}) error { 56 return e.encode(reflect.ValueOf(v)) 57 } 58 59 // EncodeAll encodes a variable number of arguments. 60 func (e *Encoder) EncodeAll(vs ...interface{}) error { 61 for _, v := range vs { 62 if err := e.Encode(v); err != nil { 63 return err 64 } 65 } 66 return nil 67 } 68 69 // write catches instances where short writes do not return an error. 70 func (e *Encoder) write(p []byte) error { 71 n, err := e.w.Write(p) 72 if n != len(p) && err == nil { 73 return io.ErrShortWrite 74 } 75 return err 76 } 77 78 // Encode writes the encoding of val to the stream. For encoding details, see 79 // the package docstring. 80 func (e *Encoder) encode(val reflect.Value) error { 81 // check for MarshalSia interface first 82 if val.CanInterface() { 83 if m, ok := val.Interface().(SiaMarshaler); ok { 84 return m.MarshalSia(e.w) 85 } 86 } 87 88 switch val.Kind() { 89 case reflect.Ptr: 90 // write either a 1 or 0 91 if err := e.Encode(!val.IsNil()); err != nil { 92 return err 93 } 94 if !val.IsNil() { 95 return e.encode(val.Elem()) 96 } 97 case reflect.Bool: 98 if val.Bool() { 99 return e.write([]byte{1}) 100 } else { 101 return e.write([]byte{0}) 102 } 103 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 104 return e.write(EncInt64(val.Int())) 105 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 106 return e.write(EncUint64(val.Uint())) 107 case reflect.String: 108 return WritePrefix(e.w, []byte(val.String())) 109 case reflect.Slice: 110 // slices are variable length, so prepend the length and then fallthrough to array logic 111 if err := e.write(EncUint64(uint64(val.Len()))); err != nil { 112 return err 113 } 114 if val.Len() == 0 { 115 return nil 116 } 117 fallthrough 118 case reflect.Array: 119 // special case for byte arrays 120 if val.Type().Elem().Kind() == reflect.Uint8 { 121 // if the array is addressable, we can optimize a bit here 122 if val.CanAddr() { 123 return e.write(val.Slice(0, val.Len()).Bytes()) 124 } 125 // otherwise we have to copy into a newly allocated slice 126 slice := reflect.MakeSlice(reflect.SliceOf(val.Type().Elem()), val.Len(), val.Len()) 127 reflect.Copy(slice, val) 128 return e.write(slice.Bytes()) 129 } 130 // normal slices/arrays are encoded by sequentially encoding their elements 131 for i := 0; i < val.Len(); i++ { 132 if err := e.encode(val.Index(i)); err != nil { 133 return err 134 } 135 } 136 return nil 137 case reflect.Struct: 138 for i := 0; i < val.NumField(); i++ { 139 if err := e.encode(val.Field(i)); err != nil { 140 return err 141 } 142 } 143 return nil 144 } 145 146 // Marshalling should never fail. If it panics, you're doing something wrong, 147 // like trying to encode a map or an unexported struct field. 148 panic("could not marshal type " + val.Type().String()) 149 } 150 151 // NewEncoder returns a new encoder that writes to w. 152 func NewEncoder(w io.Writer) *Encoder { 153 return &Encoder{w} 154 } 155 156 // Marshal returns the encoding of v. For encoding details, see the package 157 // docstring. 158 func Marshal(v interface{}) []byte { 159 b := new(bytes.Buffer) 160 NewEncoder(b).Encode(v) // no error possible when using a bytes.Buffer 161 return b.Bytes() 162 } 163 164 // MarshalAll encodes all of its inputs and returns their concatenation. 165 func MarshalAll(vs ...interface{}) []byte { 166 b := new(bytes.Buffer) 167 enc := NewEncoder(b) 168 // can't encode with EncodeAll (type information is lost) 169 for _, v := range vs { 170 enc.Encode(v) 171 } 172 return b.Bytes() 173 } 174 175 // WriteFile writes v to a file. The file will be created if it does not exist. 176 func WriteFile(filename string, v interface{}) error { 177 file, err := os.Create(filename) 178 if err != nil { 179 return err 180 } 181 defer file.Close() 182 err = NewEncoder(file).Encode(v) 183 if err != nil { 184 return errors.New("error while writing " + filename + ": " + err.Error()) 185 } 186 return nil 187 } 188 189 // A Decoder reads and decodes values from an input stream. 190 type Decoder struct { 191 r io.Reader 192 n int 193 } 194 195 // Read implements the io.Reader interface. It also keeps track of the total 196 // number of bytes decoded, and panics if that number exceeds a global 197 // maximum. 198 func (d *Decoder) Read(p []byte) (int, error) { 199 n, err := d.r.Read(p) 200 // enforce an absolute maximum size limit 201 if d.n += n; d.n > maxDecodeLen { 202 panic("encoded type exceeds size limit") 203 } 204 return n, err 205 } 206 207 // Decode reads the next encoded value from its input stream and stores it in 208 // v, which must be a pointer. The decoding rules are the inverse of those 209 // specified in the package docstring. 210 func (d *Decoder) Decode(v interface{}) (err error) { 211 // v must be a pointer 212 pval := reflect.ValueOf(v) 213 if pval.Kind() != reflect.Ptr || pval.IsNil() { 214 return errBadPointer 215 } 216 217 // catch decoding panics and convert them to errors 218 // note that this allows us to skip boundary checks during decoding 219 defer func() { 220 if r := recover(); r != nil { 221 err = fmt.Errorf("could not decode type %s: %v", pval.Elem().Type().String(), r) 222 } 223 }() 224 225 // reset the read count 226 d.n = 0 227 228 d.decode(pval.Elem()) 229 return 230 } 231 232 // DecodeAll decodes a variable number of arguments. 233 func (d *Decoder) DecodeAll(vs ...interface{}) error { 234 for _, v := range vs { 235 if err := d.Decode(v); err != nil { 236 return err 237 } 238 } 239 return nil 240 } 241 242 // readN reads n bytes and panics if the read fails. 243 func (d *Decoder) readN(n int) []byte { 244 b := make([]byte, n) 245 _, err := io.ReadFull(d, b) 246 if err != nil { 247 panic(err) 248 } 249 return b 250 } 251 252 // readPrefix reads a length-prefixed byte slice and panics if the read fails. 253 func (d *Decoder) readPrefix() []byte { 254 b, err := ReadPrefix(d, maxSliceLen) 255 if err != nil { 256 panic(err) 257 } 258 return b 259 } 260 261 // decode reads the next encoded value from its input stream and stores it in 262 // val. The decoding rules are the inverse of those specified in the package 263 // docstring. 264 func (d *Decoder) decode(val reflect.Value) { 265 // check for UnmarshalSia interface first 266 if val.CanAddr() && val.Addr().CanInterface() { 267 if u, ok := val.Addr().Interface().(SiaUnmarshaler); ok { 268 err := u.UnmarshalSia(d) 269 if err != nil { 270 panic(err) 271 } 272 return 273 } 274 } 275 276 switch val.Kind() { 277 case reflect.Ptr: 278 var valid bool 279 d.decode(reflect.ValueOf(&valid).Elem()) 280 // nil pointer, nothing to decode 281 if !valid { 282 return 283 } 284 // make sure we aren't decoding into nil 285 if val.IsNil() { 286 val.Set(reflect.New(val.Type().Elem())) 287 } 288 d.decode(val.Elem()) 289 case reflect.Bool: 290 b := d.readN(1) 291 if b[0] > 1 { 292 panic("boolean value was not 0 or 1") 293 } 294 val.SetBool(b[0] == 1) 295 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 296 val.SetInt(DecInt64(d.readN(8))) 297 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 298 val.SetUint(DecUint64(d.readN(8))) 299 case reflect.String: 300 val.SetString(string(d.readPrefix())) 301 case reflect.Slice: 302 // slices are variable length, but otherwise the same as arrays. 303 // just have to allocate them first, then we can fallthrough to the array logic. 304 sliceLen := DecUint64(d.readN(8)) 305 // sanity-check the sliceLen, otherwise you can crash a peer by making 306 // them allocate a massive slice 307 if sliceLen > 1<<31-1 || sliceLen*uint64(val.Type().Elem().Size()) > maxSliceLen { 308 panic("slice is too large") 309 } else if sliceLen == 0 { 310 return 311 } 312 val.Set(reflect.MakeSlice(val.Type(), int(sliceLen), int(sliceLen))) 313 fallthrough 314 case reflect.Array: 315 // special case for byte arrays (e.g. hashes) 316 if val.Type().Elem().Kind() == reflect.Uint8 { 317 // convert val to a slice and read into it directly 318 b := val.Slice(0, val.Len()) 319 _, err := io.ReadFull(d, b.Bytes()) 320 if err != nil { 321 panic(err) 322 } 323 return 324 } 325 // arrays are unmarshalled by sequentially unmarshalling their elements 326 for i := 0; i < val.Len(); i++ { 327 d.decode(val.Index(i)) 328 } 329 return 330 case reflect.Struct: 331 for i := 0; i < val.NumField(); i++ { 332 d.decode(val.Field(i)) 333 } 334 return 335 default: 336 panic("unknown type") 337 } 338 } 339 340 // NewDecoder returns a new decoder that reads from r. 341 func NewDecoder(r io.Reader) *Decoder { 342 return &Decoder{r, 0} 343 } 344 345 // Unmarshal decodes the encoded value b and stores it in v, which must be a 346 // pointer. The decoding rules are the inverse of those specified in the 347 // package docstring for marshaling. 348 func Unmarshal(b []byte, v interface{}) error { 349 r := bytes.NewReader(b) 350 return NewDecoder(r).Decode(v) 351 } 352 353 // UnmarshalAll decodes the encoded values in b and stores them in vs, which 354 // must be pointers. 355 func UnmarshalAll(b []byte, vs ...interface{}) error { 356 dec := NewDecoder(bytes.NewReader(b)) 357 // can't use DecodeAll (type information is lost) 358 for _, v := range vs { 359 if err := dec.Decode(v); err != nil { 360 return err 361 } 362 } 363 return nil 364 } 365 366 // ReadFile reads the contents of a file and decodes them into v. 367 func ReadFile(filename string, v interface{}) error { 368 file, err := os.Open(filename) 369 if err != nil { 370 return err 371 } 372 defer file.Close() 373 err = NewDecoder(file).Decode(v) 374 if err != nil { 375 return errors.New("error while reading " + filename + ": " + err.Error()) 376 } 377 return nil 378 } 379 380 // Marshal returns the encoding of v. For encoding details, see the package 381 // docstring. 382 func (m StdGenericMarshaler) Marshal(v interface{}) []byte { 383 return Marshal(v) 384 } 385 386 // Unmarshal decodes the encoded value b and stores it in v, which must be a 387 // pointer. The decoding rules are the inverse of those specified in the 388 // package docstring for marshaling. 389 func (m StdGenericMarshaler) Unmarshal(b []byte, v interface{}) error { 390 return Unmarshal(b, v) 391 }