github.com/fraugster/parquet-go@v0.12.0/floor/writer.go (about)

     1  package floor
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"io"
     7  	"os"
     8  	"reflect"
     9  	"time"
    10  
    11  	"github.com/araddon/dateparse"
    12  	goparquet "github.com/fraugster/parquet-go"
    13  	"github.com/fraugster/parquet-go/floor/interfaces"
    14  	"github.com/fraugster/parquet-go/parquet"
    15  	"github.com/fraugster/parquet-go/parquetschema"
    16  )
    17  
    18  // NewWriter creates a new high-level writer for parquet.
    19  // NOTE: We assume the schema definition is constant.
    20  func NewWriter(w *goparquet.FileWriter) *Writer {
    21  	return &Writer{
    22  		w:         w,
    23  		schemaDef: w.GetSchemaDefinition(),
    24  	}
    25  }
    26  
    27  // NewFileWriter creates a nigh high-level writer for parquet
    28  // that writes to a particular file.
    29  // NOTE: We assume the schema definition is constant.
    30  func NewFileWriter(file string, opts ...goparquet.FileWriterOption) (*Writer, error) {
    31  	f, err := os.OpenFile(file, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
    32  	if err != nil {
    33  		return nil, err
    34  	}
    35  
    36  	w := goparquet.NewFileWriter(f, opts...)
    37  	return &Writer{
    38  		w:         w,
    39  		f:         f,
    40  		schemaDef: w.GetSchemaDefinition(),
    41  	}, nil
    42  }
    43  
    44  // Writer represents a high-level writer for parquet files.
    45  type Writer struct {
    46  	w         *goparquet.FileWriter
    47  	f         io.Closer
    48  	schemaDef *parquetschema.SchemaDefinition
    49  }
    50  
    51  // Write adds a new object to be written to the parquet file. If
    52  // obj implements the floor.Marshaller object, then obj.(Marshaller).Marshal
    53  // will be called to determine the data, otherwise reflection will be used.
    54  func (w *Writer) Write(obj interface{}) error {
    55  	m, ok := obj.(interfaces.Marshaller)
    56  	if !ok {
    57  		m = &reflectMarshaller{obj: obj, schemaDef: w.schemaDef}
    58  	}
    59  
    60  	data := interfaces.NewMarshallObjectWithSchema(nil, w.schemaDef)
    61  	if err := m.MarshalParquet(data); err != nil {
    62  		return err
    63  	}
    64  
    65  	if err := w.w.AddData(data.GetData()); err != nil {
    66  		return err
    67  	}
    68  
    69  	return nil
    70  }
    71  
    72  type reflectMarshaller struct {
    73  	obj       interface{}
    74  	schemaDef *parquetschema.SchemaDefinition
    75  }
    76  
    77  func (m *reflectMarshaller) MarshalParquet(record interfaces.MarshalObject) error {
    78  	return m.marshal(record, reflect.ValueOf(m.obj), m.schemaDef)
    79  }
    80  
    81  func (m *reflectMarshaller) marshal(record interfaces.MarshalObject, value reflect.Value, schemaDef *parquetschema.SchemaDefinition) error {
    82  	if value.Type().Kind() == reflect.Ptr {
    83  		if value.IsNil() {
    84  			return errors.New("object is nil")
    85  		}
    86  		value = value.Elem()
    87  	}
    88  
    89  	typ := value.Type()
    90  
    91  	if typ.Kind() == reflect.Struct {
    92  		return m.decodeStruct(record, value, schemaDef)
    93  	}
    94  
    95  	if typ.Kind() != reflect.Map {
    96  		return fmt.Errorf("object needs to be a struct, *struct or map, it's a %v instead", typ)
    97  	}
    98  
    99  	iter := value.MapRange()
   100  	for iter.Next() {
   101  		fieldName := iter.Key().String()
   102  		subSchemaDef := schemaDef.SubSchema(fieldName)
   103  		field := record.AddField(fieldName)
   104  
   105  		err := m.decodeValue(field, iter.Value(), subSchemaDef)
   106  		if err != nil {
   107  			return err
   108  		}
   109  	}
   110  
   111  	return nil
   112  }
   113  
   114  func (m *reflectMarshaller) decodeStruct(record interfaces.MarshalObject, value reflect.Value, schemaDef *parquetschema.SchemaDefinition) error {
   115  	if value.Type().Kind() == reflect.Ptr {
   116  		if value.IsNil() {
   117  			return errors.New("object is nil")
   118  		}
   119  		value = value.Elem()
   120  	}
   121  
   122  	typ := value.Type()
   123  
   124  	if typ.Kind() != reflect.Struct {
   125  		return fmt.Errorf("object needs to be a struct or a *struct, it's a %v instead", typ)
   126  	}
   127  
   128  	numFields := typ.NumField()
   129  	for i := 0; i < numFields; i++ {
   130  		fieldValue := value.Field(i)
   131  
   132  		fieldName := fieldNameFunc(typ.Field(i))
   133  
   134  		subSchemaDef := schemaDef.SubSchema(fieldName)
   135  
   136  		field := record.AddField(fieldName)
   137  
   138  		err := m.decodeValue(field, fieldValue, subSchemaDef)
   139  		if err != nil {
   140  			return err
   141  		}
   142  	}
   143  
   144  	return nil
   145  }
   146  
   147  func (m *reflectMarshaller) decodeTimeValue(elem *parquet.SchemaElement, field interfaces.MarshalElement, value reflect.Value) error {
   148  	switch {
   149  	case elem.GetLogicalType().TIME.Unit.IsSetNANOS():
   150  		field.SetInt64(value.Interface().(Time).Nanoseconds())
   151  	case elem.GetLogicalType().TIME.Unit.IsSetMICROS():
   152  		field.SetInt64(value.Interface().(Time).Microseconds())
   153  	case elem.GetLogicalType().TIME.Unit.IsSetMILLIS():
   154  		field.SetInt32(value.Interface().(Time).Milliseconds())
   155  	default:
   156  		return errors.New("invalid TIME unit")
   157  	}
   158  	return nil
   159  }
   160  
   161  func (m *reflectMarshaller) decodeTimestampValue(elem *parquet.SchemaElement, field interfaces.MarshalElement, value reflect.Value) error {
   162  	var factor int64
   163  	switch {
   164  	case elem.GetLogicalType().TIMESTAMP.Unit.IsSetNANOS():
   165  		factor = 1
   166  	case elem.GetLogicalType().TIMESTAMP.Unit.IsSetMICROS():
   167  		factor = 1000
   168  	case elem.GetLogicalType().TIMESTAMP.Unit.IsSetMILLIS():
   169  		factor = 1000000
   170  	default:
   171  		return errors.New("invalid TIMESTAMP unit")
   172  	}
   173  	ts := value.Interface().(time.Time).UnixNano()
   174  	ts /= factor
   175  	field.SetInt64(ts)
   176  	return nil
   177  }
   178  
   179  func (m *reflectMarshaller) decodeValue(field interfaces.MarshalElement, value reflect.Value, schemaDef *parquetschema.SchemaDefinition) error {
   180  	elem := schemaDef.SchemaElement()
   181  	if elem == nil {
   182  		return nil
   183  	}
   184  
   185  	if value.Kind() == reflect.Ptr || value.Kind() == reflect.Interface {
   186  		if value.IsNil() {
   187  			return nil
   188  		}
   189  		value = value.Elem()
   190  	}
   191  
   192  	if value.Type().ConvertibleTo(reflect.TypeOf(Time{})) {
   193  		if elem.LogicalType != nil && elem.GetLogicalType().IsSetTIME() {
   194  			return m.decodeTimeValue(elem, field, value)
   195  		}
   196  	}
   197  
   198  	if value.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) {
   199  		if elem.LogicalType != nil {
   200  			switch {
   201  			case elem.GetLogicalType().IsSetDATE():
   202  				days := int32(value.Interface().(time.Time).Sub(time.Unix(0, 0).UTC()).Hours() / 24)
   203  				field.SetInt32(days)
   204  				return nil
   205  			case elem.GetLogicalType().IsSetTIMESTAMP():
   206  				return m.decodeTimestampValue(elem, field, value)
   207  			}
   208  		} else if elem.GetType() == parquet.Type_INT96 {
   209  			field.SetInt96(goparquet.TimeToInt96(value.Interface().(time.Time)))
   210  			return nil
   211  		}
   212  	}
   213  
   214  	if !elem.IsSetType() && !elem.IsSetConvertedType() && elem.GetNumChildren() > 0 && value.Kind() == reflect.Map {
   215  		group := field.Group()
   216  		iter := value.MapRange()
   217  		for iter.Next() {
   218  			fieldName := iter.Key().String()
   219  			err := m.decodeValue(group.AddField(fieldName), iter.Value(), schemaDef.SubSchema(fieldName))
   220  			if err != nil {
   221  				return err
   222  			}
   223  		}
   224  
   225  		return nil
   226  	}
   227  
   228  	switch elem.GetType() {
   229  	case parquet.Type_INT64:
   230  		switch value.Kind() {
   231  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   232  			field.SetInt64(value.Int())
   233  		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   234  			field.SetInt64(int64(value.Uint()))
   235  		default:
   236  			return fmt.Errorf("unable to decode %s:%s to int64", elem.Name, value.Kind())
   237  		}
   238  		return nil
   239  	case parquet.Type_INT32:
   240  		switch value.Kind() {
   241  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   242  			field.SetInt32(int32(value.Int()))
   243  		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   244  			field.SetInt32(int32(value.Uint()))
   245  		default:
   246  			return fmt.Errorf("unable to decode %s:%s to int32", elem.Name, value.Kind())
   247  		}
   248  		return nil
   249  	case parquet.Type_INT96:
   250  		switch value.Kind() {
   251  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   252  			return m.decodeUnixTime(field, value.Int())
   253  		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   254  			return m.decodeUnixTime(field, int64(value.Uint()))
   255  		case reflect.String:
   256  			dt, _ := dateparse.ParseAny(value.String())
   257  			field.SetInt96(goparquet.TimeToInt96(dt))
   258  			return nil
   259  		case reflect.Slice:
   260  			if value.IsNil() {
   261  				return nil
   262  			}
   263  			if value.Type().Elem().Kind() != reflect.Uint8 {
   264  				return fmt.Errorf("field is of type INT96 but type is %s", value.Type().String())
   265  			}
   266  
   267  			if value.Len() != 12 {
   268  				return fmt.Errorf("field is of type INT96 but length is %d", value.Len())
   269  			}
   270  			var dst [12]byte
   271  			src := value.Interface().([]byte)
   272  			copy(dst[:], src)
   273  			field.SetInt96(dst)
   274  			return nil
   275  		case reflect.Array:
   276  			if value.Type().Elem().Kind() != reflect.Uint8 {
   277  				return fmt.Errorf("field is of type INT96 but type is %s", value.Type().String())
   278  			}
   279  			if value.Len() != 12 {
   280  				return fmt.Errorf("field is of type INT96 but length is %d", value.Len())
   281  			}
   282  			var dst [12]byte
   283  			src := value.Interface().([12]byte)
   284  			copy(dst[:], src[:])
   285  			field.SetInt96(dst)
   286  			return nil
   287  		}
   288  	}
   289  
   290  	switch value.Kind() {
   291  	case reflect.Bool:
   292  		field.SetBool(value.Bool())
   293  		return nil
   294  	case reflect.Float32:
   295  		field.SetFloat32(float32(value.Float()))
   296  		return nil
   297  	case reflect.Float64:
   298  		field.SetFloat64(value.Float())
   299  		return nil
   300  	case reflect.Array, reflect.Slice:
   301  		if value.Type().Elem().Kind() == reflect.Uint8 {
   302  			return m.decodeByteSliceOrArray(field, value, schemaDef)
   303  		}
   304  		return m.decodeSliceOrArray(field, value, schemaDef)
   305  	case reflect.Map:
   306  		return m.decodeMap(field, value, schemaDef)
   307  	case reflect.String:
   308  		field.SetByteArray([]byte(value.String()))
   309  		return nil
   310  	case reflect.Struct:
   311  		return m.decodeStruct(field.Group(), value, schemaDef)
   312  	default:
   313  		return fmt.Errorf("unsupported type %s", value.Type())
   314  	}
   315  }
   316  
   317  func (m *reflectMarshaller) decodeUnixTime(field interfaces.MarshalElement, i64 int64) error {
   318  	// best effort parse unix timestamps.
   319  	// since 99% of the time these are timestamps and are <= now this is a fairly safe bet
   320  	digits := i64Digits(i64)
   321  	now := time.Now()
   322  
   323  	switch {
   324  	case digits <= i64Digits(now.Unix()):
   325  		dt := time.Unix(i64, 0)
   326  		field.SetInt96(goparquet.TimeToInt96(dt))
   327  	case digits <= i64Digits(now.UnixNano()/1000000):
   328  		dt := time.Unix(0, i64*int64(time.Millisecond))
   329  		field.SetInt96(goparquet.TimeToInt96(dt))
   330  	case digits <= i64Digits(now.UnixNano()/1000):
   331  		dt := time.Unix(0, i64*int64(time.Microsecond))
   332  		field.SetInt96(goparquet.TimeToInt96(dt))
   333  	case digits <= i64Digits(now.UnixNano()):
   334  		dt := time.Unix(0, i64)
   335  		field.SetInt96(goparquet.TimeToInt96(dt))
   336  	default:
   337  		return fmt.Errorf("field is of type INT96 but value is not valid %d", i64)
   338  	}
   339  	return nil
   340  }
   341  
   342  func (m *reflectMarshaller) decodeByteSliceOrArray(field interfaces.MarshalElement, value reflect.Value, schemaDef *parquetschema.SchemaDefinition) error {
   343  	elem := schemaDef.SchemaElement()
   344  	if elem == nil {
   345  		return nil
   346  	}
   347  
   348  	if value.Kind() == reflect.Slice && value.IsNil() {
   349  		return nil
   350  	}
   351  
   352  	if elem.LogicalType != nil && elem.GetLogicalType().IsSetUUID() {
   353  		if value.Len() != 16 {
   354  			return fmt.Errorf("field is annotated as UUID but length is %d", value.Len())
   355  		}
   356  	}
   357  
   358  	switch value.Kind() {
   359  	case reflect.Slice:
   360  		if value.IsNil() {
   361  			return nil
   362  		}
   363  		field.SetByteArray(value.Bytes())
   364  	case reflect.Array:
   365  		data := reflect.MakeSlice(reflect.TypeOf([]byte{}), value.Len(), value.Len())
   366  		_ = reflect.Copy(data, value)
   367  		field.SetByteArray(data.Bytes())
   368  	}
   369  	return nil
   370  }
   371  
   372  func (m *reflectMarshaller) decodeSliceOrArray(field interfaces.MarshalElement, value reflect.Value, schemaDef *parquetschema.SchemaDefinition) error {
   373  	elem := schemaDef.SchemaElement()
   374  	if elem == nil {
   375  		return nil
   376  	}
   377  
   378  	if value.Kind() == reflect.Slice && value.IsNil() {
   379  		return nil
   380  	}
   381  
   382  	if elem.GetConvertedType() != parquet.ConvertedType_LIST {
   383  		return fmt.Errorf("decoding slice or array but schema element %s is not annotated as LIST", elem.GetName())
   384  	}
   385  
   386  	elementSchemaDef := schemaDef.SubSchema("list").SubSchema("element")
   387  	if elementSchemaDef == nil {
   388  		elementSchemaDef = schemaDef.SubSchema("bag").SubSchema("array_element")
   389  		if elementSchemaDef == nil {
   390  			return fmt.Errorf("element %s is annotated as LIST but group structure seems invalid", schemaDef.SchemaElement().GetName())
   391  		}
   392  	}
   393  
   394  	list := field.List()
   395  
   396  	for i := 0; i < value.Len(); i++ {
   397  		if err := m.decodeValue(list.Add(), value.Index(i), elementSchemaDef); err != nil {
   398  			return err
   399  		}
   400  	}
   401  
   402  	return nil
   403  }
   404  
   405  func (m *reflectMarshaller) decodeMap(field interfaces.MarshalElement, value reflect.Value, schemaDef *parquetschema.SchemaDefinition) error {
   406  	if value.IsNil() {
   407  		return nil
   408  	}
   409  
   410  	if elem := schemaDef.SchemaElement(); elem.GetConvertedType() != parquet.ConvertedType_MAP {
   411  		return fmt.Errorf("decoding map but schema element %s is not annotated as MAP", elem.GetName())
   412  	}
   413  
   414  	keyValueSchemaDef := schemaDef.SubSchema("key_value")
   415  	keySchemaDef := keyValueSchemaDef.SubSchema("key")
   416  	valueSchemaDef := keyValueSchemaDef.SubSchema("value")
   417  
   418  	mapData := field.Map()
   419  
   420  	iter := value.MapRange()
   421  
   422  	for iter.Next() {
   423  		kvPair := mapData.Add()
   424  
   425  		if err := m.decodeValue(kvPair.Key(), iter.Key(), keySchemaDef); err != nil {
   426  			return err
   427  		}
   428  
   429  		if err := m.decodeValue(kvPair.Value(), iter.Value(), valueSchemaDef); err != nil {
   430  			return err
   431  		}
   432  	}
   433  
   434  	return nil
   435  }
   436  
   437  // Close flushes outstanding data and closes the underlying
   438  // parquet writer.
   439  func (w *Writer) Close() error {
   440  	if w.f != nil {
   441  		defer w.f.Close()
   442  	}
   443  
   444  	return w.w.Close()
   445  }
   446  
   447  func i64Digits(number int64) int {
   448  	count := 0
   449  	for number != 0 {
   450  		number /= 10
   451  		count++
   452  	}
   453  	return count
   454  }