go-ml.dev/pkg/base@v0.0.0-20200610162856-60c38abac71b/fu/tensor.go (about)

     1  package fu
     2  
     3  import (
     4  	"reflect"
     5  	"strconv"
     6  )
     7  
     8  type dimension struct{ Channels, Height, Width int }
     9  
    10  func (d dimension) Volume() int              { return d.Channels * d.Width * d.Height }
    11  func (d dimension) Dimension() (c, h, w int) { return d.Channels, d.Height, d.Width }
    12  
    13  type tensor32f struct {
    14  	dimension
    15  	values []float32
    16  }
    17  
    18  type tensor64f struct {
    19  	dimension
    20  	values []float64
    21  }
    22  
    23  type tensor8u struct {
    24  	dimension
    25  	values []byte
    26  }
    27  
    28  type tensori struct {
    29  	dimension
    30  	values []int
    31  }
    32  
    33  type tensor8f struct {
    34  	dimension
    35  	values []Fixed8
    36  }
    37  
    38  func (t tensor32f) ConvertElem(val string, index int) (err error) {
    39  	t.values[index], err = Fast32f(val)
    40  	return
    41  }
    42  
    43  func (t tensor64f) ConvertElem(val string, index int) (err error) {
    44  	t.values[index], err = strconv.ParseFloat(val, 64)
    45  	return
    46  }
    47  
    48  func (t tensori) ConvertElem(val string, index int) (err error) {
    49  	i, err := strconv.ParseInt(val, 10, 64)
    50  	if err != nil {
    51  		return
    52  	}
    53  	t.values[index] = int(i)
    54  	return
    55  }
    56  
    57  func (t tensor8f) ConvertElem(val string, index int) (err error) {
    58  	t.values[index], err = Fast8f(val)
    59  	return
    60  }
    61  
    62  func (t tensor8u) ConvertElem(val string, index int) (err error) {
    63  	i, err := strconv.ParseInt(val, 10, 8)
    64  	if err != nil {
    65  		return
    66  	}
    67  	t.values[index] = byte(i)
    68  	return
    69  }
    70  
    71  func (t tensori) Index(index int) interface{}   { return t.values[index] }
    72  func (t tensor8f) Index(index int) interface{}  { return t.values[index] }
    73  func (t tensor8u) Index(index int) interface{}  { return t.values[index] }
    74  func (t tensor32f) Index(index int) interface{} { return t.values[index] }
    75  func (t tensor64f) Index(index int) interface{} { return t.values[index] }
    76  
    77  func (t tensori) Values() interface{}   { return t.values }
    78  func (t tensor8f) Values() interface{}  { return t.values }
    79  func (t tensor8u) Values() interface{}  { return t.values }
    80  func (t tensor32f) Values() interface{} { return t.values }
    81  func (t tensor64f) Values() interface{} { return t.values }
    82  
    83  func (t tensori) Type() reflect.Type   { return Int }
    84  func (t tensor8f) Type() reflect.Type  { return Fixed8Type }
    85  func (t tensor8u) Type() reflect.Type  { return Byte }
    86  func (t tensor32f) Type() reflect.Type { return Float32 }
    87  func (t tensor64f) Type() reflect.Type { return Float64 }
    88  
    89  func (t tensori) Magic() byte   { return 'i' }
    90  func (t tensor8f) Magic() byte  { return '8' }
    91  func (t tensor8u) Magic() byte  { return 'u' }
    92  func (t tensor32f) Magic() byte { return 'f' }
    93  func (t tensor64f) Magic() byte { return 'F' }
    94  
    95  func (t tensori) HotOne() (j int) {
    96  	for i, v := range t.values {
    97  		if t.values[j] < v {
    98  			j = i
    99  		}
   100  	}
   101  	return
   102  }
   103  
   104  func (t tensor8f) HotOne() (j int) {
   105  	for i, v := range t.values {
   106  		if t.values[j].int8 < v.int8 {
   107  			j = i
   108  		}
   109  	}
   110  	return
   111  }
   112  
   113  func (t tensor8u) HotOne() (j int) {
   114  	for i, v := range t.values {
   115  		if t.values[j] < v {
   116  			j = i
   117  		}
   118  	}
   119  	return
   120  }
   121  
   122  func (t tensor32f) HotOne() (j int) {
   123  	for i, v := range t.values {
   124  		if t.values[j] < v {
   125  			j = i
   126  		}
   127  	}
   128  	return
   129  }
   130  
   131  func (t tensor64f) HotOne() (j int) {
   132  	for i, v := range t.values {
   133  		if t.values[j] < v {
   134  			j = i
   135  		}
   136  	}
   137  	return
   138  }
   139  
   140  func (t tensori) Extract(r []reflect.Value) {
   141  	for i, v := range t.values {
   142  		r[i] = reflect.ValueOf(v)
   143  	}
   144  }
   145  
   146  func (t tensori) Floats32(...bool) (r []float32) {
   147  	r = make([]float32, len(t.values))
   148  	for i, v := range t.values {
   149  		r[i] = float32(v)
   150  	}
   151  	return
   152  }
   153  
   154  func (t tensor8f) Extract(r []reflect.Value) {
   155  	for i, v := range t.values {
   156  		r[i] = reflect.ValueOf(v)
   157  	}
   158  }
   159  
   160  func (t tensor8f) Floats32(...bool) (r []float32) {
   161  	r = make([]float32, len(t.values))
   162  	for i, v := range t.values {
   163  		r[i] = v.Float32()
   164  	}
   165  	return
   166  }
   167  
   168  func (t tensor8u) Extract(r []reflect.Value) {
   169  	for i, v := range t.values {
   170  		r[i] = reflect.ValueOf(v)
   171  	}
   172  }
   173  
   174  func (t tensor8u) Floats32(...bool) (r []float32) {
   175  	r = make([]float32, len(t.values))
   176  	for i, v := range t.values {
   177  		r[i] = float32(v) / 256
   178  	}
   179  	return
   180  }
   181  
   182  func (t tensor64f) Extract(r []reflect.Value) {
   183  	for i, v := range t.values {
   184  		r[i] = reflect.ValueOf(v)
   185  	}
   186  }
   187  
   188  func (t tensor64f) Floats32(...bool) (r []float32) {
   189  	r = make([]float32, len(t.values))
   190  	for i, v := range t.values {
   191  		r[i] = float32(v)
   192  	}
   193  	return
   194  }
   195  
   196  func (t tensor32f) Extract(r []reflect.Value) {
   197  	for i, v := range t.values {
   198  		r[i] = reflect.ValueOf(v)
   199  	}
   200  }
   201  
   202  func (t tensor32f) Floats32(c ...bool) []float32 {
   203  	if Fnzb(c...) {
   204  		r := make([]float32, len(t.values))
   205  		copy(r, t.values)
   206  		return r
   207  	}
   208  	return t.values
   209  }
   210  
   211  type tensor interface {
   212  	Dimension() (c, h, w int)
   213  	Volume() int
   214  	Type() reflect.Type
   215  	Magic() byte
   216  	Values() interface{}
   217  	Index(index int) interface{}
   218  	ConvertElem(val string, index int) error
   219  	HotOne() int
   220  	Floats32(copy ...bool) []float32
   221  	Extract([]reflect.Value)
   222  }
   223  
   224  type Tensor struct{ tensor }
   225  
   226  //	gets base64-encoded compressed stream as a string prefixed by \xE2\x9C\x97` (✗`)
   227  func DecodeTensor(string) (t Tensor, err error) {
   228  	return
   229  }
   230  
   231  func (t Tensor) Width() int {
   232  	_, _, w := t.Dimension()
   233  	return w
   234  }
   235  
   236  func (t Tensor) Height() int {
   237  	_, h, _ := t.Dimension()
   238  	return h
   239  }
   240  
   241  func (t Tensor) Depth() int {
   242  	c, _, _ := t.Dimension()
   243  	return c
   244  }
   245  
   246  func (t Tensor) String() (str string) {
   247  	return t.Encode(false)
   248  }
   249  
   250  func (t Tensor) Encode(compress bool) (str string) {
   251  	//t.Magic()
   252  	//t.Dimension()
   253  	//t.Values()
   254  	//gzip => base64
   255  	return
   256  }
   257  
   258  func MakeFloat64Tensor(channels, height, width int, values []float64, docopy ...bool) Tensor {
   259  	v := values
   260  	if values != nil {
   261  		if len(docopy) > 0 && docopy[0] {
   262  			v := make([]float64, len(values))
   263  			copy(v, values)
   264  		}
   265  	} else {
   266  		v = make([]float64, channels*height*width)
   267  	}
   268  	x := tensor64f{
   269  		dimension: dimension{
   270  			Channels: channels,
   271  			Height:   height,
   272  			Width:    width,
   273  		},
   274  		values: v,
   275  	}
   276  	return Tensor{x}
   277  }
   278  
   279  func MakeFloat32Tensor(channels, height, width int, values []float32, docopy ...bool) Tensor {
   280  	v := values
   281  	if values != nil {
   282  		if len(docopy) > 0 && docopy[0] {
   283  			v := make([]float32, len(values))
   284  			copy(v, values)
   285  		}
   286  	} else {
   287  		v = make([]float32, channels*height*width)
   288  	}
   289  	if width <= 0 {
   290  		width = len(values) / (channels * height)
   291  	}
   292  	x := tensor32f{
   293  		dimension: dimension{
   294  			Channels: channels,
   295  			Height:   height,
   296  			Width:    width,
   297  		},
   298  		values: v,
   299  	}
   300  	return Tensor{x}
   301  }
   302  
   303  func MakeByteTensor(channels, height, width int, values []byte, docopy ...bool) Tensor {
   304  	v := values
   305  	if values != nil {
   306  		if len(docopy) > 0 && docopy[0] {
   307  			v := make([]byte, len(values))
   308  			copy(v, values)
   309  		}
   310  	} else {
   311  		v = make([]byte, channels*height*width)
   312  	}
   313  	x := tensor8u{
   314  		dimension: dimension{
   315  			Channels: channels,
   316  			Height:   height,
   317  			Width:    width},
   318  		values: v,
   319  	}
   320  	return Tensor{x}
   321  }
   322  
   323  func MakeFixed8Tensor(channels, height, width int, values []Fixed8, docopy ...bool) Tensor {
   324  	v := values
   325  	if values != nil {
   326  		if len(docopy) > 0 && docopy[0] {
   327  			v := make([]Fixed8, len(values))
   328  			copy(v, values)
   329  		}
   330  	} else {
   331  		v = make([]Fixed8, channels*height*width)
   332  	}
   333  	x := tensor8f{
   334  		dimension: dimension{
   335  			Channels: channels,
   336  			Height:   height,
   337  			Width:    width},
   338  		values: v,
   339  	}
   340  	return Tensor{x}
   341  }
   342  
   343  func MakeIntTensor(channels, height, width int, values []int, docopy ...bool) Tensor {
   344  	v := values
   345  	if values != nil {
   346  		if len(docopy) > 0 && docopy[0] {
   347  			v := make([]int, len(values))
   348  			copy(v, values)
   349  		}
   350  	} else {
   351  		v = make([]int, channels*height*width)
   352  	}
   353  	x := tensori{
   354  		dimension: dimension{
   355  			Channels: channels,
   356  			Height:   height,
   357  			Width:    width},
   358  		values: v,
   359  	}
   360  	return Tensor{x}
   361  }
   362  
   363  var TensorType = reflect.TypeOf(Tensor{})