go-hep.org/x/hep@v0.38.1/groot/rarrow/tree_writer.go (about) 1 // Copyright ©2019 The go-hep Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package rarrow 6 7 import ( 8 "fmt" 9 "reflect" 10 11 "git.sr.ht/~sbinet/go-arrow" 12 "git.sr.ht/~sbinet/go-arrow/array" 13 "git.sr.ht/~sbinet/go-arrow/arrio" 14 "go-hep.org/x/hep/groot/riofs" 15 "go-hep.org/x/hep/groot/rtree" 16 ) 17 18 // flatTreeWriter writes ARROW data as a ROOT flat-tree. 19 type flatTreeWriter struct { 20 w rtree.Writer 21 schema *arrow.Schema 22 ctx contextWriter 23 } 24 25 // NewFlatTreeWriter creates an arrio.Writer that writes ARROW data as a ROOT 26 // flat-tree under the provided dir directory. 27 func NewFlatTreeWriter(dir riofs.Directory, name string, schema *arrow.Schema, opts ...rtree.WriteOption) (*flatTreeWriter, error) { 28 var ( 29 ctx = newContextWriter(schema) 30 wvars = make([]rtree.WriteVar, 0, len(ctx.wvars)+len(ctx.count)) 31 ) 32 33 for _, wvar := range ctx.count { 34 wvars = append(wvars, wvar) 35 } 36 wvars = append(wvars, ctx.wvars...) 37 38 tree, err := rtree.NewWriter(dir, name, wvars, opts...) 39 if err != nil { 40 return nil, fmt.Errorf("rarrow: could not create flat-tree writer %q: %w", name, err) 41 } 42 return &flatTreeWriter{w: tree, schema: schema, ctx: ctx}, nil 43 } 44 45 // Close closes the underlying ROOT tree writer. 46 func (fw *flatTreeWriter) Close() error { 47 return fw.w.Close() 48 } 49 50 // Write writes the provided ARROW record to the underlying ROOT flat-tree. 51 // Write implements arrio.Writer. 52 func (fw *flatTreeWriter) Write(rec array.Record) error { 53 if src := rec.Schema(); !fw.schema.Equal(src) { 54 return fmt.Errorf("rarrow: invalid input record schema:\n - got= %v\n - want=%v", src, fw.schema) 55 } 56 57 nrows := rec.Column(0).Len() 58 for icol, col := range rec.Columns() { 59 if col.Len() != nrows { 60 return fmt.Errorf( 61 "rarrow: column %q (index=%d) has not the same number of rows than others (got=%d, want=%d)", 62 rec.ColumnName(icol), icol, col.Len(), nrows, 63 ) 64 } 65 } 66 67 for irow := range nrows { 68 for icol, col := range rec.Columns() { 69 wvar := &fw.ctx.wvars[icol] 70 err := fw.ctx.readFrom(wvar, irow, col) 71 if err != nil { 72 return fmt.Errorf( 73 "rarrow: could not read row=%d from column[%d](name=%s): %w", 74 irow, icol, rec.ColumnName(icol), err, 75 ) 76 } 77 } 78 _, err := fw.w.Write() 79 if err != nil { 80 return fmt.Errorf("rarrow: could not write row=%d to tree: %w", irow, err) 81 } 82 } 83 84 return nil 85 } 86 87 type contextWriter struct { 88 wvars []rtree.WriteVar 89 count map[string]rtree.WriteVar 90 } 91 92 func newContextWriter(schema *arrow.Schema) contextWriter { 93 ctx := contextWriter{ 94 wvars: make([]rtree.WriteVar, len(schema.Fields())), 95 count: make(map[string]rtree.WriteVar), 96 } 97 for i, field := range schema.Fields() { 98 ctx.wvars[i] = ctx.writeVarFrom(field) 99 } 100 return ctx 101 } 102 103 func (ctx *contextWriter) writeVarFrom(field arrow.Field) rtree.WriteVar { 104 switch dt := field.Type.(type) { 105 case *arrow.BooleanType: 106 return rtree.WriteVar{ 107 Name: field.Name, 108 Value: new(bool), 109 } 110 111 case *arrow.Int8Type: 112 return rtree.WriteVar{ 113 Name: field.Name, 114 Value: new(int8), 115 } 116 117 case *arrow.Int16Type: 118 return rtree.WriteVar{ 119 Name: field.Name, 120 Value: new(int16), 121 } 122 123 case *arrow.Int32Type: 124 return rtree.WriteVar{ 125 Name: field.Name, 126 Value: new(int32), 127 } 128 129 case *arrow.Int64Type: 130 return rtree.WriteVar{ 131 Name: field.Name, 132 Value: new(int64), 133 } 134 135 case *arrow.Uint8Type: 136 return rtree.WriteVar{ 137 Name: field.Name, 138 Value: new(uint8), 139 } 140 141 case *arrow.Uint16Type: 142 return rtree.WriteVar{ 143 Name: field.Name, 144 Value: new(uint16), 145 } 146 147 case *arrow.Uint32Type: 148 return rtree.WriteVar{ 149 Name: field.Name, 150 Value: new(uint32), 151 } 152 153 case *arrow.Uint64Type: 154 return rtree.WriteVar{ 155 Name: field.Name, 156 Value: new(uint64), 157 } 158 159 case *arrow.Float32Type: 160 return rtree.WriteVar{ 161 Name: field.Name, 162 Value: new(float32), 163 } 164 165 case *arrow.Float64Type: 166 return rtree.WriteVar{ 167 Name: field.Name, 168 Value: new(float64), 169 } 170 171 case *arrow.StringType: 172 return rtree.WriteVar{ 173 Name: field.Name, 174 Value: new(string), 175 } 176 case *arrow.BinaryType: 177 // FIXME(sbinet): differentiate the 2 (Binary/String) ? 178 return rtree.WriteVar{ 179 Name: field.Name, 180 Value: new(string), 181 } 182 183 case *arrow.FixedSizeListType: 184 wv := ctx.writeVarFrom(arrow.Field{Type: dt.Elem(), Name: "elem"}) 185 rt := reflect.ArrayOf(int(dt.Len()), reflect.TypeOf(wv.Value).Elem()) 186 return rtree.WriteVar{ 187 Name: field.Name, 188 Value: reflect.New(rt).Interface(), 189 } 190 191 case *arrow.FixedSizeBinaryType: 192 rt := reflect.ArrayOf(dt.ByteWidth, reflect.TypeOf(byte(0))) 193 return rtree.WriteVar{ 194 Name: field.Name, 195 Value: reflect.New(rt).Interface(), 196 } 197 198 case *arrow.ListType: 199 wv := ctx.writeVarFrom(arrow.Field{Type: dt.Elem(), Name: "elem"}) 200 rt := reflect.SliceOf(reflect.TypeOf(wv.Value).Elem()) 201 nn := "rarrow_n_" + field.Name 202 ctx.count[field.Name] = rtree.WriteVar{ 203 Name: nn, 204 Value: new(int32), 205 } 206 return rtree.WriteVar{ 207 Name: field.Name, 208 Value: reflect.New(rt).Interface(), 209 Count: nn, 210 } 211 212 // case *arrow.StructType: 213 // fields := make([]reflect.StructField, len(dt.Fields())) 214 // for i, ft := range dt.Fields() { 215 // wv := writeVarFrom(ft) 216 // fields[i] = reflect.StructField{ 217 // Name: "ROOT_" + ft.Name, 218 // Type: reflect.TypeOf(wv.Value).Elem(), 219 // Tag: reflect.StructTag(fmt.Sprintf("groot:%q", ft.Name)), 220 // } 221 // } 222 // rt := reflect.StructOf(fields) 223 // return rtree.WriteVar{ 224 // Name: field.Name, 225 // Value: reflect.New(rt).Interface(), 226 // } 227 228 default: 229 panic(fmt.Errorf("invalid ARROW data-type: %T", dt)) 230 } 231 } 232 233 func (ctx *contextWriter) readFrom(wvar *rtree.WriteVar, irow int, arr array.Interface) error { 234 ptr := wvar.Value 235 switch arr := arr.(type) { 236 case *array.Boolean: 237 *ptr.(*bool) = arr.Value(irow) 238 case *array.Int8: 239 *ptr.(*int8) = arr.Value(irow) 240 case *array.Int16: 241 *ptr.(*int16) = arr.Value(irow) 242 case *array.Int32: 243 *ptr.(*int32) = arr.Value(irow) 244 case *array.Int64: 245 *ptr.(*int64) = arr.Value(irow) 246 case *array.Uint8: 247 *ptr.(*uint8) = arr.Value(irow) 248 case *array.Uint16: 249 *ptr.(*uint16) = arr.Value(irow) 250 case *array.Uint32: 251 *ptr.(*uint32) = arr.Value(irow) 252 case *array.Uint64: 253 *ptr.(*uint64) = arr.Value(irow) 254 case *array.Float32: 255 *ptr.(*float32) = arr.Value(irow) 256 case *array.Float64: 257 *ptr.(*float64) = arr.Value(irow) 258 case *array.String: 259 *ptr.(*string) = arr.Value(irow) 260 case *array.Binary: 261 *ptr.(*string) = string(arr.Value(irow)) 262 263 case *array.FixedSizeList: 264 rv := reflect.ValueOf(ptr).Elem() 265 n := int64(rv.Len()) 266 off := int64(arr.Offset()) 267 beg := (off + int64(irow)) * n 268 end := (off + int64(irow+1)) * n 269 ra := array.NewSlice(arr.ListValues(), beg, end) 270 defer ra.Release() 271 ptr := &rtree.WriteVar{ 272 Name: "_rarrow_elem_" + wvar.Name, 273 } 274 for i := range rv.Len() { 275 ptr.Value = rv.Index(i).Addr().Interface() 276 err := ctx.readFrom(ptr, i, ra) 277 if err != nil { 278 return err 279 } 280 } 281 282 case *array.FixedSizeBinary: 283 rv := reflect.ValueOf(ptr).Elem() 284 sli := rv.Slice(0, rv.Len()).Interface().([]byte) 285 copy(sli, arr.Value(irow)) 286 287 case *array.List: 288 rv := reflect.ValueOf(ptr).Elem() 289 rc := reflect.ValueOf(ctx.count[wvar.Name].Value).Elem() 290 if !arr.IsValid(irow) { 291 rc.SetInt(0) 292 rv.SetLen(0) 293 return nil 294 } 295 296 j := irow + arr.Data().Offset() 297 beg := int64(arr.Offsets()[j]) 298 end := int64(arr.Offsets()[j+1]) 299 sli := array.NewSlice(arr.ListValues(), beg, end) 300 defer sli.Release() 301 302 sz := sli.Len() 303 rc.SetInt(int64(sz)) 304 305 if src, dst := sz, rv.Len(); src > dst { 306 rv.Set(reflect.MakeSlice(rv.Type(), src, src)) 307 } 308 rv.SetLen(sz) 309 310 ptr := &rtree.WriteVar{ 311 Name: "_rarrow_elem_" + wvar.Name, 312 } 313 for i := range sli.Len() { 314 ptr.Value = rv.Index(i).Addr().Interface() 315 err := ctx.readFrom(ptr, i, sli) 316 if err != nil { 317 return err 318 } 319 } 320 321 default: 322 panic(fmt.Errorf("invalid array type %T", arr)) 323 } 324 return nil 325 } 326 327 var ( 328 _ arrio.Writer = (*flatTreeWriter)(nil) 329 )