go-hep.org/x/hep@v0.38.1/groot/rnpy/arrow.go (about)

     1  // Copyright ©2019 The go-hep Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package rnpy
     6  
     7  import (
     8  	"fmt"
     9  	"io"
    10  	"reflect"
    11  	"sync/atomic"
    12  
    13  	"codeberg.org/sbinet/npyio/npy"
    14  	"git.sr.ht/~sbinet/go-arrow"
    15  	"git.sr.ht/~sbinet/go-arrow/array"
    16  	"git.sr.ht/~sbinet/go-arrow/arrio"
    17  	"git.sr.ht/~sbinet/go-arrow/memory"
    18  )
    19  
    20  var (
    21  	boolType    = reflect.TypeOf(true)
    22  	uint8Type   = reflect.TypeOf((*uint8)(nil)).Elem()
    23  	uint16Type  = reflect.TypeOf((*uint16)(nil)).Elem()
    24  	uint32Type  = reflect.TypeOf((*uint32)(nil)).Elem()
    25  	uint64Type  = reflect.TypeOf((*uint64)(nil)).Elem()
    26  	int8Type    = reflect.TypeOf((*int8)(nil)).Elem()
    27  	int16Type   = reflect.TypeOf((*int16)(nil)).Elem()
    28  	int32Type   = reflect.TypeOf((*int32)(nil)).Elem()
    29  	int64Type   = reflect.TypeOf((*int64)(nil)).Elem()
    30  	float32Type = reflect.TypeOf((*float32)(nil)).Elem()
    31  	float64Type = reflect.TypeOf((*float64)(nil)).Elem()
    32  
    33  // complex64Type  = reflect.TypeOf((*complex64)(nil)).Elem()
    34  // complex128Type = reflect.TypeOf((*complex128)(nil)).Elem()
    35  )
    36  
    37  // Record is an in-memory Arrow Record backed by a NumPy data file.
    38  type Record struct {
    39  	refs int64
    40  
    41  	mem memory.Allocator
    42  
    43  	schema *arrow.Schema
    44  	nrows  int64
    45  	ncols  int64
    46  
    47  	cols []array.Interface
    48  }
    49  
    50  // NewRecord returns an Arrow Record from a NumPy data file reader.
    51  func NewRecord(npy *npy.Reader) *Record {
    52  	var (
    53  		mem    = memory.NewGoAllocator()
    54  		schema = schemaFrom(npy)
    55  		shape  = make([]int, len(npy.Header.Descr.Shape))
    56  	)
    57  
    58  	copy(shape, npy.Header.Descr.Shape)
    59  	if npy.Header.Descr.Fortran {
    60  		a := shape
    61  		for i := len(a)/2 - 1; i >= 0; i-- {
    62  			opp := len(a) - 1 - i
    63  			a[i], a[opp] = a[opp], a[i]
    64  		}
    65  		shape = a
    66  	}
    67  	nrows := int64(shape[0])
    68  
    69  	rec := &Record{
    70  		refs:   1,
    71  		mem:    mem,
    72  		schema: schema,
    73  		nrows:  nrows,
    74  		ncols:  1,
    75  	}
    76  
    77  	nelem := int64(1)
    78  	for _, v := range shape {
    79  		nelem *= int64(v)
    80  	}
    81  
    82  	bldr := builderFrom(mem, schema.Field(0).Type, nrows)
    83  	defer bldr.Release()
    84  
    85  	rec.read(npy, nelem, bldr)
    86  
    87  	return rec
    88  }
    89  
    90  // Retain increases the reference count by 1.
    91  // Retain may be called simultaneously from multiple goroutines.
    92  func (rec *Record) Retain() {
    93  	atomic.AddInt64(&rec.refs, 1)
    94  }
    95  
    96  // Release decreases the reference count by 1.
    97  // When the reference count goes to zero, the memory is freed.
    98  // Release may be called simultaneously from multiple goroutines.
    99  func (rec *Record) Release() {
   100  	if atomic.LoadInt64(&rec.refs) <= 0 {
   101  		panic("groot/rarrow: too many releases")
   102  	}
   103  
   104  	if atomic.AddInt64(&rec.refs, -1) == 0 {
   105  		for i := range rec.cols {
   106  			rec.cols[i].Release()
   107  		}
   108  		rec.cols = nil
   109  	}
   110  }
   111  
   112  func (rec *Record) Schema() *arrow.Schema        { return rec.schema }
   113  func (rec *Record) NumRows() int64               { return rec.nrows }
   114  func (rec *Record) NumCols() int64               { return rec.ncols }
   115  func (rec *Record) Columns() []array.Interface   { return rec.cols }
   116  func (rec *Record) Column(i int) array.Interface { return rec.cols[i] }
   117  func (rec *Record) ColumnName(i int) string      { return rec.schema.Field(i).Name }
   118  
   119  // NewSlice constructs a zero-copy slice of the record with the indicated
   120  // indices i and j, corresponding to array[i:j].
   121  // The returned record must be Release()'d after use.
   122  //
   123  // NewSlice panics if the slice is outside the valid range of the record array.
   124  // NewSlice panics if j < i.
   125  func (rec *Record) NewSlice(i, j int64) array.Record {
   126  	panic("not implemented")
   127  }
   128  
   129  func (rec *Record) read(r *npy.Reader, nelem int64, bldr array.Builder) {
   130  	rt := dtypeFrom(rec.schema.Field(0).Type)
   131  	rv := reflect.New(reflect.SliceOf(rt)).Elem()
   132  	rv.Set(reflect.MakeSlice(rv.Type(), int(nelem), int(nelem)))
   133  
   134  	err := r.Read(rv.Addr().Interface())
   135  	if err != nil {
   136  		panic(fmt.Errorf("npy2root: could not read numpy data: %w", err))
   137  	}
   138  
   139  	ch := make(chan any, nelem/2)
   140  	go func() {
   141  		defer close(ch)
   142  		for i := range rv.Len() {
   143  			ch <- rv.Index(i).Interface()
   144  		}
   145  	}()
   146  
   147  	for i := int64(0); i < rec.nrows; i++ {
   148  		appendData(bldr, ch, rec.schema.Field(0).Type)
   149  	}
   150  
   151  	rec.cols = append(rec.cols, bldr.NewArray())
   152  }
   153  
   154  func schemaFrom(npy *npy.Reader) *arrow.Schema {
   155  	var (
   156  		hdr   = npy.Header
   157  		dtype arrow.DataType
   158  	)
   159  	switch hdr.Descr.Type {
   160  	case "b1", "<b1", "|b1", "bool":
   161  		dtype = arrow.FixedWidthTypes.Boolean
   162  
   163  	case "u1", "<u1", "|u1", "uint8":
   164  		dtype = arrow.PrimitiveTypes.Uint8
   165  
   166  	case "u2", "<u2", "|u2", ">u2", "uint16":
   167  		dtype = arrow.PrimitiveTypes.Uint16
   168  
   169  	case "u4", "<u4", "|u4", ">u4", "uint32":
   170  		dtype = arrow.PrimitiveTypes.Uint32
   171  
   172  	case "u8", "<u8", "|u8", ">u8", "uint64":
   173  		dtype = arrow.PrimitiveTypes.Uint64
   174  
   175  	case "i1", "<i1", "|i1", ">i1", "int8":
   176  		dtype = arrow.PrimitiveTypes.Int8
   177  
   178  	case "i2", "<i2", "|i2", ">i2", "int16":
   179  		dtype = arrow.PrimitiveTypes.Int16
   180  
   181  	case "i4", "<i4", "|i4", ">i4", "int32":
   182  		dtype = arrow.PrimitiveTypes.Int32
   183  
   184  	case "i8", "<i8", "|i8", ">i8", "int64":
   185  		dtype = arrow.PrimitiveTypes.Int64
   186  
   187  	case "f4", "<f4", "|f4", ">f4", "float32":
   188  		dtype = arrow.PrimitiveTypes.Float32
   189  
   190  	case "f8", "<f8", "|f8", ">f8", "float64":
   191  		dtype = arrow.PrimitiveTypes.Float64
   192  
   193  		//	case "c8", "<c8", "|c8", ">c8", "complex64":
   194  		//		panic(fmt.Errorf("npy2root: complex64 not supported"))
   195  		//
   196  		//	case "c16", "<c16", "|c16", ">c16", "complex128":
   197  		//		panic(fmt.Errorf("npy2root: complex128 not supported"))
   198  
   199  	default:
   200  		panic(fmt.Errorf("npy2root: invalid dtype descriptor %q", hdr.Descr.Type))
   201  	}
   202  
   203  	shape := make([]int, len(hdr.Descr.Shape))
   204  	copy(shape, hdr.Descr.Shape)
   205  	if hdr.Descr.Fortran {
   206  		a := shape
   207  		for i := len(a)/2 - 1; i >= 0; i-- {
   208  			opp := len(a) - 1 - i
   209  			a[i], a[opp] = a[opp], a[i]
   210  		}
   211  		shape = a
   212  	}
   213  
   214  	switch len(shape) {
   215  	case 1:
   216  		// scalar
   217  
   218  	case 2:
   219  		// 1d-array
   220  		dtype = arrow.FixedSizeListOf(int32(shape[1]), dtype)
   221  
   222  	case 3, 4, 5:
   223  		// 2,3d-array
   224  		for i := range shape[1:] {
   225  			dtype = arrow.FixedSizeListOf(int32(shape[len(shape)-1-i]), dtype)
   226  		}
   227  
   228  	default:
   229  		panic(fmt.Errorf("npy2root: invalid shape descriptor %v", hdr.Descr.Shape))
   230  	}
   231  
   232  	schema := arrow.NewSchema([]arrow.Field{{Name: "numpy", Type: dtype}}, nil)
   233  	return schema
   234  }
   235  
   236  func builderFrom(mem memory.Allocator, dt arrow.DataType, size int64) array.Builder {
   237  	var bldr array.Builder
   238  	switch dt := dt.(type) {
   239  	case *arrow.BooleanType:
   240  		bldr = array.NewBooleanBuilder(mem)
   241  	case *arrow.Int8Type:
   242  		bldr = array.NewInt8Builder(mem)
   243  	case *arrow.Int16Type:
   244  		bldr = array.NewInt16Builder(mem)
   245  	case *arrow.Int32Type:
   246  		bldr = array.NewInt32Builder(mem)
   247  	case *arrow.Int64Type:
   248  		bldr = array.NewInt64Builder(mem)
   249  	case *arrow.Uint8Type:
   250  		bldr = array.NewUint8Builder(mem)
   251  	case *arrow.Uint16Type:
   252  		bldr = array.NewUint16Builder(mem)
   253  	case *arrow.Uint32Type:
   254  		bldr = array.NewUint32Builder(mem)
   255  	case *arrow.Uint64Type:
   256  		bldr = array.NewUint64Builder(mem)
   257  	case *arrow.Float32Type:
   258  		bldr = array.NewFloat32Builder(mem)
   259  	case *arrow.Float64Type:
   260  		bldr = array.NewFloat64Builder(mem)
   261  		//	case *arrow.BinaryType:
   262  		//		bldr = array.NewBinaryBuilder(mem, dt)
   263  		//	case *arrow.StringType:
   264  		//		bldr = array.NewStringBuilder(mem)
   265  	case *arrow.FixedSizeListType:
   266  		bldr = array.NewFixedSizeListBuilder(mem, dt.Len(), dt.Elem())
   267  	default:
   268  		panic(fmt.Errorf("npy2root: invalid Arrow type %v", dt))
   269  	}
   270  	bldr.Reserve(int(size))
   271  	return bldr
   272  }
   273  
   274  func dtypeFrom(dt arrow.DataType) reflect.Type {
   275  	switch dt := dt.(type) {
   276  	case *arrow.BooleanType:
   277  		return boolType
   278  	case *arrow.Int8Type:
   279  		return int8Type
   280  	case *arrow.Int16Type:
   281  		return int16Type
   282  	case *arrow.Int32Type:
   283  		return int32Type
   284  	case *arrow.Int64Type:
   285  		return int64Type
   286  	case *arrow.Uint8Type:
   287  		return uint8Type
   288  	case *arrow.Uint16Type:
   289  		return uint16Type
   290  	case *arrow.Uint32Type:
   291  		return uint32Type
   292  	case *arrow.Uint64Type:
   293  		return uint64Type
   294  	case *arrow.Float32Type:
   295  		return float32Type
   296  	case *arrow.Float64Type:
   297  		return float64Type
   298  		//	case *arrow.BinaryType:
   299  		//		bldr = array.NewBinaryBuilder(mem, dt)
   300  		//	case *arrow.StringType:
   301  		//		bldr = array.NewStringBuilder(mem)
   302  	case *arrow.FixedSizeListType:
   303  		return dtypeFrom(dt.Elem())
   304  	default:
   305  		panic(fmt.Errorf("npy2root: invalid Arrow type %v", dt))
   306  	}
   307  }
   308  
   309  func appendData(bldr array.Builder, ch <-chan any, dt arrow.DataType) {
   310  	switch bldr := bldr.(type) {
   311  	case *array.BooleanBuilder:
   312  		v := <-ch
   313  		bldr.Append(v.(bool))
   314  	case *array.Int8Builder:
   315  		v := <-ch
   316  		bldr.Append(v.(int8))
   317  	case *array.Int16Builder:
   318  		v := <-ch
   319  		bldr.Append(v.(int16))
   320  	case *array.Int32Builder:
   321  		v := <-ch
   322  		bldr.Append(v.(int32))
   323  	case *array.Int64Builder:
   324  		v := <-ch
   325  		bldr.Append(v.(int64))
   326  	case *array.Uint8Builder:
   327  		v := <-ch
   328  		bldr.Append(v.(uint8))
   329  	case *array.Uint16Builder:
   330  		v := <-ch
   331  		bldr.Append(v.(uint16))
   332  	case *array.Uint32Builder:
   333  		v := <-ch
   334  		bldr.Append(v.(uint32))
   335  	case *array.Uint64Builder:
   336  		v := <-ch
   337  		bldr.Append(v.(uint64))
   338  	case *array.Float32Builder:
   339  		v := <-ch
   340  		bldr.Append(v.(float32))
   341  	case *array.Float64Builder:
   342  		v := <-ch
   343  		bldr.Append(v.(float64))
   344  	case *array.FixedSizeListBuilder:
   345  		dt := dt.(*arrow.FixedSizeListType)
   346  		sub := bldr.ValueBuilder()
   347  		n := int(dt.Len())
   348  		sub.Reserve(n)
   349  		bldr.Append(true)
   350  		for range n {
   351  			appendData(sub, ch, dt.Elem())
   352  		}
   353  	default:
   354  		panic(fmt.Errorf("npy2root: invalid Arrow builder type %T", bldr))
   355  	}
   356  }
   357  
   358  type RecordReader struct {
   359  	recs []array.Record
   360  	cur  int
   361  }
   362  
   363  func NewRecordReader(recs ...array.Record) *RecordReader {
   364  	return &RecordReader{
   365  		recs: recs,
   366  		cur:  0,
   367  	}
   368  }
   369  
   370  func (rr *RecordReader) Read() (array.Record, error) {
   371  	if rr.cur >= len(rr.recs) {
   372  		return nil, io.EOF
   373  	}
   374  	rec := rr.recs[rr.cur]
   375  	rr.cur++
   376  	return rec, nil
   377  }
   378  
   379  var (
   380  	_ array.Record = (*Record)(nil)
   381  	_ arrio.Reader = (*RecordReader)(nil)
   382  )