github.com/wzzhu/tensor@v0.9.24/genlib2/dense_io.go (about) 1 package main 2 3 import ( 4 "fmt" 5 "io" 6 "text/template" 7 ) 8 9 const writeNpyRaw = ` 10 type binaryWriter struct { 11 io.Writer 12 err error 13 seq int 14 } 15 16 func (w *binaryWriter) w(x interface{}) { 17 if w.err != nil { 18 return 19 } 20 21 w.err = binary.Write(w, binary.LittleEndian, x) 22 w.seq++ 23 } 24 25 func (w *binaryWriter) Err() error { 26 if w.err == nil { 27 return nil 28 } 29 return errors.Wrapf(w.err, "Sequence %d", w.seq) 30 } 31 32 type binaryReader struct { 33 io.Reader 34 err error 35 seq int 36 } 37 38 func (r *binaryReader) Read(data interface{}) { 39 if r.err != nil { 40 return 41 } 42 r.err = binary.Read(r.Reader, binary.LittleEndian, data) 43 r.seq++ 44 } 45 46 func (r *binaryReader) Err() error { 47 if r.err == nil { 48 return nil 49 } 50 return errors.Wrapf(r.err, "Sequence %d", r.seq) 51 } 52 53 // WriteNpy writes the *Tensor as a numpy compatible serialized file. 54 // 55 // The format is very well documented here: 56 // http://docs.scipy.org/doc/numpy/neps/npy-format.html 57 // 58 // Gorgonia specifically uses Version 1.0, as 65535 bytes should be more than enough for the headers. 59 // The values are written in little endian order, because let's face it - 60 // 90% of the world's computers are running on x86+ processors. 61 // 62 // This method does not close the writer. Closing (if needed) is deferred to the caller 63 // If tensor is masked, invalid values are replaced by the default fill value. 64 func (t *Dense) WriteNpy(w io.Writer) (err error) { 65 var npdt string 66 if npdt, err = t.t.numpyDtype(); err != nil{ 67 return 68 } 69 70 var header string 71 if t.Dims() == 1 { 72 // when t is a 1D vector, numpy expects "(N,)" instead of "(N)" which t.Shape() returns. 73 header = "{'descr': '<%v', 'fortran_order': False, 'shape': (%d,)}" 74 header = fmt.Sprintf(header, npdt, t.Shape()[0]) 75 } else { 76 header = "{'descr': '<%v', 'fortran_order': False, 'shape': %v}" 77 header = fmt.Sprintf(header, npdt, t.Shape()) 78 } 79 padding := 16 - ((10 + len(header)) % 16) 80 if padding > 0 { 81 header = header + strings.Repeat(" ", padding) 82 } 83 bw := binaryWriter{Writer: w} 84 bw.Write([]byte("\x93NUMPY")) // stupid magic 85 bw.w(byte(1)) // major version 86 bw.w(byte(0)) // minor version 87 bw.w(uint16(len(header))) // 4 bytes to denote header length 88 if err = bw.Err() ; err != nil { 89 return err 90 } 91 bw.Write([]byte(header)) 92 93 bw.seq = 0 94 if t.IsMasked(){ 95 fillval:=t.FillValue() 96 it := FlatMaskedIteratorFromDense(t) 97 for i, err := it.Next(); err == nil; i, err = it.Next() { 98 if t.mask[i] { 99 bw.w(fillval) 100 } else{ 101 bw.w(t.Get(i)) 102 } 103 } 104 } else { 105 for i := 0; i < t.len(); i++ { 106 bw.w(t.Get(i)) 107 } 108 } 109 110 return bw.Err() 111 } 112 ` 113 114 const writeCSVRaw = `// WriteCSV writes the *Dense to a CSV. It accepts an optional string formatting ("%v", "%f", etc...), which controls what is written to the CSV. 115 // If tensor is masked, invalid values are replaced by the default fill value. 116 func (t *Dense) WriteCSV(w io.Writer, formats ...string) (err error) { 117 // checks: 118 if !t.IsMatrix() { 119 // error 120 err = errors.Errorf("Cannot write *Dense to CSV. Expected number of dimensions: <=2, T has got %d dimensions (Shape: %v)", t.Dims(), t.Shape()) 121 return 122 } 123 format := "%v" 124 if len(formats) > 0{ 125 format = formats[0] 126 } 127 128 cw := csv.NewWriter(w) 129 it := IteratorFromDense(t) 130 coord := it.Coord() 131 132 // rows := t.Shape()[0] 133 cols := t.Shape()[1] 134 record := make([]string, 0, cols) 135 var i, k, lastCol int 136 isMasked:=t.IsMasked() 137 fillval:= t.FillValue() 138 fillstr:= fmt.Sprintf(format, fillval) 139 for i, err = it.Next(); err == nil; i, err = it.Next() { 140 record = append(record, fmt.Sprintf(format, t.Get(i))) 141 if isMasked{ 142 if t.mask[i] { 143 record[k]=fillstr 144 } 145 k++ 146 } 147 if lastCol == cols-1 { 148 if err = cw.Write(record); err != nil { 149 // TODO: wrap errors 150 return 151 } 152 cw.Flush() 153 record = record[:0] 154 } 155 156 // cleanup 157 switch { 158 case t.IsRowVec(): 159 // lastRow = coord[len(coord)-2] 160 lastCol = coord[len(coord)-1] 161 case t.IsColVec(): 162 // lastRow = coord[len(coord)-1] 163 lastCol = coord[len(coord)-2] 164 case t.IsVector(): 165 lastCol = coord[len(coord)-1] 166 default: 167 // lastRow = coord[len(coord)-2] 168 lastCol = coord[len(coord)-1] 169 } 170 } 171 return nil 172 } 173 174 ` 175 176 const gobEncodeRaw = `// GobEncode implements gob.GobEncoder 177 func (t *Dense) GobEncode() (p []byte, err error){ 178 var buf bytes.Buffer 179 encoder := gob.NewEncoder(&buf) 180 181 if err = encoder.Encode(t.Shape()); err != nil { 182 return 183 } 184 185 if err = encoder.Encode(t.Strides()); err != nil { 186 return 187 } 188 189 if err = encoder.Encode(t.AP.o); err != nil { 190 return 191 } 192 193 if err = encoder.Encode(t.AP.Δ); err != nil { 194 return 195 } 196 197 if err = encoder.Encode(t.mask); err != nil { 198 return 199 } 200 201 data := t.Data() 202 if err = encoder.Encode(&data); err != nil { 203 return 204 } 205 206 return buf.Bytes(), err 207 } 208 ` 209 210 const gobDecodeRaw = `// GobDecode implements gob.GobDecoder 211 func (t *Dense) GobDecode(p []byte) (err error){ 212 buf := bytes.NewBuffer(p) 213 decoder := gob.NewDecoder(buf) 214 215 216 var shape Shape 217 if err = decoder.Decode(&shape); err != nil { 218 return 219 } 220 221 var strides []int 222 if err = decoder.Decode(&strides); err != nil { 223 return 224 } 225 226 var o DataOrder 227 var tr Triangle 228 if err = decoder.Decode(&o); err == nil { 229 if err = decoder.Decode(&tr); err != nil { 230 return 231 } 232 } 233 234 t.AP.Init(shape, strides) 235 t.AP.o = o 236 t.AP.Δ = tr 237 238 var mask []bool 239 if err = decoder.Decode(&mask); err != nil { 240 return 241 } 242 243 var data interface{} 244 if err = decoder.Decode(&data); err != nil { 245 return 246 } 247 248 t.fromSlice(data) 249 t.addMask(mask) 250 t.fix() 251 if t.e == nil { 252 t.e = StdEng{} 253 } 254 return t.sanity() 255 } 256 ` 257 const npyDescRE = `var npyDescRE = regexp.MustCompile(` + "`" + `'descr':` + `\` + `s*'([^']*)'` + "`" + ")" 258 const rowOrderRE = `var rowOrderRE = regexp.MustCompile(` + "`" + `'fortran_order':\s*(False|True)` + "`)" 259 const shapeRE = `var shapeRE = regexp.MustCompile(` + "`" + `'shape':\s*\(([^\(]*)\)` + "`)" 260 261 const readNpyRaw = `// ReadNpy reads NumPy formatted files into a *Dense 262 func (t *Dense) ReadNpy(r io.Reader) (err error){ 263 br := binaryReader{Reader: r} 264 var magic [6]byte 265 if br.Read(magic[:]); string(magic[:]) != "\x93NUMPY" { 266 return errors.Errorf("Not a numpy file. Got %q as the magic number instead", string(magic[:])) 267 } 268 269 var version, minor byte 270 if br.Read(&version); version != 1 { 271 return errors.New("Only verion 1.0 of numpy's serialization format is currently supported (65535 bytes ought to be enough for a header)") 272 } 273 274 if br.Read(&minor); minor != 0 { 275 return errors.New("Only verion 1.0 of numpy's serialization format is currently supported (65535 bytes ought to be enough for a header)") 276 } 277 278 var headerLen uint16 279 br.Read(&headerLen) 280 header := make([]byte, int(headerLen)) 281 br.Read(header) 282 if err = br.Err(); err != nil { 283 return 284 } 285 286 // extract stuff from header 287 var match [][]byte 288 if match = npyDescRE.FindSubmatch(header); match == nil { 289 return errors.New("No dtype information in npy file") 290 } 291 292 // TODO: check for endianness. For now we assume everything is little endian 293 if t.t, err = fromNumpyDtype(string(match[1][1:])); err != nil { 294 return 295 } 296 297 if match = rowOrderRE.FindSubmatch(header); match == nil { 298 return errors.New("No Row Order information found in the numpy file") 299 } 300 if string(match[1]) != "False" { 301 return errors.New("Cannot yet read from Fortran Ordered Numpy files") 302 } 303 304 if match = shapeRE.FindSubmatch(header); match == nil { 305 return errors.New("No shape information found in npy file") 306 } 307 sizesStr := strings.Split(string(match[1]), ",") 308 309 310 var shape Shape 311 for _, s := range sizesStr { 312 s = strings.Trim(s, " ") 313 if len(s) == 0 { 314 break 315 } 316 var size int 317 if size, err = strconv.Atoi(s); err != nil { 318 return 319 } 320 shape = append(shape, size) 321 } 322 323 size := shape.TotalSize() 324 if t.e == nil { 325 t.e = StdEng{} 326 } 327 t.makeArray(size) 328 329 switch t.t.Kind() { 330 {{range .Kinds -}} 331 case reflect.{{reflectKind .}}: 332 data := t.{{sliceOf .}} 333 for i := 0; i < size; i++ { 334 br.Read(&data[i]) 335 } 336 {{end -}} 337 } 338 if err = br.Err(); err != nil { 339 return err 340 } 341 342 t.AP.zeroWithDims(len(shape)) 343 t.setShape(shape...) 344 t.fix() 345 return t.sanity() 346 } 347 ` 348 349 const readCSVRaw = `// convFromStrs converts a []string to a slice of the Dtype provided. It takes a provided backing slice. 350 // If into is nil, then a backing slice will be created. 351 func convFromStrs(to Dtype, record []string, into interface{}) (interface{}, error) { 352 var err error 353 switch to.Kind() { 354 {{range .Kinds -}} 355 {{if isNumber . -}} 356 {{if isOrd . -}} 357 case reflect.{{reflectKind .}}: 358 retVal := make([]{{asType .}}, len(record)) 359 var backing []{{asType .}} 360 if into == nil { 361 backing = make([]{{asType .}}, 0, len(record)) 362 }else{ 363 backing = into.([]{{asType .}}) 364 } 365 366 for i, v := range record { 367 {{if eq .String "float64" -}} 368 if retVal[i], err = strconv.ParseFloat(v, 64); err != nil { 369 return nil, err 370 } 371 {{else if eq .String "float32" -}} 372 var f float64 373 if f, err = strconv.ParseFloat(v, 32); err != nil { 374 return nil, err 375 } 376 retVal[i] = float32(f) 377 {{else if hasPrefix .String "int" -}} 378 var i64 int64 379 if i64, err = strconv.ParseInt(v, 10, {{bitSizeOf .}}); err != nil { 380 return nil, err 381 } 382 retVal[i] = {{asType .}}(i64) 383 {{else if hasPrefix .String "uint" -}} 384 var u uint64 385 if u, err = strconv.ParseUint(v, 10, {{bitSizeOf .}}); err != nil { 386 return nil, err 387 } 388 retVal[i] = {{asType .}}(u) 389 {{end -}} 390 } 391 backing = append(backing, retVal...) 392 return backing, nil 393 {{end -}} 394 {{end -}} 395 {{end -}} 396 case reflect.String: 397 var backing []string 398 if into == nil { 399 backing = make([]string, 0, len(record)) 400 }else{ 401 backing = into.([]string) 402 } 403 backing = append(backing, record...) 404 return backing, nil 405 default: 406 return nil,errors.Errorf(methodNYI, "convFromStrs", to) 407 } 408 } 409 410 // ReadCSV reads a CSV into a *Dense. It will override the underlying data. 411 // 412 // BUG(chewxy): reading CSV doesn't handle CSVs with different columns per row yet. 413 func (t *Dense) ReadCSV(r io.Reader, opts ...FuncOpt) (err error) { 414 fo := ParseFuncOpts(opts...) 415 as := fo.As() 416 if as.Type == nil { 417 as = Float64 418 } 419 420 cr := csv.NewReader(r) 421 422 var record []string 423 var rows, cols int 424 var backing interface{} 425 for { 426 record, err = cr.Read() 427 if err == io.EOF{ 428 break 429 } else if err != nil { 430 return 431 } 432 if backing, err = convFromStrs(as, record, backing); err != nil { 433 return 434 } 435 cols = len(record) 436 rows++ 437 } 438 t.fromSlice(backing) 439 t.AP.zero() 440 t.AP.SetShape(rows, cols) 441 return nil 442 return errors.Errorf("not yet handled") 443 } 444 ` 445 446 var fbEncodeDecodeRaw = `// FBEncode encodes to a byte slice using flatbuffers. 447 // 448 // Only natively accessible data can be encided 449 func (t *Dense) FBEncode() ([]byte, error) { 450 builder := flatbuffers.NewBuilder(1024) 451 452 fb.DenseStartShapeVector(builder, len(t.shape)) 453 for i := len(t.shape) - 1; i >= 0; i-- { 454 builder.PrependInt32(int32(t.shape[i])) 455 } 456 shape := builder.EndVector(len(t.shape)) 457 458 fb.DenseStartStridesVector(builder, len(t.strides)) 459 for i := len(t.strides) - 1; i >= 0; i-- { 460 builder.PrependInt32(int32(t.strides[i])) 461 } 462 strides := builder.EndVector(len(t.strides)) 463 464 var o uint32 465 switch { 466 case t.o.IsRowMajor() && t.o.IsContiguous(): 467 o = 0 468 case t.o.IsRowMajor() && !t.o.IsContiguous(): 469 o = 1 470 case t.o.IsColMajor() && t.o.IsContiguous(): 471 o = 2 472 case t.o.IsColMajor() && !t.o.IsContiguous(): 473 o = 3 474 } 475 476 var triangle int32 477 switch t.Δ { 478 case NotTriangle: 479 triangle = fb.TriangleNOT_TRIANGLE 480 case Upper: 481 triangle = fb.TriangleUPPER 482 case Lower: 483 triangle = fb.TriangleLOWER 484 case Symmetric: 485 triangle = fb.TriangleSYMMETRIC 486 } 487 488 dt := builder.CreateString(t.Dtype().String()) 489 data := t.byteSlice() 490 491 fb.DenseStartDataVector(builder, len(data)) 492 for i := len(data) - 1; i >= 0; i-- { 493 builder.PrependUint8(data[i]) 494 } 495 databyte := builder.EndVector(len(data)) 496 497 fb.DenseStart(builder) 498 fb.DenseAddShape(builder, shape) 499 fb.DenseAddStrides(builder, strides) 500 fb.DenseAddO(builder, o) 501 fb.DenseAddT(builder, triangle) 502 fb.DenseAddType(builder, dt) 503 fb.DenseAddData(builder, databyte) 504 serialized := fb.DenseEnd(builder) 505 builder.Finish(serialized) 506 507 return builder.FinishedBytes(), nil 508 } 509 510 // FBDecode decodes a byteslice from a flatbuffer table into a *Dense 511 func (t *Dense) FBDecode(buf []byte) error { 512 serialized := fb.GetRootAsDense(buf, 0) 513 514 o := serialized.O() 515 switch o { 516 case 0: 517 t.o = 0 518 case 1: 519 t.o = MakeDataOrder(NonContiguous) 520 case 2: 521 t.o = MakeDataOrder(ColMajor) 522 case 3: 523 t.o = MakeDataOrder(ColMajor, NonContiguous) 524 } 525 526 tri := serialized.T() 527 switch tri { 528 case fb.TriangleNOT_TRIANGLE: 529 t.Δ = NotTriangle 530 case fb.TriangleUPPER: 531 t.Δ = Upper 532 case fb.TriangleLOWER: 533 t.Δ = Lower 534 case fb.TriangleSYMMETRIC: 535 t.Δ = Symmetric 536 } 537 538 t.shape = Shape(BorrowInts(serialized.ShapeLength())) 539 for i := 0; i < serialized.ShapeLength(); i++ { 540 t.shape[i] = int(int32(serialized.Shape(i))) 541 } 542 543 t.strides = BorrowInts(serialized.StridesLength()) 544 for i := 0; i < serialized.ShapeLength(); i++ { 545 t.strides[i] = int(serialized.Strides(i)) 546 } 547 typ := string(serialized.Type()) 548 for _, dt := range allTypes.set { 549 if dt.String() == typ { 550 t.t = dt 551 break 552 } 553 } 554 555 if t.e == nil { 556 t.e = StdEng{} 557 } 558 t.makeArray(t.shape.TotalSize()) 559 560 // allocated data. Now time to actually copy over the data 561 db := t.byteSlice() 562 copy(db, serialized.DataBytes()) 563 t.fix() 564 return t.sanity() 565 } 566 ` 567 568 var pbEncodeDecodeRaw = `// PBEncode encodes the Dense into a protobuf byte slice. 569 func (t *Dense) PBEncode() ([]byte, error) { 570 var toSerialize pb.Dense 571 toSerialize.Shape = make([]int32, len(t.shape)) 572 for i, v := range t.shape { 573 toSerialize.Shape[i] = int32(v) 574 } 575 toSerialize.Strides = make([]int32, len(t.strides)) 576 for i, v := range t.strides { 577 toSerialize.Strides[i] = int32(v) 578 } 579 580 switch { 581 case t.o.IsRowMajor() && t.o.IsContiguous(): 582 toSerialize.O = pb.RowMajorContiguous 583 case t.o.IsRowMajor() && !t.o.IsContiguous(): 584 toSerialize.O = pb.RowMajorNonContiguous 585 case t.o.IsColMajor() && t.o.IsContiguous(): 586 toSerialize.O = pb.ColMajorContiguous 587 case t.o.IsColMajor() && !t.o.IsContiguous(): 588 toSerialize.O = pb.ColMajorNonContiguous 589 } 590 toSerialize.T = pb.Triangle(t.Δ) 591 toSerialize.Type = t.t.String() 592 data := t.byteSlice() 593 toSerialize.Data = make([]byte, len(data)) 594 copy(toSerialize.Data, data) 595 return toSerialize.Marshal() 596 } 597 598 // PBDecode unmarshalls a protobuf byteslice into a *Dense. 599 func (t *Dense) PBDecode(buf []byte) error { 600 var toSerialize pb.Dense 601 if err := toSerialize.Unmarshal(buf); err != nil { 602 return err 603 } 604 t.shape = make(Shape, len(toSerialize.Shape)) 605 for i, v := range toSerialize.Shape { 606 t.shape[i] = int(v) 607 } 608 t.strides = make([]int, len(toSerialize.Strides)) 609 for i, v := range toSerialize.Strides { 610 t.strides[i] = int(v) 611 } 612 613 switch toSerialize.O { 614 case pb.RowMajorContiguous: 615 case pb.RowMajorNonContiguous: 616 t.o = MakeDataOrder(NonContiguous) 617 case pb.ColMajorContiguous: 618 t.o = MakeDataOrder(ColMajor) 619 case pb.ColMajorNonContiguous: 620 t.o = MakeDataOrder(ColMajor, NonContiguous) 621 } 622 t.Δ = Triangle(toSerialize.T) 623 typ := string(toSerialize.Type) 624 for _, dt := range allTypes.set { 625 if dt.String() == typ { 626 t.t = dt 627 break 628 } 629 } 630 631 if t.e == nil { 632 t.e = StdEng{} 633 } 634 t.makeArray(t.shape.TotalSize()) 635 636 // allocated data. Now time to actually copy over the data 637 db := t.byteSlice() 638 copy(db, toSerialize.Data) 639 return t.sanity() 640 } 641 ` 642 643 var ( 644 readNpy *template.Template 645 gobEncode *template.Template 646 gobDecode *template.Template 647 readCSV *template.Template 648 ) 649 650 func init() { 651 readNpy = template.Must(template.New("readNpy").Funcs(funcs).Parse(readNpyRaw)) 652 readCSV = template.Must(template.New("readCSV").Funcs(funcs).Parse(readCSVRaw)) 653 gobEncode = template.Must(template.New("gobEncode").Funcs(funcs).Parse(gobEncodeRaw)) 654 gobDecode = template.Must(template.New("gobDecode").Funcs(funcs).Parse(gobDecodeRaw)) 655 } 656 657 func generateDenseIO(f io.Writer, generic Kinds) { 658 mk := Kinds{Kinds: filter(generic.Kinds, isNumber)} 659 660 fmt.Fprint(f, "/* GOB SERIALIZATION */\n\n") 661 gobEncode.Execute(f, mk) 662 gobDecode.Execute(f, mk) 663 fmt.Fprint(f, "\n") 664 665 fmt.Fprint(f, "/* NPY SERIALIZATION */\n\n") 666 fmt.Fprintln(f, npyDescRE) 667 fmt.Fprintln(f, rowOrderRE) 668 fmt.Fprintln(f, shapeRE) 669 fmt.Fprintln(f, writeNpyRaw) 670 readNpy.Execute(f, mk) 671 fmt.Fprint(f, "\n") 672 673 fmt.Fprint(f, "/* CSV SERIALIZATION */\n\n") 674 fmt.Fprintln(f, writeCSVRaw) 675 readCSV.Execute(f, mk) 676 fmt.Fprint(f, "\n") 677 678 fmt.Fprint(f, "/* FB SERIALIZATION */\n\n") 679 fmt.Fprintln(f, fbEncodeDecodeRaw) 680 fmt.Fprint(f, "\n") 681 682 fmt.Fprint(f, "/* PB SERIALIZATION */\n\n") 683 fmt.Fprintln(f, pbEncodeDecodeRaw) 684 fmt.Fprint(f, "\n") 685 686 }