github.com/fraugster/parquet-go@v0.12.0/floor/reader.go (about)

     1  package floor
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"io"
     7  	"os"
     8  	"reflect"
     9  	"time"
    10  
    11  	goparquet "github.com/fraugster/parquet-go"
    12  	"github.com/fraugster/parquet-go/floor/interfaces"
    13  	"github.com/fraugster/parquet-go/parquet"
    14  	"github.com/fraugster/parquet-go/parquetschema"
    15  )
    16  
    17  // NewReader returns a new high-level parquet file reader.
    18  func NewReader(r *goparquet.FileReader) *Reader {
    19  	return &Reader{
    20  		r: r,
    21  	}
    22  }
    23  
    24  // NewFileReader returns a new high-level parquet file reader
    25  // that directly reads from the provided file.
    26  func NewFileReader(file string) (*Reader, error) {
    27  	f, err := os.Open(file)
    28  	if err != nil {
    29  		return nil, err
    30  	}
    31  
    32  	r, err := goparquet.NewFileReader(f)
    33  	if err != nil {
    34  		return nil, err
    35  	}
    36  
    37  	return &Reader{
    38  		r: r,
    39  		f: f,
    40  	}, nil
    41  }
    42  
    43  // Reader represents a high-level reader for parquet files.
    44  type Reader struct {
    45  	r *goparquet.FileReader
    46  	f io.Closer
    47  
    48  	data map[string]interface{}
    49  	err  error
    50  	eof  bool
    51  }
    52  
    53  // Close closes the reader.
    54  func (r *Reader) Close() error {
    55  	if r.f != nil {
    56  		return r.f.Close()
    57  	}
    58  
    59  	return nil
    60  }
    61  
    62  // Next reads the next object so that it is ready to be scanned.
    63  // Returns true if fetching the next object was successful, false
    64  // otherwise, e.g. in case of an error or when EOF was reached.
    65  func (r *Reader) Next() bool {
    66  	r.data, r.err = r.r.NextRow()
    67  	if r.err == io.EOF {
    68  		r.eof = true
    69  		r.err = nil
    70  		return false
    71  	}
    72  	if r.err != nil {
    73  		return false
    74  	}
    75  
    76  	return true
    77  }
    78  
    79  // Scan fills obj with the data from the record last fetched.
    80  // Returns an error if there is no data available or if the
    81  // structure of obj doesn't fit the data. obj needs to be
    82  // a pointer to an object, or alternatively implement the
    83  // Unmarshaller interface.
    84  func (r *Reader) Scan(obj interface{}) error {
    85  	if r.data == nil {
    86  		return errors.New("the Next function needs to be called before Scan can be called")
    87  	}
    88  	um, ok := obj.(interfaces.Unmarshaller)
    89  	if !ok {
    90  		um = &reflectUnmarshaller{obj: obj, schemaDef: r.r.GetSchemaDefinition()}
    91  	}
    92  
    93  	return um.UnmarshalParquet(interfaces.NewUnmarshallObject(r.data))
    94  }
    95  
    96  type reflectUnmarshaller struct {
    97  	obj       interface{}
    98  	schemaDef *parquetschema.SchemaDefinition
    99  }
   100  
   101  func (um *reflectUnmarshaller) UnmarshalParquet(record interfaces.UnmarshalObject) error {
   102  	objValue := reflect.ValueOf(um.obj)
   103  
   104  	if objValue.Kind() != reflect.Ptr {
   105  		return fmt.Errorf("you need to provide an object of type *%T to unmarshal into", um.obj)
   106  	}
   107  
   108  	objValue = objValue.Elem()
   109  	if objValue.Kind() != reflect.Struct {
   110  		return fmt.Errorf("provided object of type %T is not a struct", um.obj)
   111  	}
   112  
   113  	if err := um.fillStruct(objValue, record, um.schemaDef); err != nil {
   114  		return err
   115  	}
   116  
   117  	return nil
   118  }
   119  
   120  func (um *reflectUnmarshaller) fillStruct(value reflect.Value, record interfaces.UnmarshalObject, schemaDef *parquetschema.SchemaDefinition) error {
   121  	typ := value.Type()
   122  
   123  	numFields := typ.NumField()
   124  	for i := 0; i < numFields; i++ {
   125  		fieldValue := value.Field(i)
   126  
   127  		fieldName := fieldNameFunc(typ.Field(i))
   128  
   129  		fieldSchemaDef := schemaDef.SubSchema(fieldName)
   130  
   131  		if fieldSchemaDef == nil {
   132  			continue
   133  		}
   134  
   135  		fieldData := record.GetField(fieldName)
   136  		if fieldData.Error() != nil {
   137  			if elem := fieldSchemaDef.SchemaElement(); elem.GetRepetitionType() == parquet.FieldRepetitionType_REQUIRED {
   138  				return fmt.Errorf("field %s is %s but couldn't be found in data", fieldName, elem.GetRepetitionType())
   139  			}
   140  			continue
   141  		}
   142  
   143  		if err := um.fillValue(fieldValue, fieldData, fieldSchemaDef); err != nil {
   144  			return err
   145  		}
   146  	}
   147  
   148  	return nil
   149  }
   150  
   151  func (um *reflectUnmarshaller) fillTimeValue(elem *parquet.SchemaElement, value reflect.Value, data interfaces.UnmarshalElement) error {
   152  	i, err := getIntValue(data)
   153  	if err != nil {
   154  		return err
   155  	}
   156  
   157  	var t Time
   158  	switch {
   159  	case elem.GetLogicalType().TIME.Unit.IsSetNANOS():
   160  		t = TimeFromNanoseconds(i)
   161  	case elem.GetLogicalType().TIME.Unit.IsSetMICROS():
   162  		t = TimeFromMicroseconds(i)
   163  	case elem.GetLogicalType().TIME.Unit.IsSetMILLIS():
   164  		t = TimeFromMilliseconds(int32(i))
   165  	default:
   166  		return errors.New("invalid TIME unit")
   167  	}
   168  
   169  	if elem.GetLogicalType().TIME.GetIsAdjustedToUTC() {
   170  		t = t.UTC()
   171  	}
   172  
   173  	value.Set(reflect.ValueOf(t))
   174  	return nil
   175  }
   176  
   177  func (um *reflectUnmarshaller) fillTimestampValue(elem *parquet.SchemaElement, value reflect.Value, data interfaces.UnmarshalElement) error {
   178  	i, err := getIntValue(data)
   179  	if err != nil {
   180  		return err
   181  	}
   182  
   183  	var ts time.Time
   184  	switch {
   185  	case elem.GetLogicalType().TIMESTAMP.Unit.IsSetNANOS():
   186  		ts = time.Unix(i/1000000000, i%1000000000)
   187  	case elem.GetLogicalType().TIMESTAMP.Unit.IsSetMICROS():
   188  		ts = time.Unix(i/1000000, 1000*(i%1000000))
   189  	case elem.GetLogicalType().TIMESTAMP.Unit.IsSetMILLIS():
   190  		ts = time.Unix(i/1000, 1000000*(i%1000))
   191  	default:
   192  		return errors.New("invalid TIMESTAMP unit")
   193  	}
   194  
   195  	if elem.GetLogicalType().TIMESTAMP.GetIsAdjustedToUTC() {
   196  		ts = ts.UTC()
   197  	}
   198  
   199  	value.Set(reflect.ValueOf(ts))
   200  	return nil
   201  }
   202  
   203  func (um *reflectUnmarshaller) fillDateValue(value reflect.Value, data interfaces.UnmarshalElement) error {
   204  	i, err := getIntValue(data)
   205  	if err != nil {
   206  		return err
   207  	}
   208  
   209  	date := time.Unix(0, 0).UTC().Add(24 * time.Hour * time.Duration(i))
   210  	value.Set(reflect.ValueOf(date))
   211  	return nil
   212  }
   213  
   214  func (um *reflectUnmarshaller) fillValue(value reflect.Value, data interfaces.UnmarshalElement, schemaDef *parquetschema.SchemaDefinition) error {
   215  	if value.Kind() == reflect.Ptr {
   216  		value.Set(reflect.New(value.Type().Elem()))
   217  		value = value.Elem()
   218  	}
   219  
   220  	if !value.CanSet() {
   221  		return nil
   222  	}
   223  
   224  	if value.Type().ConvertibleTo(reflect.TypeOf(Time{})) {
   225  		if elem := schemaDef.SchemaElement(); elem.LogicalType != nil && elem.GetLogicalType().IsSetTIME() {
   226  			return um.fillTimeValue(elem, value, data)
   227  		}
   228  	}
   229  
   230  	if value.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) {
   231  		if elem := schemaDef.SchemaElement(); elem.LogicalType != nil {
   232  			switch {
   233  			case elem.GetLogicalType().IsSetDATE():
   234  				return um.fillDateValue(value, data)
   235  			case elem.GetLogicalType().IsSetTIMESTAMP():
   236  				return um.fillTimestampValue(elem, value, data)
   237  			}
   238  		} else if elem.GetType() == parquet.Type_INT96 {
   239  			i96, err := data.Int96()
   240  			if err != nil {
   241  				return err
   242  			}
   243  
   244  			value.Set(reflect.ValueOf(goparquet.Int96ToTime(i96).UTC()))
   245  			return nil
   246  		}
   247  	}
   248  
   249  	switch value.Kind() {
   250  	case reflect.Bool:
   251  		b, err := data.Bool()
   252  		if err != nil {
   253  			return err
   254  		}
   255  		value.SetBool(b)
   256  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   257  		i, err := getIntValue(data)
   258  		if err != nil {
   259  			return err
   260  		}
   261  		value.SetInt(i)
   262  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   263  		u, err := getIntValue(data)
   264  		if err != nil {
   265  			return err
   266  		}
   267  		value.SetUint(uint64(u))
   268  	case reflect.Float32, reflect.Float64:
   269  		f, err := getFloatValue(data)
   270  		if err != nil {
   271  			return err
   272  		}
   273  		value.SetFloat(f)
   274  	case reflect.Array, reflect.Slice:
   275  		if value.Type().Elem().Kind() == reflect.Uint8 {
   276  			return um.fillByteArrayOrSlice(value, data, schemaDef)
   277  		}
   278  		return um.fillArrayOrSlice(value, data, schemaDef)
   279  	case reflect.Map:
   280  		return um.fillMap(value, data, schemaDef)
   281  	case reflect.String:
   282  		s, err := data.ByteArray()
   283  		if err != nil {
   284  			return err
   285  		}
   286  		value.SetString(string(s))
   287  	case reflect.Struct:
   288  		groupData, err := data.Group()
   289  		if err != nil {
   290  			return err
   291  		}
   292  		if err := um.fillStruct(value, groupData, schemaDef); err != nil {
   293  			return err
   294  		}
   295  	default:
   296  		return fmt.Errorf("unsupported type %s", value.Type())
   297  	}
   298  
   299  	return nil
   300  }
   301  
   302  func (um *reflectUnmarshaller) fillMap(value reflect.Value, data interfaces.UnmarshalElement, schemaDef *parquetschema.SchemaDefinition) error {
   303  	if elem := schemaDef.SchemaElement(); elem.GetConvertedType() != parquet.ConvertedType_MAP {
   304  		return fmt.Errorf("filling map but schema element %s is not annotated as MAP", elem.GetName())
   305  	}
   306  
   307  	keyValueList, err := data.Map()
   308  	if err != nil {
   309  		return err
   310  	}
   311  
   312  	value.Set(reflect.MakeMap(value.Type()))
   313  
   314  	keyValueSchemaDef := schemaDef.SubSchema("key_value")
   315  	keySchemaDef := keyValueSchemaDef.SubSchema("key")
   316  	valueSchemaDef := keyValueSchemaDef.SubSchema("value")
   317  
   318  	for keyValueList.Next() {
   319  		key, err := keyValueList.Key()
   320  		if err != nil {
   321  			return err
   322  		}
   323  
   324  		valueData, err := keyValueList.Value()
   325  		if err != nil {
   326  			return err
   327  		}
   328  
   329  		keyValue := reflect.New(value.Type().Key()).Elem()
   330  		if err := um.fillValue(keyValue, key, keySchemaDef); err != nil {
   331  			return fmt.Errorf("couldn't fill key with key data: %v", err)
   332  		}
   333  
   334  		valueValue := reflect.New(value.Type().Elem()).Elem()
   335  		if err := um.fillValue(valueValue, valueData, valueSchemaDef); err != nil {
   336  			return fmt.Errorf("couldn't fill value with value data: %v", err)
   337  		}
   338  
   339  		value.SetMapIndex(keyValue, valueValue)
   340  	}
   341  
   342  	return nil
   343  }
   344  
   345  func (um *reflectUnmarshaller) fillByteArrayOrSlice(value reflect.Value, data interfaces.UnmarshalElement, schemaDef *parquetschema.SchemaDefinition) error {
   346  	byteSlice, err := data.ByteArray()
   347  	if err != nil {
   348  		// check to see if it's actually an INT96
   349  		int96, int96Err := data.Int96()
   350  		if int96Err != nil {
   351  			return err
   352  		}
   353  		byteSlice = int96[0:]
   354  	}
   355  	if value.Kind() == reflect.Slice {
   356  		value.Set(reflect.MakeSlice(value.Type(), len(byteSlice), len(byteSlice)))
   357  	}
   358  
   359  	for i, b := range byteSlice {
   360  		if i < value.Len() {
   361  			value.Index(i).SetUint(uint64(b))
   362  		}
   363  	}
   364  	return nil
   365  }
   366  
   367  func (um *reflectUnmarshaller) fillArrayOrSlice(value reflect.Value, data interfaces.UnmarshalElement, schemaDef *parquetschema.SchemaDefinition) error {
   368  	if elem := schemaDef.SchemaElement(); elem.GetConvertedType() != parquet.ConvertedType_LIST {
   369  		return fmt.Errorf("filling slice or array but schema element %s is not annotated as LIST", elem.GetName())
   370  	}
   371  
   372  	elemList, err := data.List()
   373  	if err != nil {
   374  		return err
   375  	}
   376  
   377  	elementList := []interfaces.UnmarshalElement{}
   378  
   379  	for elemList.Next() {
   380  		elemValue, err := elemList.Value()
   381  		if err != nil {
   382  			return err
   383  		}
   384  
   385  		elementList = append(elementList, elemValue)
   386  	}
   387  
   388  	if value.Kind() == reflect.Slice {
   389  		value.Set(reflect.MakeSlice(value.Type(), len(elementList), len(elementList)))
   390  	}
   391  
   392  	elemSchemaDef := schemaDef.SubSchema("list").SubSchema("element")
   393  	if elemSchemaDef == nil {
   394  		elemSchemaDef = schemaDef.SubSchema("bag").SubSchema("array_element")
   395  		if elemSchemaDef == nil {
   396  			return fmt.Errorf("element %s is annotated as LIST but group structure seems invalid", schemaDef.SchemaElement().GetName())
   397  		}
   398  	}
   399  
   400  	for idx, elem := range elementList {
   401  		if idx < value.Len() {
   402  			if err := um.fillValue(value.Index(idx), elem, elemSchemaDef); err != nil {
   403  				return err
   404  			}
   405  		}
   406  	}
   407  
   408  	return nil
   409  }
   410  
   411  func getIntValue(data interfaces.UnmarshalElement) (int64, error) {
   412  	i32, err := data.Int32()
   413  	if err == nil {
   414  		return int64(i32), nil
   415  	}
   416  
   417  	i64, err := data.Int64()
   418  	if err == nil {
   419  		return i64, nil
   420  	}
   421  	return 0, err
   422  }
   423  
   424  func getFloatValue(data interfaces.UnmarshalElement) (float64, error) {
   425  	f32, err := data.Float32()
   426  	if err == nil {
   427  		return float64(f32), nil
   428  	}
   429  
   430  	f64, err := data.Float64()
   431  	if err == nil {
   432  		return f64, nil
   433  	}
   434  
   435  	return 0, err
   436  }
   437  
   438  // Err returns an error in case Next returned false due to an error.
   439  // If Next returned false due to EOF, Err returns nil.
   440  func (r *Reader) Err() error {
   441  	return r.err
   442  }
   443  
   444  // GetSchemaDefinition returns the schema definition of the parquet
   445  // file.
   446  func (r *Reader) GetSchemaDefinition() *parquetschema.SchemaDefinition {
   447  	return r.r.GetSchemaDefinition()
   448  }