go-ml.dev/pkg/base@v0.0.0-20200610162856-60c38abac71b/tables/xtensor.go (about)

     1  package tables
     2  
     3  import (
     4  	"go-ml.dev/pkg/base/fu"
     5  	"golang.org/x/xerrors"
     6  	"reflect"
     7  )
     8  
     9  type Xtensor struct{ T reflect.Type }
    10  
    11  func (t *Xtensor) Type() reflect.Type {
    12  	return fu.TensorType
    13  }
    14  
    15  func (t Xtensor) Convert(value string, field *reflect.Value, _, _ int) (_ bool, err error) {
    16  	z, err := fu.DecodeTensor(value)
    17  	if err != nil {
    18  		return
    19  	}
    20  	*field = reflect.ValueOf(z)
    21  	return
    22  }
    23  
    24  func tensorOf(field *reflect.Value, tp reflect.Type, width int) (fu.Tensor, error) {
    25  	if field.IsValid() {
    26  		return (field.Interface()).(fu.Tensor), nil
    27  	}
    28  	var z fu.Tensor
    29  	switch tp {
    30  	case fu.Float64:
    31  		z = fu.MakeFloat64Tensor(1, 1, width, nil)
    32  	case fu.Float32:
    33  		z = fu.MakeFloat32Tensor(1, 1, width, nil)
    34  	case fu.Fixed8Type:
    35  		z = fu.MakeFixed8Tensor(1, 1, width, nil)
    36  	case fu.Int:
    37  		z = fu.MakeIntTensor(1, 1, width, nil)
    38  	case fu.Byte:
    39  		z = fu.MakeByteTensor(1, 1, width, nil)
    40  	default:
    41  		return z, xerrors.Errorf("unknown tensor value type " + tp.String())
    42  	}
    43  	*field = reflect.ValueOf(z)
    44  	return z, nil
    45  }
    46  
    47  func (t Xtensor) ConvertElm(value string, field *reflect.Value, index, width int) (err error) {
    48  	z, err := tensorOf(field, t.T, width)
    49  	if err != nil {
    50  		return
    51  	}
    52  	return z.ConvertElem(value, index)
    53  }
    54  
    55  func (Xtensor) Format(x reflect.Value, na bool) string {
    56  	if na {
    57  		return ""
    58  	}
    59  	if x.Type() == fu.TensorType {
    60  		return x.String()
    61  	}
    62  	panic(xerrors.Errorf("`%v` is not an Xtensor value", x))
    63  }