github.com/fraugster/parquet-go@v0.12.0/parquetschema/schema_parser.go (about)

     1  package parquetschema
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"math"
     7  	"runtime"
     8  	"strconv"
     9  	"strings"
    10  	"unicode"
    11  	"unicode/utf8"
    12  
    13  	"github.com/fraugster/parquet-go/parquet"
    14  )
    15  
    16  type item struct {
    17  	typ  itemType
    18  	pos  pos
    19  	val  string
    20  	line int
    21  }
    22  
    23  type pos int
    24  
    25  func (i item) String() string {
    26  	switch {
    27  	case i.typ == itemEOF:
    28  		return "EOF"
    29  	case i.typ == itemError:
    30  		return i.val
    31  	case len(i.val) > 10:
    32  		return fmt.Sprintf("%.10q...", i.val)
    33  	}
    34  	return fmt.Sprintf("%q", i.val)
    35  }
    36  
    37  type itemType int
    38  
    39  const (
    40  	itemError itemType = iota
    41  	itemEOF
    42  
    43  	itemLeftParen
    44  	itemRightParen
    45  	itemLeftBrace
    46  	itemRightBrace
    47  	itemEqual
    48  	itemSemicolon
    49  	itemComma
    50  	itemNumber
    51  	itemIdentifier
    52  	itemKeyword
    53  	itemMessage
    54  	itemRepeated
    55  	itemOptional
    56  	itemRequired
    57  	itemGroup
    58  )
    59  
    60  func (i itemType) String() string {
    61  	typeNames := map[itemType]string{
    62  		itemError:      "error",
    63  		itemEOF:        "EOF",
    64  		itemLeftParen:  "(",
    65  		itemRightParen: ")",
    66  		itemLeftBrace:  "{",
    67  		itemRightBrace: "}",
    68  		itemEqual:      "=",
    69  		itemSemicolon:  ";",
    70  		itemComma:      ",",
    71  		itemNumber:     "number",
    72  		itemIdentifier: "identifier",
    73  		itemKeyword:    "<keyword>",
    74  		itemMessage:    "message",
    75  		itemRepeated:   "repeated",
    76  		itemOptional:   "optional",
    77  		itemRequired:   "required",
    78  		itemGroup:      "group",
    79  	}
    80  
    81  	n, ok := typeNames[i]
    82  	if !ok {
    83  		return fmt.Sprintf("<type:%d>", int(i))
    84  	}
    85  	return n
    86  }
    87  
    88  var key = map[string]itemType{
    89  	"message":  itemMessage,
    90  	"repeated": itemRepeated,
    91  	"optional": itemOptional,
    92  	"required": itemRequired,
    93  	"group":    itemGroup,
    94  }
    95  
    96  const eof = -1
    97  
    98  type stateFn func(*schemaLexer) stateFn
    99  
   100  type schemaLexer struct {
   101  	input     string
   102  	pos       pos
   103  	start     pos
   104  	width     pos
   105  	items     chan item
   106  	line      int
   107  	startLine int
   108  }
   109  
   110  func (l *schemaLexer) next() rune {
   111  	if int(l.pos) >= len(l.input) {
   112  		l.width = 0
   113  		return eof
   114  	}
   115  
   116  	r, w := utf8.DecodeRuneInString(l.input[l.pos:])
   117  	l.width = pos(w)
   118  	l.pos += l.width
   119  	if r == '\n' {
   120  		l.line++
   121  	}
   122  	return r
   123  }
   124  
   125  func (l *schemaLexer) peek() rune {
   126  	r := l.next()
   127  	l.backup()
   128  	return r
   129  }
   130  
   131  func (l *schemaLexer) backup() {
   132  	l.pos -= l.width
   133  	if l.width == 1 && l.input[l.pos] == '\n' {
   134  		l.line--
   135  	}
   136  }
   137  
   138  func (l *schemaLexer) ignore() {
   139  	l.start = l.pos
   140  	l.startLine = l.line
   141  }
   142  
   143  func (l *schemaLexer) emit(t itemType) {
   144  	l.items <- item{t, l.start, l.input[l.start:l.pos], l.startLine}
   145  	l.start = l.pos
   146  	l.startLine = l.line
   147  }
   148  
   149  func (l *schemaLexer) acceptRun(valid string) {
   150  	for strings.ContainsRune(valid, l.next()) {
   151  	}
   152  	l.backup()
   153  }
   154  
   155  func (l *schemaLexer) nextItem() item {
   156  	return <-l.items
   157  }
   158  
   159  func (l *schemaLexer) drain() {
   160  	for range l.items {
   161  	}
   162  }
   163  
   164  func lex(input string) *schemaLexer {
   165  	l := &schemaLexer{
   166  		input:     input,
   167  		items:     make(chan item),
   168  		line:      1,
   169  		startLine: 1,
   170  	}
   171  
   172  	go l.run()
   173  	return l
   174  }
   175  
   176  func (l *schemaLexer) run() {
   177  	for state := lexText; state != nil; {
   178  		state = state(l)
   179  	}
   180  	close(l.items)
   181  }
   182  
   183  func lexText(l *schemaLexer) stateFn {
   184  	switch r := l.next(); {
   185  	case r == eof:
   186  		l.emit(itemEOF)
   187  		return nil
   188  	case isSpace(r):
   189  		return lexSpace
   190  	case r == '(':
   191  		l.emit(itemLeftParen)
   192  	case r == ')':
   193  		l.emit(itemRightParen)
   194  	case r == '{':
   195  		l.emit(itemLeftBrace)
   196  	case r == '}':
   197  		l.emit(itemRightBrace)
   198  	case isDigit(r):
   199  		return lexNumber
   200  	case r == '=':
   201  		l.emit(itemEqual)
   202  	case r == ';':
   203  		l.emit(itemSemicolon)
   204  	case r == ',':
   205  		l.emit(itemComma)
   206  	default:
   207  		return lexIdentifier
   208  	}
   209  	return lexText
   210  }
   211  
   212  func isSpace(r rune) bool {
   213  	return r == ' ' || r == '\t' || r == '\n' || r == '\r'
   214  }
   215  
   216  func isDigit(r rune) bool {
   217  	return unicode.IsDigit(r)
   218  }
   219  
   220  func isSchemaDelim(r rune) bool {
   221  	return r == ' ' || r == ';' || r == '{' || r == '}' || r == '(' || r == ')' || r == '=' || r == ','
   222  }
   223  
   224  func lexSpace(l *schemaLexer) stateFn {
   225  	for isSpace(l.peek()) {
   226  		l.next()
   227  	}
   228  	l.ignore()
   229  	return lexText
   230  }
   231  
   232  func lexNumber(l *schemaLexer) stateFn {
   233  	l.acceptRun("0123456789")
   234  	l.emit(itemNumber)
   235  	return lexText
   236  }
   237  
   238  func lexIdentifier(l *schemaLexer) stateFn {
   239  loop:
   240  	for {
   241  		switch r := l.next(); {
   242  		case !isSchemaDelim(r): // the = is there to accept it as part of the identifiers being read within type annotations.
   243  			// absorb.
   244  		default:
   245  			l.backup()
   246  			word := l.input[l.start:l.pos]
   247  			switch {
   248  			case key[word] > itemKeyword:
   249  				l.emit(key[word])
   250  			default:
   251  				l.emit(itemIdentifier)
   252  			}
   253  			break loop
   254  		}
   255  	}
   256  	return lexText
   257  }
   258  
   259  type schemaParser struct {
   260  	l     *schemaLexer
   261  	token item
   262  	root  *ColumnDefinition
   263  }
   264  
   265  func newSchemaParser(text string) *schemaParser {
   266  	return &schemaParser{
   267  		l:    lex(text),
   268  		root: &ColumnDefinition{SchemaElement: &parquet.SchemaElement{}},
   269  	}
   270  }
   271  
   272  func (p *schemaParser) parse() (err error) {
   273  	defer p.recover(&err)
   274  
   275  	p.parseMessage()
   276  
   277  	p.next()
   278  	p.expect(itemEOF)
   279  
   280  	p.validate(p.root, false)
   281  
   282  	return nil
   283  }
   284  
   285  func (p *schemaParser) recover(errp *error) {
   286  	if e := recover(); e != nil {
   287  		if _, ok := e.(runtime.Error); ok {
   288  			panic(e)
   289  		}
   290  		p.l.drain()
   291  		*errp = e.(error)
   292  	}
   293  }
   294  
   295  func (p *schemaParser) errorf(msg string, args ...interface{}) {
   296  	msg = fmt.Sprintf("line %d: %s", p.token.line, msg)
   297  	panic(fmt.Errorf(msg, args...))
   298  }
   299  
   300  func (p *schemaParser) expect(typ itemType) {
   301  	if typ == itemIdentifier && p.token.typ > itemKeyword {
   302  		return
   303  	}
   304  
   305  	if p.token.typ != typ {
   306  		p.errorf("expected %s, got %s instead", typ, p.token)
   307  	}
   308  }
   309  
   310  func (p *schemaParser) next() {
   311  	p.token = p.l.nextItem()
   312  }
   313  
   314  func (p *schemaParser) parseMessage() {
   315  	p.next()
   316  	p.expect(itemMessage)
   317  
   318  	p.next()
   319  	p.expect(itemIdentifier)
   320  
   321  	p.root.SchemaElement.Name = p.token.val
   322  
   323  	p.next()
   324  	p.expect(itemLeftBrace)
   325  
   326  	p.root.Children = p.parseMessageBody()
   327  	for _, c := range p.root.Children {
   328  		recursiveFix(c)
   329  	}
   330  
   331  	p.expect(itemRightBrace)
   332  }
   333  
   334  func recursiveFix(col *ColumnDefinition) {
   335  	if nc := int32(len(col.Children)); nc > 0 {
   336  		col.SchemaElement.NumChildren = &nc
   337  	}
   338  
   339  	for i := range col.Children {
   340  		recursiveFix(col.Children[i])
   341  	}
   342  }
   343  
   344  func (p *schemaParser) parseMessageBody() []*ColumnDefinition {
   345  	var cols []*ColumnDefinition
   346  	p.expect(itemLeftBrace)
   347  	for {
   348  		p.next()
   349  		if p.token.typ == itemRightBrace {
   350  			return cols
   351  		}
   352  
   353  		cols = append(cols, p.parseColumnDefinition())
   354  	}
   355  }
   356  
   357  func (p *schemaParser) parseColumnDefinition() *ColumnDefinition {
   358  	col := &ColumnDefinition{
   359  		SchemaElement: &parquet.SchemaElement{},
   360  	}
   361  
   362  	switch p.token.typ {
   363  	case itemRepeated:
   364  		col.SchemaElement.RepetitionType = parquet.FieldRepetitionTypePtr(parquet.FieldRepetitionType_REPEATED)
   365  	case itemOptional:
   366  		col.SchemaElement.RepetitionType = parquet.FieldRepetitionTypePtr(parquet.FieldRepetitionType_OPTIONAL)
   367  	case itemRequired:
   368  		col.SchemaElement.RepetitionType = parquet.FieldRepetitionTypePtr(parquet.FieldRepetitionType_REQUIRED)
   369  	default:
   370  		p.errorf("invalid field repetition type %q", p.token.val)
   371  	}
   372  
   373  	p.next()
   374  
   375  	if p.token.typ == itemGroup {
   376  		p.next()
   377  		p.expect(itemIdentifier)
   378  		col.SchemaElement.Name = p.token.val
   379  
   380  		p.next()
   381  		if p.token.typ == itemLeftParen {
   382  			col.SchemaElement.ConvertedType = p.parseConvertedType()
   383  			p.next()
   384  		}
   385  
   386  		col.Children = p.parseMessageBody()
   387  
   388  		p.expect(itemRightBrace)
   389  	} else {
   390  		col.SchemaElement.Type = p.getTokenType()
   391  
   392  		if col.SchemaElement.GetType() == parquet.Type_FIXED_LEN_BYTE_ARRAY {
   393  			p.next()
   394  			p.expect(itemLeftParen)
   395  			p.next()
   396  			p.expect(itemNumber)
   397  
   398  			i, err := strconv.ParseUint(p.token.val, 10, 32)
   399  			if err != nil {
   400  				p.errorf("invalid fixed_len_byte_array length %q: %v", p.token.val, err)
   401  			}
   402  
   403  			byteArraySize := int32(i)
   404  
   405  			col.SchemaElement.TypeLength = &byteArraySize
   406  
   407  			p.next()
   408  			p.expect(itemRightParen)
   409  		}
   410  
   411  		p.next()
   412  		p.expect(itemIdentifier)
   413  		col.SchemaElement.Name = p.token.val
   414  
   415  		p.next()
   416  		if p.token.typ == itemLeftParen {
   417  			col.SchemaElement.LogicalType, col.SchemaElement.ConvertedType = p.parseLogicalOrConvertedType()
   418  			if col.SchemaElement.LogicalType != nil && col.SchemaElement.LogicalType.IsSetDECIMAL() {
   419  				col.SchemaElement.Scale = &col.SchemaElement.LogicalType.DECIMAL.Scale
   420  				col.SchemaElement.Precision = &col.SchemaElement.LogicalType.DECIMAL.Precision
   421  			}
   422  			p.next()
   423  		}
   424  
   425  		if p.token.typ == itemEqual {
   426  			col.SchemaElement.FieldID = p.parseFieldID()
   427  			p.next()
   428  		}
   429  
   430  		p.expect(itemSemicolon)
   431  	}
   432  
   433  	return col
   434  }
   435  
   436  func (p *schemaParser) isValidType(typ string) {
   437  	validTypes := []string{"binary", "float", "double", "boolean", "int32", "int64", "int96", "fixed_len_byte_array"}
   438  	for _, vt := range validTypes {
   439  		if vt == typ {
   440  			return
   441  		}
   442  	}
   443  	p.errorf("invalid type %q", typ)
   444  }
   445  
   446  func (p *schemaParser) getTokenType() *parquet.Type {
   447  	p.isValidType(p.token.val)
   448  
   449  	switch p.token.val {
   450  	case "binary":
   451  		return parquet.TypePtr(parquet.Type_BYTE_ARRAY)
   452  	case "float":
   453  		return parquet.TypePtr(parquet.Type_FLOAT)
   454  	case "double":
   455  		return parquet.TypePtr(parquet.Type_DOUBLE)
   456  	case "boolean":
   457  		return parquet.TypePtr(parquet.Type_BOOLEAN)
   458  	case "int32":
   459  		return parquet.TypePtr(parquet.Type_INT32)
   460  	case "int64":
   461  		return parquet.TypePtr(parquet.Type_INT64)
   462  	case "int96":
   463  		return parquet.TypePtr(parquet.Type_INT96)
   464  	case "fixed_len_byte_array":
   465  		return parquet.TypePtr(parquet.Type_FIXED_LEN_BYTE_ARRAY)
   466  	default:
   467  		p.errorf("unsupported type %q", p.token.val)
   468  		return nil
   469  	}
   470  }
   471  
   472  func (p *schemaParser) parseLogicalOrConvertedType() (*parquet.LogicalType, *parquet.ConvertedType) {
   473  	p.expect(itemLeftParen)
   474  	p.next()
   475  	p.expect(itemIdentifier)
   476  
   477  	typStr := p.token.val
   478  
   479  	lt := parquet.NewLogicalType()
   480  	var ct *parquet.ConvertedType
   481  
   482  	switch strings.ToUpper(typStr) {
   483  	case "STRING":
   484  		lt.STRING = parquet.NewStringType()
   485  		ct = parquet.ConvertedTypePtr(parquet.ConvertedType_UTF8)
   486  		p.next()
   487  	case "DATE":
   488  		lt.DATE = parquet.NewDateType()
   489  		ct = parquet.ConvertedTypePtr(parquet.ConvertedType_DATE)
   490  		p.next()
   491  	case "TIMESTAMP":
   492  		ct = p.parseTimestampLogicalType(lt)
   493  		p.next()
   494  	case "TIME":
   495  		ct = p.parseTimeLogicalType(lt)
   496  		p.next()
   497  	case "INT":
   498  		ct = p.parseIntLogicalType(lt)
   499  		p.next()
   500  	case "UUID":
   501  		lt.UUID = parquet.NewUUIDType()
   502  		p.next()
   503  	case "ENUM":
   504  		lt.ENUM = parquet.NewEnumType()
   505  		ct = parquet.ConvertedTypePtr(parquet.ConvertedType_ENUM)
   506  		p.next()
   507  	case "JSON":
   508  		lt.JSON = parquet.NewJsonType()
   509  		ct = parquet.ConvertedTypePtr(parquet.ConvertedType_JSON)
   510  		p.next()
   511  	case "BSON":
   512  		lt.BSON = parquet.NewBsonType()
   513  		ct = parquet.ConvertedTypePtr(parquet.ConvertedType_BSON)
   514  		p.next()
   515  	case "DECIMAL":
   516  		ct = p.parseDecimalLogicalType(lt)
   517  		// n.b. no p.next is necessary because parseDecimalLogicalType may have already seen the ) if the list of scale and precision were not there, i.e. if it was a converted type.
   518  	default:
   519  		convertedType, err := parquet.ConvertedTypeFromString(strings.ToUpper(typStr))
   520  		if err != nil {
   521  			p.errorf("unsupported logical type or converted type %q", typStr)
   522  		}
   523  		lt = nil
   524  		ct = &convertedType
   525  		p.next()
   526  	}
   527  
   528  	p.expect(itemRightParen)
   529  
   530  	return lt, ct
   531  }
   532  
   533  func (p *schemaParser) parseTimestampLogicalType(lt *parquet.LogicalType) (ct *parquet.ConvertedType) {
   534  	lt.TIMESTAMP = parquet.NewTimestampType()
   535  	p.next()
   536  	p.expect(itemLeftParen)
   537  
   538  	p.next()
   539  	p.expect(itemIdentifier)
   540  
   541  	lt.TIMESTAMP.Unit = parquet.NewTimeUnit()
   542  	switch p.token.val {
   543  	case "MILLIS":
   544  		lt.TIMESTAMP.Unit.MILLIS = parquet.NewMilliSeconds()
   545  		ct = parquet.ConvertedTypePtr(parquet.ConvertedType_TIMESTAMP_MILLIS)
   546  	case "MICROS":
   547  		lt.TIMESTAMP.Unit.MICROS = parquet.NewMicroSeconds()
   548  		ct = parquet.ConvertedTypePtr(parquet.ConvertedType_TIMESTAMP_MICROS)
   549  	case "NANOS":
   550  		lt.TIMESTAMP.Unit.NANOS = parquet.NewNanoSeconds()
   551  	default:
   552  		p.errorf("unknown unit annotation %q for TIMESTAMP", p.token.val)
   553  	}
   554  
   555  	p.next()
   556  	p.expect(itemComma)
   557  
   558  	p.next()
   559  	p.expect(itemIdentifier)
   560  
   561  	switch p.token.val {
   562  	case "true", "false":
   563  		lt.TIMESTAMP.IsAdjustedToUTC, _ = strconv.ParseBool(p.token.val)
   564  	default:
   565  		p.errorf("invalid isAdjustedToUTC annotation %q for TIMESTAMP", p.token.val)
   566  	}
   567  
   568  	p.next()
   569  	p.expect(itemRightParen)
   570  
   571  	return ct
   572  }
   573  
   574  func (p *schemaParser) parseTimeLogicalType(lt *parquet.LogicalType) (ct *parquet.ConvertedType) {
   575  	lt.TIME = parquet.NewTimeType()
   576  	p.next()
   577  	p.expect(itemLeftParen)
   578  
   579  	p.next()
   580  	p.expect(itemIdentifier)
   581  
   582  	lt.TIME.Unit = parquet.NewTimeUnit()
   583  	switch p.token.val {
   584  	case "MILLIS":
   585  		lt.TIME.Unit.MILLIS = parquet.NewMilliSeconds()
   586  		ct = parquet.ConvertedTypePtr(parquet.ConvertedType_TIME_MILLIS)
   587  	case "MICROS":
   588  		lt.TIME.Unit.MICROS = parquet.NewMicroSeconds()
   589  		ct = parquet.ConvertedTypePtr(parquet.ConvertedType_TIME_MICROS)
   590  	case "NANOS":
   591  		lt.TIME.Unit.NANOS = parquet.NewNanoSeconds()
   592  	default:
   593  		p.errorf("unknown unit annotation %q for TIME", p.token.val)
   594  	}
   595  
   596  	p.next()
   597  	p.expect(itemComma)
   598  
   599  	p.next()
   600  	p.expect(itemIdentifier)
   601  
   602  	switch p.token.val {
   603  	case "true", "false":
   604  		lt.TIME.IsAdjustedToUTC, _ = strconv.ParseBool(p.token.val)
   605  	default:
   606  		p.errorf("invalid isAdjustedToUTC annotation %q for TIME", p.token.val)
   607  	}
   608  
   609  	p.next()
   610  	p.expect(itemRightParen)
   611  
   612  	return ct
   613  }
   614  
   615  func (p *schemaParser) parseIntLogicalType(lt *parquet.LogicalType) *parquet.ConvertedType {
   616  	lt.INTEGER = parquet.NewIntType()
   617  	p.next()
   618  	p.expect(itemLeftParen)
   619  
   620  	p.next()
   621  	p.expect(itemNumber)
   622  
   623  	bitWidth, _ := strconv.ParseInt(p.token.val, 10, 64)
   624  	if bitWidth != 8 && bitWidth != 16 && bitWidth != 32 && bitWidth != 64 {
   625  		p.errorf("INT: unsupported bitwidth %d", bitWidth)
   626  	}
   627  
   628  	lt.INTEGER.BitWidth = int8(bitWidth)
   629  
   630  	p.next()
   631  	p.expect(itemComma)
   632  
   633  	p.next()
   634  	p.expect(itemIdentifier)
   635  	switch p.token.val {
   636  	case "true", "false":
   637  		lt.INTEGER.IsSigned, _ = strconv.ParseBool(p.token.val)
   638  	default:
   639  		p.errorf("invalid isSigned annotation %q for INT", p.token.val)
   640  	}
   641  
   642  	p.next()
   643  	p.expect(itemRightParen)
   644  
   645  	convertedTypeStr := fmt.Sprintf("INT_%d", bitWidth)
   646  	if !lt.INTEGER.IsSigned {
   647  		convertedTypeStr = "U" + convertedTypeStr
   648  	}
   649  
   650  	convertedType, err := parquet.ConvertedTypeFromString(convertedTypeStr)
   651  	if err != nil {
   652  		p.errorf("couldn't convert INT(%d, %t) annotation to converted type %s: %v", bitWidth, lt.INTEGER.IsSigned, convertedTypeStr, err)
   653  	}
   654  	return parquet.ConvertedTypePtr(convertedType)
   655  }
   656  
   657  func (p *schemaParser) parseDecimalLogicalType(lt *parquet.LogicalType) *parquet.ConvertedType {
   658  	ct := parquet.ConvertedTypePtr(parquet.ConvertedType_DECIMAL)
   659  	p.next()
   660  
   661  	if p.token.typ == itemRightParen { // if the next token is ), skip parsing precision and scale because we only got a converted type.
   662  		return ct
   663  	}
   664  
   665  	lt.DECIMAL = parquet.NewDecimalType()
   666  
   667  	p.expect(itemLeftParen)
   668  
   669  	p.next()
   670  	p.expect(itemNumber)
   671  
   672  	prec, _ := strconv.ParseInt(p.token.val, 10, 64)
   673  	lt.DECIMAL.Precision = int32(prec)
   674  
   675  	p.next()
   676  	p.expect(itemComma)
   677  
   678  	p.next()
   679  	p.expect(itemNumber)
   680  
   681  	scale, _ := strconv.ParseInt(p.token.val, 10, 64)
   682  	lt.DECIMAL.Scale = int32(scale)
   683  
   684  	p.next()
   685  	p.expect(itemRightParen)
   686  
   687  	p.next() // here, we're pre-loading the next token for the caller.
   688  	return ct
   689  }
   690  
   691  func (p *schemaParser) parseConvertedType() *parquet.ConvertedType {
   692  	p.expect(itemLeftParen)
   693  	p.next()
   694  	p.expect(itemIdentifier)
   695  
   696  	typStr := p.token.val
   697  
   698  	convertedType, err := parquet.ConvertedTypeFromString(typStr)
   699  	if err != nil {
   700  		p.errorf("invalid converted type %q", typStr)
   701  	}
   702  
   703  	p.next()
   704  	p.expect(itemRightParen)
   705  
   706  	return parquet.ConvertedTypePtr(convertedType)
   707  }
   708  
   709  func (p *schemaParser) parseFieldID() *int32 {
   710  	p.expect(itemEqual)
   711  	p.next()
   712  	p.expect(itemNumber)
   713  
   714  	i, err := strconv.ParseInt(p.token.val, 10, 32)
   715  	if err != nil {
   716  		p.errorf("couldn't parse field ID %q: %v", p.token.val, err)
   717  	}
   718  
   719  	i32 := int32(i)
   720  
   721  	return &i32
   722  }
   723  
   724  func (p *schemaParser) validate(col *ColumnDefinition, strictMode bool) {
   725  	if err := col.validate(true, strictMode); err != nil {
   726  		p.errorf("%v", err)
   727  	}
   728  }
   729  
   730  // Validate conducts a validation of the schema definition. This is
   731  // useful when the schema definition has been constructed programmatically
   732  // by other means than the schema parser to ensure that it is still
   733  // valid.
   734  func (sd *SchemaDefinition) Validate() error {
   735  	if sd == nil {
   736  		return errors.New("schema definition is nil")
   737  	}
   738  
   739  	return sd.RootColumn.validate(true, false)
   740  }
   741  
   742  // ValidateStrict conducts a stricter validation of the schema definition.
   743  // This includes the validation as done by Validate, but prohibits backwards-
   744  // compatible definitions of LIST and MAP.
   745  func (sd *SchemaDefinition) ValidateStrict() error {
   746  	if sd == nil {
   747  		return errors.New("schema definition is nil")
   748  	}
   749  	return sd.RootColumn.validate(true, true)
   750  }
   751  
   752  func (col *ColumnDefinition) validateColumn(isRoot, strictMode bool) error {
   753  	if col == nil {
   754  		return errors.New("column definition is nil")
   755  	}
   756  
   757  	if col.SchemaElement == nil {
   758  		return errors.New("column has no schema element")
   759  	}
   760  
   761  	if col.SchemaElement.Name == "" {
   762  		return errors.New("column has no name")
   763  	}
   764  
   765  	if !isRoot && len(col.Children) == 0 && col.SchemaElement.Type == nil {
   766  		return fmt.Errorf("field %s has neither children nor a type", col.SchemaElement.Name)
   767  	}
   768  
   769  	if col.SchemaElement.Type != nil && len(col.Children) > 0 {
   770  		return fmt.Errorf("field %s has a type but also children", col.SchemaElement.Name)
   771  	}
   772  
   773  	return nil
   774  }
   775  
   776  func (col *ColumnDefinition) validateListLogicalType(strictMode bool) error {
   777  	if col.SchemaElement.Type != nil {
   778  		return fmt.Errorf("field %s is not a group but annotated as LIST", col.SchemaElement.Name)
   779  	}
   780  	if rep := col.SchemaElement.GetRepetitionType(); rep != parquet.FieldRepetitionType_OPTIONAL && rep != parquet.FieldRepetitionType_REQUIRED {
   781  		return fmt.Errorf("field %s is a LIST but has repetition type %s", col.SchemaElement.Name, rep)
   782  	}
   783  	if len(col.Children) != 1 {
   784  		return fmt.Errorf("field %s is a LIST but has %d children", col.SchemaElement.Name, len(col.Children))
   785  	}
   786  	if col.Children[0].SchemaElement.Name != "list" {
   787  		if strictMode {
   788  			return fmt.Errorf("field %s is a LIST but its child is not named \"list\"", col.SchemaElement.Name)
   789  		}
   790  
   791  		if col.Children[0].SchemaElement.Type != nil {
   792  			// backwards compatibility rule 1: repeated field is not a group, its type is the element type and elements are required.
   793  		} else {
   794  			repeatedGroup := col.Children[0]
   795  			switch len(repeatedGroup.Children) {
   796  			case 0:
   797  				return fmt.Errorf("field %s is a LIST but the repeated group inside it is not called \"list\" and contains no fields", col.SchemaElement.Name)
   798  			case 1:
   799  				// if col.Children[0].SchemaElement.Name == "array" or
   800  				//	col.Children[0].SchemaElement.Name == col.SchemaElement.Name+"_tuple" or
   801  				//	col.Children[0].SchemaElement.Name == "bag":
   802  				// backwards compatibility rule 3: repeated field is a group with one field and is named either array or uses the LIST-annotated
   803  				// group's name with _tuple appended then the repeated type is the element type and elements are required.
   804  				// also added "bag" because that's what we see generated on AWS Athena.
   805  				// else: backwards compatibility rule 4: the repeated field's type is the element type with the repeated field's repetition.
   806  			default:
   807  				// backwards compatbility rule 2: repeated field is a group with multiple fields, its type is the element type and elements are required.
   808  			}
   809  		}
   810  	} else {
   811  		if col.Children[0].SchemaElement.Type != nil || col.Children[0].SchemaElement.GetRepetitionType() != parquet.FieldRepetitionType_REPEATED {
   812  			return fmt.Errorf("field %s is a LIST but its child is not a repeated group", col.SchemaElement.Name)
   813  		}
   814  		if len(col.Children[0].Children) != 1 {
   815  			return fmt.Errorf("field %s.list has %d children", col.SchemaElement.Name, len(col.Children[0].Children))
   816  		}
   817  		if col.Children[0].Children[0].SchemaElement.Name != "element" {
   818  			return fmt.Errorf("%s.list has a child but it's called %q, not \"element\"", col.SchemaElement.Name, col.Children[0].Children[0].SchemaElement.Name)
   819  		}
   820  		if rep := col.Children[0].Children[0].SchemaElement.GetRepetitionType(); rep != parquet.FieldRepetitionType_OPTIONAL && rep != parquet.FieldRepetitionType_REQUIRED {
   821  			return fmt.Errorf("%s.list.element has disallowed repetition type %s", col.SchemaElement.Name, rep)
   822  		}
   823  	}
   824  
   825  	for _, c := range col.Children[0].Children {
   826  		if err := c.validate(false, strictMode); err != nil {
   827  			return err
   828  		}
   829  	}
   830  
   831  	return nil
   832  }
   833  
   834  func (col *ColumnDefinition) validateMapLogicalType(strictMode bool) error {
   835  	if col.SchemaElement.GetConvertedType() == parquet.ConvertedType_MAP_KEY_VALUE {
   836  		if strictMode {
   837  			return fmt.Errorf("field %s is incorrectly annotated as MAP_KEY_VALUE", col.SchemaElement.Name)
   838  		}
   839  	}
   840  
   841  	if col.SchemaElement.Type != nil {
   842  		return fmt.Errorf("field %s is not a group but annotated as MAP", col.SchemaElement.Name)
   843  	}
   844  	if len(col.Children) != 1 {
   845  		return fmt.Errorf("field %s is a MAP but has %d children", col.SchemaElement.Name, len(col.Children))
   846  	}
   847  	if col.Children[0].SchemaElement.Type != nil || col.Children[0].SchemaElement.GetRepetitionType() != parquet.FieldRepetitionType_REPEATED {
   848  		return fmt.Errorf("filed %s is a MAP but its child is not a repeated group", col.SchemaElement.Name)
   849  	}
   850  	if strictMode && col.Children[0].SchemaElement.Name != "key_value" {
   851  		return fmt.Errorf("field %s is a MAP but its child is not named \"key_value\"", col.SchemaElement.Name)
   852  	}
   853  
   854  	if strictMode {
   855  		foundKey := false
   856  		foundValue := false
   857  		for _, c := range col.Children[0].Children {
   858  			switch c.SchemaElement.Name {
   859  			case "key":
   860  				if c.SchemaElement.GetRepetitionType() != parquet.FieldRepetitionType_REQUIRED {
   861  					return fmt.Errorf("field %s.key_value.key is not of repetition type \"required\"", col.SchemaElement.Name)
   862  				}
   863  				foundKey = true
   864  			case "value":
   865  				foundValue = true
   866  				// nothing else to check.
   867  			default:
   868  				return fmt.Errorf("field %[1]s is a MAP so %[1]s.key_value.%[2]s is not allowed", col.SchemaElement.Name, c.SchemaElement.Name)
   869  			}
   870  		}
   871  		if !foundKey {
   872  			return fmt.Errorf("field %[1]s is missing %[1]s.key_value.key", col.SchemaElement.Name)
   873  		}
   874  		if !foundValue {
   875  			return fmt.Errorf("field %[1]s is missing %[1]s.key_value.value", col.SchemaElement.Name)
   876  		}
   877  	} else {
   878  		if len(col.Children[0].Children) != 2 {
   879  			return fmt.Errorf("field %[1]s is a MAP but %[1]s.%[2]s contains %[3]d children (expected 2)", col.SchemaElement.Name, col.Children[0].SchemaElement.Name, len(col.Children[0].Children))
   880  		}
   881  	}
   882  
   883  	for _, c := range col.Children[0].Children {
   884  		if err := c.validate(false, strictMode); err != nil {
   885  			return err
   886  		}
   887  	}
   888  
   889  	return nil
   890  }
   891  
   892  func (col *ColumnDefinition) validateTimeLogicalType() error {
   893  	t := col.SchemaElement.GetLogicalType().TIME
   894  	switch {
   895  	case t.Unit.IsSetNANOS():
   896  		if col.SchemaElement.GetType() != parquet.Type_INT64 {
   897  			return fmt.Errorf("field %s is annotated as TIME(NANOS, %t) but is not an int64", col.SchemaElement.Name, t.IsAdjustedToUTC)
   898  		}
   899  	case t.Unit.IsSetMICROS():
   900  		if col.SchemaElement.GetType() != parquet.Type_INT64 {
   901  			return fmt.Errorf("field %s is annotated as TIME(MICROS, %t) but is not an int64", col.SchemaElement.Name, t.IsAdjustedToUTC)
   902  		}
   903  	case t.Unit.IsSetMILLIS():
   904  		if col.SchemaElement.GetType() != parquet.Type_INT32 {
   905  			return fmt.Errorf("field %s is annotated as TIME(MILLIS, %t) but is not an int32", col.SchemaElement.Name, t.IsAdjustedToUTC)
   906  		}
   907  	}
   908  	return nil
   909  }
   910  
   911  func (col *ColumnDefinition) validateDecimalLogicalType() error {
   912  	dec := col.SchemaElement.GetLogicalType().DECIMAL
   913  	switch col.SchemaElement.GetType() {
   914  	case parquet.Type_INT32:
   915  		if dec.Precision < 1 || dec.Precision > 9 {
   916  			return fmt.Errorf("field %s is int32 and annotated as DECIMAL but precision %d is out of bounds; needs to be 1 <= precision <= 9", col.SchemaElement.Name, dec.Precision)
   917  		}
   918  	case parquet.Type_INT64:
   919  		if dec.Precision < 1 || dec.Precision > 18 {
   920  			return fmt.Errorf("field %s is int64 and annotated as DECIMAL but precision %d is out of bounds; needs to be 1 <= precision <= 18", col.SchemaElement.Name, dec.Precision)
   921  		}
   922  	case parquet.Type_FIXED_LEN_BYTE_ARRAY:
   923  		n := *col.SchemaElement.TypeLength
   924  		maxDigits := int32(math.Floor(math.Log10(math.Exp2(8*float64(n)-1) - 1)))
   925  		if dec.Precision < 1 || dec.Precision > maxDigits {
   926  			return fmt.Errorf("field %s is fixed_len_byte_array(%d) and annotated as DECIMAL but precision %d is out of bounds; needs to be 0 <= precision <= %d", col.SchemaElement.Name, n, dec.Precision, maxDigits)
   927  		}
   928  	case parquet.Type_BYTE_ARRAY:
   929  		if dec.Precision < 1 {
   930  			return fmt.Errorf("field %s is int64 and annotated as DECIMAL but precision %d is out of bounds; needs to be 1 <= precision", col.SchemaElement.Name, dec.Precision)
   931  		}
   932  	default:
   933  		return fmt.Errorf("field %s is annotated as DECIMAL but type %s is unsupported", col.SchemaElement.Name, col.SchemaElement.GetType().String())
   934  	}
   935  	return nil
   936  }
   937  
   938  func (col *ColumnDefinition) validateIntegerLogicalType() error {
   939  	bitWidth := col.SchemaElement.LogicalType.INTEGER.BitWidth
   940  	isSigned := col.SchemaElement.LogicalType.INTEGER.IsSigned
   941  	switch bitWidth {
   942  	case 8, 16, 32:
   943  		if col.SchemaElement.GetType() != parquet.Type_INT32 {
   944  			return fmt.Errorf("field %s is annotated as INT(%d, %t) but element type is %s", col.SchemaElement.Name, bitWidth, isSigned, col.SchemaElement.GetType().String())
   945  		}
   946  	case 64:
   947  		if col.SchemaElement.GetType() != parquet.Type_INT64 {
   948  			return fmt.Errorf("field %s is annotated as INT(%d, %t) but element type is %s", col.SchemaElement.Name, bitWidth, isSigned, col.SchemaElement.GetType().String())
   949  		}
   950  	default:
   951  		return fmt.Errorf("invalid bitWidth %d", bitWidth)
   952  	}
   953  	return nil
   954  }
   955  
   956  func (col *ColumnDefinition) validate(isRoot bool, strictMode bool) error {
   957  	if err := col.validateColumn(isRoot, strictMode); err != nil {
   958  		return err
   959  	}
   960  
   961  	switch {
   962  	case (col.SchemaElement.LogicalType != nil && col.SchemaElement.GetLogicalType().IsSetLIST()) || col.SchemaElement.GetConvertedType() == parquet.ConvertedType_LIST:
   963  		if err := col.validateListLogicalType(strictMode); err != nil {
   964  			return err
   965  		}
   966  	case (col.SchemaElement.LogicalType != nil && col.SchemaElement.GetLogicalType().IsSetMAP()) || col.SchemaElement.GetConvertedType() == parquet.ConvertedType_MAP || col.SchemaElement.GetConvertedType() == parquet.ConvertedType_MAP_KEY_VALUE:
   967  		if err := col.validateMapLogicalType(strictMode); err != nil {
   968  			return err
   969  		}
   970  	case (col.SchemaElement.LogicalType != nil && col.SchemaElement.GetLogicalType().IsSetDATE()) || col.SchemaElement.GetConvertedType() == parquet.ConvertedType_DATE:
   971  		if col.SchemaElement.GetType() != parquet.Type_INT32 {
   972  			return fmt.Errorf("field %[1]s is annotated as DATE but is not an int32", col.SchemaElement.Name)
   973  		}
   974  	case col.SchemaElement.LogicalType != nil && col.SchemaElement.GetLogicalType().IsSetTIMESTAMP():
   975  		if col.SchemaElement.GetType() != parquet.Type_INT64 && col.SchemaElement.GetType() != parquet.Type_INT96 {
   976  			return fmt.Errorf("field %s is annotated as TIMESTAMP but is not an int64/int96", col.SchemaElement.Name)
   977  		}
   978  	case col.SchemaElement.LogicalType != nil && col.SchemaElement.GetLogicalType().IsSetTIME():
   979  		if err := col.validateTimeLogicalType(); err != nil {
   980  			return err
   981  		}
   982  	case col.SchemaElement.LogicalType != nil && col.SchemaElement.GetLogicalType().IsSetUUID():
   983  		if col.SchemaElement.GetType() != parquet.Type_FIXED_LEN_BYTE_ARRAY || col.SchemaElement.GetTypeLength() != 16 {
   984  			return fmt.Errorf("field %s is annotated as UUID but is not a fixed_len_byte_array(16)", col.SchemaElement.Name)
   985  		}
   986  	case col.SchemaElement.LogicalType != nil && col.SchemaElement.GetLogicalType().IsSetENUM():
   987  		if col.SchemaElement.GetType() != parquet.Type_BYTE_ARRAY {
   988  			return fmt.Errorf("field %s is annotated as ENUM but is not a binary", col.SchemaElement.Name)
   989  		}
   990  	case col.SchemaElement.LogicalType != nil && col.SchemaElement.GetLogicalType().IsSetJSON():
   991  		if col.SchemaElement.GetType() != parquet.Type_BYTE_ARRAY {
   992  			return fmt.Errorf("field %s is annotated as JSON but is not a binary", col.SchemaElement.Name)
   993  		}
   994  	case col.SchemaElement.LogicalType != nil && col.SchemaElement.GetLogicalType().IsSetBSON():
   995  		if col.SchemaElement.GetType() != parquet.Type_BYTE_ARRAY {
   996  			return fmt.Errorf("field %s is annotated as BSON but is not a binary", col.SchemaElement.Name)
   997  		}
   998  	case col.SchemaElement.LogicalType != nil && col.SchemaElement.GetLogicalType().IsSetDECIMAL():
   999  		if err := col.validateDecimalLogicalType(); err != nil {
  1000  			return err
  1001  		}
  1002  	case col.SchemaElement.LogicalType != nil && col.SchemaElement.GetLogicalType().IsSetINTEGER():
  1003  		if err := col.validateIntegerLogicalType(); err != nil {
  1004  			return err
  1005  		}
  1006  	case col.SchemaElement.ConvertedType != nil && col.SchemaElement.GetConvertedType() == parquet.ConvertedType_UTF8:
  1007  		if col.SchemaElement.GetType() != parquet.Type_BYTE_ARRAY {
  1008  			return fmt.Errorf("field %s is annotated as UTF8 but element type is %s, not binary", col.SchemaElement.Name, col.SchemaElement.GetType().String())
  1009  		}
  1010  	case col.SchemaElement.ConvertedType != nil && col.SchemaElement.GetConvertedType() == parquet.ConvertedType_TIME_MILLIS:
  1011  		if col.SchemaElement.GetType() != parquet.Type_INT32 {
  1012  			return fmt.Errorf("field %s is annotated as TIME_MILLIS but element type is %s, not int32", col.SchemaElement.Name, col.SchemaElement.GetType().String())
  1013  		}
  1014  	case col.SchemaElement.ConvertedType != nil && col.SchemaElement.GetConvertedType() == parquet.ConvertedType_TIME_MICROS:
  1015  		if col.SchemaElement.GetType() != parquet.Type_INT64 {
  1016  			return fmt.Errorf("field %s is annotated as TIME_MICROS but element type is %s, not int64", col.SchemaElement.Name, col.SchemaElement.GetType().String())
  1017  		}
  1018  	case col.SchemaElement.ConvertedType != nil && col.SchemaElement.GetConvertedType() == parquet.ConvertedType_TIMESTAMP_MILLIS:
  1019  		if col.SchemaElement.GetType() != parquet.Type_INT64 {
  1020  			return fmt.Errorf("field %s is annotated as TIMESTAMP_MILLIS but element type is %s, not int64", col.SchemaElement.Name, col.SchemaElement.GetType().String())
  1021  		}
  1022  	case col.SchemaElement.ConvertedType != nil && col.SchemaElement.GetConvertedType() == parquet.ConvertedType_TIMESTAMP_MICROS:
  1023  		if col.SchemaElement.GetType() != parquet.Type_INT64 {
  1024  			return fmt.Errorf("field %s is annotated as TIMESTAMP_MICROS but element type is %s, not int64", col.SchemaElement.Name, col.SchemaElement.GetType().String())
  1025  		}
  1026  	case col.SchemaElement.ConvertedType != nil &&
  1027  		col.SchemaElement.GetConvertedType() == parquet.ConvertedType_UINT_8 ||
  1028  		col.SchemaElement.GetConvertedType() == parquet.ConvertedType_UINT_16 ||
  1029  		col.SchemaElement.GetConvertedType() == parquet.ConvertedType_UINT_32 ||
  1030  		col.SchemaElement.GetConvertedType() == parquet.ConvertedType_INT_8 ||
  1031  		col.SchemaElement.GetConvertedType() == parquet.ConvertedType_INT_16 ||
  1032  		col.SchemaElement.GetConvertedType() == parquet.ConvertedType_INT_32:
  1033  		if col.SchemaElement.GetType() != parquet.Type_INT32 {
  1034  			return fmt.Errorf("field %s is annotated as %s but element type is %s, not int32", col.SchemaElement.Name, col.SchemaElement.GetConvertedType().String(), col.SchemaElement.GetType().String())
  1035  		}
  1036  	case col.SchemaElement.ConvertedType != nil && col.SchemaElement.GetConvertedType() == parquet.ConvertedType_UINT_64 || col.SchemaElement.GetConvertedType() == parquet.ConvertedType_INT_64:
  1037  		if col.SchemaElement.GetType() != parquet.Type_INT64 {
  1038  			return fmt.Errorf("field %s is annotated as %s but element type is %s, not int64", col.SchemaElement.Name, col.SchemaElement.GetConvertedType().String(), col.SchemaElement.GetType().String())
  1039  		}
  1040  	case col.SchemaElement.ConvertedType != nil && col.SchemaElement.GetConvertedType() == parquet.ConvertedType_INTERVAL:
  1041  		if col.SchemaElement.GetType() != parquet.Type_FIXED_LEN_BYTE_ARRAY || col.SchemaElement.GetTypeLength() != 12 {
  1042  			return fmt.Errorf("field %s is annotated as INTERVAL but element type is %s, not fixed_len_byte_array(12)", col.SchemaElement.Name, col.SchemaElement.GetType().String())
  1043  		}
  1044  	default:
  1045  		for _, c := range col.Children {
  1046  			if err := c.validate(false, strictMode); err != nil {
  1047  				return err
  1048  			}
  1049  		}
  1050  	}
  1051  
  1052  	return nil
  1053  }