go-hep.org/x/hep@v0.38.1/groot/rarrow/rarrow.go (about) 1 // Copyright ©2019 The go-hep Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Package rarrow handles conversion between ROOT and ARROW data models. 6 package rarrow // import "go-hep.org/x/hep/groot/rarrow" 7 8 import ( 9 "fmt" 10 "reflect" 11 "strings" 12 13 "git.sr.ht/~sbinet/go-arrow" 14 "git.sr.ht/~sbinet/go-arrow/array" 15 "git.sr.ht/~sbinet/go-arrow/memory" 16 "go-hep.org/x/hep/groot/root" 17 "go-hep.org/x/hep/groot/rtree" 18 ) 19 20 // SchemaFrom returns an Arrow schema from the provided ROOT tree. 21 func SchemaFrom(t rtree.Tree) *arrow.Schema { 22 fields := make([]arrow.Field, len(t.Branches())) 23 for i, b := range t.Branches() { 24 fields[i] = fieldFromBranch(b) 25 } 26 27 return arrow.NewSchema(fields, nil) // FIXME(sbinet): add metadata. 28 } 29 30 func fieldFromBranch(b rtree.Branch) arrow.Field { 31 fields := make([]arrow.Field, len(b.Leaves())) 32 for i, leaf := range b.Leaves() { 33 fields[i] = arrow.Field{ 34 Name: leaf.Name(), 35 Type: dataTypeFromLeaf(leaf), 36 } 37 } 38 39 if len(fields) == 1 { 40 fields[0].Name = b.Name() 41 return fields[0] 42 } 43 44 return arrow.Field{ 45 Name: b.Name(), 46 Type: arrow.StructOf(fields...), 47 } 48 } 49 50 func dataTypeFromLeaf(leaf rtree.Leaf) arrow.DataType { 51 var ( 52 unsigned = leaf.IsUnsigned() 53 kind = leaf.Kind() 54 typ = leaf.Type() 55 dt arrow.DataType 56 ) 57 58 switch kind { 59 case reflect.Bool: 60 dt = arrow.FixedWidthTypes.Boolean 61 case reflect.Int8, reflect.Uint8: 62 switch { 63 case unsigned: 64 dt = arrow.PrimitiveTypes.Uint8 65 default: 66 dt = arrow.PrimitiveTypes.Int8 67 } 68 case reflect.Int16, reflect.Uint16: 69 switch { 70 case unsigned: 71 dt = arrow.PrimitiveTypes.Uint16 72 default: 73 dt = arrow.PrimitiveTypes.Int16 74 } 75 case reflect.Int32, reflect.Uint32: 76 switch { 77 case unsigned: 78 dt = arrow.PrimitiveTypes.Uint32 79 default: 80 dt = arrow.PrimitiveTypes.Int32 81 } 82 case reflect.Int64, reflect.Uint64: 83 switch { 84 case unsigned: 85 dt = arrow.PrimitiveTypes.Uint64 86 default: 87 dt = arrow.PrimitiveTypes.Int64 88 } 89 case reflect.Float32: 90 dt = arrow.PrimitiveTypes.Float32 91 case reflect.Float64: 92 dt = arrow.PrimitiveTypes.Float64 93 case reflect.String: 94 dt = arrow.BinaryTypes.String 95 96 case reflect.Struct: 97 dt = dataTypeFromGo(typ) 98 99 case reflect.Slice: 100 dt = dataTypeFromGo(typ) 101 102 default: 103 panic(fmt.Errorf("not implemented %#v (kind=%v)", leaf, kind)) 104 } 105 106 switch { 107 case leaf.LeafCount() != nil: 108 shape := leaf.Shape() 109 switch leaf.(type) { 110 case *rtree.LeafF16, *rtree.LeafD32: 111 // workaround for https://sft.its.cern.ch/jira/browse/ROOT-10149 112 shape = nil 113 } 114 for i := range shape { 115 dt = arrow.FixedSizeListOf(int32(shape[len(shape)-1-i]), dt) 116 } 117 dt = arrow.ListOf(dt) 118 case leaf.Len() > 1: 119 shape := leaf.Shape() 120 switch leaf.Kind() { 121 case reflect.String: 122 switch dims := len(shape); dims { 123 case 0, 1: 124 // interpret as a single string 125 default: 126 // FIXME(sbinet): properly handle [N]string (but ROOT doesn't support that.) 127 // see: https://root-forum.cern.ch/t/char-t-in-a-branch/5591/2 128 // etype = reflect.ArrayOf(leaf.Len(), etype) 129 panic(fmt.Errorf("groot/rtree: invalid number of dimensions (%d)", dims)) 130 } 131 default: 132 switch leaf.(type) { 133 case *rtree.LeafF16, *rtree.LeafD32: 134 // workaround for https://sft.its.cern.ch/jira/browse/ROOT-10149 135 shape = []int{leaf.Len()} 136 } 137 for i := range shape { 138 dt = arrow.FixedSizeListOf(int32(shape[len(shape)-1-i]), dt) 139 } 140 } 141 } 142 143 return dt 144 } 145 146 func dataTypeFromGo(typ reflect.Type) arrow.DataType { 147 switch typ.Kind() { 148 case reflect.Bool: 149 return arrow.FixedWidthTypes.Boolean 150 case reflect.Int8: 151 return arrow.PrimitiveTypes.Int8 152 case reflect.Int16: 153 return arrow.PrimitiveTypes.Int16 154 case reflect.Int32: 155 return arrow.PrimitiveTypes.Int32 156 case reflect.Int64: 157 return arrow.PrimitiveTypes.Int64 158 case reflect.Uint8: 159 return arrow.PrimitiveTypes.Uint8 160 case reflect.Uint16: 161 return arrow.PrimitiveTypes.Uint16 162 case reflect.Uint32: 163 return arrow.PrimitiveTypes.Uint32 164 case reflect.Uint64: 165 return arrow.PrimitiveTypes.Uint64 166 case reflect.Float32: 167 return arrow.PrimitiveTypes.Float32 168 case reflect.Float64: 169 return arrow.PrimitiveTypes.Float64 170 case reflect.Slice: 171 // special case []byte 172 if typ.Elem().Kind() == reflect.Uint8 { 173 return arrow.BinaryTypes.Binary 174 } 175 return arrow.ListOf(dataTypeFromGo(typ.Elem())) 176 case reflect.Array: 177 return arrow.FixedSizeListOf(int32(typ.Len()), dataTypeFromGo(typ.Elem())) 178 case reflect.String: 179 return arrow.BinaryTypes.String 180 181 case reflect.Struct: 182 fields := make([]arrow.Field, typ.NumField()) 183 for i := range fields { 184 f := typ.Field(i) 185 name := f.Name 186 if v, ok := f.Tag.Lookup("groot"); ok { 187 name = v 188 } 189 if idx := strings.Index(name, "["); idx > 0 { 190 name = name[:idx] 191 } 192 fields[i] = arrow.Field{ 193 Name: name, 194 Type: dataTypeFromGo(f.Type), 195 } 196 } 197 return arrow.StructOf(fields...) 198 199 default: 200 panic(fmt.Errorf("rarrow: unsupported Go type %v", typ)) 201 } 202 } 203 204 func builderFrom(mem memory.Allocator, dt arrow.DataType, size int64) array.Builder { 205 var bldr array.Builder 206 switch dt := dt.(type) { 207 case *arrow.BooleanType: 208 bldr = array.NewBooleanBuilder(mem) 209 case *arrow.Int8Type: 210 bldr = array.NewInt8Builder(mem) 211 case *arrow.Int16Type: 212 bldr = array.NewInt16Builder(mem) 213 case *arrow.Int32Type: 214 bldr = array.NewInt32Builder(mem) 215 case *arrow.Int64Type: 216 bldr = array.NewInt64Builder(mem) 217 case *arrow.Uint8Type: 218 bldr = array.NewUint8Builder(mem) 219 case *arrow.Uint16Type: 220 bldr = array.NewUint16Builder(mem) 221 case *arrow.Uint32Type: 222 bldr = array.NewUint32Builder(mem) 223 case *arrow.Uint64Type: 224 bldr = array.NewUint64Builder(mem) 225 case *arrow.Float32Type: 226 bldr = array.NewFloat32Builder(mem) 227 case *arrow.Float64Type: 228 bldr = array.NewFloat64Builder(mem) 229 case *arrow.BinaryType: 230 bldr = array.NewBinaryBuilder(mem, dt) 231 case *arrow.StringType: 232 bldr = array.NewStringBuilder(mem) 233 case *arrow.ListType: 234 bldr = array.NewListBuilder(mem, dt.Elem()) 235 case *arrow.FixedSizeListType: 236 bldr = array.NewFixedSizeListBuilder(mem, dt.Len(), dt.Elem()) 237 case *arrow.StructType: 238 bldr = array.NewStructBuilder(mem, dt) 239 default: 240 panic(fmt.Errorf("groot/rarrow: invalid Arrow type %v", dt)) 241 } 242 bldr.Reserve(int(size)) 243 return bldr 244 } 245 246 func appendData(bldr array.Builder, v rtree.ReadVar, dt arrow.DataType) { 247 switch bldr := bldr.(type) { 248 case *array.BooleanBuilder: 249 bldr.Append(*v.Value.(*bool)) 250 case *array.Int8Builder: 251 bldr.Append(*v.Value.(*int8)) 252 case *array.Int16Builder: 253 bldr.Append(*v.Value.(*int16)) 254 case *array.Int32Builder: 255 bldr.Append(*v.Value.(*int32)) 256 case *array.Int64Builder: 257 bldr.Append(*v.Value.(*int64)) 258 case *array.Uint8Builder: 259 bldr.Append(*v.Value.(*uint8)) 260 case *array.Uint16Builder: 261 bldr.Append(*v.Value.(*uint16)) 262 case *array.Uint32Builder: 263 bldr.Append(*v.Value.(*uint32)) 264 case *array.Uint64Builder: 265 bldr.Append(*v.Value.(*uint64)) 266 case *array.Float32Builder: 267 switch ptr := v.Value.(type) { 268 case *float32: 269 bldr.Append(*ptr) 270 case *root.Float16: 271 bldr.Append(float32(*ptr)) 272 } 273 case *array.Float64Builder: 274 switch ptr := v.Value.(type) { 275 case *float64: 276 bldr.Append(*ptr) 277 case *root.Double32: 278 bldr.Append(float64(*ptr)) 279 } 280 case *array.StringBuilder: 281 bldr.Append(*v.Value.(*string)) 282 283 case *array.ListBuilder: 284 sub := bldr.ValueBuilder() 285 v := reflect.ValueOf(v.Value).Elem() 286 sub.Reserve(v.Len()) 287 bldr.Append(true) 288 for i := range v.Len() { 289 appendValue(sub, v.Index(i).Interface()) 290 } 291 292 case *array.FixedSizeListBuilder: 293 sub := bldr.ValueBuilder() 294 v := reflect.ValueOf(v.Value).Elem() 295 sub.Reserve(v.Len()) 296 bldr.Append(true) 297 for i := range v.Len() { 298 appendValue(sub, v.Index(i).Interface()) 299 } 300 301 case *array.StructBuilder: 302 bldr.Append(true) 303 v := reflect.ValueOf(v.Value).Elem() 304 for i := range bldr.NumField() { 305 f := bldr.FieldBuilder(i) 306 appendValue(f, v.Field(i).Interface()) 307 } 308 309 default: 310 panic(fmt.Errorf("groot/rarrow: invalid Arrow builder type %T", bldr)) 311 } 312 } 313 314 func appendValue(bldr array.Builder, v any) { 315 switch b := bldr.(type) { 316 case *array.BooleanBuilder: 317 b.Append(v.(bool)) 318 case *array.Int8Builder: 319 b.Append(v.(int8)) 320 case *array.Int16Builder: 321 b.Append(v.(int16)) 322 case *array.Int32Builder: 323 b.Append(v.(int32)) 324 case *array.Int64Builder: 325 b.Append(v.(int64)) 326 case *array.Uint8Builder: 327 b.Append(v.(uint8)) 328 case *array.Uint16Builder: 329 b.Append(v.(uint16)) 330 case *array.Uint32Builder: 331 b.Append(v.(uint32)) 332 case *array.Uint64Builder: 333 b.Append(v.(uint64)) 334 case *array.Float32Builder: 335 switch v := v.(type) { 336 case float32: 337 b.Append(v) 338 case root.Float16: 339 b.Append(float32(v)) 340 } 341 case *array.Float64Builder: 342 switch v := v.(type) { 343 case float64: 344 b.Append(v) 345 case root.Double32: 346 b.Append(float64(v)) 347 } 348 case *array.StringBuilder: 349 b.Append(v.(string)) 350 351 case *array.ListBuilder: 352 b.Append(true) 353 sub := b.ValueBuilder() 354 v := reflect.ValueOf(v) 355 for i := range v.Len() { 356 appendValue(sub, v.Index(i).Interface()) 357 } 358 359 case *array.FixedSizeListBuilder: 360 b.Append(true) 361 sub := b.ValueBuilder() 362 v := reflect.ValueOf(v) 363 for i := range v.Len() { 364 appendValue(sub, v.Index(i).Interface()) 365 } 366 367 case *array.StructBuilder: 368 b.Append(true) 369 v := reflect.ValueOf(v) 370 for i := range b.NumField() { 371 f := b.FieldBuilder(i) 372 appendValue(f, v.Field(i).Interface()) 373 } 374 375 default: 376 panic(fmt.Errorf("groot/rarrow: invalid Arrow builder type %T", b)) 377 } 378 }