go-hep.org/x/hep@v0.38.1/groot/rarrow/tree_writer.go (about)

     1  // Copyright ©2019 The go-hep Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package rarrow
     6  
     7  import (
     8  	"fmt"
     9  	"reflect"
    10  
    11  	"git.sr.ht/~sbinet/go-arrow"
    12  	"git.sr.ht/~sbinet/go-arrow/array"
    13  	"git.sr.ht/~sbinet/go-arrow/arrio"
    14  	"go-hep.org/x/hep/groot/riofs"
    15  	"go-hep.org/x/hep/groot/rtree"
    16  )
    17  
    18  // flatTreeWriter writes ARROW data as a ROOT flat-tree.
    19  type flatTreeWriter struct {
    20  	w      rtree.Writer
    21  	schema *arrow.Schema
    22  	ctx    contextWriter
    23  }
    24  
    25  // NewFlatTreeWriter creates an arrio.Writer that writes ARROW data as a ROOT
    26  // flat-tree under the provided dir directory.
    27  func NewFlatTreeWriter(dir riofs.Directory, name string, schema *arrow.Schema, opts ...rtree.WriteOption) (*flatTreeWriter, error) {
    28  	var (
    29  		ctx   = newContextWriter(schema)
    30  		wvars = make([]rtree.WriteVar, 0, len(ctx.wvars)+len(ctx.count))
    31  	)
    32  
    33  	for _, wvar := range ctx.count {
    34  		wvars = append(wvars, wvar)
    35  	}
    36  	wvars = append(wvars, ctx.wvars...)
    37  
    38  	tree, err := rtree.NewWriter(dir, name, wvars, opts...)
    39  	if err != nil {
    40  		return nil, fmt.Errorf("rarrow: could not create flat-tree writer %q: %w", name, err)
    41  	}
    42  	return &flatTreeWriter{w: tree, schema: schema, ctx: ctx}, nil
    43  }
    44  
    45  // Close closes the underlying ROOT tree writer.
    46  func (fw *flatTreeWriter) Close() error {
    47  	return fw.w.Close()
    48  }
    49  
    50  // Write writes the provided ARROW record to the underlying ROOT flat-tree.
    51  // Write implements arrio.Writer.
    52  func (fw *flatTreeWriter) Write(rec array.Record) error {
    53  	if src := rec.Schema(); !fw.schema.Equal(src) {
    54  		return fmt.Errorf("rarrow: invalid input record schema:\n - got= %v\n - want=%v", src, fw.schema)
    55  	}
    56  
    57  	nrows := rec.Column(0).Len()
    58  	for icol, col := range rec.Columns() {
    59  		if col.Len() != nrows {
    60  			return fmt.Errorf(
    61  				"rarrow: column %q (index=%d) has not the same number of rows than others (got=%d, want=%d)",
    62  				rec.ColumnName(icol), icol, col.Len(), nrows,
    63  			)
    64  		}
    65  	}
    66  
    67  	for irow := range nrows {
    68  		for icol, col := range rec.Columns() {
    69  			wvar := &fw.ctx.wvars[icol]
    70  			err := fw.ctx.readFrom(wvar, irow, col)
    71  			if err != nil {
    72  				return fmt.Errorf(
    73  					"rarrow: could not read row=%d from column[%d](name=%s): %w",
    74  					irow, icol, rec.ColumnName(icol), err,
    75  				)
    76  			}
    77  		}
    78  		_, err := fw.w.Write()
    79  		if err != nil {
    80  			return fmt.Errorf("rarrow: could not write row=%d to tree: %w", irow, err)
    81  		}
    82  	}
    83  
    84  	return nil
    85  }
    86  
    87  type contextWriter struct {
    88  	wvars []rtree.WriteVar
    89  	count map[string]rtree.WriteVar
    90  }
    91  
    92  func newContextWriter(schema *arrow.Schema) contextWriter {
    93  	ctx := contextWriter{
    94  		wvars: make([]rtree.WriteVar, len(schema.Fields())),
    95  		count: make(map[string]rtree.WriteVar),
    96  	}
    97  	for i, field := range schema.Fields() {
    98  		ctx.wvars[i] = ctx.writeVarFrom(field)
    99  	}
   100  	return ctx
   101  }
   102  
   103  func (ctx *contextWriter) writeVarFrom(field arrow.Field) rtree.WriteVar {
   104  	switch dt := field.Type.(type) {
   105  	case *arrow.BooleanType:
   106  		return rtree.WriteVar{
   107  			Name:  field.Name,
   108  			Value: new(bool),
   109  		}
   110  
   111  	case *arrow.Int8Type:
   112  		return rtree.WriteVar{
   113  			Name:  field.Name,
   114  			Value: new(int8),
   115  		}
   116  
   117  	case *arrow.Int16Type:
   118  		return rtree.WriteVar{
   119  			Name:  field.Name,
   120  			Value: new(int16),
   121  		}
   122  
   123  	case *arrow.Int32Type:
   124  		return rtree.WriteVar{
   125  			Name:  field.Name,
   126  			Value: new(int32),
   127  		}
   128  
   129  	case *arrow.Int64Type:
   130  		return rtree.WriteVar{
   131  			Name:  field.Name,
   132  			Value: new(int64),
   133  		}
   134  
   135  	case *arrow.Uint8Type:
   136  		return rtree.WriteVar{
   137  			Name:  field.Name,
   138  			Value: new(uint8),
   139  		}
   140  
   141  	case *arrow.Uint16Type:
   142  		return rtree.WriteVar{
   143  			Name:  field.Name,
   144  			Value: new(uint16),
   145  		}
   146  
   147  	case *arrow.Uint32Type:
   148  		return rtree.WriteVar{
   149  			Name:  field.Name,
   150  			Value: new(uint32),
   151  		}
   152  
   153  	case *arrow.Uint64Type:
   154  		return rtree.WriteVar{
   155  			Name:  field.Name,
   156  			Value: new(uint64),
   157  		}
   158  
   159  	case *arrow.Float32Type:
   160  		return rtree.WriteVar{
   161  			Name:  field.Name,
   162  			Value: new(float32),
   163  		}
   164  
   165  	case *arrow.Float64Type:
   166  		return rtree.WriteVar{
   167  			Name:  field.Name,
   168  			Value: new(float64),
   169  		}
   170  
   171  	case *arrow.StringType:
   172  		return rtree.WriteVar{
   173  			Name:  field.Name,
   174  			Value: new(string),
   175  		}
   176  	case *arrow.BinaryType:
   177  		// FIXME(sbinet): differentiate the 2 (Binary/String) ?
   178  		return rtree.WriteVar{
   179  			Name:  field.Name,
   180  			Value: new(string),
   181  		}
   182  
   183  	case *arrow.FixedSizeListType:
   184  		wv := ctx.writeVarFrom(arrow.Field{Type: dt.Elem(), Name: "elem"})
   185  		rt := reflect.ArrayOf(int(dt.Len()), reflect.TypeOf(wv.Value).Elem())
   186  		return rtree.WriteVar{
   187  			Name:  field.Name,
   188  			Value: reflect.New(rt).Interface(),
   189  		}
   190  
   191  	case *arrow.FixedSizeBinaryType:
   192  		rt := reflect.ArrayOf(dt.ByteWidth, reflect.TypeOf(byte(0)))
   193  		return rtree.WriteVar{
   194  			Name:  field.Name,
   195  			Value: reflect.New(rt).Interface(),
   196  		}
   197  
   198  	case *arrow.ListType:
   199  		wv := ctx.writeVarFrom(arrow.Field{Type: dt.Elem(), Name: "elem"})
   200  		rt := reflect.SliceOf(reflect.TypeOf(wv.Value).Elem())
   201  		nn := "rarrow_n_" + field.Name
   202  		ctx.count[field.Name] = rtree.WriteVar{
   203  			Name:  nn,
   204  			Value: new(int32),
   205  		}
   206  		return rtree.WriteVar{
   207  			Name:  field.Name,
   208  			Value: reflect.New(rt).Interface(),
   209  			Count: nn,
   210  		}
   211  
   212  		//	case *arrow.StructType:
   213  		//		fields := make([]reflect.StructField, len(dt.Fields()))
   214  		//		for i, ft := range dt.Fields() {
   215  		//			wv := writeVarFrom(ft)
   216  		//			fields[i] = reflect.StructField{
   217  		//				Name: "ROOT_" + ft.Name,
   218  		//				Type: reflect.TypeOf(wv.Value).Elem(),
   219  		//				Tag:  reflect.StructTag(fmt.Sprintf("groot:%q", ft.Name)),
   220  		//			}
   221  		//		}
   222  		//		rt := reflect.StructOf(fields)
   223  		//		return rtree.WriteVar{
   224  		//			Name:  field.Name,
   225  		//			Value: reflect.New(rt).Interface(),
   226  		//		}
   227  
   228  	default:
   229  		panic(fmt.Errorf("invalid ARROW data-type: %T", dt))
   230  	}
   231  }
   232  
   233  func (ctx *contextWriter) readFrom(wvar *rtree.WriteVar, irow int, arr array.Interface) error {
   234  	ptr := wvar.Value
   235  	switch arr := arr.(type) {
   236  	case *array.Boolean:
   237  		*ptr.(*bool) = arr.Value(irow)
   238  	case *array.Int8:
   239  		*ptr.(*int8) = arr.Value(irow)
   240  	case *array.Int16:
   241  		*ptr.(*int16) = arr.Value(irow)
   242  	case *array.Int32:
   243  		*ptr.(*int32) = arr.Value(irow)
   244  	case *array.Int64:
   245  		*ptr.(*int64) = arr.Value(irow)
   246  	case *array.Uint8:
   247  		*ptr.(*uint8) = arr.Value(irow)
   248  	case *array.Uint16:
   249  		*ptr.(*uint16) = arr.Value(irow)
   250  	case *array.Uint32:
   251  		*ptr.(*uint32) = arr.Value(irow)
   252  	case *array.Uint64:
   253  		*ptr.(*uint64) = arr.Value(irow)
   254  	case *array.Float32:
   255  		*ptr.(*float32) = arr.Value(irow)
   256  	case *array.Float64:
   257  		*ptr.(*float64) = arr.Value(irow)
   258  	case *array.String:
   259  		*ptr.(*string) = arr.Value(irow)
   260  	case *array.Binary:
   261  		*ptr.(*string) = string(arr.Value(irow))
   262  
   263  	case *array.FixedSizeList:
   264  		rv := reflect.ValueOf(ptr).Elem()
   265  		n := int64(rv.Len())
   266  		off := int64(arr.Offset())
   267  		beg := (off + int64(irow)) * n
   268  		end := (off + int64(irow+1)) * n
   269  		ra := array.NewSlice(arr.ListValues(), beg, end)
   270  		defer ra.Release()
   271  		ptr := &rtree.WriteVar{
   272  			Name: "_rarrow_elem_" + wvar.Name,
   273  		}
   274  		for i := range rv.Len() {
   275  			ptr.Value = rv.Index(i).Addr().Interface()
   276  			err := ctx.readFrom(ptr, i, ra)
   277  			if err != nil {
   278  				return err
   279  			}
   280  		}
   281  
   282  	case *array.FixedSizeBinary:
   283  		rv := reflect.ValueOf(ptr).Elem()
   284  		sli := rv.Slice(0, rv.Len()).Interface().([]byte)
   285  		copy(sli, arr.Value(irow))
   286  
   287  	case *array.List:
   288  		rv := reflect.ValueOf(ptr).Elem()
   289  		rc := reflect.ValueOf(ctx.count[wvar.Name].Value).Elem()
   290  		if !arr.IsValid(irow) {
   291  			rc.SetInt(0)
   292  			rv.SetLen(0)
   293  			return nil
   294  		}
   295  
   296  		j := irow + arr.Data().Offset()
   297  		beg := int64(arr.Offsets()[j])
   298  		end := int64(arr.Offsets()[j+1])
   299  		sli := array.NewSlice(arr.ListValues(), beg, end)
   300  		defer sli.Release()
   301  
   302  		sz := sli.Len()
   303  		rc.SetInt(int64(sz))
   304  
   305  		if src, dst := sz, rv.Len(); src > dst {
   306  			rv.Set(reflect.MakeSlice(rv.Type(), src, src))
   307  		}
   308  		rv.SetLen(sz)
   309  
   310  		ptr := &rtree.WriteVar{
   311  			Name: "_rarrow_elem_" + wvar.Name,
   312  		}
   313  		for i := range sli.Len() {
   314  			ptr.Value = rv.Index(i).Addr().Interface()
   315  			err := ctx.readFrom(ptr, i, sli)
   316  			if err != nil {
   317  				return err
   318  			}
   319  		}
   320  
   321  	default:
   322  		panic(fmt.Errorf("invalid array type %T", arr))
   323  	}
   324  	return nil
   325  }
   326  
   327  var (
   328  	_ arrio.Writer = (*flatTreeWriter)(nil)
   329  )