github.com/wzzhu/tensor@v0.9.24/dense_compat.go (about)

     1  // Code generated by genlib2. DO NOT EDIT.
     2  
     3  package tensor
     4  
     5  import (
     6  	"fmt"
     7  	"math"
     8  	"math/cmplx"
     9  	"reflect"
    10  
    11  	arrow "github.com/apache/arrow/go/arrow"
    12  	arrowArray "github.com/apache/arrow/go/arrow/array"
    13  	"github.com/apache/arrow/go/arrow/bitutil"
    14  	arrowTensor "github.com/apache/arrow/go/arrow/tensor"
    15  	"github.com/chewxy/math32"
    16  	"github.com/pkg/errors"
    17  	"gonum.org/v1/gonum/mat"
    18  )
    19  
    20  func convFromFloat64s(to Dtype, data []float64) interface{} {
    21  	switch to {
    22  	case Int:
    23  		retVal := make([]int, len(data))
    24  		for i, v := range data {
    25  			switch {
    26  			case math.IsNaN(v), math.IsInf(v, 0):
    27  				retVal[i] = 0
    28  			default:
    29  				retVal[i] = int(v)
    30  			}
    31  		}
    32  		return retVal
    33  	case Int8:
    34  		retVal := make([]int8, len(data))
    35  		for i, v := range data {
    36  			switch {
    37  			case math.IsNaN(v), math.IsInf(v, 0):
    38  				retVal[i] = 0
    39  			default:
    40  				retVal[i] = int8(v)
    41  			}
    42  		}
    43  		return retVal
    44  	case Int16:
    45  		retVal := make([]int16, len(data))
    46  		for i, v := range data {
    47  			switch {
    48  			case math.IsNaN(v), math.IsInf(v, 0):
    49  				retVal[i] = 0
    50  			default:
    51  				retVal[i] = int16(v)
    52  			}
    53  		}
    54  		return retVal
    55  	case Int32:
    56  		retVal := make([]int32, len(data))
    57  		for i, v := range data {
    58  			switch {
    59  			case math.IsNaN(v), math.IsInf(v, 0):
    60  				retVal[i] = 0
    61  			default:
    62  				retVal[i] = int32(v)
    63  			}
    64  		}
    65  		return retVal
    66  	case Int64:
    67  		retVal := make([]int64, len(data))
    68  		for i, v := range data {
    69  			switch {
    70  			case math.IsNaN(v), math.IsInf(v, 0):
    71  				retVal[i] = 0
    72  			default:
    73  				retVal[i] = int64(v)
    74  			}
    75  		}
    76  		return retVal
    77  	case Uint:
    78  		retVal := make([]uint, len(data))
    79  		for i, v := range data {
    80  			switch {
    81  			case math.IsNaN(v), math.IsInf(v, 0):
    82  				retVal[i] = 0
    83  			default:
    84  				retVal[i] = uint(v)
    85  			}
    86  		}
    87  		return retVal
    88  	case Uint8:
    89  		retVal := make([]uint8, len(data))
    90  		for i, v := range data {
    91  			switch {
    92  			case math.IsNaN(v), math.IsInf(v, 0):
    93  				retVal[i] = 0
    94  			default:
    95  				retVal[i] = uint8(v)
    96  			}
    97  		}
    98  		return retVal
    99  	case Uint16:
   100  		retVal := make([]uint16, len(data))
   101  		for i, v := range data {
   102  			switch {
   103  			case math.IsNaN(v), math.IsInf(v, 0):
   104  				retVal[i] = 0
   105  			default:
   106  				retVal[i] = uint16(v)
   107  			}
   108  		}
   109  		return retVal
   110  	case Uint32:
   111  		retVal := make([]uint32, len(data))
   112  		for i, v := range data {
   113  			switch {
   114  			case math.IsNaN(v), math.IsInf(v, 0):
   115  				retVal[i] = 0
   116  			default:
   117  				retVal[i] = uint32(v)
   118  			}
   119  		}
   120  		return retVal
   121  	case Uint64:
   122  		retVal := make([]uint64, len(data))
   123  		for i, v := range data {
   124  			switch {
   125  			case math.IsNaN(v), math.IsInf(v, 0):
   126  				retVal[i] = 0
   127  			default:
   128  				retVal[i] = uint64(v)
   129  			}
   130  		}
   131  		return retVal
   132  	case Float32:
   133  		retVal := make([]float32, len(data))
   134  		for i, v := range data {
   135  			switch {
   136  			case math.IsNaN(v):
   137  				retVal[i] = math32.NaN()
   138  			case math.IsInf(v, 1):
   139  				retVal[i] = math32.Inf(1)
   140  			case math.IsInf(v, -1):
   141  				retVal[i] = math32.Inf(-1)
   142  			default:
   143  				retVal[i] = float32(v)
   144  			}
   145  		}
   146  		return retVal
   147  	case Float64:
   148  		retVal := make([]float64, len(data))
   149  		copy(retVal, data)
   150  		return retVal
   151  	case Complex64:
   152  		retVal := make([]complex64, len(data))
   153  		for i, v := range data {
   154  			switch {
   155  			case math.IsNaN(v):
   156  				retVal[i] = complex64(cmplx.NaN())
   157  			case math.IsInf(v, 0):
   158  				retVal[i] = complex64(cmplx.Inf())
   159  			default:
   160  				retVal[i] = complex(float32(v), float32(0))
   161  			}
   162  		}
   163  		return retVal
   164  	case Complex128:
   165  		retVal := make([]complex128, len(data))
   166  		for i, v := range data {
   167  			switch {
   168  			case math.IsNaN(v):
   169  				retVal[i] = cmplx.NaN()
   170  			case math.IsInf(v, 0):
   171  				retVal[i] = cmplx.Inf()
   172  			default:
   173  				retVal[i] = complex(v, float64(0))
   174  			}
   175  		}
   176  		return retVal
   177  	default:
   178  		panic("Unsupported Dtype")
   179  	}
   180  }
   181  
   182  func convToFloat64s(t *Dense) (retVal []float64) {
   183  	retVal = make([]float64, t.len())
   184  	switch t.t {
   185  	case Int:
   186  		for i, v := range t.Ints() {
   187  			retVal[i] = float64(v)
   188  		}
   189  		return retVal
   190  	case Int8:
   191  		for i, v := range t.Int8s() {
   192  			retVal[i] = float64(v)
   193  		}
   194  		return retVal
   195  	case Int16:
   196  		for i, v := range t.Int16s() {
   197  			retVal[i] = float64(v)
   198  		}
   199  		return retVal
   200  	case Int32:
   201  		for i, v := range t.Int32s() {
   202  			retVal[i] = float64(v)
   203  		}
   204  		return retVal
   205  	case Int64:
   206  		for i, v := range t.Int64s() {
   207  			retVal[i] = float64(v)
   208  		}
   209  		return retVal
   210  	case Uint:
   211  		for i, v := range t.Uints() {
   212  			retVal[i] = float64(v)
   213  		}
   214  		return retVal
   215  	case Uint8:
   216  		for i, v := range t.Uint8s() {
   217  			retVal[i] = float64(v)
   218  		}
   219  		return retVal
   220  	case Uint16:
   221  		for i, v := range t.Uint16s() {
   222  			retVal[i] = float64(v)
   223  		}
   224  		return retVal
   225  	case Uint32:
   226  		for i, v := range t.Uint32s() {
   227  			retVal[i] = float64(v)
   228  		}
   229  		return retVal
   230  	case Uint64:
   231  		for i, v := range t.Uint64s() {
   232  			retVal[i] = float64(v)
   233  		}
   234  		return retVal
   235  	case Float32:
   236  		for i, v := range t.Float32s() {
   237  			switch {
   238  			case math32.IsNaN(v):
   239  				retVal[i] = math.NaN()
   240  			case math32.IsInf(v, 1):
   241  				retVal[i] = math.Inf(1)
   242  			case math32.IsInf(v, -1):
   243  				retVal[i] = math.Inf(-1)
   244  			default:
   245  				retVal[i] = float64(v)
   246  			}
   247  		}
   248  		return retVal
   249  	case Float64:
   250  		return t.Float64s()
   251  		return retVal
   252  	case Complex64:
   253  		for i, v := range t.Complex64s() {
   254  			switch {
   255  			case cmplx.IsNaN(complex128(v)):
   256  				retVal[i] = math.NaN()
   257  			case cmplx.IsInf(complex128(v)):
   258  				retVal[i] = math.Inf(1)
   259  			default:
   260  				retVal[i] = float64(real(v))
   261  			}
   262  		}
   263  		return retVal
   264  	case Complex128:
   265  		for i, v := range t.Complex128s() {
   266  			switch {
   267  			case cmplx.IsNaN(v):
   268  				retVal[i] = math.NaN()
   269  			case cmplx.IsInf(v):
   270  				retVal[i] = math.Inf(1)
   271  			default:
   272  				retVal[i] = real(v)
   273  			}
   274  		}
   275  		return retVal
   276  	default:
   277  		panic(fmt.Sprintf("Cannot convert *Dense of %v to []float64", t.t))
   278  	}
   279  }
   280  
   281  func convToFloat64(x interface{}) float64 {
   282  	switch xt := x.(type) {
   283  	case int:
   284  		return float64(xt)
   285  	case int8:
   286  		return float64(xt)
   287  	case int16:
   288  		return float64(xt)
   289  	case int32:
   290  		return float64(xt)
   291  	case int64:
   292  		return float64(xt)
   293  	case uint:
   294  		return float64(xt)
   295  	case uint8:
   296  		return float64(xt)
   297  	case uint16:
   298  		return float64(xt)
   299  	case uint32:
   300  		return float64(xt)
   301  	case uint64:
   302  		return float64(xt)
   303  	case float32:
   304  		return float64(xt)
   305  	case float64:
   306  		return float64(xt)
   307  	case complex64:
   308  		return float64(real(xt))
   309  	case complex128:
   310  		return real(xt)
   311  	default:
   312  		panic("Cannot convert to float64")
   313  	}
   314  }
   315  
   316  // FromMat64 converts a *"gonum/matrix/mat64".Dense into a *tensorf64.Tensor.
   317  func FromMat64(m *mat.Dense, opts ...FuncOpt) *Dense {
   318  	r, c := m.Dims()
   319  	fo := ParseFuncOpts(opts...)
   320  	defer returnOpOpt(fo)
   321  	toCopy := fo.Safe()
   322  	as := fo.As()
   323  	if as.Type == nil {
   324  		as = Float64
   325  	}
   326  
   327  	switch as.Kind() {
   328  	case reflect.Int:
   329  		backing := convFromFloat64s(Int, m.RawMatrix().Data).([]int)
   330  		retVal := New(WithBacking(backing), WithShape(r, c))
   331  		return retVal
   332  	case reflect.Int8:
   333  		backing := convFromFloat64s(Int8, m.RawMatrix().Data).([]int8)
   334  		retVal := New(WithBacking(backing), WithShape(r, c))
   335  		return retVal
   336  	case reflect.Int16:
   337  		backing := convFromFloat64s(Int16, m.RawMatrix().Data).([]int16)
   338  		retVal := New(WithBacking(backing), WithShape(r, c))
   339  		return retVal
   340  	case reflect.Int32:
   341  		backing := convFromFloat64s(Int32, m.RawMatrix().Data).([]int32)
   342  		retVal := New(WithBacking(backing), WithShape(r, c))
   343  		return retVal
   344  	case reflect.Int64:
   345  		backing := convFromFloat64s(Int64, m.RawMatrix().Data).([]int64)
   346  		retVal := New(WithBacking(backing), WithShape(r, c))
   347  		return retVal
   348  	case reflect.Uint:
   349  		backing := convFromFloat64s(Uint, m.RawMatrix().Data).([]uint)
   350  		retVal := New(WithBacking(backing), WithShape(r, c))
   351  		return retVal
   352  	case reflect.Uint8:
   353  		backing := convFromFloat64s(Uint8, m.RawMatrix().Data).([]uint8)
   354  		retVal := New(WithBacking(backing), WithShape(r, c))
   355  		return retVal
   356  	case reflect.Uint16:
   357  		backing := convFromFloat64s(Uint16, m.RawMatrix().Data).([]uint16)
   358  		retVal := New(WithBacking(backing), WithShape(r, c))
   359  		return retVal
   360  	case reflect.Uint32:
   361  		backing := convFromFloat64s(Uint32, m.RawMatrix().Data).([]uint32)
   362  		retVal := New(WithBacking(backing), WithShape(r, c))
   363  		return retVal
   364  	case reflect.Uint64:
   365  		backing := convFromFloat64s(Uint64, m.RawMatrix().Data).([]uint64)
   366  		retVal := New(WithBacking(backing), WithShape(r, c))
   367  		return retVal
   368  	case reflect.Float32:
   369  		backing := convFromFloat64s(Float32, m.RawMatrix().Data).([]float32)
   370  		retVal := New(WithBacking(backing), WithShape(r, c))
   371  		return retVal
   372  	case reflect.Float64:
   373  		var backing []float64
   374  		if toCopy {
   375  			backing = make([]float64, len(m.RawMatrix().Data))
   376  			copy(backing, m.RawMatrix().Data)
   377  		} else {
   378  			backing = m.RawMatrix().Data
   379  		}
   380  		retVal := New(WithBacking(backing), WithShape(r, c))
   381  		return retVal
   382  	case reflect.Complex64:
   383  		backing := convFromFloat64s(Complex64, m.RawMatrix().Data).([]complex64)
   384  		retVal := New(WithBacking(backing), WithShape(r, c))
   385  		return retVal
   386  	case reflect.Complex128:
   387  		backing := convFromFloat64s(Complex128, m.RawMatrix().Data).([]complex128)
   388  		retVal := New(WithBacking(backing), WithShape(r, c))
   389  		return retVal
   390  	default:
   391  		panic(fmt.Sprintf("Unsupported Dtype - cannot convert float64 to %v", as))
   392  	}
   393  	panic("Unreachable")
   394  }
   395  
   396  // ToMat64 converts a *Dense to a *mat.Dense. All the values are converted into float64s.
   397  // This function will only convert matrices. Anything *Dense with dimensions larger than 2 will cause an error.
   398  func ToMat64(t *Dense, opts ...FuncOpt) (retVal *mat.Dense, err error) {
   399  	// checks:
   400  	if !t.IsNativelyAccessible() {
   401  		return nil, errors.Errorf("Cannot convert *Dense to *mat.Dense. Data is inaccessible")
   402  	}
   403  
   404  	if !t.IsMatrix() {
   405  		// error
   406  		return nil, errors.Errorf("Cannot convert *Dense to *mat.Dense. Expected number of dimensions: <=2, T has got %d dimensions (Shape: %v)", t.Dims(), t.Shape())
   407  	}
   408  
   409  	fo := ParseFuncOpts(opts...)
   410  	defer returnOpOpt(fo)
   411  	toCopy := fo.Safe()
   412  
   413  	// fix dims
   414  	r := t.Shape()[0]
   415  	c := t.Shape()[1]
   416  
   417  	var data []float64
   418  	switch {
   419  	case t.t == Float64 && toCopy && !t.IsMaterializable():
   420  		data = make([]float64, t.len())
   421  		copy(data, t.Float64s())
   422  	case !t.IsMaterializable():
   423  		data = convToFloat64s(t)
   424  	default:
   425  		it := newFlatIterator(&t.AP)
   426  		var next int
   427  		for next, err = it.Next(); err == nil; next, err = it.Next() {
   428  			if err = handleNoOp(err); err != nil {
   429  				return
   430  			}
   431  			data = append(data, convToFloat64(t.Get(next)))
   432  		}
   433  		err = nil
   434  
   435  	}
   436  
   437  	retVal = mat.NewDense(r, c, data)
   438  	return
   439  }
   440  
   441  // FromArrowArray converts an "arrow/array".Interface into a Tensor of matching DataType.
   442  func FromArrowArray(a arrowArray.Interface) *Dense {
   443  	a.Retain()
   444  	defer a.Release()
   445  
   446  	r := a.Len()
   447  
   448  	// TODO(poopoothegorilla): instead of creating bool ValidMask maybe
   449  	// bitmapBytes can be used from arrow API
   450  	mask := make([]bool, r)
   451  	for i := 0; i < r; i++ {
   452  		mask[i] = a.IsNull(i)
   453  	}
   454  
   455  	switch a.DataType() {
   456  	case arrow.BinaryTypes.String:
   457  		backing := make([]string, r)
   458  		for i := 0; i < r; i++ {
   459  			backing[i] = a.(*arrowArray.String).Value(i)
   460  		}
   461  		retVal := New(WithBacking(backing, mask), WithShape(r, 1))
   462  		return retVal
   463  	case arrow.FixedWidthTypes.Boolean:
   464  		backing := make([]bool, r)
   465  		for i := 0; i < r; i++ {
   466  			backing[i] = a.(*arrowArray.Boolean).Value(i)
   467  		}
   468  		retVal := New(WithBacking(backing, mask), WithShape(r, 1))
   469  		return retVal
   470  	case arrow.PrimitiveTypes.Int8:
   471  		backing := a.(*arrowArray.Int8).Int8Values()
   472  		retVal := New(WithBacking(backing, mask), WithShape(r, 1))
   473  		return retVal
   474  	case arrow.PrimitiveTypes.Int16:
   475  		backing := a.(*arrowArray.Int16).Int16Values()
   476  		retVal := New(WithBacking(backing, mask), WithShape(r, 1))
   477  		return retVal
   478  	case arrow.PrimitiveTypes.Int32:
   479  		backing := a.(*arrowArray.Int32).Int32Values()
   480  		retVal := New(WithBacking(backing, mask), WithShape(r, 1))
   481  		return retVal
   482  	case arrow.PrimitiveTypes.Int64:
   483  		backing := a.(*arrowArray.Int64).Int64Values()
   484  		retVal := New(WithBacking(backing, mask), WithShape(r, 1))
   485  		return retVal
   486  	case arrow.PrimitiveTypes.Uint8:
   487  		backing := a.(*arrowArray.Uint8).Uint8Values()
   488  		retVal := New(WithBacking(backing, mask), WithShape(r, 1))
   489  		return retVal
   490  	case arrow.PrimitiveTypes.Uint16:
   491  		backing := a.(*arrowArray.Uint16).Uint16Values()
   492  		retVal := New(WithBacking(backing, mask), WithShape(r, 1))
   493  		return retVal
   494  	case arrow.PrimitiveTypes.Uint32:
   495  		backing := a.(*arrowArray.Uint32).Uint32Values()
   496  		retVal := New(WithBacking(backing, mask), WithShape(r, 1))
   497  		return retVal
   498  	case arrow.PrimitiveTypes.Uint64:
   499  		backing := a.(*arrowArray.Uint64).Uint64Values()
   500  		retVal := New(WithBacking(backing, mask), WithShape(r, 1))
   501  		return retVal
   502  	case arrow.PrimitiveTypes.Float32:
   503  		backing := a.(*arrowArray.Float32).Float32Values()
   504  		retVal := New(WithBacking(backing, mask), WithShape(r, 1))
   505  		return retVal
   506  	case arrow.PrimitiveTypes.Float64:
   507  		backing := a.(*arrowArray.Float64).Float64Values()
   508  		retVal := New(WithBacking(backing, mask), WithShape(r, 1))
   509  		return retVal
   510  	default:
   511  		panic(fmt.Sprintf("Unsupported Arrow DataType - %v", a.DataType()))
   512  	}
   513  
   514  	panic("Unreachable")
   515  }
   516  
   517  // FromArrowTensor converts an "arrow/tensor".Interface into a Tensor of matching DataType.
   518  func FromArrowTensor(a arrowTensor.Interface) *Dense {
   519  	a.Retain()
   520  	defer a.Release()
   521  
   522  	if !a.IsContiguous() {
   523  		panic("Non-contiguous data is Unsupported")
   524  	}
   525  
   526  	var shape []int
   527  	for _, val := range a.Shape() {
   528  		shape = append(shape, int(val))
   529  	}
   530  
   531  	l := a.Len()
   532  	validMask := a.Data().Buffers()[0].Bytes()
   533  	dataOffset := a.Data().Offset()
   534  	mask := make([]bool, l)
   535  	for i := 0; i < l; i++ {
   536  		mask[i] = len(validMask) != 0 && bitutil.BitIsNotSet(validMask, dataOffset+i)
   537  	}
   538  
   539  	switch a.DataType() {
   540  	case arrow.PrimitiveTypes.Int8:
   541  		backing := a.(*arrowTensor.Int8).Int8Values()
   542  		if a.IsColMajor() {
   543  			return New(WithShape(shape...), AsFortran(backing, mask))
   544  		}
   545  
   546  		return New(WithShape(shape...), WithBacking(backing, mask))
   547  	case arrow.PrimitiveTypes.Int16:
   548  		backing := a.(*arrowTensor.Int16).Int16Values()
   549  		if a.IsColMajor() {
   550  			return New(WithShape(shape...), AsFortran(backing, mask))
   551  		}
   552  
   553  		return New(WithShape(shape...), WithBacking(backing, mask))
   554  	case arrow.PrimitiveTypes.Int32:
   555  		backing := a.(*arrowTensor.Int32).Int32Values()
   556  		if a.IsColMajor() {
   557  			return New(WithShape(shape...), AsFortran(backing, mask))
   558  		}
   559  
   560  		return New(WithShape(shape...), WithBacking(backing, mask))
   561  	case arrow.PrimitiveTypes.Int64:
   562  		backing := a.(*arrowTensor.Int64).Int64Values()
   563  		if a.IsColMajor() {
   564  			return New(WithShape(shape...), AsFortran(backing, mask))
   565  		}
   566  
   567  		return New(WithShape(shape...), WithBacking(backing, mask))
   568  	case arrow.PrimitiveTypes.Uint8:
   569  		backing := a.(*arrowTensor.Uint8).Uint8Values()
   570  		if a.IsColMajor() {
   571  			return New(WithShape(shape...), AsFortran(backing, mask))
   572  		}
   573  
   574  		return New(WithShape(shape...), WithBacking(backing, mask))
   575  	case arrow.PrimitiveTypes.Uint16:
   576  		backing := a.(*arrowTensor.Uint16).Uint16Values()
   577  		if a.IsColMajor() {
   578  			return New(WithShape(shape...), AsFortran(backing, mask))
   579  		}
   580  
   581  		return New(WithShape(shape...), WithBacking(backing, mask))
   582  	case arrow.PrimitiveTypes.Uint32:
   583  		backing := a.(*arrowTensor.Uint32).Uint32Values()
   584  		if a.IsColMajor() {
   585  			return New(WithShape(shape...), AsFortran(backing, mask))
   586  		}
   587  
   588  		return New(WithShape(shape...), WithBacking(backing, mask))
   589  	case arrow.PrimitiveTypes.Uint64:
   590  		backing := a.(*arrowTensor.Uint64).Uint64Values()
   591  		if a.IsColMajor() {
   592  			return New(WithShape(shape...), AsFortran(backing, mask))
   593  		}
   594  
   595  		return New(WithShape(shape...), WithBacking(backing, mask))
   596  	case arrow.PrimitiveTypes.Float32:
   597  		backing := a.(*arrowTensor.Float32).Float32Values()
   598  		if a.IsColMajor() {
   599  			return New(WithShape(shape...), AsFortran(backing, mask))
   600  		}
   601  
   602  		return New(WithShape(shape...), WithBacking(backing, mask))
   603  	case arrow.PrimitiveTypes.Float64:
   604  		backing := a.(*arrowTensor.Float64).Float64Values()
   605  		if a.IsColMajor() {
   606  			return New(WithShape(shape...), AsFortran(backing, mask))
   607  		}
   608  
   609  		return New(WithShape(shape...), WithBacking(backing, mask))
   610  	default:
   611  		panic(fmt.Sprintf("Unsupported Arrow DataType - %v", a.DataType()))
   612  	}
   613  
   614  	panic("Unreachable")
   615  }