go-hep.org/x/hep@v0.38.1/groot/rarrow/rarrow.go (about)

     1  // Copyright ©2019 The go-hep Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package rarrow handles conversion between ROOT and ARROW data models.
     6  package rarrow // import "go-hep.org/x/hep/groot/rarrow"
     7  
     8  import (
     9  	"fmt"
    10  	"reflect"
    11  	"strings"
    12  
    13  	"git.sr.ht/~sbinet/go-arrow"
    14  	"git.sr.ht/~sbinet/go-arrow/array"
    15  	"git.sr.ht/~sbinet/go-arrow/memory"
    16  	"go-hep.org/x/hep/groot/root"
    17  	"go-hep.org/x/hep/groot/rtree"
    18  )
    19  
    20  // SchemaFrom returns an Arrow schema from the provided ROOT tree.
    21  func SchemaFrom(t rtree.Tree) *arrow.Schema {
    22  	fields := make([]arrow.Field, len(t.Branches()))
    23  	for i, b := range t.Branches() {
    24  		fields[i] = fieldFromBranch(b)
    25  	}
    26  
    27  	return arrow.NewSchema(fields, nil) // FIXME(sbinet): add metadata.
    28  }
    29  
    30  func fieldFromBranch(b rtree.Branch) arrow.Field {
    31  	fields := make([]arrow.Field, len(b.Leaves()))
    32  	for i, leaf := range b.Leaves() {
    33  		fields[i] = arrow.Field{
    34  			Name: leaf.Name(),
    35  			Type: dataTypeFromLeaf(leaf),
    36  		}
    37  	}
    38  
    39  	if len(fields) == 1 {
    40  		fields[0].Name = b.Name()
    41  		return fields[0]
    42  	}
    43  
    44  	return arrow.Field{
    45  		Name: b.Name(),
    46  		Type: arrow.StructOf(fields...),
    47  	}
    48  }
    49  
    50  func dataTypeFromLeaf(leaf rtree.Leaf) arrow.DataType {
    51  	var (
    52  		unsigned = leaf.IsUnsigned()
    53  		kind     = leaf.Kind()
    54  		typ      = leaf.Type()
    55  		dt       arrow.DataType
    56  	)
    57  
    58  	switch kind {
    59  	case reflect.Bool:
    60  		dt = arrow.FixedWidthTypes.Boolean
    61  	case reflect.Int8, reflect.Uint8:
    62  		switch {
    63  		case unsigned:
    64  			dt = arrow.PrimitiveTypes.Uint8
    65  		default:
    66  			dt = arrow.PrimitiveTypes.Int8
    67  		}
    68  	case reflect.Int16, reflect.Uint16:
    69  		switch {
    70  		case unsigned:
    71  			dt = arrow.PrimitiveTypes.Uint16
    72  		default:
    73  			dt = arrow.PrimitiveTypes.Int16
    74  		}
    75  	case reflect.Int32, reflect.Uint32:
    76  		switch {
    77  		case unsigned:
    78  			dt = arrow.PrimitiveTypes.Uint32
    79  		default:
    80  			dt = arrow.PrimitiveTypes.Int32
    81  		}
    82  	case reflect.Int64, reflect.Uint64:
    83  		switch {
    84  		case unsigned:
    85  			dt = arrow.PrimitiveTypes.Uint64
    86  		default:
    87  			dt = arrow.PrimitiveTypes.Int64
    88  		}
    89  	case reflect.Float32:
    90  		dt = arrow.PrimitiveTypes.Float32
    91  	case reflect.Float64:
    92  		dt = arrow.PrimitiveTypes.Float64
    93  	case reflect.String:
    94  		dt = arrow.BinaryTypes.String
    95  
    96  	case reflect.Struct:
    97  		dt = dataTypeFromGo(typ)
    98  
    99  	case reflect.Slice:
   100  		dt = dataTypeFromGo(typ)
   101  
   102  	default:
   103  		panic(fmt.Errorf("not implemented %#v (kind=%v)", leaf, kind))
   104  	}
   105  
   106  	switch {
   107  	case leaf.LeafCount() != nil:
   108  		shape := leaf.Shape()
   109  		switch leaf.(type) {
   110  		case *rtree.LeafF16, *rtree.LeafD32:
   111  			// workaround for https://sft.its.cern.ch/jira/browse/ROOT-10149
   112  			shape = nil
   113  		}
   114  		for i := range shape {
   115  			dt = arrow.FixedSizeListOf(int32(shape[len(shape)-1-i]), dt)
   116  		}
   117  		dt = arrow.ListOf(dt)
   118  	case leaf.Len() > 1:
   119  		shape := leaf.Shape()
   120  		switch leaf.Kind() {
   121  		case reflect.String:
   122  			switch dims := len(shape); dims {
   123  			case 0, 1:
   124  				// interpret as a single string
   125  			default:
   126  				// FIXME(sbinet): properly handle [N]string (but ROOT doesn't support that.)
   127  				// see: https://root-forum.cern.ch/t/char-t-in-a-branch/5591/2
   128  				// etype = reflect.ArrayOf(leaf.Len(), etype)
   129  				panic(fmt.Errorf("groot/rtree: invalid number of dimensions (%d)", dims))
   130  			}
   131  		default:
   132  			switch leaf.(type) {
   133  			case *rtree.LeafF16, *rtree.LeafD32:
   134  				// workaround for https://sft.its.cern.ch/jira/browse/ROOT-10149
   135  				shape = []int{leaf.Len()}
   136  			}
   137  			for i := range shape {
   138  				dt = arrow.FixedSizeListOf(int32(shape[len(shape)-1-i]), dt)
   139  			}
   140  		}
   141  	}
   142  
   143  	return dt
   144  }
   145  
   146  func dataTypeFromGo(typ reflect.Type) arrow.DataType {
   147  	switch typ.Kind() {
   148  	case reflect.Bool:
   149  		return arrow.FixedWidthTypes.Boolean
   150  	case reflect.Int8:
   151  		return arrow.PrimitiveTypes.Int8
   152  	case reflect.Int16:
   153  		return arrow.PrimitiveTypes.Int16
   154  	case reflect.Int32:
   155  		return arrow.PrimitiveTypes.Int32
   156  	case reflect.Int64:
   157  		return arrow.PrimitiveTypes.Int64
   158  	case reflect.Uint8:
   159  		return arrow.PrimitiveTypes.Uint8
   160  	case reflect.Uint16:
   161  		return arrow.PrimitiveTypes.Uint16
   162  	case reflect.Uint32:
   163  		return arrow.PrimitiveTypes.Uint32
   164  	case reflect.Uint64:
   165  		return arrow.PrimitiveTypes.Uint64
   166  	case reflect.Float32:
   167  		return arrow.PrimitiveTypes.Float32
   168  	case reflect.Float64:
   169  		return arrow.PrimitiveTypes.Float64
   170  	case reflect.Slice:
   171  		// special case []byte
   172  		if typ.Elem().Kind() == reflect.Uint8 {
   173  			return arrow.BinaryTypes.Binary
   174  		}
   175  		return arrow.ListOf(dataTypeFromGo(typ.Elem()))
   176  	case reflect.Array:
   177  		return arrow.FixedSizeListOf(int32(typ.Len()), dataTypeFromGo(typ.Elem()))
   178  	case reflect.String:
   179  		return arrow.BinaryTypes.String
   180  
   181  	case reflect.Struct:
   182  		fields := make([]arrow.Field, typ.NumField())
   183  		for i := range fields {
   184  			f := typ.Field(i)
   185  			name := f.Name
   186  			if v, ok := f.Tag.Lookup("groot"); ok {
   187  				name = v
   188  			}
   189  			if idx := strings.Index(name, "["); idx > 0 {
   190  				name = name[:idx]
   191  			}
   192  			fields[i] = arrow.Field{
   193  				Name: name,
   194  				Type: dataTypeFromGo(f.Type),
   195  			}
   196  		}
   197  		return arrow.StructOf(fields...)
   198  
   199  	default:
   200  		panic(fmt.Errorf("rarrow: unsupported Go type %v", typ))
   201  	}
   202  }
   203  
   204  func builderFrom(mem memory.Allocator, dt arrow.DataType, size int64) array.Builder {
   205  	var bldr array.Builder
   206  	switch dt := dt.(type) {
   207  	case *arrow.BooleanType:
   208  		bldr = array.NewBooleanBuilder(mem)
   209  	case *arrow.Int8Type:
   210  		bldr = array.NewInt8Builder(mem)
   211  	case *arrow.Int16Type:
   212  		bldr = array.NewInt16Builder(mem)
   213  	case *arrow.Int32Type:
   214  		bldr = array.NewInt32Builder(mem)
   215  	case *arrow.Int64Type:
   216  		bldr = array.NewInt64Builder(mem)
   217  	case *arrow.Uint8Type:
   218  		bldr = array.NewUint8Builder(mem)
   219  	case *arrow.Uint16Type:
   220  		bldr = array.NewUint16Builder(mem)
   221  	case *arrow.Uint32Type:
   222  		bldr = array.NewUint32Builder(mem)
   223  	case *arrow.Uint64Type:
   224  		bldr = array.NewUint64Builder(mem)
   225  	case *arrow.Float32Type:
   226  		bldr = array.NewFloat32Builder(mem)
   227  	case *arrow.Float64Type:
   228  		bldr = array.NewFloat64Builder(mem)
   229  	case *arrow.BinaryType:
   230  		bldr = array.NewBinaryBuilder(mem, dt)
   231  	case *arrow.StringType:
   232  		bldr = array.NewStringBuilder(mem)
   233  	case *arrow.ListType:
   234  		bldr = array.NewListBuilder(mem, dt.Elem())
   235  	case *arrow.FixedSizeListType:
   236  		bldr = array.NewFixedSizeListBuilder(mem, dt.Len(), dt.Elem())
   237  	case *arrow.StructType:
   238  		bldr = array.NewStructBuilder(mem, dt)
   239  	default:
   240  		panic(fmt.Errorf("groot/rarrow: invalid Arrow type %v", dt))
   241  	}
   242  	bldr.Reserve(int(size))
   243  	return bldr
   244  }
   245  
   246  func appendData(bldr array.Builder, v rtree.ReadVar, dt arrow.DataType) {
   247  	switch bldr := bldr.(type) {
   248  	case *array.BooleanBuilder:
   249  		bldr.Append(*v.Value.(*bool))
   250  	case *array.Int8Builder:
   251  		bldr.Append(*v.Value.(*int8))
   252  	case *array.Int16Builder:
   253  		bldr.Append(*v.Value.(*int16))
   254  	case *array.Int32Builder:
   255  		bldr.Append(*v.Value.(*int32))
   256  	case *array.Int64Builder:
   257  		bldr.Append(*v.Value.(*int64))
   258  	case *array.Uint8Builder:
   259  		bldr.Append(*v.Value.(*uint8))
   260  	case *array.Uint16Builder:
   261  		bldr.Append(*v.Value.(*uint16))
   262  	case *array.Uint32Builder:
   263  		bldr.Append(*v.Value.(*uint32))
   264  	case *array.Uint64Builder:
   265  		bldr.Append(*v.Value.(*uint64))
   266  	case *array.Float32Builder:
   267  		switch ptr := v.Value.(type) {
   268  		case *float32:
   269  			bldr.Append(*ptr)
   270  		case *root.Float16:
   271  			bldr.Append(float32(*ptr))
   272  		}
   273  	case *array.Float64Builder:
   274  		switch ptr := v.Value.(type) {
   275  		case *float64:
   276  			bldr.Append(*ptr)
   277  		case *root.Double32:
   278  			bldr.Append(float64(*ptr))
   279  		}
   280  	case *array.StringBuilder:
   281  		bldr.Append(*v.Value.(*string))
   282  
   283  	case *array.ListBuilder:
   284  		sub := bldr.ValueBuilder()
   285  		v := reflect.ValueOf(v.Value).Elem()
   286  		sub.Reserve(v.Len())
   287  		bldr.Append(true)
   288  		for i := range v.Len() {
   289  			appendValue(sub, v.Index(i).Interface())
   290  		}
   291  
   292  	case *array.FixedSizeListBuilder:
   293  		sub := bldr.ValueBuilder()
   294  		v := reflect.ValueOf(v.Value).Elem()
   295  		sub.Reserve(v.Len())
   296  		bldr.Append(true)
   297  		for i := range v.Len() {
   298  			appendValue(sub, v.Index(i).Interface())
   299  		}
   300  
   301  	case *array.StructBuilder:
   302  		bldr.Append(true)
   303  		v := reflect.ValueOf(v.Value).Elem()
   304  		for i := range bldr.NumField() {
   305  			f := bldr.FieldBuilder(i)
   306  			appendValue(f, v.Field(i).Interface())
   307  		}
   308  
   309  	default:
   310  		panic(fmt.Errorf("groot/rarrow: invalid Arrow builder type %T", bldr))
   311  	}
   312  }
   313  
   314  func appendValue(bldr array.Builder, v any) {
   315  	switch b := bldr.(type) {
   316  	case *array.BooleanBuilder:
   317  		b.Append(v.(bool))
   318  	case *array.Int8Builder:
   319  		b.Append(v.(int8))
   320  	case *array.Int16Builder:
   321  		b.Append(v.(int16))
   322  	case *array.Int32Builder:
   323  		b.Append(v.(int32))
   324  	case *array.Int64Builder:
   325  		b.Append(v.(int64))
   326  	case *array.Uint8Builder:
   327  		b.Append(v.(uint8))
   328  	case *array.Uint16Builder:
   329  		b.Append(v.(uint16))
   330  	case *array.Uint32Builder:
   331  		b.Append(v.(uint32))
   332  	case *array.Uint64Builder:
   333  		b.Append(v.(uint64))
   334  	case *array.Float32Builder:
   335  		switch v := v.(type) {
   336  		case float32:
   337  			b.Append(v)
   338  		case root.Float16:
   339  			b.Append(float32(v))
   340  		}
   341  	case *array.Float64Builder:
   342  		switch v := v.(type) {
   343  		case float64:
   344  			b.Append(v)
   345  		case root.Double32:
   346  			b.Append(float64(v))
   347  		}
   348  	case *array.StringBuilder:
   349  		b.Append(v.(string))
   350  
   351  	case *array.ListBuilder:
   352  		b.Append(true)
   353  		sub := b.ValueBuilder()
   354  		v := reflect.ValueOf(v)
   355  		for i := range v.Len() {
   356  			appendValue(sub, v.Index(i).Interface())
   357  		}
   358  
   359  	case *array.FixedSizeListBuilder:
   360  		b.Append(true)
   361  		sub := b.ValueBuilder()
   362  		v := reflect.ValueOf(v)
   363  		for i := range v.Len() {
   364  			appendValue(sub, v.Index(i).Interface())
   365  		}
   366  
   367  	case *array.StructBuilder:
   368  		b.Append(true)
   369  		v := reflect.ValueOf(v)
   370  		for i := range b.NumField() {
   371  			f := b.FieldBuilder(i)
   372  			appendValue(f, v.Field(i).Interface())
   373  		}
   374  
   375  	default:
   376  		panic(fmt.Errorf("groot/rarrow: invalid Arrow builder type %T", b))
   377  	}
   378  }