github.com/pingcap/tidb/parser@v0.0.0-20231013125129-93a834a6bf8d/goyacc/format_yacc.go (about)

     1  // Copyright 2019 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package main
    15  
    16  import (
    17  	"bufio"
    18  	"fmt"
    19  	gofmt "go/format"
    20  	"go/token"
    21  	"os"
    22  	"regexp"
    23  	"strings"
    24  
    25  	"github.com/cznic/strutil"
    26  	"github.com/pingcap/errors"
    27  	"github.com/pingcap/tidb/parser/format"
    28  	parser "modernc.org/parser/yacc"
    29  )
    30  
    31  func Format(inputFilename string, goldenFilename string) (err error) {
    32  	spec, err := parseFileToSpec(inputFilename)
    33  	if err != nil {
    34  		return err
    35  	}
    36  
    37  	yFmt := &OutputFormatter{}
    38  	if err = yFmt.Setup(goldenFilename); err != nil {
    39  		return err
    40  	}
    41  	defer func() {
    42  		teardownErr := yFmt.Teardown()
    43  		if err == nil {
    44  			err = teardownErr
    45  		}
    46  	}()
    47  
    48  	if err = printDefinitions(yFmt, spec.Defs); err != nil {
    49  		return err
    50  	}
    51  
    52  	return printRules(yFmt, spec.Rules)
    53  }
    54  
    55  func parseFileToSpec(inputFilename string) (*parser.Specification, error) {
    56  	src, err := os.ReadFile(inputFilename)
    57  	if err != nil {
    58  		return nil, err
    59  	}
    60  	return parser.Parse(token.NewFileSet(), inputFilename, src)
    61  }
    62  
    63  // Definition represents data reduced by productions:
    64  //
    65  //	Definition:
    66  //	        START IDENTIFIER
    67  //	|       UNION                      // Case 1
    68  //	|       LCURL RCURL                // Case 2
    69  //	|       ReservedWord Tag NameList  // Case 3
    70  //	|       ReservedWord Tag           // Case 4
    71  //	|       ERROR_VERBOSE              // Case 5
    72  const (
    73  	StartIdentifierCase = iota
    74  	UnionDefinitionCase
    75  	LCURLRCURLCase
    76  	ReservedWordTagNameListCase
    77  	ReservedWordTagCase
    78  )
    79  
    80  func printDefinitions(formatter format.Formatter, definitions []*parser.Definition) error {
    81  	for _, def := range definitions {
    82  		var err error
    83  		switch def.Case {
    84  		case StartIdentifierCase:
    85  			err = handleStart(formatter, def)
    86  		case UnionDefinitionCase:
    87  			err = handleUnion(formatter, def)
    88  		case LCURLRCURLCase:
    89  			err = handleProlog(formatter, def)
    90  		case ReservedWordTagNameListCase, ReservedWordTagCase:
    91  			err = handleReservedWordTagNameList(formatter, def)
    92  		}
    93  		if err != nil {
    94  			return err
    95  		}
    96  	}
    97  	_, err := formatter.Format("\n%%%%")
    98  	return err
    99  }
   100  
   101  func handleStart(f format.Formatter, definition *parser.Definition) error {
   102  	if err := Ensure(definition).
   103  		and(definition.Token2).
   104  		and(definition.Token2).NotNil(); err != nil {
   105  		return err
   106  	}
   107  	cmt1 := strings.Join(definition.Token.Comments, "\n")
   108  	cmt2 := strings.Join(definition.Token2.Comments, "\n")
   109  	_, err := f.Format("\n%s%s\t%s%s\n", cmt1, definition.Token.Val, cmt2, definition.Token2.Val)
   110  	return err
   111  }
   112  
   113  func handleUnion(f format.Formatter, definition *parser.Definition) error {
   114  	if err := Ensure(definition).
   115  		and(definition.Value).NotNil(); err != nil {
   116  		return err
   117  	}
   118  	if len(definition.Value) != 0 {
   119  		_, err := f.Format("%%union%i%s%u\n\n", definition.Value)
   120  		if err != nil {
   121  			return err
   122  		}
   123  	}
   124  	return nil
   125  }
   126  
   127  func handleProlog(f format.Formatter, definition *parser.Definition) error {
   128  	if err := Ensure(definition).
   129  		and(definition.Value).NotNil(); err != nil {
   130  		return err
   131  	}
   132  	_, err := f.Format("%%{%s%%}\n\n", definition.Value)
   133  	return err
   134  }
   135  
   136  func handleReservedWordTagNameList(f format.Formatter, def *parser.Definition) error {
   137  	if err := Ensure(def).
   138  		and(def.ReservedWord).
   139  		and(def.ReservedWord.Token).NotNil(); err != nil {
   140  		return err
   141  	}
   142  	comment := getTokenComment(def.ReservedWord.Token, divNewLineStringLayout)
   143  	directive := def.ReservedWord.Token.Val
   144  
   145  	hasTag := def.Tag != nil
   146  	var wordAfterDirective string
   147  	if hasTag {
   148  		wordAfterDirective = joinTag(def.Tag)
   149  	} else {
   150  		wordAfterDirective = joinNames(def.Nlist)
   151  	}
   152  
   153  	if _, err := f.Format("%s%s%s%i", comment, directive, wordAfterDirective); err != nil {
   154  		return err
   155  	}
   156  	if hasTag {
   157  		if _, err := f.Format("\n"); err != nil {
   158  			return err
   159  		}
   160  		if err := printNameListVertical(f, def.Nlist); err != nil {
   161  			return err
   162  		}
   163  	}
   164  	_, err := f.Format("%u\n")
   165  	return err
   166  }
   167  
   168  func joinTag(tag *parser.Tag) string {
   169  	var sb strings.Builder
   170  	sb.WriteString("\t")
   171  	if tag.Token != nil {
   172  		sb.WriteString(tag.Token.Val)
   173  	}
   174  	if tag.Token2 != nil {
   175  		sb.WriteString(tag.Token2.Val)
   176  	}
   177  	if tag.Token3 != nil {
   178  		sb.WriteString(tag.Token3.Val)
   179  	}
   180  	return sb.String()
   181  }
   182  
   183  type stringLayout int8
   184  
   185  const (
   186  	spanStringLayout stringLayout = iota
   187  	divStringLayout
   188  	divNewLineStringLayout
   189  )
   190  
   191  func getTokenComment(token *parser.Token, layout stringLayout) string {
   192  	if len(token.Comments) == 0 {
   193  		return ""
   194  	}
   195  	var splitter, beforeComment string
   196  	switch layout {
   197  	case spanStringLayout:
   198  		splitter, beforeComment = " ", ""
   199  	case divStringLayout:
   200  		splitter, beforeComment = "\n", ""
   201  	case divNewLineStringLayout:
   202  		splitter, beforeComment = "\n", "\n"
   203  	default:
   204  		panic(errors.Errorf("unsupported stringLayout: %v", layout))
   205  	}
   206  
   207  	var sb strings.Builder
   208  	sb.WriteString(beforeComment)
   209  	for _, comment := range token.Comments {
   210  		sb.WriteString(comment)
   211  		sb.WriteString(splitter)
   212  	}
   213  	return sb.String()
   214  }
   215  
   216  func printNameListVertical(f format.Formatter, names NameArr) (err error) {
   217  	rest := names
   218  	for len(rest) != 0 {
   219  		var processing NameArr
   220  		processing, rest = rest[:1], rest[1:]
   221  
   222  		var noComments NameArr
   223  		noComments, rest = rest.span(noComment)
   224  		processing = append(processing, noComments...)
   225  
   226  		maxCharLength := processing.findMaxLength()
   227  		for _, name := range processing {
   228  			if err := printSingleName(f, name, maxCharLength); err != nil {
   229  				return err
   230  			}
   231  		}
   232  	}
   233  	return nil
   234  }
   235  
   236  func joinNames(names NameArr) string {
   237  	var sb strings.Builder
   238  	for _, name := range names {
   239  		sb.WriteString(" ")
   240  		sb.WriteString(getTokenComment(name.Token, spanStringLayout))
   241  		sb.WriteString(name.Token.Val)
   242  	}
   243  	return sb.String()
   244  }
   245  
   246  func printSingleName(f format.Formatter, name *parser.Name, maxCharLength int) error {
   247  	cmt := getTokenComment(name.Token, divNewLineStringLayout)
   248  	if _, err := f.Format(escapePercent(cmt)); err != nil {
   249  		return err
   250  	}
   251  	strLit := name.LiteralStringOpt
   252  	if strLit != nil && strLit.Token != nil {
   253  		_, err := f.Format("%-*s %s\n", maxCharLength, name.Token.Val, strLit.Token.Val)
   254  		return err
   255  	}
   256  	_, err := f.Format("%s\n", name.Token.Val)
   257  	return err
   258  }
   259  
   260  type NameArr []*parser.Name
   261  
   262  func (ns NameArr) span(pred func(*parser.Name) bool) (first NameArr, second NameArr) {
   263  	first = ns.takeWhile(pred)
   264  	second = ns[len(first):]
   265  	return first, second
   266  }
   267  
   268  func (ns NameArr) takeWhile(pred func(*parser.Name) bool) NameArr {
   269  	for i, def := range ns {
   270  		if pred(def) {
   271  			continue
   272  		}
   273  		return ns[:i]
   274  	}
   275  	return ns
   276  }
   277  
   278  func (ns NameArr) findMaxLength() int {
   279  	maxLen := -1
   280  	for _, s := range ns {
   281  		if len(s.Token.Val) > maxLen {
   282  			maxLen = len(s.Token.Val)
   283  		}
   284  	}
   285  	return maxLen
   286  }
   287  
   288  func hasComments(n *parser.Name) bool {
   289  	return len(n.Token.Comments) != 0
   290  }
   291  
   292  func noComment(n *parser.Name) bool {
   293  	return !hasComments(n)
   294  }
   295  
   296  func containsActionInRule(rule *parser.Rule) bool {
   297  	for _, b := range rule.Body {
   298  		if _, ok := b.(*parser.Action); ok {
   299  			return true
   300  		}
   301  	}
   302  	return false
   303  }
   304  
   305  type RuleArr []*parser.Rule
   306  
   307  func printRules(f format.Formatter, rules RuleArr) (err error) {
   308  	var lastRuleName string
   309  	for _, rule := range rules {
   310  		if rule.Name.Val == lastRuleName {
   311  			cmt := getTokenComment(rule.Token, divStringLayout)
   312  			_, err = f.Format("\n%s|\t%i", cmt)
   313  		} else {
   314  			cmt := getTokenComment(rule.Name, divStringLayout)
   315  			_, err = f.Format("\n\n%s%s:%i\n", cmt, rule.Name.Val)
   316  		}
   317  		if err != nil {
   318  			return err
   319  		}
   320  		lastRuleName = rule.Name.Val
   321  
   322  		if err = printRuleBody(f, rule); err != nil {
   323  			return err
   324  		}
   325  		if _, err = f.Format("%u"); err != nil {
   326  			return err
   327  		}
   328  	}
   329  	_, err = f.Format("\n%%%%\n")
   330  	return err
   331  }
   332  
   333  type ruleItemType int8
   334  
   335  const (
   336  	identRuleItemType      ruleItemType = 1
   337  	actionRuleItemType     ruleItemType = 2
   338  	strLiteralRuleItemType ruleItemType = 3
   339  )
   340  
   341  func printRuleBody(f format.Formatter, rule *parser.Rule) error {
   342  	firstRuleItem, counter := rule.RuleItemList, 0
   343  	for ri := rule.RuleItemList; ri != nil; ri = ri.RuleItemList {
   344  		switch ruleItemType(ri.Case) {
   345  		case identRuleItemType, strLiteralRuleItemType:
   346  			term := fmt.Sprintf(" %s", ri.Token.Val)
   347  			if ri == firstRuleItem {
   348  				term = term[1:]
   349  			}
   350  			cmt := getTokenComment(ri.Token, divStringLayout)
   351  
   352  			if _, err := f.Format(escapePercent(cmt)); err != nil {
   353  				return err
   354  			}
   355  			if _, err := f.Format("%s", term); err != nil {
   356  				return err
   357  			}
   358  		case actionRuleItemType:
   359  			isFirstRuleItem := ri == firstRuleItem
   360  			if err := handlePrecedence(f, rule.Precedence, isFirstRuleItem); err != nil {
   361  				return err
   362  			}
   363  			if err := handleAction(f, rule, ri.Action, isFirstRuleItem); err != nil {
   364  				return err
   365  			}
   366  		}
   367  		counter++
   368  	}
   369  	if err := checkInconsistencyInYaccParser(f, rule, counter); err != nil {
   370  		return err
   371  	}
   372  	if !containsActionInRule(rule) {
   373  		if err := handlePrecedence(f, rule.Precedence, counter == 0); err != nil {
   374  			return err
   375  		}
   376  	}
   377  	return nil
   378  }
   379  
   380  func handleAction(f format.Formatter, rule *parser.Rule, action *parser.Action, isFirstItem bool) error {
   381  	if !isFirstItem || rule.Precedence != nil {
   382  		if _, err := f.Format("\n"); err != nil {
   383  			return err
   384  		}
   385  	}
   386  
   387  	cmt := getTokenComment(action.Token, divStringLayout)
   388  	if _, err := f.Format(escapePercent(cmt)); err != nil {
   389  		return err
   390  	}
   391  
   392  	goSnippet, err := formatGoSnippet(action.Values)
   393  	goSnippet = escapePercent(goSnippet)
   394  	if err != nil {
   395  		return err
   396  	}
   397  	snippet := "{}"
   398  	if len(goSnippet) != 0 {
   399  		snippet = fmt.Sprintf("{%%i\n%s%%u\n}", goSnippet)
   400  	}
   401  	_, err = f.Format(snippet)
   402  	return err
   403  }
   404  
   405  func handlePrecedence(f format.Formatter, p *parser.Precedence, isFirstItem bool) error {
   406  	if p == nil {
   407  		return nil
   408  	}
   409  	if err := Ensure(p.Token).
   410  		and(p.Token2).NotNil(); err != nil {
   411  		return err
   412  	}
   413  	cmt := getTokenComment(p.Token, spanStringLayout)
   414  	if !isFirstItem {
   415  		if _, err := f.Format(" "); err != nil {
   416  			return err
   417  		}
   418  	}
   419  	_, err := f.Format("%s%s %s", cmt, p.Token.Val, p.Token2.Val)
   420  	return err
   421  }
   422  
   423  func formatGoSnippet(actVal []*parser.ActionValue) (string, error) {
   424  	tran := &SpecialActionValTransformer{
   425  		store: map[string]string{},
   426  	}
   427  	goSnippet := collectGoSnippet(tran, actVal)
   428  	formatted, err := gofmt.Source([]byte(goSnippet))
   429  	if err != nil {
   430  		return "", err
   431  	}
   432  	formattedSnippet := tran.restore(string(formatted))
   433  	return strings.TrimSpace(formattedSnippet), nil
   434  }
   435  
   436  func collectGoSnippet(tran *SpecialActionValTransformer, actionValArr []*parser.ActionValue) string {
   437  	var sb strings.Builder
   438  	for _, value := range actionValArr {
   439  		trimTab := removeLineBeginBlanks(value.Src)
   440  		sb.WriteString(tran.transform(trimTab))
   441  	}
   442  	snipWithPar := strings.TrimSpace(sb.String())
   443  	if strings.HasPrefix(snipWithPar, "{") && strings.HasSuffix(snipWithPar, "}") {
   444  		return snipWithPar[1 : len(snipWithPar)-1]
   445  	}
   446  	return ""
   447  }
   448  
   449  var lineBeginBlankRegex = regexp.MustCompile("(?m)^[\t ]+")
   450  
   451  func removeLineBeginBlanks(src string) string {
   452  	return lineBeginBlankRegex.ReplaceAllString(src, "")
   453  }
   454  
   455  type SpecialActionValTransformer struct {
   456  	store map[string]string
   457  }
   458  
   459  const yaccFmtVar = "_yaccfmt_var_"
   460  
   461  var yaccFmtVarRegex = regexp.MustCompile("_yaccfmt_var_[0-9]{1,5}")
   462  
   463  func (s *SpecialActionValTransformer) transform(val string) string {
   464  	if strings.HasPrefix(val, "$") {
   465  		generated := fmt.Sprintf("%s%d", yaccFmtVar, len(s.store))
   466  		s.store[generated] = val
   467  		return generated
   468  	}
   469  	return val
   470  }
   471  
   472  func (s *SpecialActionValTransformer) restore(src string) string {
   473  	return yaccFmtVarRegex.ReplaceAllStringFunc(src, func(matched string) string {
   474  		origin, ok := s.store[matched]
   475  		if !ok {
   476  			panic(errors.Errorf("mismatch in SpecialActionValTransformer"))
   477  		}
   478  		return origin
   479  	})
   480  }
   481  
   482  type OutputFormatter struct {
   483  	file      *os.File
   484  	out       *bufio.Writer
   485  	formatter strutil.Formatter
   486  }
   487  
   488  func (y *OutputFormatter) Setup(filename string) (err error) {
   489  	if y.file, err = os.Create(filename); err != nil {
   490  		return
   491  	}
   492  	y.out = bufio.NewWriter(y.file)
   493  	y.formatter = strutil.IndentFormatter(y.out, "\t")
   494  	return
   495  }
   496  
   497  func (y *OutputFormatter) Teardown() error {
   498  	if y.out != nil {
   499  		if err := y.out.Flush(); err != nil {
   500  			return err
   501  		}
   502  	}
   503  	if y.file != nil {
   504  		if err := y.file.Close(); err != nil {
   505  			return err
   506  		}
   507  	}
   508  	return nil
   509  }
   510  
   511  func (y *OutputFormatter) Format(format string, args ...interface{}) (int, error) {
   512  	return y.formatter.Format(format, args...)
   513  }
   514  
   515  func (y *OutputFormatter) Write(bytes []byte) (int, error) {
   516  	return y.formatter.Write(bytes)
   517  }
   518  
   519  type NotNilAssert struct {
   520  	idx int
   521  	err error
   522  }
   523  
   524  func (n *NotNilAssert) and(target interface{}) *NotNilAssert {
   525  	if n.err != nil {
   526  		return n
   527  	}
   528  	if target == nil {
   529  		n.err = errors.Errorf("encounter nil, index: %d", n.idx)
   530  	}
   531  	n.idx++
   532  	return n
   533  }
   534  
   535  func (n *NotNilAssert) NotNil() error {
   536  	return n.err
   537  }
   538  
   539  func Ensure(target interface{}) *NotNilAssert {
   540  	return (&NotNilAssert{}).and(target)
   541  }
   542  
   543  func escapePercent(src string) string {
   544  	return strings.ReplaceAll(src, "%", "%%")
   545  }
   546  
   547  func checkInconsistencyInYaccParser(f format.Formatter, rule *parser.Rule, counter int) error {
   548  	if counter == len(rule.Body) {
   549  		return nil
   550  	}
   551  	// pickup rule item in ruleBody
   552  	for i := counter; i < len(rule.Body); i++ {
   553  		body := rule.Body[i]
   554  		switch b := body.(type) {
   555  		case string, int:
   556  			if bInt, ok := b.(int); ok {
   557  				b = fmt.Sprintf("'%c'", bInt)
   558  			}
   559  			term := fmt.Sprintf(" %s", b)
   560  			if i == 0 {
   561  				term = term[1:]
   562  			}
   563  			_, err := f.Format("%s", term)
   564  			return err
   565  		case *parser.Action:
   566  			isFirstRuleItem := i == 0
   567  			if err := handlePrecedence(f, rule.Precedence, isFirstRuleItem); err != nil {
   568  				return err
   569  			}
   570  			if err := handleAction(f, rule, b, isFirstRuleItem); err != nil {
   571  				return err
   572  			}
   573  		}
   574  	}
   575  	return nil
   576  }