github.com/fraugster/parquet-go@v0.12.0/floor/writer.go (about) 1 package floor 2 3 import ( 4 "errors" 5 "fmt" 6 "io" 7 "os" 8 "reflect" 9 "time" 10 11 "github.com/araddon/dateparse" 12 goparquet "github.com/fraugster/parquet-go" 13 "github.com/fraugster/parquet-go/floor/interfaces" 14 "github.com/fraugster/parquet-go/parquet" 15 "github.com/fraugster/parquet-go/parquetschema" 16 ) 17 18 // NewWriter creates a new high-level writer for parquet. 19 // NOTE: We assume the schema definition is constant. 20 func NewWriter(w *goparquet.FileWriter) *Writer { 21 return &Writer{ 22 w: w, 23 schemaDef: w.GetSchemaDefinition(), 24 } 25 } 26 27 // NewFileWriter creates a nigh high-level writer for parquet 28 // that writes to a particular file. 29 // NOTE: We assume the schema definition is constant. 30 func NewFileWriter(file string, opts ...goparquet.FileWriterOption) (*Writer, error) { 31 f, err := os.OpenFile(file, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) 32 if err != nil { 33 return nil, err 34 } 35 36 w := goparquet.NewFileWriter(f, opts...) 37 return &Writer{ 38 w: w, 39 f: f, 40 schemaDef: w.GetSchemaDefinition(), 41 }, nil 42 } 43 44 // Writer represents a high-level writer for parquet files. 45 type Writer struct { 46 w *goparquet.FileWriter 47 f io.Closer 48 schemaDef *parquetschema.SchemaDefinition 49 } 50 51 // Write adds a new object to be written to the parquet file. If 52 // obj implements the floor.Marshaller object, then obj.(Marshaller).Marshal 53 // will be called to determine the data, otherwise reflection will be used. 54 func (w *Writer) Write(obj interface{}) error { 55 m, ok := obj.(interfaces.Marshaller) 56 if !ok { 57 m = &reflectMarshaller{obj: obj, schemaDef: w.schemaDef} 58 } 59 60 data := interfaces.NewMarshallObjectWithSchema(nil, w.schemaDef) 61 if err := m.MarshalParquet(data); err != nil { 62 return err 63 } 64 65 if err := w.w.AddData(data.GetData()); err != nil { 66 return err 67 } 68 69 return nil 70 } 71 72 type reflectMarshaller struct { 73 obj interface{} 74 schemaDef *parquetschema.SchemaDefinition 75 } 76 77 func (m *reflectMarshaller) MarshalParquet(record interfaces.MarshalObject) error { 78 return m.marshal(record, reflect.ValueOf(m.obj), m.schemaDef) 79 } 80 81 func (m *reflectMarshaller) marshal(record interfaces.MarshalObject, value reflect.Value, schemaDef *parquetschema.SchemaDefinition) error { 82 if value.Type().Kind() == reflect.Ptr { 83 if value.IsNil() { 84 return errors.New("object is nil") 85 } 86 value = value.Elem() 87 } 88 89 typ := value.Type() 90 91 if typ.Kind() == reflect.Struct { 92 return m.decodeStruct(record, value, schemaDef) 93 } 94 95 if typ.Kind() != reflect.Map { 96 return fmt.Errorf("object needs to be a struct, *struct or map, it's a %v instead", typ) 97 } 98 99 iter := value.MapRange() 100 for iter.Next() { 101 fieldName := iter.Key().String() 102 subSchemaDef := schemaDef.SubSchema(fieldName) 103 field := record.AddField(fieldName) 104 105 err := m.decodeValue(field, iter.Value(), subSchemaDef) 106 if err != nil { 107 return err 108 } 109 } 110 111 return nil 112 } 113 114 func (m *reflectMarshaller) decodeStruct(record interfaces.MarshalObject, value reflect.Value, schemaDef *parquetschema.SchemaDefinition) error { 115 if value.Type().Kind() == reflect.Ptr { 116 if value.IsNil() { 117 return errors.New("object is nil") 118 } 119 value = value.Elem() 120 } 121 122 typ := value.Type() 123 124 if typ.Kind() != reflect.Struct { 125 return fmt.Errorf("object needs to be a struct or a *struct, it's a %v instead", typ) 126 } 127 128 numFields := typ.NumField() 129 for i := 0; i < numFields; i++ { 130 fieldValue := value.Field(i) 131 132 fieldName := fieldNameFunc(typ.Field(i)) 133 134 subSchemaDef := schemaDef.SubSchema(fieldName) 135 136 field := record.AddField(fieldName) 137 138 err := m.decodeValue(field, fieldValue, subSchemaDef) 139 if err != nil { 140 return err 141 } 142 } 143 144 return nil 145 } 146 147 func (m *reflectMarshaller) decodeTimeValue(elem *parquet.SchemaElement, field interfaces.MarshalElement, value reflect.Value) error { 148 switch { 149 case elem.GetLogicalType().TIME.Unit.IsSetNANOS(): 150 field.SetInt64(value.Interface().(Time).Nanoseconds()) 151 case elem.GetLogicalType().TIME.Unit.IsSetMICROS(): 152 field.SetInt64(value.Interface().(Time).Microseconds()) 153 case elem.GetLogicalType().TIME.Unit.IsSetMILLIS(): 154 field.SetInt32(value.Interface().(Time).Milliseconds()) 155 default: 156 return errors.New("invalid TIME unit") 157 } 158 return nil 159 } 160 161 func (m *reflectMarshaller) decodeTimestampValue(elem *parquet.SchemaElement, field interfaces.MarshalElement, value reflect.Value) error { 162 var factor int64 163 switch { 164 case elem.GetLogicalType().TIMESTAMP.Unit.IsSetNANOS(): 165 factor = 1 166 case elem.GetLogicalType().TIMESTAMP.Unit.IsSetMICROS(): 167 factor = 1000 168 case elem.GetLogicalType().TIMESTAMP.Unit.IsSetMILLIS(): 169 factor = 1000000 170 default: 171 return errors.New("invalid TIMESTAMP unit") 172 } 173 ts := value.Interface().(time.Time).UnixNano() 174 ts /= factor 175 field.SetInt64(ts) 176 return nil 177 } 178 179 func (m *reflectMarshaller) decodeValue(field interfaces.MarshalElement, value reflect.Value, schemaDef *parquetschema.SchemaDefinition) error { 180 elem := schemaDef.SchemaElement() 181 if elem == nil { 182 return nil 183 } 184 185 if value.Kind() == reflect.Ptr || value.Kind() == reflect.Interface { 186 if value.IsNil() { 187 return nil 188 } 189 value = value.Elem() 190 } 191 192 if value.Type().ConvertibleTo(reflect.TypeOf(Time{})) { 193 if elem.LogicalType != nil && elem.GetLogicalType().IsSetTIME() { 194 return m.decodeTimeValue(elem, field, value) 195 } 196 } 197 198 if value.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) { 199 if elem.LogicalType != nil { 200 switch { 201 case elem.GetLogicalType().IsSetDATE(): 202 days := int32(value.Interface().(time.Time).Sub(time.Unix(0, 0).UTC()).Hours() / 24) 203 field.SetInt32(days) 204 return nil 205 case elem.GetLogicalType().IsSetTIMESTAMP(): 206 return m.decodeTimestampValue(elem, field, value) 207 } 208 } else if elem.GetType() == parquet.Type_INT96 { 209 field.SetInt96(goparquet.TimeToInt96(value.Interface().(time.Time))) 210 return nil 211 } 212 } 213 214 if !elem.IsSetType() && !elem.IsSetConvertedType() && elem.GetNumChildren() > 0 && value.Kind() == reflect.Map { 215 group := field.Group() 216 iter := value.MapRange() 217 for iter.Next() { 218 fieldName := iter.Key().String() 219 err := m.decodeValue(group.AddField(fieldName), iter.Value(), schemaDef.SubSchema(fieldName)) 220 if err != nil { 221 return err 222 } 223 } 224 225 return nil 226 } 227 228 switch elem.GetType() { 229 case parquet.Type_INT64: 230 switch value.Kind() { 231 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 232 field.SetInt64(value.Int()) 233 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 234 field.SetInt64(int64(value.Uint())) 235 default: 236 return fmt.Errorf("unable to decode %s:%s to int64", elem.Name, value.Kind()) 237 } 238 return nil 239 case parquet.Type_INT32: 240 switch value.Kind() { 241 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 242 field.SetInt32(int32(value.Int())) 243 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 244 field.SetInt32(int32(value.Uint())) 245 default: 246 return fmt.Errorf("unable to decode %s:%s to int32", elem.Name, value.Kind()) 247 } 248 return nil 249 case parquet.Type_INT96: 250 switch value.Kind() { 251 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 252 return m.decodeUnixTime(field, value.Int()) 253 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 254 return m.decodeUnixTime(field, int64(value.Uint())) 255 case reflect.String: 256 dt, _ := dateparse.ParseAny(value.String()) 257 field.SetInt96(goparquet.TimeToInt96(dt)) 258 return nil 259 case reflect.Slice: 260 if value.IsNil() { 261 return nil 262 } 263 if value.Type().Elem().Kind() != reflect.Uint8 { 264 return fmt.Errorf("field is of type INT96 but type is %s", value.Type().String()) 265 } 266 267 if value.Len() != 12 { 268 return fmt.Errorf("field is of type INT96 but length is %d", value.Len()) 269 } 270 var dst [12]byte 271 src := value.Interface().([]byte) 272 copy(dst[:], src) 273 field.SetInt96(dst) 274 return nil 275 case reflect.Array: 276 if value.Type().Elem().Kind() != reflect.Uint8 { 277 return fmt.Errorf("field is of type INT96 but type is %s", value.Type().String()) 278 } 279 if value.Len() != 12 { 280 return fmt.Errorf("field is of type INT96 but length is %d", value.Len()) 281 } 282 var dst [12]byte 283 src := value.Interface().([12]byte) 284 copy(dst[:], src[:]) 285 field.SetInt96(dst) 286 return nil 287 } 288 } 289 290 switch value.Kind() { 291 case reflect.Bool: 292 field.SetBool(value.Bool()) 293 return nil 294 case reflect.Float32: 295 field.SetFloat32(float32(value.Float())) 296 return nil 297 case reflect.Float64: 298 field.SetFloat64(value.Float()) 299 return nil 300 case reflect.Array, reflect.Slice: 301 if value.Type().Elem().Kind() == reflect.Uint8 { 302 return m.decodeByteSliceOrArray(field, value, schemaDef) 303 } 304 return m.decodeSliceOrArray(field, value, schemaDef) 305 case reflect.Map: 306 return m.decodeMap(field, value, schemaDef) 307 case reflect.String: 308 field.SetByteArray([]byte(value.String())) 309 return nil 310 case reflect.Struct: 311 return m.decodeStruct(field.Group(), value, schemaDef) 312 default: 313 return fmt.Errorf("unsupported type %s", value.Type()) 314 } 315 } 316 317 func (m *reflectMarshaller) decodeUnixTime(field interfaces.MarshalElement, i64 int64) error { 318 // best effort parse unix timestamps. 319 // since 99% of the time these are timestamps and are <= now this is a fairly safe bet 320 digits := i64Digits(i64) 321 now := time.Now() 322 323 switch { 324 case digits <= i64Digits(now.Unix()): 325 dt := time.Unix(i64, 0) 326 field.SetInt96(goparquet.TimeToInt96(dt)) 327 case digits <= i64Digits(now.UnixNano()/1000000): 328 dt := time.Unix(0, i64*int64(time.Millisecond)) 329 field.SetInt96(goparquet.TimeToInt96(dt)) 330 case digits <= i64Digits(now.UnixNano()/1000): 331 dt := time.Unix(0, i64*int64(time.Microsecond)) 332 field.SetInt96(goparquet.TimeToInt96(dt)) 333 case digits <= i64Digits(now.UnixNano()): 334 dt := time.Unix(0, i64) 335 field.SetInt96(goparquet.TimeToInt96(dt)) 336 default: 337 return fmt.Errorf("field is of type INT96 but value is not valid %d", i64) 338 } 339 return nil 340 } 341 342 func (m *reflectMarshaller) decodeByteSliceOrArray(field interfaces.MarshalElement, value reflect.Value, schemaDef *parquetschema.SchemaDefinition) error { 343 elem := schemaDef.SchemaElement() 344 if elem == nil { 345 return nil 346 } 347 348 if value.Kind() == reflect.Slice && value.IsNil() { 349 return nil 350 } 351 352 if elem.LogicalType != nil && elem.GetLogicalType().IsSetUUID() { 353 if value.Len() != 16 { 354 return fmt.Errorf("field is annotated as UUID but length is %d", value.Len()) 355 } 356 } 357 358 switch value.Kind() { 359 case reflect.Slice: 360 if value.IsNil() { 361 return nil 362 } 363 field.SetByteArray(value.Bytes()) 364 case reflect.Array: 365 data := reflect.MakeSlice(reflect.TypeOf([]byte{}), value.Len(), value.Len()) 366 _ = reflect.Copy(data, value) 367 field.SetByteArray(data.Bytes()) 368 } 369 return nil 370 } 371 372 func (m *reflectMarshaller) decodeSliceOrArray(field interfaces.MarshalElement, value reflect.Value, schemaDef *parquetschema.SchemaDefinition) error { 373 elem := schemaDef.SchemaElement() 374 if elem == nil { 375 return nil 376 } 377 378 if value.Kind() == reflect.Slice && value.IsNil() { 379 return nil 380 } 381 382 if elem.GetConvertedType() != parquet.ConvertedType_LIST { 383 return fmt.Errorf("decoding slice or array but schema element %s is not annotated as LIST", elem.GetName()) 384 } 385 386 elementSchemaDef := schemaDef.SubSchema("list").SubSchema("element") 387 if elementSchemaDef == nil { 388 elementSchemaDef = schemaDef.SubSchema("bag").SubSchema("array_element") 389 if elementSchemaDef == nil { 390 return fmt.Errorf("element %s is annotated as LIST but group structure seems invalid", schemaDef.SchemaElement().GetName()) 391 } 392 } 393 394 list := field.List() 395 396 for i := 0; i < value.Len(); i++ { 397 if err := m.decodeValue(list.Add(), value.Index(i), elementSchemaDef); err != nil { 398 return err 399 } 400 } 401 402 return nil 403 } 404 405 func (m *reflectMarshaller) decodeMap(field interfaces.MarshalElement, value reflect.Value, schemaDef *parquetschema.SchemaDefinition) error { 406 if value.IsNil() { 407 return nil 408 } 409 410 if elem := schemaDef.SchemaElement(); elem.GetConvertedType() != parquet.ConvertedType_MAP { 411 return fmt.Errorf("decoding map but schema element %s is not annotated as MAP", elem.GetName()) 412 } 413 414 keyValueSchemaDef := schemaDef.SubSchema("key_value") 415 keySchemaDef := keyValueSchemaDef.SubSchema("key") 416 valueSchemaDef := keyValueSchemaDef.SubSchema("value") 417 418 mapData := field.Map() 419 420 iter := value.MapRange() 421 422 for iter.Next() { 423 kvPair := mapData.Add() 424 425 if err := m.decodeValue(kvPair.Key(), iter.Key(), keySchemaDef); err != nil { 426 return err 427 } 428 429 if err := m.decodeValue(kvPair.Value(), iter.Value(), valueSchemaDef); err != nil { 430 return err 431 } 432 } 433 434 return nil 435 } 436 437 // Close flushes outstanding data and closes the underlying 438 // parquet writer. 439 func (w *Writer) Close() error { 440 if w.f != nil { 441 defer w.f.Close() 442 } 443 444 return w.w.Close() 445 } 446 447 func i64Digits(number int64) int { 448 count := 0 449 for number != 0 { 450 number /= 10 451 count++ 452 } 453 return count 454 }