github.com/segmentio/parquet-go@v0.0.0-20230712180008-5d42db8f0d47/column_buffer_go18.go (about)

     1  //go:build go1.18
     2  
     3  package parquet
     4  
     5  import (
     6  	"encoding/json"
     7  	"math/bits"
     8  	"reflect"
     9  	"time"
    10  	"unsafe"
    11  
    12  	"github.com/segmentio/parquet-go/deprecated"
    13  	"github.com/segmentio/parquet-go/internal/unsafecast"
    14  	"github.com/segmentio/parquet-go/sparse"
    15  )
    16  
    17  // writeRowsFunc is the type of functions that apply rows to a set of column
    18  // buffers.
    19  //
    20  // - columns is the array of column buffer where the rows are written.
    21  //
    22  // - rows is the array of Go values to write to the column buffers.
    23  //
    24  //   - levels is used to track the column index, repetition and definition levels
    25  //     of values when writing optional or repeated columns.
    26  type writeRowsFunc func(columns []ColumnBuffer, rows sparse.Array, levels columnLevels) error
    27  
    28  // writeRowsFuncOf generates a writeRowsFunc function for the given Go type and
    29  // parquet schema. The column path indicates the column that the function is
    30  // being generated for in the parquet schema.
    31  func writeRowsFuncOf(t reflect.Type, schema *Schema, path columnPath) writeRowsFunc {
    32  	if leaf, exists := schema.Lookup(path...); exists && leaf.Node.Type().LogicalType() != nil && leaf.Node.Type().LogicalType().Json != nil {
    33  		return writeRowsFuncOfJSON(t, schema, path)
    34  	}
    35  
    36  	switch t {
    37  	case reflect.TypeOf(deprecated.Int96{}):
    38  		return writeRowsFuncOfRequired(t, schema, path)
    39  	case reflect.TypeOf(time.Time{}):
    40  		return writeRowsFuncOfTime(t, schema, path)
    41  	}
    42  
    43  	switch t.Kind() {
    44  	case reflect.Bool,
    45  		reflect.Int,
    46  		reflect.Uint,
    47  		reflect.Int32,
    48  		reflect.Uint32,
    49  		reflect.Int64,
    50  		reflect.Uint64,
    51  		reflect.Float32,
    52  		reflect.Float64,
    53  		reflect.String:
    54  		return writeRowsFuncOfRequired(t, schema, path)
    55  
    56  	case reflect.Slice:
    57  		if t.Elem().Kind() == reflect.Uint8 {
    58  			return writeRowsFuncOfRequired(t, schema, path)
    59  		} else {
    60  			return writeRowsFuncOfSlice(t, schema, path)
    61  		}
    62  
    63  	case reflect.Array:
    64  		if t.Elem().Kind() == reflect.Uint8 {
    65  			return writeRowsFuncOfRequired(t, schema, path)
    66  		}
    67  
    68  	case reflect.Pointer:
    69  		return writeRowsFuncOfPointer(t, schema, path)
    70  
    71  	case reflect.Struct:
    72  		return writeRowsFuncOfStruct(t, schema, path)
    73  
    74  	case reflect.Map:
    75  		return writeRowsFuncOfMap(t, schema, path)
    76  	}
    77  
    78  	panic("cannot convert Go values of type " + typeNameOf(t) + " to parquet value")
    79  }
    80  
    81  func writeRowsFuncOfRequired(t reflect.Type, schema *Schema, path columnPath) writeRowsFunc {
    82  	column := schema.mapping.lookup(path)
    83  	columnIndex := column.columnIndex
    84  	return func(columns []ColumnBuffer, rows sparse.Array, levels columnLevels) error {
    85  		columns[columnIndex].writeValues(rows, levels)
    86  		return nil
    87  	}
    88  }
    89  
    90  func writeRowsFuncOfOptional(t reflect.Type, schema *Schema, path columnPath, writeRows writeRowsFunc) writeRowsFunc {
    91  	nullIndex := nullIndexFuncOf(t)
    92  	return func(columns []ColumnBuffer, rows sparse.Array, levels columnLevels) error {
    93  		if rows.Len() == 0 {
    94  			return writeRows(columns, rows, levels)
    95  		}
    96  
    97  		nulls := acquireBitmap(rows.Len())
    98  		defer releaseBitmap(nulls)
    99  		nullIndex(nulls.bits, rows)
   100  
   101  		nullLevels := levels
   102  		levels.definitionLevel++
   103  		// In this function, we are dealing with optional values which are
   104  		// neither pointers nor slices; for example, a int32 field marked
   105  		// "optional" in its parent struct.
   106  		//
   107  		// We need to find zero values, which should be represented as nulls
   108  		// in the parquet column. In order to minimize the calls to writeRows
   109  		// and maximize throughput, we use the nullIndex and nonNullIndex
   110  		// functions, which are type-specific implementations of the algorithm.
   111  		//
   112  		// Sections of the input that are contiguous nulls or non-nulls can be
   113  		// sent to a single call to writeRows to be written to the underlying
   114  		// buffer since they share the same definition level.
   115  		//
   116  		// This optimization is defeated by inputs alternating null and non-null
   117  		// sequences of single values, we do not expect this condition to be a
   118  		// common case.
   119  		for i := 0; i < rows.Len(); {
   120  			j := 0
   121  			x := i / 64
   122  			y := i % 64
   123  
   124  			if y != 0 {
   125  				if b := nulls.bits[x] >> uint(y); b == 0 {
   126  					x++
   127  					y = 0
   128  				} else {
   129  					y += bits.TrailingZeros64(b)
   130  					goto writeNulls
   131  				}
   132  			}
   133  
   134  			for x < len(nulls.bits) && nulls.bits[x] == 0 {
   135  				x++
   136  			}
   137  
   138  			if x < len(nulls.bits) {
   139  				y = bits.TrailingZeros64(nulls.bits[x]) % 64
   140  			}
   141  
   142  		writeNulls:
   143  			if j = x*64 + y; j > rows.Len() {
   144  				j = rows.Len()
   145  			}
   146  
   147  			if i < j {
   148  				if err := writeRows(columns, rows.Slice(i, j), nullLevels); err != nil {
   149  					return err
   150  				}
   151  				i = j
   152  			}
   153  
   154  			if y != 0 {
   155  				if b := nulls.bits[x] >> uint(y); b == (1<<uint64(y))-1 {
   156  					x++
   157  					y = 0
   158  				} else {
   159  					y += bits.TrailingZeros64(^b)
   160  					goto writeNonNulls
   161  				}
   162  			}
   163  
   164  			for x < len(nulls.bits) && nulls.bits[x] == ^uint64(0) {
   165  				x++
   166  			}
   167  
   168  			if x < len(nulls.bits) {
   169  				y = bits.TrailingZeros64(^nulls.bits[x]) % 64
   170  			}
   171  
   172  		writeNonNulls:
   173  			if j = x*64 + y; j > rows.Len() {
   174  				j = rows.Len()
   175  			}
   176  
   177  			if i < j {
   178  				if err := writeRows(columns, rows.Slice(i, j), levels); err != nil {
   179  					return err
   180  				}
   181  				i = j
   182  			}
   183  		}
   184  
   185  		return nil
   186  	}
   187  }
   188  
   189  func writeRowsFuncOfPointer(t reflect.Type, schema *Schema, path columnPath) writeRowsFunc {
   190  	elemType := t.Elem()
   191  	elemSize := uintptr(elemType.Size())
   192  	writeRows := writeRowsFuncOf(elemType, schema, path)
   193  
   194  	if len(path) == 0 {
   195  		// This code path is taken when generating a writeRowsFunc for a pointer
   196  		// type. In this case, we do not need to increase the definition level
   197  		// since we are not deailng with an optional field but a pointer to the
   198  		// row type.
   199  		return func(columns []ColumnBuffer, rows sparse.Array, levels columnLevels) error {
   200  			if rows.Len() == 0 {
   201  				return writeRows(columns, rows, levels)
   202  			}
   203  
   204  			for i := 0; i < rows.Len(); i++ {
   205  				p := *(*unsafe.Pointer)(rows.Index(i))
   206  				a := sparse.Array{}
   207  				if p != nil {
   208  					a = makeArray(p, 1, elemSize)
   209  				}
   210  				if err := writeRows(columns, a, levels); err != nil {
   211  					return err
   212  				}
   213  			}
   214  
   215  			return nil
   216  		}
   217  	}
   218  
   219  	return func(columns []ColumnBuffer, rows sparse.Array, levels columnLevels) error {
   220  		if rows.Len() == 0 {
   221  			return writeRows(columns, rows, levels)
   222  		}
   223  
   224  		for i := 0; i < rows.Len(); i++ {
   225  			p := *(*unsafe.Pointer)(rows.Index(i))
   226  			a := sparse.Array{}
   227  			elemLevels := levels
   228  			if p != nil {
   229  				a = makeArray(p, 1, elemSize)
   230  				elemLevels.definitionLevel++
   231  			}
   232  			if err := writeRows(columns, a, elemLevels); err != nil {
   233  				return err
   234  			}
   235  		}
   236  
   237  		return nil
   238  	}
   239  }
   240  
   241  func writeRowsFuncOfSlice(t reflect.Type, schema *Schema, path columnPath) writeRowsFunc {
   242  	elemType := t.Elem()
   243  	elemSize := uintptr(elemType.Size())
   244  	writeRows := writeRowsFuncOf(elemType, schema, path)
   245  
   246  	// When the element is a pointer type, the writeRows function will be an
   247  	// instance returned by writeRowsFuncOfPointer, which handles incrementing
   248  	// the definition level if the pointer value is not nil.
   249  	definitionLevelIncrement := byte(0)
   250  	if elemType.Kind() != reflect.Ptr {
   251  		definitionLevelIncrement = 1
   252  	}
   253  
   254  	return func(columns []ColumnBuffer, rows sparse.Array, levels columnLevels) error {
   255  		if rows.Len() == 0 {
   256  			return writeRows(columns, rows, levels)
   257  		}
   258  
   259  		levels.repetitionDepth++
   260  
   261  		for i := 0; i < rows.Len(); i++ {
   262  			p := (*sliceHeader)(rows.Index(i))
   263  			a := makeArray(p.base, p.len, elemSize)
   264  			b := sparse.Array{}
   265  
   266  			elemLevels := levels
   267  			if a.Len() > 0 {
   268  				b = a.Slice(0, 1)
   269  				elemLevels.definitionLevel += definitionLevelIncrement
   270  			}
   271  
   272  			if err := writeRows(columns, b, elemLevels); err != nil {
   273  				return err
   274  			}
   275  
   276  			if a.Len() > 1 {
   277  				elemLevels.repetitionLevel = elemLevels.repetitionDepth
   278  
   279  				if err := writeRows(columns, a.Slice(1, a.Len()), elemLevels); err != nil {
   280  					return err
   281  				}
   282  			}
   283  		}
   284  
   285  		return nil
   286  	}
   287  }
   288  
   289  func writeRowsFuncOfStruct(t reflect.Type, schema *Schema, path columnPath) writeRowsFunc {
   290  	type column struct {
   291  		offset    uintptr
   292  		writeRows writeRowsFunc
   293  	}
   294  
   295  	fields := structFieldsOf(t)
   296  	columns := make([]column, len(fields))
   297  
   298  	for i, f := range fields {
   299  		optional := false
   300  		columnPath := path.append(f.Name)
   301  		forEachStructTagOption(f, func(_ reflect.Type, option, _ string) {
   302  			switch option {
   303  			case "list":
   304  				columnPath = columnPath.append("list", "element")
   305  			case "optional":
   306  				optional = true
   307  			}
   308  		})
   309  
   310  		writeRows := writeRowsFuncOf(f.Type, schema, columnPath)
   311  		if optional {
   312  			switch f.Type.Kind() {
   313  			case reflect.Pointer, reflect.Slice:
   314  			default:
   315  				writeRows = writeRowsFuncOfOptional(f.Type, schema, columnPath, writeRows)
   316  			}
   317  		}
   318  
   319  		columns[i] = column{
   320  			offset:    f.Offset,
   321  			writeRows: writeRows,
   322  		}
   323  	}
   324  
   325  	return func(buffers []ColumnBuffer, rows sparse.Array, levels columnLevels) error {
   326  		if rows.Len() == 0 {
   327  			for _, column := range columns {
   328  				if err := column.writeRows(buffers, rows, levels); err != nil {
   329  					return err
   330  				}
   331  			}
   332  		} else {
   333  			for _, column := range columns {
   334  				if err := column.writeRows(buffers, rows.Offset(column.offset), levels); err != nil {
   335  					return err
   336  				}
   337  			}
   338  		}
   339  		return nil
   340  	}
   341  }
   342  
   343  func writeRowsFuncOfMap(t reflect.Type, schema *Schema, path columnPath) writeRowsFunc {
   344  	keyPath := path.append("key_value", "key")
   345  	keyType := t.Key()
   346  	keySize := uintptr(keyType.Size())
   347  	writeKeys := writeRowsFuncOf(keyType, schema, keyPath)
   348  
   349  	valuePath := path.append("key_value", "value")
   350  	valueType := t.Elem()
   351  	valueSize := uintptr(valueType.Size())
   352  	writeValues := writeRowsFuncOf(valueType, schema, valuePath)
   353  
   354  	writeKeyValues := func(columns []ColumnBuffer, keys, values sparse.Array, levels columnLevels) error {
   355  		if err := writeKeys(columns, keys, levels); err != nil {
   356  			return err
   357  		}
   358  		if err := writeValues(columns, values, levels); err != nil {
   359  			return err
   360  		}
   361  		return nil
   362  	}
   363  
   364  	return func(columns []ColumnBuffer, rows sparse.Array, levels columnLevels) error {
   365  		if rows.Len() == 0 {
   366  			return writeKeyValues(columns, rows, rows, levels)
   367  		}
   368  
   369  		levels.repetitionDepth++
   370  		mapKey := reflect.New(keyType).Elem()
   371  		mapValue := reflect.New(valueType).Elem()
   372  
   373  		for i := 0; i < rows.Len(); i++ {
   374  			m := reflect.NewAt(t, rows.Index(i)).Elem()
   375  
   376  			if m.Len() == 0 {
   377  				empty := sparse.Array{}
   378  				if err := writeKeyValues(columns, empty, empty, levels); err != nil {
   379  					return err
   380  				}
   381  			} else {
   382  				elemLevels := levels
   383  				elemLevels.definitionLevel++
   384  
   385  				for it := m.MapRange(); it.Next(); {
   386  					mapKey.SetIterKey(it)
   387  					mapValue.SetIterValue(it)
   388  
   389  					k := makeArray(unsafecast.PointerOfValue(mapKey), 1, keySize)
   390  					v := makeArray(unsafecast.PointerOfValue(mapValue), 1, valueSize)
   391  
   392  					if err := writeKeyValues(columns, k, v, elemLevels); err != nil {
   393  						return err
   394  					}
   395  
   396  					elemLevels.repetitionLevel = elemLevels.repetitionDepth
   397  				}
   398  			}
   399  		}
   400  
   401  		return nil
   402  	}
   403  }
   404  
   405  func writeRowsFuncOfJSON(t reflect.Type, schema *Schema, path columnPath) writeRowsFunc {
   406  	// If this is a string or a byte array write directly.
   407  	switch t.Kind() {
   408  	case reflect.String:
   409  		return writeRowsFuncOfRequired(t, schema, path)
   410  	case reflect.Slice:
   411  		if t.Elem().Kind() == reflect.Uint8 {
   412  			return writeRowsFuncOfRequired(t, schema, path)
   413  		}
   414  	}
   415  
   416  	// Otherwise handle with a json.Marshal
   417  	asStrT := reflect.TypeOf(string(""))
   418  	writer := writeRowsFuncOfRequired(asStrT, schema, path)
   419  
   420  	return func(columns []ColumnBuffer, rows sparse.Array, levels columnLevels) error {
   421  		if rows.Len() == 0 {
   422  			return writer(columns, rows, levels)
   423  		}
   424  		for i := 0; i < rows.Len(); i++ {
   425  			val := reflect.NewAt(t, rows.Index(i))
   426  			asI := val.Interface()
   427  
   428  			b, err := json.Marshal(asI)
   429  			if err != nil {
   430  				return err
   431  			}
   432  
   433  			asStr := string(b)
   434  			a := sparse.MakeStringArray([]string{asStr})
   435  			if err := writer(columns, a.UnsafeArray(), levels); err != nil {
   436  				return err
   437  			}
   438  		}
   439  		return nil
   440  	}
   441  }
   442  
   443  func writeRowsFuncOfTime(_ reflect.Type, schema *Schema, path columnPath) writeRowsFunc {
   444  	t := reflect.TypeOf(int64(0))
   445  	elemSize := uintptr(t.Size())
   446  	writeRows := writeRowsFuncOf(t, schema, path)
   447  
   448  	col, _ := schema.Lookup(path...)
   449  	unit := Nanosecond.TimeUnit()
   450  	lt := col.Node.Type().LogicalType()
   451  	if lt != nil && lt.Timestamp != nil {
   452  		unit = lt.Timestamp.Unit
   453  	}
   454  
   455  	return func(columns []ColumnBuffer, rows sparse.Array, levels columnLevels) error {
   456  		if rows.Len() == 0 {
   457  			return writeRows(columns, rows, levels)
   458  		}
   459  
   460  		times := rows.TimeArray()
   461  		for i := 0; i < times.Len(); i++ {
   462  			t := times.Index(i)
   463  			var val int64
   464  			switch {
   465  			case unit.Millis != nil:
   466  				val = t.UnixMilli()
   467  			case unit.Micros != nil:
   468  				val = t.UnixMicro()
   469  			default:
   470  				val = t.UnixNano()
   471  			}
   472  
   473  			a := makeArray(unsafecast.PointerOfValue(reflect.ValueOf(val)), 1, elemSize)
   474  			if err := writeRows(columns, a, levels); err != nil {
   475  				return err
   476  			}
   477  		}
   478  
   479  		return nil
   480  	}
   481  }