github.com/wzzhu/tensor@v0.9.24/dense_compat.go (about) 1 // Code generated by genlib2. DO NOT EDIT. 2 3 package tensor 4 5 import ( 6 "fmt" 7 "math" 8 "math/cmplx" 9 "reflect" 10 11 arrow "github.com/apache/arrow/go/arrow" 12 arrowArray "github.com/apache/arrow/go/arrow/array" 13 "github.com/apache/arrow/go/arrow/bitutil" 14 arrowTensor "github.com/apache/arrow/go/arrow/tensor" 15 "github.com/chewxy/math32" 16 "github.com/pkg/errors" 17 "gonum.org/v1/gonum/mat" 18 ) 19 20 func convFromFloat64s(to Dtype, data []float64) interface{} { 21 switch to { 22 case Int: 23 retVal := make([]int, len(data)) 24 for i, v := range data { 25 switch { 26 case math.IsNaN(v), math.IsInf(v, 0): 27 retVal[i] = 0 28 default: 29 retVal[i] = int(v) 30 } 31 } 32 return retVal 33 case Int8: 34 retVal := make([]int8, len(data)) 35 for i, v := range data { 36 switch { 37 case math.IsNaN(v), math.IsInf(v, 0): 38 retVal[i] = 0 39 default: 40 retVal[i] = int8(v) 41 } 42 } 43 return retVal 44 case Int16: 45 retVal := make([]int16, len(data)) 46 for i, v := range data { 47 switch { 48 case math.IsNaN(v), math.IsInf(v, 0): 49 retVal[i] = 0 50 default: 51 retVal[i] = int16(v) 52 } 53 } 54 return retVal 55 case Int32: 56 retVal := make([]int32, len(data)) 57 for i, v := range data { 58 switch { 59 case math.IsNaN(v), math.IsInf(v, 0): 60 retVal[i] = 0 61 default: 62 retVal[i] = int32(v) 63 } 64 } 65 return retVal 66 case Int64: 67 retVal := make([]int64, len(data)) 68 for i, v := range data { 69 switch { 70 case math.IsNaN(v), math.IsInf(v, 0): 71 retVal[i] = 0 72 default: 73 retVal[i] = int64(v) 74 } 75 } 76 return retVal 77 case Uint: 78 retVal := make([]uint, len(data)) 79 for i, v := range data { 80 switch { 81 case math.IsNaN(v), math.IsInf(v, 0): 82 retVal[i] = 0 83 default: 84 retVal[i] = uint(v) 85 } 86 } 87 return retVal 88 case Uint8: 89 retVal := make([]uint8, len(data)) 90 for i, v := range data { 91 switch { 92 case math.IsNaN(v), math.IsInf(v, 0): 93 retVal[i] = 0 94 default: 95 retVal[i] = uint8(v) 96 } 97 } 98 return retVal 99 case Uint16: 100 retVal := make([]uint16, len(data)) 101 for i, v := range data { 102 switch { 103 case math.IsNaN(v), math.IsInf(v, 0): 104 retVal[i] = 0 105 default: 106 retVal[i] = uint16(v) 107 } 108 } 109 return retVal 110 case Uint32: 111 retVal := make([]uint32, len(data)) 112 for i, v := range data { 113 switch { 114 case math.IsNaN(v), math.IsInf(v, 0): 115 retVal[i] = 0 116 default: 117 retVal[i] = uint32(v) 118 } 119 } 120 return retVal 121 case Uint64: 122 retVal := make([]uint64, len(data)) 123 for i, v := range data { 124 switch { 125 case math.IsNaN(v), math.IsInf(v, 0): 126 retVal[i] = 0 127 default: 128 retVal[i] = uint64(v) 129 } 130 } 131 return retVal 132 case Float32: 133 retVal := make([]float32, len(data)) 134 for i, v := range data { 135 switch { 136 case math.IsNaN(v): 137 retVal[i] = math32.NaN() 138 case math.IsInf(v, 1): 139 retVal[i] = math32.Inf(1) 140 case math.IsInf(v, -1): 141 retVal[i] = math32.Inf(-1) 142 default: 143 retVal[i] = float32(v) 144 } 145 } 146 return retVal 147 case Float64: 148 retVal := make([]float64, len(data)) 149 copy(retVal, data) 150 return retVal 151 case Complex64: 152 retVal := make([]complex64, len(data)) 153 for i, v := range data { 154 switch { 155 case math.IsNaN(v): 156 retVal[i] = complex64(cmplx.NaN()) 157 case math.IsInf(v, 0): 158 retVal[i] = complex64(cmplx.Inf()) 159 default: 160 retVal[i] = complex(float32(v), float32(0)) 161 } 162 } 163 return retVal 164 case Complex128: 165 retVal := make([]complex128, len(data)) 166 for i, v := range data { 167 switch { 168 case math.IsNaN(v): 169 retVal[i] = cmplx.NaN() 170 case math.IsInf(v, 0): 171 retVal[i] = cmplx.Inf() 172 default: 173 retVal[i] = complex(v, float64(0)) 174 } 175 } 176 return retVal 177 default: 178 panic("Unsupported Dtype") 179 } 180 } 181 182 func convToFloat64s(t *Dense) (retVal []float64) { 183 retVal = make([]float64, t.len()) 184 switch t.t { 185 case Int: 186 for i, v := range t.Ints() { 187 retVal[i] = float64(v) 188 } 189 return retVal 190 case Int8: 191 for i, v := range t.Int8s() { 192 retVal[i] = float64(v) 193 } 194 return retVal 195 case Int16: 196 for i, v := range t.Int16s() { 197 retVal[i] = float64(v) 198 } 199 return retVal 200 case Int32: 201 for i, v := range t.Int32s() { 202 retVal[i] = float64(v) 203 } 204 return retVal 205 case Int64: 206 for i, v := range t.Int64s() { 207 retVal[i] = float64(v) 208 } 209 return retVal 210 case Uint: 211 for i, v := range t.Uints() { 212 retVal[i] = float64(v) 213 } 214 return retVal 215 case Uint8: 216 for i, v := range t.Uint8s() { 217 retVal[i] = float64(v) 218 } 219 return retVal 220 case Uint16: 221 for i, v := range t.Uint16s() { 222 retVal[i] = float64(v) 223 } 224 return retVal 225 case Uint32: 226 for i, v := range t.Uint32s() { 227 retVal[i] = float64(v) 228 } 229 return retVal 230 case Uint64: 231 for i, v := range t.Uint64s() { 232 retVal[i] = float64(v) 233 } 234 return retVal 235 case Float32: 236 for i, v := range t.Float32s() { 237 switch { 238 case math32.IsNaN(v): 239 retVal[i] = math.NaN() 240 case math32.IsInf(v, 1): 241 retVal[i] = math.Inf(1) 242 case math32.IsInf(v, -1): 243 retVal[i] = math.Inf(-1) 244 default: 245 retVal[i] = float64(v) 246 } 247 } 248 return retVal 249 case Float64: 250 return t.Float64s() 251 return retVal 252 case Complex64: 253 for i, v := range t.Complex64s() { 254 switch { 255 case cmplx.IsNaN(complex128(v)): 256 retVal[i] = math.NaN() 257 case cmplx.IsInf(complex128(v)): 258 retVal[i] = math.Inf(1) 259 default: 260 retVal[i] = float64(real(v)) 261 } 262 } 263 return retVal 264 case Complex128: 265 for i, v := range t.Complex128s() { 266 switch { 267 case cmplx.IsNaN(v): 268 retVal[i] = math.NaN() 269 case cmplx.IsInf(v): 270 retVal[i] = math.Inf(1) 271 default: 272 retVal[i] = real(v) 273 } 274 } 275 return retVal 276 default: 277 panic(fmt.Sprintf("Cannot convert *Dense of %v to []float64", t.t)) 278 } 279 } 280 281 func convToFloat64(x interface{}) float64 { 282 switch xt := x.(type) { 283 case int: 284 return float64(xt) 285 case int8: 286 return float64(xt) 287 case int16: 288 return float64(xt) 289 case int32: 290 return float64(xt) 291 case int64: 292 return float64(xt) 293 case uint: 294 return float64(xt) 295 case uint8: 296 return float64(xt) 297 case uint16: 298 return float64(xt) 299 case uint32: 300 return float64(xt) 301 case uint64: 302 return float64(xt) 303 case float32: 304 return float64(xt) 305 case float64: 306 return float64(xt) 307 case complex64: 308 return float64(real(xt)) 309 case complex128: 310 return real(xt) 311 default: 312 panic("Cannot convert to float64") 313 } 314 } 315 316 // FromMat64 converts a *"gonum/matrix/mat64".Dense into a *tensorf64.Tensor. 317 func FromMat64(m *mat.Dense, opts ...FuncOpt) *Dense { 318 r, c := m.Dims() 319 fo := ParseFuncOpts(opts...) 320 defer returnOpOpt(fo) 321 toCopy := fo.Safe() 322 as := fo.As() 323 if as.Type == nil { 324 as = Float64 325 } 326 327 switch as.Kind() { 328 case reflect.Int: 329 backing := convFromFloat64s(Int, m.RawMatrix().Data).([]int) 330 retVal := New(WithBacking(backing), WithShape(r, c)) 331 return retVal 332 case reflect.Int8: 333 backing := convFromFloat64s(Int8, m.RawMatrix().Data).([]int8) 334 retVal := New(WithBacking(backing), WithShape(r, c)) 335 return retVal 336 case reflect.Int16: 337 backing := convFromFloat64s(Int16, m.RawMatrix().Data).([]int16) 338 retVal := New(WithBacking(backing), WithShape(r, c)) 339 return retVal 340 case reflect.Int32: 341 backing := convFromFloat64s(Int32, m.RawMatrix().Data).([]int32) 342 retVal := New(WithBacking(backing), WithShape(r, c)) 343 return retVal 344 case reflect.Int64: 345 backing := convFromFloat64s(Int64, m.RawMatrix().Data).([]int64) 346 retVal := New(WithBacking(backing), WithShape(r, c)) 347 return retVal 348 case reflect.Uint: 349 backing := convFromFloat64s(Uint, m.RawMatrix().Data).([]uint) 350 retVal := New(WithBacking(backing), WithShape(r, c)) 351 return retVal 352 case reflect.Uint8: 353 backing := convFromFloat64s(Uint8, m.RawMatrix().Data).([]uint8) 354 retVal := New(WithBacking(backing), WithShape(r, c)) 355 return retVal 356 case reflect.Uint16: 357 backing := convFromFloat64s(Uint16, m.RawMatrix().Data).([]uint16) 358 retVal := New(WithBacking(backing), WithShape(r, c)) 359 return retVal 360 case reflect.Uint32: 361 backing := convFromFloat64s(Uint32, m.RawMatrix().Data).([]uint32) 362 retVal := New(WithBacking(backing), WithShape(r, c)) 363 return retVal 364 case reflect.Uint64: 365 backing := convFromFloat64s(Uint64, m.RawMatrix().Data).([]uint64) 366 retVal := New(WithBacking(backing), WithShape(r, c)) 367 return retVal 368 case reflect.Float32: 369 backing := convFromFloat64s(Float32, m.RawMatrix().Data).([]float32) 370 retVal := New(WithBacking(backing), WithShape(r, c)) 371 return retVal 372 case reflect.Float64: 373 var backing []float64 374 if toCopy { 375 backing = make([]float64, len(m.RawMatrix().Data)) 376 copy(backing, m.RawMatrix().Data) 377 } else { 378 backing = m.RawMatrix().Data 379 } 380 retVal := New(WithBacking(backing), WithShape(r, c)) 381 return retVal 382 case reflect.Complex64: 383 backing := convFromFloat64s(Complex64, m.RawMatrix().Data).([]complex64) 384 retVal := New(WithBacking(backing), WithShape(r, c)) 385 return retVal 386 case reflect.Complex128: 387 backing := convFromFloat64s(Complex128, m.RawMatrix().Data).([]complex128) 388 retVal := New(WithBacking(backing), WithShape(r, c)) 389 return retVal 390 default: 391 panic(fmt.Sprintf("Unsupported Dtype - cannot convert float64 to %v", as)) 392 } 393 panic("Unreachable") 394 } 395 396 // ToMat64 converts a *Dense to a *mat.Dense. All the values are converted into float64s. 397 // This function will only convert matrices. Anything *Dense with dimensions larger than 2 will cause an error. 398 func ToMat64(t *Dense, opts ...FuncOpt) (retVal *mat.Dense, err error) { 399 // checks: 400 if !t.IsNativelyAccessible() { 401 return nil, errors.Errorf("Cannot convert *Dense to *mat.Dense. Data is inaccessible") 402 } 403 404 if !t.IsMatrix() { 405 // error 406 return nil, errors.Errorf("Cannot convert *Dense to *mat.Dense. Expected number of dimensions: <=2, T has got %d dimensions (Shape: %v)", t.Dims(), t.Shape()) 407 } 408 409 fo := ParseFuncOpts(opts...) 410 defer returnOpOpt(fo) 411 toCopy := fo.Safe() 412 413 // fix dims 414 r := t.Shape()[0] 415 c := t.Shape()[1] 416 417 var data []float64 418 switch { 419 case t.t == Float64 && toCopy && !t.IsMaterializable(): 420 data = make([]float64, t.len()) 421 copy(data, t.Float64s()) 422 case !t.IsMaterializable(): 423 data = convToFloat64s(t) 424 default: 425 it := newFlatIterator(&t.AP) 426 var next int 427 for next, err = it.Next(); err == nil; next, err = it.Next() { 428 if err = handleNoOp(err); err != nil { 429 return 430 } 431 data = append(data, convToFloat64(t.Get(next))) 432 } 433 err = nil 434 435 } 436 437 retVal = mat.NewDense(r, c, data) 438 return 439 } 440 441 // FromArrowArray converts an "arrow/array".Interface into a Tensor of matching DataType. 442 func FromArrowArray(a arrowArray.Interface) *Dense { 443 a.Retain() 444 defer a.Release() 445 446 r := a.Len() 447 448 // TODO(poopoothegorilla): instead of creating bool ValidMask maybe 449 // bitmapBytes can be used from arrow API 450 mask := make([]bool, r) 451 for i := 0; i < r; i++ { 452 mask[i] = a.IsNull(i) 453 } 454 455 switch a.DataType() { 456 case arrow.BinaryTypes.String: 457 backing := make([]string, r) 458 for i := 0; i < r; i++ { 459 backing[i] = a.(*arrowArray.String).Value(i) 460 } 461 retVal := New(WithBacking(backing, mask), WithShape(r, 1)) 462 return retVal 463 case arrow.FixedWidthTypes.Boolean: 464 backing := make([]bool, r) 465 for i := 0; i < r; i++ { 466 backing[i] = a.(*arrowArray.Boolean).Value(i) 467 } 468 retVal := New(WithBacking(backing, mask), WithShape(r, 1)) 469 return retVal 470 case arrow.PrimitiveTypes.Int8: 471 backing := a.(*arrowArray.Int8).Int8Values() 472 retVal := New(WithBacking(backing, mask), WithShape(r, 1)) 473 return retVal 474 case arrow.PrimitiveTypes.Int16: 475 backing := a.(*arrowArray.Int16).Int16Values() 476 retVal := New(WithBacking(backing, mask), WithShape(r, 1)) 477 return retVal 478 case arrow.PrimitiveTypes.Int32: 479 backing := a.(*arrowArray.Int32).Int32Values() 480 retVal := New(WithBacking(backing, mask), WithShape(r, 1)) 481 return retVal 482 case arrow.PrimitiveTypes.Int64: 483 backing := a.(*arrowArray.Int64).Int64Values() 484 retVal := New(WithBacking(backing, mask), WithShape(r, 1)) 485 return retVal 486 case arrow.PrimitiveTypes.Uint8: 487 backing := a.(*arrowArray.Uint8).Uint8Values() 488 retVal := New(WithBacking(backing, mask), WithShape(r, 1)) 489 return retVal 490 case arrow.PrimitiveTypes.Uint16: 491 backing := a.(*arrowArray.Uint16).Uint16Values() 492 retVal := New(WithBacking(backing, mask), WithShape(r, 1)) 493 return retVal 494 case arrow.PrimitiveTypes.Uint32: 495 backing := a.(*arrowArray.Uint32).Uint32Values() 496 retVal := New(WithBacking(backing, mask), WithShape(r, 1)) 497 return retVal 498 case arrow.PrimitiveTypes.Uint64: 499 backing := a.(*arrowArray.Uint64).Uint64Values() 500 retVal := New(WithBacking(backing, mask), WithShape(r, 1)) 501 return retVal 502 case arrow.PrimitiveTypes.Float32: 503 backing := a.(*arrowArray.Float32).Float32Values() 504 retVal := New(WithBacking(backing, mask), WithShape(r, 1)) 505 return retVal 506 case arrow.PrimitiveTypes.Float64: 507 backing := a.(*arrowArray.Float64).Float64Values() 508 retVal := New(WithBacking(backing, mask), WithShape(r, 1)) 509 return retVal 510 default: 511 panic(fmt.Sprintf("Unsupported Arrow DataType - %v", a.DataType())) 512 } 513 514 panic("Unreachable") 515 } 516 517 // FromArrowTensor converts an "arrow/tensor".Interface into a Tensor of matching DataType. 518 func FromArrowTensor(a arrowTensor.Interface) *Dense { 519 a.Retain() 520 defer a.Release() 521 522 if !a.IsContiguous() { 523 panic("Non-contiguous data is Unsupported") 524 } 525 526 var shape []int 527 for _, val := range a.Shape() { 528 shape = append(shape, int(val)) 529 } 530 531 l := a.Len() 532 validMask := a.Data().Buffers()[0].Bytes() 533 dataOffset := a.Data().Offset() 534 mask := make([]bool, l) 535 for i := 0; i < l; i++ { 536 mask[i] = len(validMask) != 0 && bitutil.BitIsNotSet(validMask, dataOffset+i) 537 } 538 539 switch a.DataType() { 540 case arrow.PrimitiveTypes.Int8: 541 backing := a.(*arrowTensor.Int8).Int8Values() 542 if a.IsColMajor() { 543 return New(WithShape(shape...), AsFortran(backing, mask)) 544 } 545 546 return New(WithShape(shape...), WithBacking(backing, mask)) 547 case arrow.PrimitiveTypes.Int16: 548 backing := a.(*arrowTensor.Int16).Int16Values() 549 if a.IsColMajor() { 550 return New(WithShape(shape...), AsFortran(backing, mask)) 551 } 552 553 return New(WithShape(shape...), WithBacking(backing, mask)) 554 case arrow.PrimitiveTypes.Int32: 555 backing := a.(*arrowTensor.Int32).Int32Values() 556 if a.IsColMajor() { 557 return New(WithShape(shape...), AsFortran(backing, mask)) 558 } 559 560 return New(WithShape(shape...), WithBacking(backing, mask)) 561 case arrow.PrimitiveTypes.Int64: 562 backing := a.(*arrowTensor.Int64).Int64Values() 563 if a.IsColMajor() { 564 return New(WithShape(shape...), AsFortran(backing, mask)) 565 } 566 567 return New(WithShape(shape...), WithBacking(backing, mask)) 568 case arrow.PrimitiveTypes.Uint8: 569 backing := a.(*arrowTensor.Uint8).Uint8Values() 570 if a.IsColMajor() { 571 return New(WithShape(shape...), AsFortran(backing, mask)) 572 } 573 574 return New(WithShape(shape...), WithBacking(backing, mask)) 575 case arrow.PrimitiveTypes.Uint16: 576 backing := a.(*arrowTensor.Uint16).Uint16Values() 577 if a.IsColMajor() { 578 return New(WithShape(shape...), AsFortran(backing, mask)) 579 } 580 581 return New(WithShape(shape...), WithBacking(backing, mask)) 582 case arrow.PrimitiveTypes.Uint32: 583 backing := a.(*arrowTensor.Uint32).Uint32Values() 584 if a.IsColMajor() { 585 return New(WithShape(shape...), AsFortran(backing, mask)) 586 } 587 588 return New(WithShape(shape...), WithBacking(backing, mask)) 589 case arrow.PrimitiveTypes.Uint64: 590 backing := a.(*arrowTensor.Uint64).Uint64Values() 591 if a.IsColMajor() { 592 return New(WithShape(shape...), AsFortran(backing, mask)) 593 } 594 595 return New(WithShape(shape...), WithBacking(backing, mask)) 596 case arrow.PrimitiveTypes.Float32: 597 backing := a.(*arrowTensor.Float32).Float32Values() 598 if a.IsColMajor() { 599 return New(WithShape(shape...), AsFortran(backing, mask)) 600 } 601 602 return New(WithShape(shape...), WithBacking(backing, mask)) 603 case arrow.PrimitiveTypes.Float64: 604 backing := a.(*arrowTensor.Float64).Float64Values() 605 if a.IsColMajor() { 606 return New(WithShape(shape...), AsFortran(backing, mask)) 607 } 608 609 return New(WithShape(shape...), WithBacking(backing, mask)) 610 default: 611 panic(fmt.Sprintf("Unsupported Arrow DataType - %v", a.DataType())) 612 } 613 614 panic("Unreachable") 615 }