github.com/XiaoMi/Gaea@v1.2.5/parser/ast/util_test.go (about)

     1  // Copyright 2017 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package ast_test
    15  
    16  import (
    17  	"fmt"
    18  	"strings"
    19  
    20  	. "github.com/pingcap/check"
    21  
    22  	"github.com/XiaoMi/Gaea/parser"
    23  	. "github.com/XiaoMi/Gaea/parser/ast"
    24  	. "github.com/XiaoMi/Gaea/parser/format"
    25  	driver "github.com/XiaoMi/Gaea/parser/tidb-types/parser_driver"
    26  )
    27  
    28  var _ = Suite(&testCacheableSuite{})
    29  
    30  type testCacheableSuite struct {
    31  }
    32  
    33  func (s *testCacheableSuite) TestCacheable(c *C) {
    34  	// test non-SelectStmt
    35  	var stmt Node = &DeleteStmt{}
    36  	c.Assert(IsReadOnly(stmt), IsFalse)
    37  
    38  	stmt = &InsertStmt{}
    39  	c.Assert(IsReadOnly(stmt), IsFalse)
    40  
    41  	stmt = &UpdateStmt{}
    42  	c.Assert(IsReadOnly(stmt), IsFalse)
    43  
    44  	stmt = &ExplainStmt{}
    45  	c.Assert(IsReadOnly(stmt), IsTrue)
    46  
    47  	stmt = &ExplainStmt{}
    48  	c.Assert(IsReadOnly(stmt), IsTrue)
    49  
    50  	stmt = &DoStmt{}
    51  	c.Assert(IsReadOnly(stmt), IsTrue)
    52  }
    53  
    54  // CleanNodeText set the text of node and all child node empty.
    55  // For test only.
    56  func CleanNodeText(node Node) {
    57  	var cleaner nodeTextCleaner
    58  	node.Accept(&cleaner)
    59  }
    60  
    61  // nodeTextCleaner clean the text of a node and it's child node.
    62  // For test only.
    63  type nodeTextCleaner struct {
    64  }
    65  
    66  // Enter implements Visitor interface.
    67  func (checker *nodeTextCleaner) Enter(in Node) (out Node, skipChildren bool) {
    68  	in.SetText("")
    69  	switch node := in.(type) {
    70  	case *Constraint:
    71  		if node.Option != nil {
    72  			if node.Option.KeyBlockSize == 0x0 && node.Option.Tp == 0 && node.Option.Comment == "" {
    73  				node.Option = nil
    74  			}
    75  		}
    76  	case *FuncCallExpr:
    77  		node.FnName.O = strings.ToLower(node.FnName.O)
    78  		switch node.FnName.L {
    79  		case "convert":
    80  			node.Args[1].(*driver.ValueExpr).Datum.SetBytes(nil)
    81  		}
    82  	case *AggregateFuncExpr:
    83  		node.F = strings.ToLower(node.F)
    84  	case *FieldList:
    85  		for _, f := range node.Fields {
    86  			f.Offset = 0
    87  		}
    88  	case *AlterTableSpec:
    89  		for _, opt := range node.Options {
    90  			opt.StrValue = strings.ToLower(opt.StrValue)
    91  		}
    92  	}
    93  	return in, false
    94  }
    95  
    96  // Leave implements Visitor interface.
    97  func (checker *nodeTextCleaner) Leave(in Node) (out Node, ok bool) {
    98  	return in, true
    99  }
   100  
   101  type NodeRestoreTestCase struct {
   102  	sourceSQL string
   103  	expectSQL string
   104  }
   105  
   106  func RunNodeRestoreTest(c *C, nodeTestCases []NodeRestoreTestCase, template string, extractNodeFunc func(node Node) Node) {
   107  	parser := parser.New()
   108  	parser.EnableWindowFunc(true)
   109  	for _, testCase := range nodeTestCases {
   110  		sourceSQL := fmt.Sprintf(template, testCase.sourceSQL)
   111  		expectSQL := fmt.Sprintf(template, testCase.expectSQL)
   112  		stmt, err := parser.ParseOneStmt(sourceSQL, "", "")
   113  		comment := Commentf("source %#v", testCase)
   114  		c.Assert(err, IsNil, comment)
   115  		var sb strings.Builder
   116  		err = extractNodeFunc(stmt).Restore(NewRestoreCtx(DefaultRestoreFlags, &sb))
   117  		c.Assert(err, IsNil, comment)
   118  		restoreSQL := fmt.Sprintf(template, sb.String())
   119  		comment = Commentf("source %#v; restore %v", testCase, restoreSQL)
   120  		c.Assert(restoreSQL, Equals, expectSQL, comment)
   121  		stmt2, err := parser.ParseOneStmt(restoreSQL, "", "")
   122  		c.Assert(err, IsNil, comment)
   123  		CleanNodeText(stmt)
   124  		CleanNodeText(stmt2)
   125  		c.Assert(stmt2, DeepEquals, stmt, comment)
   126  	}
   127  }