go-ml.dev/pkg/base@v0.0.0-20200610162856-60c38abac71b/fu/tensor.go (about) 1 package fu 2 3 import ( 4 "reflect" 5 "strconv" 6 ) 7 8 type dimension struct{ Channels, Height, Width int } 9 10 func (d dimension) Volume() int { return d.Channels * d.Width * d.Height } 11 func (d dimension) Dimension() (c, h, w int) { return d.Channels, d.Height, d.Width } 12 13 type tensor32f struct { 14 dimension 15 values []float32 16 } 17 18 type tensor64f struct { 19 dimension 20 values []float64 21 } 22 23 type tensor8u struct { 24 dimension 25 values []byte 26 } 27 28 type tensori struct { 29 dimension 30 values []int 31 } 32 33 type tensor8f struct { 34 dimension 35 values []Fixed8 36 } 37 38 func (t tensor32f) ConvertElem(val string, index int) (err error) { 39 t.values[index], err = Fast32f(val) 40 return 41 } 42 43 func (t tensor64f) ConvertElem(val string, index int) (err error) { 44 t.values[index], err = strconv.ParseFloat(val, 64) 45 return 46 } 47 48 func (t tensori) ConvertElem(val string, index int) (err error) { 49 i, err := strconv.ParseInt(val, 10, 64) 50 if err != nil { 51 return 52 } 53 t.values[index] = int(i) 54 return 55 } 56 57 func (t tensor8f) ConvertElem(val string, index int) (err error) { 58 t.values[index], err = Fast8f(val) 59 return 60 } 61 62 func (t tensor8u) ConvertElem(val string, index int) (err error) { 63 i, err := strconv.ParseInt(val, 10, 8) 64 if err != nil { 65 return 66 } 67 t.values[index] = byte(i) 68 return 69 } 70 71 func (t tensori) Index(index int) interface{} { return t.values[index] } 72 func (t tensor8f) Index(index int) interface{} { return t.values[index] } 73 func (t tensor8u) Index(index int) interface{} { return t.values[index] } 74 func (t tensor32f) Index(index int) interface{} { return t.values[index] } 75 func (t tensor64f) Index(index int) interface{} { return t.values[index] } 76 77 func (t tensori) Values() interface{} { return t.values } 78 func (t tensor8f) Values() interface{} { return t.values } 79 func (t tensor8u) Values() interface{} { return t.values } 80 func (t tensor32f) Values() interface{} { return t.values } 81 func (t tensor64f) Values() interface{} { return t.values } 82 83 func (t tensori) Type() reflect.Type { return Int } 84 func (t tensor8f) Type() reflect.Type { return Fixed8Type } 85 func (t tensor8u) Type() reflect.Type { return Byte } 86 func (t tensor32f) Type() reflect.Type { return Float32 } 87 func (t tensor64f) Type() reflect.Type { return Float64 } 88 89 func (t tensori) Magic() byte { return 'i' } 90 func (t tensor8f) Magic() byte { return '8' } 91 func (t tensor8u) Magic() byte { return 'u' } 92 func (t tensor32f) Magic() byte { return 'f' } 93 func (t tensor64f) Magic() byte { return 'F' } 94 95 func (t tensori) HotOne() (j int) { 96 for i, v := range t.values { 97 if t.values[j] < v { 98 j = i 99 } 100 } 101 return 102 } 103 104 func (t tensor8f) HotOne() (j int) { 105 for i, v := range t.values { 106 if t.values[j].int8 < v.int8 { 107 j = i 108 } 109 } 110 return 111 } 112 113 func (t tensor8u) HotOne() (j int) { 114 for i, v := range t.values { 115 if t.values[j] < v { 116 j = i 117 } 118 } 119 return 120 } 121 122 func (t tensor32f) HotOne() (j int) { 123 for i, v := range t.values { 124 if t.values[j] < v { 125 j = i 126 } 127 } 128 return 129 } 130 131 func (t tensor64f) HotOne() (j int) { 132 for i, v := range t.values { 133 if t.values[j] < v { 134 j = i 135 } 136 } 137 return 138 } 139 140 func (t tensori) Extract(r []reflect.Value) { 141 for i, v := range t.values { 142 r[i] = reflect.ValueOf(v) 143 } 144 } 145 146 func (t tensori) Floats32(...bool) (r []float32) { 147 r = make([]float32, len(t.values)) 148 for i, v := range t.values { 149 r[i] = float32(v) 150 } 151 return 152 } 153 154 func (t tensor8f) Extract(r []reflect.Value) { 155 for i, v := range t.values { 156 r[i] = reflect.ValueOf(v) 157 } 158 } 159 160 func (t tensor8f) Floats32(...bool) (r []float32) { 161 r = make([]float32, len(t.values)) 162 for i, v := range t.values { 163 r[i] = v.Float32() 164 } 165 return 166 } 167 168 func (t tensor8u) Extract(r []reflect.Value) { 169 for i, v := range t.values { 170 r[i] = reflect.ValueOf(v) 171 } 172 } 173 174 func (t tensor8u) Floats32(...bool) (r []float32) { 175 r = make([]float32, len(t.values)) 176 for i, v := range t.values { 177 r[i] = float32(v) / 256 178 } 179 return 180 } 181 182 func (t tensor64f) Extract(r []reflect.Value) { 183 for i, v := range t.values { 184 r[i] = reflect.ValueOf(v) 185 } 186 } 187 188 func (t tensor64f) Floats32(...bool) (r []float32) { 189 r = make([]float32, len(t.values)) 190 for i, v := range t.values { 191 r[i] = float32(v) 192 } 193 return 194 } 195 196 func (t tensor32f) Extract(r []reflect.Value) { 197 for i, v := range t.values { 198 r[i] = reflect.ValueOf(v) 199 } 200 } 201 202 func (t tensor32f) Floats32(c ...bool) []float32 { 203 if Fnzb(c...) { 204 r := make([]float32, len(t.values)) 205 copy(r, t.values) 206 return r 207 } 208 return t.values 209 } 210 211 type tensor interface { 212 Dimension() (c, h, w int) 213 Volume() int 214 Type() reflect.Type 215 Magic() byte 216 Values() interface{} 217 Index(index int) interface{} 218 ConvertElem(val string, index int) error 219 HotOne() int 220 Floats32(copy ...bool) []float32 221 Extract([]reflect.Value) 222 } 223 224 type Tensor struct{ tensor } 225 226 // gets base64-encoded compressed stream as a string prefixed by \xE2\x9C\x97` (✗`) 227 func DecodeTensor(string) (t Tensor, err error) { 228 return 229 } 230 231 func (t Tensor) Width() int { 232 _, _, w := t.Dimension() 233 return w 234 } 235 236 func (t Tensor) Height() int { 237 _, h, _ := t.Dimension() 238 return h 239 } 240 241 func (t Tensor) Depth() int { 242 c, _, _ := t.Dimension() 243 return c 244 } 245 246 func (t Tensor) String() (str string) { 247 return t.Encode(false) 248 } 249 250 func (t Tensor) Encode(compress bool) (str string) { 251 //t.Magic() 252 //t.Dimension() 253 //t.Values() 254 //gzip => base64 255 return 256 } 257 258 func MakeFloat64Tensor(channels, height, width int, values []float64, docopy ...bool) Tensor { 259 v := values 260 if values != nil { 261 if len(docopy) > 0 && docopy[0] { 262 v := make([]float64, len(values)) 263 copy(v, values) 264 } 265 } else { 266 v = make([]float64, channels*height*width) 267 } 268 x := tensor64f{ 269 dimension: dimension{ 270 Channels: channels, 271 Height: height, 272 Width: width, 273 }, 274 values: v, 275 } 276 return Tensor{x} 277 } 278 279 func MakeFloat32Tensor(channels, height, width int, values []float32, docopy ...bool) Tensor { 280 v := values 281 if values != nil { 282 if len(docopy) > 0 && docopy[0] { 283 v := make([]float32, len(values)) 284 copy(v, values) 285 } 286 } else { 287 v = make([]float32, channels*height*width) 288 } 289 if width <= 0 { 290 width = len(values) / (channels * height) 291 } 292 x := tensor32f{ 293 dimension: dimension{ 294 Channels: channels, 295 Height: height, 296 Width: width, 297 }, 298 values: v, 299 } 300 return Tensor{x} 301 } 302 303 func MakeByteTensor(channels, height, width int, values []byte, docopy ...bool) Tensor { 304 v := values 305 if values != nil { 306 if len(docopy) > 0 && docopy[0] { 307 v := make([]byte, len(values)) 308 copy(v, values) 309 } 310 } else { 311 v = make([]byte, channels*height*width) 312 } 313 x := tensor8u{ 314 dimension: dimension{ 315 Channels: channels, 316 Height: height, 317 Width: width}, 318 values: v, 319 } 320 return Tensor{x} 321 } 322 323 func MakeFixed8Tensor(channels, height, width int, values []Fixed8, docopy ...bool) Tensor { 324 v := values 325 if values != nil { 326 if len(docopy) > 0 && docopy[0] { 327 v := make([]Fixed8, len(values)) 328 copy(v, values) 329 } 330 } else { 331 v = make([]Fixed8, channels*height*width) 332 } 333 x := tensor8f{ 334 dimension: dimension{ 335 Channels: channels, 336 Height: height, 337 Width: width}, 338 values: v, 339 } 340 return Tensor{x} 341 } 342 343 func MakeIntTensor(channels, height, width int, values []int, docopy ...bool) Tensor { 344 v := values 345 if values != nil { 346 if len(docopy) > 0 && docopy[0] { 347 v := make([]int, len(values)) 348 copy(v, values) 349 } 350 } else { 351 v = make([]int, channels*height*width) 352 } 353 x := tensori{ 354 dimension: dimension{ 355 Channels: channels, 356 Height: height, 357 Width: width}, 358 values: v, 359 } 360 return Tensor{x} 361 } 362 363 var TensorType = reflect.TypeOf(Tensor{})