github.com/fraugster/parquet-go@v0.12.0/cmd/csv2parquet/main.go (about)

     1  package main
     2  
     3  import (
     4  	"encoding/csv"
     5  	"encoding/json"
     6  	"errors"
     7  	"flag"
     8  	"fmt"
     9  	"io"
    10  	"log"
    11  	"os"
    12  	"sort"
    13  	"strconv"
    14  	"strings"
    15  	"unicode/utf8"
    16  
    17  	goparquet "github.com/fraugster/parquet-go"
    18  	"github.com/fraugster/parquet-go/parquet"
    19  	"github.com/fraugster/parquet-go/parquetschema"
    20  )
    21  
    22  var printLog = func(string, ...interface{}) {}
    23  
    24  func main() {
    25  	inputFile := flag.String("input", "", "CSV file input")
    26  	typeHints := flag.String("typehints", "", "type hints to help derive parquet schema. A comma-separated list of type hints in the format <column_name>=<parquettype>; valid parquet types: "+strings.Join(validTypeList(), ", "))
    27  	outputFile := flag.String("output", "", "output parquet file")
    28  	rowgroupSize := flag.Int64("rowgroup-size", 100*1024*1024, "row group size in bytes; if value is 0, then the row group size is unbounded")
    29  	compressionCodec := flag.String("compression", "snappy", "compression algorithm; allowed values: "+strings.Join(validCompressionCodecs(), ", "))
    30  	delimiter := flag.String("delimiter", ",", "CSV field delimiter")
    31  	creator := flag.String("created-by", "csv2parquet", "value to set for CreatedBy field of parquet file")
    32  	verbose := flag.Bool("v", false, "enable verbose logging")
    33  	flag.Parse()
    34  
    35  	if *inputFile == "" {
    36  		log.Fatalf("Empty input file parameter")
    37  	}
    38  
    39  	if *outputFile == "" {
    40  		log.Fatalf("Empty output file parameter")
    41  	}
    42  
    43  	codec, err := lookupCompressionCodec(*compressionCodec)
    44  	if err != nil {
    45  		log.Fatalf("Invalid compression codec %q: %v", *compressionCodec, err)
    46  	}
    47  
    48  	var delimiterRune rune
    49  
    50  	if *delimiter != "" {
    51  		delimiterRune, _ = utf8.DecodeRuneInString(*delimiter)
    52  		if delimiterRune == '\r' || delimiterRune == '\n' || delimiterRune == '\uFFFD' {
    53  			log.Fatalf("Invalid CSV field separator %q", *delimiter)
    54  		}
    55  	}
    56  
    57  	if *verbose {
    58  		printLog = log.Printf
    59  	}
    60  
    61  	types, err := parseTypeHints(*typeHints)
    62  	if err != nil {
    63  		log.Fatalf("Parsing type hints failed: %v", err)
    64  	}
    65  
    66  	printLog("Opening %s...", *inputFile)
    67  
    68  	f, err := os.Open(*inputFile)
    69  	if err != nil {
    70  		log.Fatalf("Couldn't open input file: %v", err)
    71  	}
    72  
    73  	csvReader := csv.NewReader(f)
    74  
    75  	if *delimiter != "" {
    76  		csvReader.Comma = delimiterRune
    77  	}
    78  
    79  	records, err := csvReader.ReadAll()
    80  	if err != nil {
    81  		log.Fatalf("Reading CSV content failed: %v", err)
    82  	}
    83  
    84  	f.Close()
    85  
    86  	header := records[0]
    87  	records = records[1:]
    88  
    89  	printLog("Finished reading %s, got %d records", *inputFile, len(records))
    90  
    91  	of, err := os.OpenFile(*outputFile, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
    92  	if err != nil {
    93  		log.Fatalf("Couldn't open output file: %v", err)
    94  	}
    95  	defer of.Close()
    96  
    97  	if err := writeParquetData(of, header, types, records, *creator, codec, *rowgroupSize); err != nil {
    98  		log.Fatalf("Couldn't write parquet data: %v", err)
    99  	}
   100  
   101  	printLog("Finished generating output file %s", *outputFile)
   102  }
   103  
   104  func writeParquetData(of io.Writer, header []string, types map[string]string, records [][]string, creator string, codec parquet.CompressionCodec, rowgroupSize int64) error {
   105  	schema, fieldHandlers, err := deriveSchema(header, types)
   106  	if err != nil {
   107  		return fmt.Errorf("generating schema failed: %w", err)
   108  	}
   109  
   110  	printLog("Derived parquet schema: %s", schema.String())
   111  
   112  	writerOptions := []goparquet.FileWriterOption{
   113  		goparquet.WithCreator(creator),
   114  		goparquet.WithSchemaDefinition(schema),
   115  		goparquet.WithCompressionCodec(codec),
   116  	}
   117  
   118  	if rowgroupSize > 0 {
   119  		writerOptions = append(writerOptions, goparquet.WithMaxRowGroupSize(rowgroupSize))
   120  	}
   121  
   122  	pqWriter := goparquet.NewFileWriter(of, writerOptions...)
   123  
   124  	for recordIndex, record := range records {
   125  		data := make(map[string]interface{})
   126  
   127  		if len(record) < len(header) {
   128  			return fmt.Errorf("input record %d only contains %d fields instead of the expected %d", recordIndex+1, len(record), len(header))
   129  		}
   130  
   131  		for idx, fieldName := range header {
   132  			handler := fieldHandlers[idx]
   133  
   134  			v, err := handler(record[idx])
   135  			if err != nil {
   136  				return fmt.Errorf("in input record %d, couldn't convert value %q to type %s: %w", recordIndex+1, record[idx], types[fieldName], err)
   137  			}
   138  			data[fieldName] = v
   139  		}
   140  		if err := pqWriter.AddData(data); err != nil {
   141  			return fmt.Errorf("in input record %d, adding data failed: %w", recordIndex+1, err)
   142  		}
   143  	}
   144  
   145  	if err := pqWriter.Close(); err != nil {
   146  		return fmt.Errorf("closing parquet writer failed: %w", err)
   147  	}
   148  
   149  	return nil
   150  }
   151  
   152  type fieldHandler func(string) (interface{}, error)
   153  
   154  func deriveSchema(header []string, types map[string]string) (schema *parquetschema.SchemaDefinition, fieldHandlers []fieldHandler, err error) {
   155  	schema = &parquetschema.SchemaDefinition{
   156  		RootColumn: &parquetschema.ColumnDefinition{
   157  			SchemaElement: &parquet.SchemaElement{
   158  				Name: "msg",
   159  			},
   160  		},
   161  	}
   162  
   163  	fieldHandlers = make([]fieldHandler, 0, len(header))
   164  
   165  	for _, field := range header {
   166  		typ := types[field]
   167  		if typ == "" {
   168  			typ = "string"
   169  			types[field] = typ
   170  		}
   171  
   172  		col, handler, err := createColumn(field, typ)
   173  		if err != nil {
   174  			return nil, nil, fmt.Errorf("couldn't create column for field %s: %v", field, err)
   175  		}
   176  
   177  		fieldHandlers = append(fieldHandlers, handler)
   178  		schema.RootColumn.Children = append(schema.RootColumn.Children, col)
   179  	}
   180  
   181  	if err := schema.Validate(); err != nil {
   182  		return schema, nil, fmt.Errorf("validation of generated schema failed: %w", err)
   183  	}
   184  
   185  	return schema, fieldHandlers, nil
   186  }
   187  
   188  func createColumn(field, typ string) (col *parquetschema.ColumnDefinition, fieldHandler func(string) (interface{}, error), rr error) {
   189  	col = &parquetschema.ColumnDefinition{
   190  		SchemaElement: &parquet.SchemaElement{},
   191  	}
   192  	col.SchemaElement.RepetitionType = parquet.FieldRepetitionTypePtr(parquet.FieldRepetitionType_OPTIONAL)
   193  	col.SchemaElement.Name = field
   194  
   195  	switch typ {
   196  	case "string":
   197  		col.SchemaElement.Type = parquet.TypePtr(parquet.Type_BYTE_ARRAY)
   198  		col.SchemaElement.LogicalType = parquet.NewLogicalType()
   199  		col.SchemaElement.LogicalType.STRING = &parquet.StringType{}
   200  		col.SchemaElement.ConvertedType = parquet.ConvertedTypePtr(parquet.ConvertedType_UTF8)
   201  		fieldHandler = byteArrayHandler
   202  	case "byte_array":
   203  		col.SchemaElement.Type = parquet.TypePtr(parquet.Type_BYTE_ARRAY)
   204  		fieldHandler = byteArrayHandler
   205  	case "boolean":
   206  		col.SchemaElement.Type = parquet.TypePtr(parquet.Type_BOOLEAN)
   207  		fieldHandler = booleanHandler
   208  	case "int8":
   209  		col.SchemaElement.Type = parquet.TypePtr(parquet.Type_INT32)
   210  		col.SchemaElement.LogicalType = parquet.NewLogicalType()
   211  		col.SchemaElement.LogicalType.INTEGER = &parquet.IntType{BitWidth: 8, IsSigned: true}
   212  		col.SchemaElement.ConvertedType = parquet.ConvertedTypePtr(parquet.ConvertedType_INT_8)
   213  		fieldHandler = intHandler(8)
   214  	case "uint8":
   215  		col.SchemaElement.Type = parquet.TypePtr(parquet.Type_INT32)
   216  		col.SchemaElement.LogicalType = parquet.NewLogicalType()
   217  		col.SchemaElement.LogicalType.INTEGER = &parquet.IntType{BitWidth: 8, IsSigned: false}
   218  		col.SchemaElement.ConvertedType = parquet.ConvertedTypePtr(parquet.ConvertedType_UINT_8)
   219  		fieldHandler = uintHandler(8)
   220  	case "int16":
   221  		col.SchemaElement.Type = parquet.TypePtr(parquet.Type_INT32)
   222  		col.SchemaElement.LogicalType = parquet.NewLogicalType()
   223  		col.SchemaElement.LogicalType.INTEGER = &parquet.IntType{BitWidth: 16, IsSigned: true}
   224  		col.SchemaElement.ConvertedType = parquet.ConvertedTypePtr(parquet.ConvertedType_INT_16)
   225  		fieldHandler = intHandler(16)
   226  	case "uint16":
   227  		col.SchemaElement.Type = parquet.TypePtr(parquet.Type_INT32)
   228  		col.SchemaElement.LogicalType = parquet.NewLogicalType()
   229  		col.SchemaElement.LogicalType.INTEGER = &parquet.IntType{BitWidth: 16, IsSigned: false}
   230  		col.SchemaElement.ConvertedType = parquet.ConvertedTypePtr(parquet.ConvertedType_UINT_16)
   231  		fieldHandler = uintHandler(16)
   232  	case "int32":
   233  		col.SchemaElement.Type = parquet.TypePtr(parquet.Type_INT32)
   234  		col.SchemaElement.LogicalType = parquet.NewLogicalType()
   235  		col.SchemaElement.LogicalType.INTEGER = &parquet.IntType{BitWidth: 32, IsSigned: true}
   236  		col.SchemaElement.ConvertedType = parquet.ConvertedTypePtr(parquet.ConvertedType_INT_32)
   237  		fieldHandler = intHandler(32)
   238  	case "uint32":
   239  		col.SchemaElement.Type = parquet.TypePtr(parquet.Type_INT32)
   240  		col.SchemaElement.LogicalType = parquet.NewLogicalType()
   241  		col.SchemaElement.LogicalType.INTEGER = &parquet.IntType{BitWidth: 32, IsSigned: false}
   242  		col.SchemaElement.ConvertedType = parquet.ConvertedTypePtr(parquet.ConvertedType_UINT_32)
   243  		fieldHandler = uintHandler(32)
   244  	case "int64":
   245  		col.SchemaElement.Type = parquet.TypePtr(parquet.Type_INT64)
   246  		col.SchemaElement.LogicalType = parquet.NewLogicalType()
   247  		col.SchemaElement.LogicalType.INTEGER = &parquet.IntType{BitWidth: 64, IsSigned: true}
   248  		col.SchemaElement.ConvertedType = parquet.ConvertedTypePtr(parquet.ConvertedType_INT_64)
   249  		fieldHandler = intHandler(64)
   250  	case "uint64":
   251  		col.SchemaElement.Type = parquet.TypePtr(parquet.Type_INT64)
   252  		col.SchemaElement.LogicalType = parquet.NewLogicalType()
   253  		col.SchemaElement.LogicalType.INTEGER = &parquet.IntType{BitWidth: 64, IsSigned: false}
   254  		col.SchemaElement.ConvertedType = parquet.ConvertedTypePtr(parquet.ConvertedType_UINT_64)
   255  		fieldHandler = uintHandler(64)
   256  	case "float":
   257  		col.SchemaElement.Type = parquet.TypePtr(parquet.Type_FLOAT)
   258  		fieldHandler = floatHandler
   259  	case "double":
   260  		col.SchemaElement.Type = parquet.TypePtr(parquet.Type_DOUBLE)
   261  		fieldHandler = doubleHandler
   262  	case "int":
   263  		col.SchemaElement.Type = parquet.TypePtr(parquet.Type_INT64)
   264  		col.SchemaElement.LogicalType = parquet.NewLogicalType()
   265  		col.SchemaElement.LogicalType.INTEGER = &parquet.IntType{BitWidth: 64, IsSigned: true}
   266  		col.SchemaElement.ConvertedType = parquet.ConvertedTypePtr(parquet.ConvertedType_INT_64)
   267  		fieldHandler = intHandler(64)
   268  	case "json":
   269  		col.SchemaElement.Type = parquet.TypePtr(parquet.Type_BYTE_ARRAY)
   270  		col.SchemaElement.LogicalType = parquet.NewLogicalType()
   271  		col.SchemaElement.LogicalType.JSON = &parquet.JsonType{}
   272  		col.SchemaElement.ConvertedType = parquet.ConvertedTypePtr(parquet.ConvertedType_JSON)
   273  		fieldHandler = jsonHandler
   274  	default:
   275  		return nil, nil, fmt.Errorf("unsupported type %q", typ)
   276  	}
   277  
   278  	fieldHandler = optionalHandler(fieldHandler) // TODO: if we make repetition type configurable, change this to use correct handler.
   279  
   280  	return col, fieldHandler, nil
   281  }
   282  
   283  func parseTypeHints(s string) (map[string]string, error) {
   284  	typeMap := make(map[string]string)
   285  
   286  	if s == "" {
   287  		return typeMap, nil
   288  	}
   289  
   290  	hintsList := strings.Split(s, ",")
   291  	for _, hint := range hintsList {
   292  		hint = strings.TrimSpace(hint)
   293  
   294  		hintFields := strings.Split(hint, "=")
   295  		if len(hintFields) != 2 {
   296  			return nil, fmt.Errorf("invalid type hint %q", hint)
   297  		}
   298  
   299  		fieldName := strings.TrimSpace(hintFields[0])
   300  		fieldType := strings.TrimSpace(hintFields[1])
   301  
   302  		if !isValidType(fieldType) {
   303  			return nil, fmt.Errorf("invalid parquet type %q", fieldType)
   304  		}
   305  
   306  		typeMap[fieldName] = fieldType
   307  	}
   308  
   309  	return typeMap, nil
   310  }
   311  
   312  var validTypes = map[string]bool{
   313  	"boolean":    true,
   314  	"int8":       true,
   315  	"uint8":      true,
   316  	"int16":      true,
   317  	"uint16":     true,
   318  	"int32":      true,
   319  	"uint32":     true,
   320  	"int64":      true,
   321  	"uint64":     true,
   322  	"float":      true,
   323  	"double":     true,
   324  	"byte_array": true,
   325  	"string":     true,
   326  	"int":        true,
   327  	"json":       true,
   328  	// TODO: support more data types
   329  }
   330  
   331  func validTypeList() []string {
   332  	l := make([]string, 0, len(validTypes))
   333  	for k := range validTypes {
   334  		l = append(l, k)
   335  	}
   336  	sort.Strings(l)
   337  	return l
   338  }
   339  
   340  func validCompressionCodecs() []string {
   341  	registeredCodecs := goparquet.GetRegisteredBlockCompressors()
   342  
   343  	l := make([]string, 0, len(registeredCodecs))
   344  	for k := range registeredCodecs {
   345  		l = append(l, strings.ToLower(k.String()))
   346  	}
   347  	sort.Strings(l)
   348  	return l
   349  }
   350  
   351  func lookupCompressionCodec(codec string) (parquet.CompressionCodec, error) {
   352  	registeredCodecs := goparquet.GetRegisteredBlockCompressors()
   353  
   354  	for c := range registeredCodecs {
   355  		if strings.ToLower(c.String()) == codec {
   356  			return c, nil
   357  		}
   358  	}
   359  
   360  	return parquet.CompressionCodec_UNCOMPRESSED, errors.New("unsupported compression codec")
   361  }
   362  
   363  func isValidType(t string) bool {
   364  	return validTypes[t]
   365  }
   366  
   367  func byteArrayHandler(s string) (interface{}, error) {
   368  	return []byte(s), nil
   369  }
   370  
   371  func booleanHandler(s string) (interface{}, error) {
   372  	return strconv.ParseBool(s)
   373  }
   374  
   375  func uintHandler(bitSize int) func(string) (interface{}, error) {
   376  	return func(s string) (interface{}, error) {
   377  		i, err := strconv.ParseUint(s, 10, bitSize)
   378  		if err != nil {
   379  			return nil, err
   380  		}
   381  		switch bitSize {
   382  		case 8, 16, 32:
   383  			return uint32(i), nil
   384  		case 64:
   385  			return i, nil
   386  		default:
   387  			return nil, fmt.Errorf("invalid bit size %d", bitSize)
   388  		}
   389  	}
   390  }
   391  
   392  func intHandler(bitSize int) func(string) (interface{}, error) {
   393  	return func(s string) (interface{}, error) {
   394  		i, err := strconv.ParseInt(s, 10, bitSize)
   395  		if err != nil {
   396  			return nil, err
   397  		}
   398  		switch bitSize {
   399  		case 8, 16, 32:
   400  			return int32(i), nil
   401  		case 64:
   402  			return i, nil
   403  		default:
   404  			return nil, fmt.Errorf("invalid bit size %d", bitSize)
   405  		}
   406  	}
   407  }
   408  
   409  func floatHandler(s string) (interface{}, error) {
   410  	f, err := strconv.ParseFloat(s, 32)
   411  	return float32(f), err
   412  }
   413  
   414  func doubleHandler(s string) (interface{}, error) {
   415  	f, err := strconv.ParseFloat(s, 64)
   416  	return f, err
   417  }
   418  
   419  func jsonHandler(s string) (interface{}, error) {
   420  	data := []byte(s)
   421  	var obj interface{}
   422  	if err := json.Unmarshal(data, &obj); err != nil {
   423  		return nil, err
   424  	}
   425  	return data, nil
   426  }
   427  
   428  func optionalHandler(next fieldHandler) fieldHandler {
   429  	return func(s string) (interface{}, error) {
   430  		if s == "" {
   431  			return nil, nil
   432  		}
   433  		return next(s)
   434  	}
   435  }