go-hep.org/x/hep@v0.38.1/groot/rnpy/arrow.go (about) 1 // Copyright ©2019 The go-hep Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package rnpy 6 7 import ( 8 "fmt" 9 "io" 10 "reflect" 11 "sync/atomic" 12 13 "codeberg.org/sbinet/npyio/npy" 14 "git.sr.ht/~sbinet/go-arrow" 15 "git.sr.ht/~sbinet/go-arrow/array" 16 "git.sr.ht/~sbinet/go-arrow/arrio" 17 "git.sr.ht/~sbinet/go-arrow/memory" 18 ) 19 20 var ( 21 boolType = reflect.TypeOf(true) 22 uint8Type = reflect.TypeOf((*uint8)(nil)).Elem() 23 uint16Type = reflect.TypeOf((*uint16)(nil)).Elem() 24 uint32Type = reflect.TypeOf((*uint32)(nil)).Elem() 25 uint64Type = reflect.TypeOf((*uint64)(nil)).Elem() 26 int8Type = reflect.TypeOf((*int8)(nil)).Elem() 27 int16Type = reflect.TypeOf((*int16)(nil)).Elem() 28 int32Type = reflect.TypeOf((*int32)(nil)).Elem() 29 int64Type = reflect.TypeOf((*int64)(nil)).Elem() 30 float32Type = reflect.TypeOf((*float32)(nil)).Elem() 31 float64Type = reflect.TypeOf((*float64)(nil)).Elem() 32 33 // complex64Type = reflect.TypeOf((*complex64)(nil)).Elem() 34 // complex128Type = reflect.TypeOf((*complex128)(nil)).Elem() 35 ) 36 37 // Record is an in-memory Arrow Record backed by a NumPy data file. 38 type Record struct { 39 refs int64 40 41 mem memory.Allocator 42 43 schema *arrow.Schema 44 nrows int64 45 ncols int64 46 47 cols []array.Interface 48 } 49 50 // NewRecord returns an Arrow Record from a NumPy data file reader. 51 func NewRecord(npy *npy.Reader) *Record { 52 var ( 53 mem = memory.NewGoAllocator() 54 schema = schemaFrom(npy) 55 shape = make([]int, len(npy.Header.Descr.Shape)) 56 ) 57 58 copy(shape, npy.Header.Descr.Shape) 59 if npy.Header.Descr.Fortran { 60 a := shape 61 for i := len(a)/2 - 1; i >= 0; i-- { 62 opp := len(a) - 1 - i 63 a[i], a[opp] = a[opp], a[i] 64 } 65 shape = a 66 } 67 nrows := int64(shape[0]) 68 69 rec := &Record{ 70 refs: 1, 71 mem: mem, 72 schema: schema, 73 nrows: nrows, 74 ncols: 1, 75 } 76 77 nelem := int64(1) 78 for _, v := range shape { 79 nelem *= int64(v) 80 } 81 82 bldr := builderFrom(mem, schema.Field(0).Type, nrows) 83 defer bldr.Release() 84 85 rec.read(npy, nelem, bldr) 86 87 return rec 88 } 89 90 // Retain increases the reference count by 1. 91 // Retain may be called simultaneously from multiple goroutines. 92 func (rec *Record) Retain() { 93 atomic.AddInt64(&rec.refs, 1) 94 } 95 96 // Release decreases the reference count by 1. 97 // When the reference count goes to zero, the memory is freed. 98 // Release may be called simultaneously from multiple goroutines. 99 func (rec *Record) Release() { 100 if atomic.LoadInt64(&rec.refs) <= 0 { 101 panic("groot/rarrow: too many releases") 102 } 103 104 if atomic.AddInt64(&rec.refs, -1) == 0 { 105 for i := range rec.cols { 106 rec.cols[i].Release() 107 } 108 rec.cols = nil 109 } 110 } 111 112 func (rec *Record) Schema() *arrow.Schema { return rec.schema } 113 func (rec *Record) NumRows() int64 { return rec.nrows } 114 func (rec *Record) NumCols() int64 { return rec.ncols } 115 func (rec *Record) Columns() []array.Interface { return rec.cols } 116 func (rec *Record) Column(i int) array.Interface { return rec.cols[i] } 117 func (rec *Record) ColumnName(i int) string { return rec.schema.Field(i).Name } 118 119 // NewSlice constructs a zero-copy slice of the record with the indicated 120 // indices i and j, corresponding to array[i:j]. 121 // The returned record must be Release()'d after use. 122 // 123 // NewSlice panics if the slice is outside the valid range of the record array. 124 // NewSlice panics if j < i. 125 func (rec *Record) NewSlice(i, j int64) array.Record { 126 panic("not implemented") 127 } 128 129 func (rec *Record) read(r *npy.Reader, nelem int64, bldr array.Builder) { 130 rt := dtypeFrom(rec.schema.Field(0).Type) 131 rv := reflect.New(reflect.SliceOf(rt)).Elem() 132 rv.Set(reflect.MakeSlice(rv.Type(), int(nelem), int(nelem))) 133 134 err := r.Read(rv.Addr().Interface()) 135 if err != nil { 136 panic(fmt.Errorf("npy2root: could not read numpy data: %w", err)) 137 } 138 139 ch := make(chan any, nelem/2) 140 go func() { 141 defer close(ch) 142 for i := range rv.Len() { 143 ch <- rv.Index(i).Interface() 144 } 145 }() 146 147 for i := int64(0); i < rec.nrows; i++ { 148 appendData(bldr, ch, rec.schema.Field(0).Type) 149 } 150 151 rec.cols = append(rec.cols, bldr.NewArray()) 152 } 153 154 func schemaFrom(npy *npy.Reader) *arrow.Schema { 155 var ( 156 hdr = npy.Header 157 dtype arrow.DataType 158 ) 159 switch hdr.Descr.Type { 160 case "b1", "<b1", "|b1", "bool": 161 dtype = arrow.FixedWidthTypes.Boolean 162 163 case "u1", "<u1", "|u1", "uint8": 164 dtype = arrow.PrimitiveTypes.Uint8 165 166 case "u2", "<u2", "|u2", ">u2", "uint16": 167 dtype = arrow.PrimitiveTypes.Uint16 168 169 case "u4", "<u4", "|u4", ">u4", "uint32": 170 dtype = arrow.PrimitiveTypes.Uint32 171 172 case "u8", "<u8", "|u8", ">u8", "uint64": 173 dtype = arrow.PrimitiveTypes.Uint64 174 175 case "i1", "<i1", "|i1", ">i1", "int8": 176 dtype = arrow.PrimitiveTypes.Int8 177 178 case "i2", "<i2", "|i2", ">i2", "int16": 179 dtype = arrow.PrimitiveTypes.Int16 180 181 case "i4", "<i4", "|i4", ">i4", "int32": 182 dtype = arrow.PrimitiveTypes.Int32 183 184 case "i8", "<i8", "|i8", ">i8", "int64": 185 dtype = arrow.PrimitiveTypes.Int64 186 187 case "f4", "<f4", "|f4", ">f4", "float32": 188 dtype = arrow.PrimitiveTypes.Float32 189 190 case "f8", "<f8", "|f8", ">f8", "float64": 191 dtype = arrow.PrimitiveTypes.Float64 192 193 // case "c8", "<c8", "|c8", ">c8", "complex64": 194 // panic(fmt.Errorf("npy2root: complex64 not supported")) 195 // 196 // case "c16", "<c16", "|c16", ">c16", "complex128": 197 // panic(fmt.Errorf("npy2root: complex128 not supported")) 198 199 default: 200 panic(fmt.Errorf("npy2root: invalid dtype descriptor %q", hdr.Descr.Type)) 201 } 202 203 shape := make([]int, len(hdr.Descr.Shape)) 204 copy(shape, hdr.Descr.Shape) 205 if hdr.Descr.Fortran { 206 a := shape 207 for i := len(a)/2 - 1; i >= 0; i-- { 208 opp := len(a) - 1 - i 209 a[i], a[opp] = a[opp], a[i] 210 } 211 shape = a 212 } 213 214 switch len(shape) { 215 case 1: 216 // scalar 217 218 case 2: 219 // 1d-array 220 dtype = arrow.FixedSizeListOf(int32(shape[1]), dtype) 221 222 case 3, 4, 5: 223 // 2,3d-array 224 for i := range shape[1:] { 225 dtype = arrow.FixedSizeListOf(int32(shape[len(shape)-1-i]), dtype) 226 } 227 228 default: 229 panic(fmt.Errorf("npy2root: invalid shape descriptor %v", hdr.Descr.Shape)) 230 } 231 232 schema := arrow.NewSchema([]arrow.Field{{Name: "numpy", Type: dtype}}, nil) 233 return schema 234 } 235 236 func builderFrom(mem memory.Allocator, dt arrow.DataType, size int64) array.Builder { 237 var bldr array.Builder 238 switch dt := dt.(type) { 239 case *arrow.BooleanType: 240 bldr = array.NewBooleanBuilder(mem) 241 case *arrow.Int8Type: 242 bldr = array.NewInt8Builder(mem) 243 case *arrow.Int16Type: 244 bldr = array.NewInt16Builder(mem) 245 case *arrow.Int32Type: 246 bldr = array.NewInt32Builder(mem) 247 case *arrow.Int64Type: 248 bldr = array.NewInt64Builder(mem) 249 case *arrow.Uint8Type: 250 bldr = array.NewUint8Builder(mem) 251 case *arrow.Uint16Type: 252 bldr = array.NewUint16Builder(mem) 253 case *arrow.Uint32Type: 254 bldr = array.NewUint32Builder(mem) 255 case *arrow.Uint64Type: 256 bldr = array.NewUint64Builder(mem) 257 case *arrow.Float32Type: 258 bldr = array.NewFloat32Builder(mem) 259 case *arrow.Float64Type: 260 bldr = array.NewFloat64Builder(mem) 261 // case *arrow.BinaryType: 262 // bldr = array.NewBinaryBuilder(mem, dt) 263 // case *arrow.StringType: 264 // bldr = array.NewStringBuilder(mem) 265 case *arrow.FixedSizeListType: 266 bldr = array.NewFixedSizeListBuilder(mem, dt.Len(), dt.Elem()) 267 default: 268 panic(fmt.Errorf("npy2root: invalid Arrow type %v", dt)) 269 } 270 bldr.Reserve(int(size)) 271 return bldr 272 } 273 274 func dtypeFrom(dt arrow.DataType) reflect.Type { 275 switch dt := dt.(type) { 276 case *arrow.BooleanType: 277 return boolType 278 case *arrow.Int8Type: 279 return int8Type 280 case *arrow.Int16Type: 281 return int16Type 282 case *arrow.Int32Type: 283 return int32Type 284 case *arrow.Int64Type: 285 return int64Type 286 case *arrow.Uint8Type: 287 return uint8Type 288 case *arrow.Uint16Type: 289 return uint16Type 290 case *arrow.Uint32Type: 291 return uint32Type 292 case *arrow.Uint64Type: 293 return uint64Type 294 case *arrow.Float32Type: 295 return float32Type 296 case *arrow.Float64Type: 297 return float64Type 298 // case *arrow.BinaryType: 299 // bldr = array.NewBinaryBuilder(mem, dt) 300 // case *arrow.StringType: 301 // bldr = array.NewStringBuilder(mem) 302 case *arrow.FixedSizeListType: 303 return dtypeFrom(dt.Elem()) 304 default: 305 panic(fmt.Errorf("npy2root: invalid Arrow type %v", dt)) 306 } 307 } 308 309 func appendData(bldr array.Builder, ch <-chan any, dt arrow.DataType) { 310 switch bldr := bldr.(type) { 311 case *array.BooleanBuilder: 312 v := <-ch 313 bldr.Append(v.(bool)) 314 case *array.Int8Builder: 315 v := <-ch 316 bldr.Append(v.(int8)) 317 case *array.Int16Builder: 318 v := <-ch 319 bldr.Append(v.(int16)) 320 case *array.Int32Builder: 321 v := <-ch 322 bldr.Append(v.(int32)) 323 case *array.Int64Builder: 324 v := <-ch 325 bldr.Append(v.(int64)) 326 case *array.Uint8Builder: 327 v := <-ch 328 bldr.Append(v.(uint8)) 329 case *array.Uint16Builder: 330 v := <-ch 331 bldr.Append(v.(uint16)) 332 case *array.Uint32Builder: 333 v := <-ch 334 bldr.Append(v.(uint32)) 335 case *array.Uint64Builder: 336 v := <-ch 337 bldr.Append(v.(uint64)) 338 case *array.Float32Builder: 339 v := <-ch 340 bldr.Append(v.(float32)) 341 case *array.Float64Builder: 342 v := <-ch 343 bldr.Append(v.(float64)) 344 case *array.FixedSizeListBuilder: 345 dt := dt.(*arrow.FixedSizeListType) 346 sub := bldr.ValueBuilder() 347 n := int(dt.Len()) 348 sub.Reserve(n) 349 bldr.Append(true) 350 for range n { 351 appendData(sub, ch, dt.Elem()) 352 } 353 default: 354 panic(fmt.Errorf("npy2root: invalid Arrow builder type %T", bldr)) 355 } 356 } 357 358 type RecordReader struct { 359 recs []array.Record 360 cur int 361 } 362 363 func NewRecordReader(recs ...array.Record) *RecordReader { 364 return &RecordReader{ 365 recs: recs, 366 cur: 0, 367 } 368 } 369 370 func (rr *RecordReader) Read() (array.Record, error) { 371 if rr.cur >= len(rr.recs) { 372 return nil, io.EOF 373 } 374 rec := rr.recs[rr.cur] 375 rr.cur++ 376 return rec, nil 377 } 378 379 var ( 380 _ array.Record = (*Record)(nil) 381 _ arrio.Reader = (*RecordReader)(nil) 382 )