github.com/wzzhu/tensor@v0.9.24/genlib2/dense_compat.go (about) 1 package main 2 3 import ( 4 "io" 5 "text/template" 6 ) 7 8 const importsArrowRaw = `import ( 9 arrowArray "github.com/apache/arrow/go/arrow/array" 10 "github.com/apache/arrow/go/arrow/bitutil" 11 arrowTensor "github.com/apache/arrow/go/arrow/tensor" 12 arrow "github.com/apache/arrow/go/arrow" 13 ) 14 ` 15 16 const conversionsRaw = `func convFromFloat64s(to Dtype, data []float64) interface{} { 17 switch to { 18 {{range .Kinds -}} 19 {{if isNumber . -}} 20 case {{reflectKind .}}: 21 {{if eq .String "float64" -}} 22 retVal := make([]float64, len(data)) 23 copy(retVal, data) 24 return retVal 25 {{else if eq .String "float32" -}} 26 retVal := make([]float32, len(data)) 27 for i, v := range data { 28 switch { 29 case math.IsNaN(v): 30 retVal[i] = math32.NaN() 31 case math.IsInf(v, 1): 32 retVal[i] = math32.Inf(1) 33 case math.IsInf(v, -1): 34 retVal[i] = math32.Inf(-1) 35 default: 36 retVal[i] = float32(v) 37 } 38 } 39 return retVal 40 {{else if eq .String "complex64" -}} 41 retVal := make([]complex64, len(data)) 42 for i, v := range data { 43 switch { 44 case math.IsNaN(v): 45 retVal[i] = complex64(cmplx.NaN()) 46 case math.IsInf(v, 0): 47 retVal[i] = complex64(cmplx.Inf()) 48 default: 49 retVal[i] = complex(float32(v), float32(0)) 50 } 51 } 52 return retVal 53 {{else if eq .String "complex128" -}} 54 retVal := make([]complex128, len(data)) 55 for i, v := range data { 56 switch { 57 case math.IsNaN(v): 58 retVal[i] = cmplx.NaN() 59 case math.IsInf(v, 0): 60 retVal[i] = cmplx.Inf() 61 default: 62 retVal[i] = complex(v, float64(0)) 63 } 64 } 65 return retVal 66 {{else -}} 67 retVal := make([]{{asType .}}, len(data)) 68 for i, v :=range data{ 69 switch { 70 case math.IsNaN(v), math.IsInf(v, 0): 71 retVal[i] = 0 72 default: 73 retVal[i] = {{asType .}}(v) 74 } 75 } 76 return retVal 77 {{end -}} 78 {{end -}} 79 {{end -}} 80 default: 81 panic("Unsupported Dtype") 82 } 83 } 84 85 func convToFloat64s(t *Dense) (retVal []float64){ 86 retVal = make([]float64, t.len()) 87 switch t.t{ 88 {{range .Kinds -}} 89 {{if isNumber . -}} 90 case {{reflectKind .}}: 91 {{if eq .String "float64" -}} 92 return t.{{sliceOf .}} 93 {{else if eq .String "float32" -}} 94 for i, v := range t.{{sliceOf .}} { 95 switch { 96 case math32.IsNaN(v): 97 retVal[i] = math.NaN() 98 case math32.IsInf(v, 1): 99 retVal[i] = math.Inf(1) 100 case math32.IsInf(v, -1): 101 retVal[i] = math.Inf(-1) 102 default: 103 retVal[i] = float64(v) 104 } 105 } 106 {{else if eq .String "complex64" -}} 107 for i, v := range t.{{sliceOf .}} { 108 switch { 109 case cmplx.IsNaN(complex128(v)): 110 retVal[i] = math.NaN() 111 case cmplx.IsInf(complex128(v)): 112 retVal[i] = math.Inf(1) 113 default: 114 retVal[i] = float64(real(v)) 115 } 116 } 117 {{else if eq .String "complex128" -}} 118 for i, v := range t.{{sliceOf .}} { 119 switch { 120 case cmplx.IsNaN(v): 121 retVal[i] = math.NaN() 122 case cmplx.IsInf(v): 123 retVal[i] = math.Inf(1) 124 default: 125 retVal[i] = real(v) 126 } 127 } 128 {{else -}} 129 for i, v := range t.{{sliceOf .}} { 130 retVal[i]= float64(v) 131 } 132 {{end -}} 133 return retVal 134 {{end -}} 135 {{end -}} 136 default: 137 panic(fmt.Sprintf("Cannot convert *Dense of %v to []float64", t.t)) 138 } 139 } 140 141 func convToFloat64(x interface{}) float64 { 142 switch xt := x.(type) { 143 {{range .Kinds -}} 144 {{if isNumber . -}} 145 case {{asType .}}: 146 {{if eq .String "float64 -"}} 147 return xt 148 {{else if eq .String "complex64" -}} 149 return float64(real(xt)) 150 {{else if eq .String "complex128" -}} 151 return real(xt) 152 {{else -}} 153 return float64(xt) 154 {{end -}} 155 {{end -}} 156 {{end -}} 157 default: 158 panic("Cannot convert to float64") 159 } 160 } 161 ` 162 163 const compatRaw = `// FromMat64 converts a *"gonum/matrix/mat64".Dense into a *tensorf64.Tensor. 164 func FromMat64(m *mat.Dense, opts ...FuncOpt) *Dense { 165 r, c := m.Dims() 166 fo := ParseFuncOpts(opts...) 167 defer returnOpOpt(fo) 168 toCopy := fo.Safe() 169 as := fo.As() 170 if as.Type == nil { 171 as = Float64 172 } 173 174 switch as.Kind() { 175 {{range .Kinds -}} 176 {{if isNumber . -}} 177 case reflect.{{reflectKind .}}: 178 {{if eq .String "float64" -}} 179 var backing []float64 180 if toCopy { 181 backing = make([]float64, len(m.RawMatrix().Data)) 182 copy(backing, m.RawMatrix().Data) 183 } else { 184 backing = m.RawMatrix().Data 185 } 186 {{else -}} 187 backing := convFromFloat64s({{asType . | title}}, m.RawMatrix().Data).([]{{asType .}}) 188 {{end -}} 189 retVal := New(WithBacking(backing), WithShape(r, c)) 190 return retVal 191 {{end -}} 192 {{end -}} 193 default: 194 panic(fmt.Sprintf("Unsupported Dtype - cannot convert float64 to %v", as)) 195 } 196 panic("Unreachable") 197 } 198 199 200 // ToMat64 converts a *Dense to a *mat.Dense. All the values are converted into float64s. 201 // This function will only convert matrices. Anything *Dense with dimensions larger than 2 will cause an error. 202 func ToMat64(t *Dense, opts ...FuncOpt) (retVal *mat.Dense, err error) { 203 // checks: 204 if !t.IsNativelyAccessible() { 205 return nil, errors.Errorf("Cannot convert *Dense to *mat.Dense. Data is inaccessible") 206 } 207 208 if !t.IsMatrix() { 209 // error 210 return nil, errors.Errorf("Cannot convert *Dense to *mat.Dense. Expected number of dimensions: <=2, T has got %d dimensions (Shape: %v)", t.Dims(), t.Shape()) 211 } 212 213 fo := ParseFuncOpts(opts...) 214 defer returnOpOpt(fo) 215 toCopy := fo.Safe() 216 217 // fix dims 218 r := t.Shape()[0] 219 c := t.Shape()[1] 220 221 var data []float64 222 switch { 223 case t.t == Float64 && toCopy && !t.IsMaterializable(): 224 data = make([]float64, t.len()) 225 copy(data, t.Float64s()) 226 case !t.IsMaterializable(): 227 data = convToFloat64s(t) 228 default: 229 it := newFlatIterator(&t.AP) 230 var next int 231 for next, err = it.Next(); err == nil; next, err = it.Next() { 232 if err = handleNoOp(err); err != nil { 233 return 234 } 235 data = append(data, convToFloat64(t.Get(next))) 236 } 237 err = nil 238 239 } 240 241 retVal = mat.NewDense(r, c, data) 242 return 243 } 244 245 246 ` 247 248 type ArrowData struct { 249 BinaryTypes []string 250 FixedWidthTypes []string 251 PrimitiveTypes []string 252 } 253 254 const compatArrowArrayRaw = `// FromArrowArray converts an "arrow/array".Interface into a Tensor of matching DataType. 255 func FromArrowArray(a arrowArray.Interface) *Dense { 256 a.Retain() 257 defer a.Release() 258 259 r := a.Len() 260 261 // TODO(poopoothegorilla): instead of creating bool ValidMask maybe 262 // bitmapBytes can be used from arrow API 263 mask := make([]bool, r) 264 for i := 0; i < r; i++ { 265 mask[i] = a.IsNull(i) 266 } 267 268 switch a.DataType() { 269 {{range .BinaryTypes -}} 270 case arrow.BinaryTypes.{{.}}: 271 {{if eq . "String" -}} 272 backing := make([]string, r) 273 for i := 0; i < r; i++ { 274 backing[i] = a.(*arrowArray.{{.}}).Value(i) 275 } 276 {{else -}} 277 backing := a.(*arrowArray.{{.}}).{{.}}Values() 278 {{end -}} 279 retVal := New(WithBacking(backing, mask), WithShape(r, 1)) 280 return retVal 281 {{end -}} 282 {{range .FixedWidthTypes -}} 283 case arrow.FixedWidthTypes.{{.}}: 284 {{if eq . "Boolean" -}} 285 backing := make([]bool, r) 286 for i := 0; i < r; i++ { 287 backing[i] = a.(*arrowArray.{{.}}).Value(i) 288 } 289 {{else -}} 290 backing := a.(*arrowArray.{{.}}).{{.}}Values() 291 {{end -}} 292 retVal := New(WithBacking(backing, mask), WithShape(r, 1)) 293 return retVal 294 {{end -}} 295 {{range .PrimitiveTypes -}} 296 case arrow.PrimitiveTypes.{{.}}: 297 backing := a.(*arrowArray.{{.}}).{{.}}Values() 298 retVal := New(WithBacking(backing, mask), WithShape(r, 1)) 299 return retVal 300 {{end -}} 301 default: 302 panic(fmt.Sprintf("Unsupported Arrow DataType - %v", a.DataType())) 303 } 304 305 panic("Unreachable") 306 } 307 ` 308 309 const compatArrowTensorRaw = `// FromArrowTensor converts an "arrow/tensor".Interface into a Tensor of matching DataType. 310 func FromArrowTensor(a arrowTensor.Interface) *Dense { 311 a.Retain() 312 defer a.Release() 313 314 if !a.IsContiguous() { 315 panic("Non-contiguous data is Unsupported") 316 } 317 318 var shape []int 319 for _, val := range a.Shape() { 320 shape = append(shape, int(val)) 321 } 322 323 l := a.Len() 324 validMask := a.Data().Buffers()[0].Bytes() 325 dataOffset := a.Data().Offset() 326 mask := make([]bool, l) 327 for i := 0; i < l; i++ { 328 mask[i] = len(validMask) != 0 && bitutil.BitIsNotSet(validMask, dataOffset+i) 329 } 330 331 switch a.DataType() { 332 {{range .PrimitiveTypes -}} 333 case arrow.PrimitiveTypes.{{.}}: 334 backing := a.(*arrowTensor.{{.}}).{{.}}Values() 335 if a.IsColMajor() { 336 return New(WithShape(shape...), AsFortran(backing, mask)) 337 } 338 339 return New(WithShape(shape...), WithBacking(backing, mask)) 340 {{end -}} 341 default: 342 panic(fmt.Sprintf("Unsupported Arrow DataType - %v", a.DataType())) 343 } 344 345 panic("Unreachable") 346 } 347 ` 348 349 var ( 350 importsArrow *template.Template 351 conversions *template.Template 352 compats *template.Template 353 compatsArrowArray *template.Template 354 compatsArrowTensor *template.Template 355 ) 356 357 func init() { 358 importsArrow = template.Must(template.New("imports_arrow").Funcs(funcs).Parse(importsArrowRaw)) 359 conversions = template.Must(template.New("conversions").Funcs(funcs).Parse(conversionsRaw)) 360 compats = template.Must(template.New("compat").Funcs(funcs).Parse(compatRaw)) 361 compatsArrowArray = template.Must(template.New("compat_arrow_array").Funcs(funcs).Parse(compatArrowArrayRaw)) 362 compatsArrowTensor = template.Must(template.New("compat_arrow_tensor").Funcs(funcs).Parse(compatArrowTensorRaw)) 363 } 364 365 func generateDenseCompat(f io.Writer, generic Kinds) { 366 // NOTE(poopoothegorilla): an alias is needed for the Arrow Array pkg to prevent naming 367 // collisions 368 importsArrow.Execute(f, generic) 369 conversions.Execute(f, generic) 370 compats.Execute(f, generic) 371 arrowData := ArrowData{ 372 BinaryTypes: arrowBinaryTypes, 373 FixedWidthTypes: arrowFixedWidthTypes, 374 PrimitiveTypes: arrowPrimitiveTypes, 375 } 376 compatsArrowArray.Execute(f, arrowData) 377 compatsArrowTensor.Execute(f, arrowData) 378 }