github.com/pingcap/tidb/parser@v0.0.0-20231013125129-93a834a6bf8d/goyacc/format_yacc.go (about) 1 // Copyright 2019 PingCAP, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package main 15 16 import ( 17 "bufio" 18 "fmt" 19 gofmt "go/format" 20 "go/token" 21 "os" 22 "regexp" 23 "strings" 24 25 "github.com/cznic/strutil" 26 "github.com/pingcap/errors" 27 "github.com/pingcap/tidb/parser/format" 28 parser "modernc.org/parser/yacc" 29 ) 30 31 func Format(inputFilename string, goldenFilename string) (err error) { 32 spec, err := parseFileToSpec(inputFilename) 33 if err != nil { 34 return err 35 } 36 37 yFmt := &OutputFormatter{} 38 if err = yFmt.Setup(goldenFilename); err != nil { 39 return err 40 } 41 defer func() { 42 teardownErr := yFmt.Teardown() 43 if err == nil { 44 err = teardownErr 45 } 46 }() 47 48 if err = printDefinitions(yFmt, spec.Defs); err != nil { 49 return err 50 } 51 52 return printRules(yFmt, spec.Rules) 53 } 54 55 func parseFileToSpec(inputFilename string) (*parser.Specification, error) { 56 src, err := os.ReadFile(inputFilename) 57 if err != nil { 58 return nil, err 59 } 60 return parser.Parse(token.NewFileSet(), inputFilename, src) 61 } 62 63 // Definition represents data reduced by productions: 64 // 65 // Definition: 66 // START IDENTIFIER 67 // | UNION // Case 1 68 // | LCURL RCURL // Case 2 69 // | ReservedWord Tag NameList // Case 3 70 // | ReservedWord Tag // Case 4 71 // | ERROR_VERBOSE // Case 5 72 const ( 73 StartIdentifierCase = iota 74 UnionDefinitionCase 75 LCURLRCURLCase 76 ReservedWordTagNameListCase 77 ReservedWordTagCase 78 ) 79 80 func printDefinitions(formatter format.Formatter, definitions []*parser.Definition) error { 81 for _, def := range definitions { 82 var err error 83 switch def.Case { 84 case StartIdentifierCase: 85 err = handleStart(formatter, def) 86 case UnionDefinitionCase: 87 err = handleUnion(formatter, def) 88 case LCURLRCURLCase: 89 err = handleProlog(formatter, def) 90 case ReservedWordTagNameListCase, ReservedWordTagCase: 91 err = handleReservedWordTagNameList(formatter, def) 92 } 93 if err != nil { 94 return err 95 } 96 } 97 _, err := formatter.Format("\n%%%%") 98 return err 99 } 100 101 func handleStart(f format.Formatter, definition *parser.Definition) error { 102 if err := Ensure(definition). 103 and(definition.Token2). 104 and(definition.Token2).NotNil(); err != nil { 105 return err 106 } 107 cmt1 := strings.Join(definition.Token.Comments, "\n") 108 cmt2 := strings.Join(definition.Token2.Comments, "\n") 109 _, err := f.Format("\n%s%s\t%s%s\n", cmt1, definition.Token.Val, cmt2, definition.Token2.Val) 110 return err 111 } 112 113 func handleUnion(f format.Formatter, definition *parser.Definition) error { 114 if err := Ensure(definition). 115 and(definition.Value).NotNil(); err != nil { 116 return err 117 } 118 if len(definition.Value) != 0 { 119 _, err := f.Format("%%union%i%s%u\n\n", definition.Value) 120 if err != nil { 121 return err 122 } 123 } 124 return nil 125 } 126 127 func handleProlog(f format.Formatter, definition *parser.Definition) error { 128 if err := Ensure(definition). 129 and(definition.Value).NotNil(); err != nil { 130 return err 131 } 132 _, err := f.Format("%%{%s%%}\n\n", definition.Value) 133 return err 134 } 135 136 func handleReservedWordTagNameList(f format.Formatter, def *parser.Definition) error { 137 if err := Ensure(def). 138 and(def.ReservedWord). 139 and(def.ReservedWord.Token).NotNil(); err != nil { 140 return err 141 } 142 comment := getTokenComment(def.ReservedWord.Token, divNewLineStringLayout) 143 directive := def.ReservedWord.Token.Val 144 145 hasTag := def.Tag != nil 146 var wordAfterDirective string 147 if hasTag { 148 wordAfterDirective = joinTag(def.Tag) 149 } else { 150 wordAfterDirective = joinNames(def.Nlist) 151 } 152 153 if _, err := f.Format("%s%s%s%i", comment, directive, wordAfterDirective); err != nil { 154 return err 155 } 156 if hasTag { 157 if _, err := f.Format("\n"); err != nil { 158 return err 159 } 160 if err := printNameListVertical(f, def.Nlist); err != nil { 161 return err 162 } 163 } 164 _, err := f.Format("%u\n") 165 return err 166 } 167 168 func joinTag(tag *parser.Tag) string { 169 var sb strings.Builder 170 sb.WriteString("\t") 171 if tag.Token != nil { 172 sb.WriteString(tag.Token.Val) 173 } 174 if tag.Token2 != nil { 175 sb.WriteString(tag.Token2.Val) 176 } 177 if tag.Token3 != nil { 178 sb.WriteString(tag.Token3.Val) 179 } 180 return sb.String() 181 } 182 183 type stringLayout int8 184 185 const ( 186 spanStringLayout stringLayout = iota 187 divStringLayout 188 divNewLineStringLayout 189 ) 190 191 func getTokenComment(token *parser.Token, layout stringLayout) string { 192 if len(token.Comments) == 0 { 193 return "" 194 } 195 var splitter, beforeComment string 196 switch layout { 197 case spanStringLayout: 198 splitter, beforeComment = " ", "" 199 case divStringLayout: 200 splitter, beforeComment = "\n", "" 201 case divNewLineStringLayout: 202 splitter, beforeComment = "\n", "\n" 203 default: 204 panic(errors.Errorf("unsupported stringLayout: %v", layout)) 205 } 206 207 var sb strings.Builder 208 sb.WriteString(beforeComment) 209 for _, comment := range token.Comments { 210 sb.WriteString(comment) 211 sb.WriteString(splitter) 212 } 213 return sb.String() 214 } 215 216 func printNameListVertical(f format.Formatter, names NameArr) (err error) { 217 rest := names 218 for len(rest) != 0 { 219 var processing NameArr 220 processing, rest = rest[:1], rest[1:] 221 222 var noComments NameArr 223 noComments, rest = rest.span(noComment) 224 processing = append(processing, noComments...) 225 226 maxCharLength := processing.findMaxLength() 227 for _, name := range processing { 228 if err := printSingleName(f, name, maxCharLength); err != nil { 229 return err 230 } 231 } 232 } 233 return nil 234 } 235 236 func joinNames(names NameArr) string { 237 var sb strings.Builder 238 for _, name := range names { 239 sb.WriteString(" ") 240 sb.WriteString(getTokenComment(name.Token, spanStringLayout)) 241 sb.WriteString(name.Token.Val) 242 } 243 return sb.String() 244 } 245 246 func printSingleName(f format.Formatter, name *parser.Name, maxCharLength int) error { 247 cmt := getTokenComment(name.Token, divNewLineStringLayout) 248 if _, err := f.Format(escapePercent(cmt)); err != nil { 249 return err 250 } 251 strLit := name.LiteralStringOpt 252 if strLit != nil && strLit.Token != nil { 253 _, err := f.Format("%-*s %s\n", maxCharLength, name.Token.Val, strLit.Token.Val) 254 return err 255 } 256 _, err := f.Format("%s\n", name.Token.Val) 257 return err 258 } 259 260 type NameArr []*parser.Name 261 262 func (ns NameArr) span(pred func(*parser.Name) bool) (first NameArr, second NameArr) { 263 first = ns.takeWhile(pred) 264 second = ns[len(first):] 265 return first, second 266 } 267 268 func (ns NameArr) takeWhile(pred func(*parser.Name) bool) NameArr { 269 for i, def := range ns { 270 if pred(def) { 271 continue 272 } 273 return ns[:i] 274 } 275 return ns 276 } 277 278 func (ns NameArr) findMaxLength() int { 279 maxLen := -1 280 for _, s := range ns { 281 if len(s.Token.Val) > maxLen { 282 maxLen = len(s.Token.Val) 283 } 284 } 285 return maxLen 286 } 287 288 func hasComments(n *parser.Name) bool { 289 return len(n.Token.Comments) != 0 290 } 291 292 func noComment(n *parser.Name) bool { 293 return !hasComments(n) 294 } 295 296 func containsActionInRule(rule *parser.Rule) bool { 297 for _, b := range rule.Body { 298 if _, ok := b.(*parser.Action); ok { 299 return true 300 } 301 } 302 return false 303 } 304 305 type RuleArr []*parser.Rule 306 307 func printRules(f format.Formatter, rules RuleArr) (err error) { 308 var lastRuleName string 309 for _, rule := range rules { 310 if rule.Name.Val == lastRuleName { 311 cmt := getTokenComment(rule.Token, divStringLayout) 312 _, err = f.Format("\n%s|\t%i", cmt) 313 } else { 314 cmt := getTokenComment(rule.Name, divStringLayout) 315 _, err = f.Format("\n\n%s%s:%i\n", cmt, rule.Name.Val) 316 } 317 if err != nil { 318 return err 319 } 320 lastRuleName = rule.Name.Val 321 322 if err = printRuleBody(f, rule); err != nil { 323 return err 324 } 325 if _, err = f.Format("%u"); err != nil { 326 return err 327 } 328 } 329 _, err = f.Format("\n%%%%\n") 330 return err 331 } 332 333 type ruleItemType int8 334 335 const ( 336 identRuleItemType ruleItemType = 1 337 actionRuleItemType ruleItemType = 2 338 strLiteralRuleItemType ruleItemType = 3 339 ) 340 341 func printRuleBody(f format.Formatter, rule *parser.Rule) error { 342 firstRuleItem, counter := rule.RuleItemList, 0 343 for ri := rule.RuleItemList; ri != nil; ri = ri.RuleItemList { 344 switch ruleItemType(ri.Case) { 345 case identRuleItemType, strLiteralRuleItemType: 346 term := fmt.Sprintf(" %s", ri.Token.Val) 347 if ri == firstRuleItem { 348 term = term[1:] 349 } 350 cmt := getTokenComment(ri.Token, divStringLayout) 351 352 if _, err := f.Format(escapePercent(cmt)); err != nil { 353 return err 354 } 355 if _, err := f.Format("%s", term); err != nil { 356 return err 357 } 358 case actionRuleItemType: 359 isFirstRuleItem := ri == firstRuleItem 360 if err := handlePrecedence(f, rule.Precedence, isFirstRuleItem); err != nil { 361 return err 362 } 363 if err := handleAction(f, rule, ri.Action, isFirstRuleItem); err != nil { 364 return err 365 } 366 } 367 counter++ 368 } 369 if err := checkInconsistencyInYaccParser(f, rule, counter); err != nil { 370 return err 371 } 372 if !containsActionInRule(rule) { 373 if err := handlePrecedence(f, rule.Precedence, counter == 0); err != nil { 374 return err 375 } 376 } 377 return nil 378 } 379 380 func handleAction(f format.Formatter, rule *parser.Rule, action *parser.Action, isFirstItem bool) error { 381 if !isFirstItem || rule.Precedence != nil { 382 if _, err := f.Format("\n"); err != nil { 383 return err 384 } 385 } 386 387 cmt := getTokenComment(action.Token, divStringLayout) 388 if _, err := f.Format(escapePercent(cmt)); err != nil { 389 return err 390 } 391 392 goSnippet, err := formatGoSnippet(action.Values) 393 goSnippet = escapePercent(goSnippet) 394 if err != nil { 395 return err 396 } 397 snippet := "{}" 398 if len(goSnippet) != 0 { 399 snippet = fmt.Sprintf("{%%i\n%s%%u\n}", goSnippet) 400 } 401 _, err = f.Format(snippet) 402 return err 403 } 404 405 func handlePrecedence(f format.Formatter, p *parser.Precedence, isFirstItem bool) error { 406 if p == nil { 407 return nil 408 } 409 if err := Ensure(p.Token). 410 and(p.Token2).NotNil(); err != nil { 411 return err 412 } 413 cmt := getTokenComment(p.Token, spanStringLayout) 414 if !isFirstItem { 415 if _, err := f.Format(" "); err != nil { 416 return err 417 } 418 } 419 _, err := f.Format("%s%s %s", cmt, p.Token.Val, p.Token2.Val) 420 return err 421 } 422 423 func formatGoSnippet(actVal []*parser.ActionValue) (string, error) { 424 tran := &SpecialActionValTransformer{ 425 store: map[string]string{}, 426 } 427 goSnippet := collectGoSnippet(tran, actVal) 428 formatted, err := gofmt.Source([]byte(goSnippet)) 429 if err != nil { 430 return "", err 431 } 432 formattedSnippet := tran.restore(string(formatted)) 433 return strings.TrimSpace(formattedSnippet), nil 434 } 435 436 func collectGoSnippet(tran *SpecialActionValTransformer, actionValArr []*parser.ActionValue) string { 437 var sb strings.Builder 438 for _, value := range actionValArr { 439 trimTab := removeLineBeginBlanks(value.Src) 440 sb.WriteString(tran.transform(trimTab)) 441 } 442 snipWithPar := strings.TrimSpace(sb.String()) 443 if strings.HasPrefix(snipWithPar, "{") && strings.HasSuffix(snipWithPar, "}") { 444 return snipWithPar[1 : len(snipWithPar)-1] 445 } 446 return "" 447 } 448 449 var lineBeginBlankRegex = regexp.MustCompile("(?m)^[\t ]+") 450 451 func removeLineBeginBlanks(src string) string { 452 return lineBeginBlankRegex.ReplaceAllString(src, "") 453 } 454 455 type SpecialActionValTransformer struct { 456 store map[string]string 457 } 458 459 const yaccFmtVar = "_yaccfmt_var_" 460 461 var yaccFmtVarRegex = regexp.MustCompile("_yaccfmt_var_[0-9]{1,5}") 462 463 func (s *SpecialActionValTransformer) transform(val string) string { 464 if strings.HasPrefix(val, "$") { 465 generated := fmt.Sprintf("%s%d", yaccFmtVar, len(s.store)) 466 s.store[generated] = val 467 return generated 468 } 469 return val 470 } 471 472 func (s *SpecialActionValTransformer) restore(src string) string { 473 return yaccFmtVarRegex.ReplaceAllStringFunc(src, func(matched string) string { 474 origin, ok := s.store[matched] 475 if !ok { 476 panic(errors.Errorf("mismatch in SpecialActionValTransformer")) 477 } 478 return origin 479 }) 480 } 481 482 type OutputFormatter struct { 483 file *os.File 484 out *bufio.Writer 485 formatter strutil.Formatter 486 } 487 488 func (y *OutputFormatter) Setup(filename string) (err error) { 489 if y.file, err = os.Create(filename); err != nil { 490 return 491 } 492 y.out = bufio.NewWriter(y.file) 493 y.formatter = strutil.IndentFormatter(y.out, "\t") 494 return 495 } 496 497 func (y *OutputFormatter) Teardown() error { 498 if y.out != nil { 499 if err := y.out.Flush(); err != nil { 500 return err 501 } 502 } 503 if y.file != nil { 504 if err := y.file.Close(); err != nil { 505 return err 506 } 507 } 508 return nil 509 } 510 511 func (y *OutputFormatter) Format(format string, args ...interface{}) (int, error) { 512 return y.formatter.Format(format, args...) 513 } 514 515 func (y *OutputFormatter) Write(bytes []byte) (int, error) { 516 return y.formatter.Write(bytes) 517 } 518 519 type NotNilAssert struct { 520 idx int 521 err error 522 } 523 524 func (n *NotNilAssert) and(target interface{}) *NotNilAssert { 525 if n.err != nil { 526 return n 527 } 528 if target == nil { 529 n.err = errors.Errorf("encounter nil, index: %d", n.idx) 530 } 531 n.idx++ 532 return n 533 } 534 535 func (n *NotNilAssert) NotNil() error { 536 return n.err 537 } 538 539 func Ensure(target interface{}) *NotNilAssert { 540 return (&NotNilAssert{}).and(target) 541 } 542 543 func escapePercent(src string) string { 544 return strings.ReplaceAll(src, "%", "%%") 545 } 546 547 func checkInconsistencyInYaccParser(f format.Formatter, rule *parser.Rule, counter int) error { 548 if counter == len(rule.Body) { 549 return nil 550 } 551 // pickup rule item in ruleBody 552 for i := counter; i < len(rule.Body); i++ { 553 body := rule.Body[i] 554 switch b := body.(type) { 555 case string, int: 556 if bInt, ok := b.(int); ok { 557 b = fmt.Sprintf("'%c'", bInt) 558 } 559 term := fmt.Sprintf(" %s", b) 560 if i == 0 { 561 term = term[1:] 562 } 563 _, err := f.Format("%s", term) 564 return err 565 case *parser.Action: 566 isFirstRuleItem := i == 0 567 if err := handlePrecedence(f, rule.Precedence, isFirstRuleItem); err != nil { 568 return err 569 } 570 if err := handleAction(f, rule, b, isFirstRuleItem); err != nil { 571 return err 572 } 573 } 574 } 575 return nil 576 }