go-ml.dev/pkg/base@v0.0.0-20200610162856-60c38abac71b/tables/xtensor.go (about) 1 package tables 2 3 import ( 4 "go-ml.dev/pkg/base/fu" 5 "golang.org/x/xerrors" 6 "reflect" 7 ) 8 9 type Xtensor struct{ T reflect.Type } 10 11 func (t *Xtensor) Type() reflect.Type { 12 return fu.TensorType 13 } 14 15 func (t Xtensor) Convert(value string, field *reflect.Value, _, _ int) (_ bool, err error) { 16 z, err := fu.DecodeTensor(value) 17 if err != nil { 18 return 19 } 20 *field = reflect.ValueOf(z) 21 return 22 } 23 24 func tensorOf(field *reflect.Value, tp reflect.Type, width int) (fu.Tensor, error) { 25 if field.IsValid() { 26 return (field.Interface()).(fu.Tensor), nil 27 } 28 var z fu.Tensor 29 switch tp { 30 case fu.Float64: 31 z = fu.MakeFloat64Tensor(1, 1, width, nil) 32 case fu.Float32: 33 z = fu.MakeFloat32Tensor(1, 1, width, nil) 34 case fu.Fixed8Type: 35 z = fu.MakeFixed8Tensor(1, 1, width, nil) 36 case fu.Int: 37 z = fu.MakeIntTensor(1, 1, width, nil) 38 case fu.Byte: 39 z = fu.MakeByteTensor(1, 1, width, nil) 40 default: 41 return z, xerrors.Errorf("unknown tensor value type " + tp.String()) 42 } 43 *field = reflect.ValueOf(z) 44 return z, nil 45 } 46 47 func (t Xtensor) ConvertElm(value string, field *reflect.Value, index, width int) (err error) { 48 z, err := tensorOf(field, t.T, width) 49 if err != nil { 50 return 51 } 52 return z.ConvertElem(value, index) 53 } 54 55 func (Xtensor) Format(x reflect.Value, na bool) string { 56 if na { 57 return "" 58 } 59 if x.Type() == fu.TensorType { 60 return x.String() 61 } 62 panic(xerrors.Errorf("`%v` is not an Xtensor value", x)) 63 }