github.com/wzzhu/tensor@v0.9.24/genlib2/dense_compat.go (about)

     1  package main
     2  
     3  import (
     4  	"io"
     5  	"text/template"
     6  )
     7  
     8  const importsArrowRaw = `import (
     9  	arrowArray "github.com/apache/arrow/go/arrow/array"
    10  	"github.com/apache/arrow/go/arrow/bitutil"
    11  	arrowTensor "github.com/apache/arrow/go/arrow/tensor"
    12  	arrow "github.com/apache/arrow/go/arrow"
    13  )
    14  `
    15  
    16  const conversionsRaw = `func convFromFloat64s(to Dtype, data []float64) interface{} {
    17  	switch to {
    18  	{{range .Kinds -}}
    19  	{{if isNumber . -}}
    20  	case {{reflectKind .}}:
    21  		{{if eq .String "float64" -}}
    22  			retVal := make([]float64, len(data))
    23  			copy(retVal, data)
    24  			return retVal
    25  		{{else if eq .String "float32" -}}
    26  			retVal := make([]float32, len(data))
    27  			for i, v := range data {
    28  				switch {
    29  				case math.IsNaN(v):
    30  					retVal[i] = math32.NaN()
    31  				case math.IsInf(v, 1):
    32  					retVal[i] = math32.Inf(1)
    33  				case math.IsInf(v, -1):
    34  					retVal[i] = math32.Inf(-1)
    35  				default:
    36  					retVal[i] = float32(v)
    37  				}
    38  			}
    39  			return retVal
    40  		{{else if eq .String "complex64" -}}
    41  			retVal := make([]complex64, len(data))
    42  			for i, v := range data {
    43  				switch {
    44  				case math.IsNaN(v):
    45  					retVal[i] = complex64(cmplx.NaN())
    46  				case math.IsInf(v, 0):
    47  					retVal[i] = complex64(cmplx.Inf())
    48  				default:
    49  					retVal[i] = complex(float32(v), float32(0))
    50  				}
    51  			}
    52  			return retVal
    53  		{{else if eq .String "complex128" -}}
    54  			retVal := make([]complex128, len(data))
    55  			for i, v := range data {
    56  				switch {
    57  				case math.IsNaN(v):
    58  					retVal[i] = cmplx.NaN()
    59  				case math.IsInf(v, 0):
    60  					retVal[i] = cmplx.Inf()
    61  				default:
    62  					retVal[i] = complex(v, float64(0))
    63  				}
    64  			}
    65  			return retVal
    66  		{{else -}}
    67  			retVal := make([]{{asType .}}, len(data))
    68  			for i, v :=range data{
    69  				switch {
    70  				case math.IsNaN(v), math.IsInf(v, 0):
    71  					retVal[i] = 0
    72  				default:
    73  					retVal[i] = {{asType .}}(v)
    74  				}
    75  			}
    76  			return retVal
    77  		{{end -}}
    78  	{{end -}}
    79  	{{end -}}
    80  	default:
    81  		panic("Unsupported Dtype")
    82  	}
    83  }
    84  
    85  func convToFloat64s(t *Dense) (retVal []float64){
    86  	retVal = make([]float64, t.len())
    87  	switch t.t{
    88  	{{range .Kinds -}}
    89  	{{if isNumber . -}}
    90  	case {{reflectKind .}}:
    91  		{{if eq .String "float64" -}}
    92  			return t.{{sliceOf .}}
    93  		{{else if eq .String "float32" -}}
    94  			for i, v := range t.{{sliceOf .}} {
    95  				switch {
    96  				case math32.IsNaN(v):
    97  					retVal[i] = math.NaN()
    98  				case math32.IsInf(v, 1):
    99  					retVal[i] = math.Inf(1)
   100  				case math32.IsInf(v, -1):
   101  					retVal[i] = math.Inf(-1)
   102  				default:
   103  					retVal[i] = float64(v)
   104  				}
   105  			}
   106  		{{else if eq .String "complex64" -}}
   107  			for i, v := range t.{{sliceOf .}} {
   108  				switch {
   109  				case cmplx.IsNaN(complex128(v)):
   110  					retVal[i] = math.NaN()
   111  				case cmplx.IsInf(complex128(v)):
   112  					retVal[i] = math.Inf(1)
   113  				default:
   114  					retVal[i] = float64(real(v))
   115  				}
   116  			}
   117  		{{else if eq .String "complex128" -}}
   118  			for i, v := range t.{{sliceOf .}} {
   119  				switch {
   120  				case cmplx.IsNaN(v):
   121  					retVal[i] = math.NaN()
   122  				case cmplx.IsInf(v):
   123  					retVal[i] = math.Inf(1)
   124  				default:
   125  					retVal[i] = real(v)
   126  				}
   127  			}
   128  		{{else -}}
   129  			for i, v := range t.{{sliceOf .}} {
   130  				retVal[i]=  float64(v)
   131  			}
   132  		{{end -}}
   133  		return retVal
   134  	{{end -}}
   135  	{{end -}}
   136  	default:
   137  		panic(fmt.Sprintf("Cannot convert *Dense of %v to []float64", t.t))
   138  	}
   139  }
   140  
   141  func convToFloat64(x interface{}) float64 {
   142  	switch xt := x.(type) {
   143  	{{range .Kinds -}}
   144  	{{if isNumber . -}}
   145  	case {{asType .}}:
   146  		{{if eq .String "float64 -"}}
   147  			return xt
   148  		{{else if eq .String "complex64" -}}
   149  			return float64(real(xt))
   150  		{{else if eq .String "complex128" -}}
   151  			return real(xt)
   152  		{{else -}}
   153  			return float64(xt)
   154  		{{end -}}
   155  	{{end -}}
   156  	{{end -}}
   157  	default:
   158  		panic("Cannot convert to float64")
   159  	}
   160  }
   161  `
   162  
   163  const compatRaw = `// FromMat64 converts a *"gonum/matrix/mat64".Dense into a *tensorf64.Tensor.
   164  func FromMat64(m *mat.Dense, opts ...FuncOpt) *Dense {
   165  	r, c := m.Dims()
   166  	fo := ParseFuncOpts(opts...)
   167  	defer returnOpOpt(fo)
   168  	toCopy := fo.Safe()
   169  	as := fo.As()
   170  	if as.Type == nil {
   171  		as = Float64
   172  	}
   173  
   174  	switch as.Kind() {
   175  	{{range .Kinds -}}
   176  	{{if isNumber . -}}
   177  	case reflect.{{reflectKind .}}:
   178  		{{if eq .String "float64" -}}
   179  			var backing []float64
   180  			if toCopy {
   181  				backing = make([]float64, len(m.RawMatrix().Data))
   182  				copy(backing, m.RawMatrix().Data)
   183  			} else {
   184  				backing = m.RawMatrix().Data
   185  			}
   186  		{{else -}}
   187  			backing := convFromFloat64s({{asType . | title}}, m.RawMatrix().Data).([]{{asType .}})
   188  		{{end -}}
   189  		retVal := New(WithBacking(backing), WithShape(r, c))
   190  		return retVal
   191  	{{end -}}
   192  	{{end -}}
   193  	default:
   194  		panic(fmt.Sprintf("Unsupported Dtype - cannot convert float64 to %v", as))
   195  	}
   196  	panic("Unreachable")
   197  }
   198  
   199  
   200  // ToMat64 converts a *Dense to a *mat.Dense. All the values are converted into float64s.
   201  // This function will only convert matrices. Anything *Dense with dimensions larger than 2 will cause an error.
   202  func ToMat64(t *Dense, opts ...FuncOpt) (retVal *mat.Dense, err error) {
   203  	// checks:
   204  	if !t.IsNativelyAccessible() {
   205  		return nil, errors.Errorf("Cannot convert *Dense to *mat.Dense. Data is inaccessible")
   206  	}
   207  
   208  	if !t.IsMatrix() {
   209  		// error
   210  		return nil, errors.Errorf("Cannot convert *Dense to *mat.Dense. Expected number of dimensions: <=2, T has got %d dimensions (Shape: %v)", t.Dims(), t.Shape())
   211  	}
   212  
   213  	fo := ParseFuncOpts(opts...)
   214  	defer returnOpOpt(fo)
   215  	toCopy := fo.Safe()
   216  
   217  	// fix dims
   218  	r := t.Shape()[0]
   219  	c := t.Shape()[1]
   220  
   221  	var data []float64
   222  	switch {
   223  	case t.t == Float64 && toCopy  && !t.IsMaterializable():
   224  		data = make([]float64, t.len())
   225  		copy(data, t.Float64s())
   226  	case !t.IsMaterializable():	
   227  		data = convToFloat64s(t)
   228  	default:
   229  		it := newFlatIterator(&t.AP)
   230  		var next int
   231  		for next, err = it.Next(); err == nil; next, err = it.Next() {
   232  			if err = handleNoOp(err); err != nil {
   233  				return
   234  			}
   235  			data = append(data, convToFloat64(t.Get(next)))
   236  		}
   237  		err = nil
   238  		
   239  	}
   240  
   241  	retVal = mat.NewDense(r, c, data)
   242  	return
   243  }
   244  
   245  
   246  `
   247  
   248  type ArrowData struct {
   249  	BinaryTypes     []string
   250  	FixedWidthTypes []string
   251  	PrimitiveTypes  []string
   252  }
   253  
   254  const compatArrowArrayRaw = `// FromArrowArray converts an "arrow/array".Interface into a Tensor of matching DataType.
   255  func FromArrowArray(a arrowArray.Interface) *Dense {
   256  	a.Retain()
   257  	defer a.Release()
   258  
   259  	r := a.Len()
   260  
   261  	// TODO(poopoothegorilla): instead of creating bool ValidMask maybe
   262  	// bitmapBytes can be used from arrow API
   263  	mask := make([]bool, r)
   264  	for i := 0; i < r; i++ {
   265  		mask[i] = a.IsNull(i)
   266  	}
   267  
   268  	switch a.DataType() {
   269  	{{range .BinaryTypes -}}
   270  	case arrow.BinaryTypes.{{.}}:
   271  		{{if eq . "String" -}}
   272  			backing := make([]string, r)
   273  			for i := 0; i < r; i++ {
   274  				backing[i] = a.(*arrowArray.{{.}}).Value(i)
   275  			}
   276  		{{else -}}
   277  			backing := a.(*arrowArray.{{.}}).{{.}}Values()
   278  		{{end -}}
   279  		retVal := New(WithBacking(backing, mask), WithShape(r, 1))
   280  		return retVal
   281  	{{end -}}
   282  	{{range .FixedWidthTypes -}}
   283  	case arrow.FixedWidthTypes.{{.}}:
   284  		{{if eq . "Boolean" -}}
   285  			backing := make([]bool, r)
   286  			for i := 0; i < r; i++ {
   287  				backing[i] = a.(*arrowArray.{{.}}).Value(i)
   288  			}
   289  		{{else -}}
   290  			backing := a.(*arrowArray.{{.}}).{{.}}Values()
   291  		{{end -}}
   292  		retVal := New(WithBacking(backing, mask), WithShape(r, 1))
   293  		return retVal
   294  	{{end -}}
   295  	{{range .PrimitiveTypes -}}
   296  	case arrow.PrimitiveTypes.{{.}}:
   297  		backing := a.(*arrowArray.{{.}}).{{.}}Values()
   298  		retVal := New(WithBacking(backing, mask), WithShape(r, 1))
   299  		return retVal
   300  	{{end -}}
   301  	default:
   302  		panic(fmt.Sprintf("Unsupported Arrow DataType - %v", a.DataType()))
   303  	}
   304  
   305  	panic("Unreachable")
   306  }
   307  `
   308  
   309  const compatArrowTensorRaw = `// FromArrowTensor converts an "arrow/tensor".Interface into a Tensor of matching DataType.
   310  func FromArrowTensor(a arrowTensor.Interface) *Dense {
   311  	a.Retain()
   312  	defer a.Release()
   313  
   314  	if !a.IsContiguous() {
   315  		panic("Non-contiguous data is Unsupported")
   316  	}
   317  
   318  	var shape []int
   319  	for _, val := range a.Shape() {
   320  		shape = append(shape, int(val))
   321  	}
   322  
   323  	l := a.Len()
   324  	validMask := a.Data().Buffers()[0].Bytes()
   325  	dataOffset := a.Data().Offset()
   326  	mask := make([]bool, l)
   327  	for i := 0; i < l; i++ {
   328  		mask[i] = len(validMask) != 0 && bitutil.BitIsNotSet(validMask, dataOffset+i)
   329  	}
   330  
   331  	switch a.DataType() {
   332  	{{range .PrimitiveTypes -}}
   333  	case arrow.PrimitiveTypes.{{.}}:
   334  		backing := a.(*arrowTensor.{{.}}).{{.}}Values()
   335  		if a.IsColMajor() {
   336  			return New(WithShape(shape...), AsFortran(backing, mask))
   337  		}
   338  
   339  		return New(WithShape(shape...), WithBacking(backing, mask))
   340  	{{end -}}
   341  	default:
   342  		panic(fmt.Sprintf("Unsupported Arrow DataType - %v", a.DataType()))
   343  	}
   344  
   345  	panic("Unreachable")
   346  }
   347  `
   348  
   349  var (
   350  	importsArrow       *template.Template
   351  	conversions        *template.Template
   352  	compats            *template.Template
   353  	compatsArrowArray  *template.Template
   354  	compatsArrowTensor *template.Template
   355  )
   356  
   357  func init() {
   358  	importsArrow = template.Must(template.New("imports_arrow").Funcs(funcs).Parse(importsArrowRaw))
   359  	conversions = template.Must(template.New("conversions").Funcs(funcs).Parse(conversionsRaw))
   360  	compats = template.Must(template.New("compat").Funcs(funcs).Parse(compatRaw))
   361  	compatsArrowArray = template.Must(template.New("compat_arrow_array").Funcs(funcs).Parse(compatArrowArrayRaw))
   362  	compatsArrowTensor = template.Must(template.New("compat_arrow_tensor").Funcs(funcs).Parse(compatArrowTensorRaw))
   363  }
   364  
   365  func generateDenseCompat(f io.Writer, generic Kinds) {
   366  	// NOTE(poopoothegorilla): an alias is needed for the Arrow Array pkg to prevent naming
   367  	// collisions
   368  	importsArrow.Execute(f, generic)
   369  	conversions.Execute(f, generic)
   370  	compats.Execute(f, generic)
   371  	arrowData := ArrowData{
   372  		BinaryTypes:     arrowBinaryTypes,
   373  		FixedWidthTypes: arrowFixedWidthTypes,
   374  		PrimitiveTypes:  arrowPrimitiveTypes,
   375  	}
   376  	compatsArrowArray.Execute(f, arrowData)
   377  	compatsArrowTensor.Execute(f, arrowData)
   378  }