vitess.io/vitess@v0.16.2/go/tools/asthelpergen/rewrite_gen.go (about)

     1  /*
     2  Copyright 2021 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package asthelpergen
    18  
    19  import (
    20  	"fmt"
    21  	"go/types"
    22  
    23  	"github.com/dave/jennifer/jen"
    24  )
    25  
    26  const (
    27  	rewriteName = "rewrite"
    28  )
    29  
    30  type rewriteGen struct {
    31  	ifaceName string
    32  	file      *jen.File
    33  }
    34  
    35  var _ generator = (*rewriteGen)(nil)
    36  
    37  func newRewriterGen(pkgname string, ifaceName string) *rewriteGen {
    38  	file := jen.NewFile(pkgname)
    39  	file.HeaderComment(licenseFileHeader)
    40  	file.HeaderComment("Code generated by ASTHelperGen. DO NOT EDIT.")
    41  
    42  	return &rewriteGen{
    43  		ifaceName: ifaceName,
    44  		file:      file,
    45  	}
    46  }
    47  
    48  func (r *rewriteGen) genFile() (string, *jen.File) {
    49  	return "ast_rewrite.go", r.file
    50  }
    51  
    52  func (r *rewriteGen) interfaceMethod(t types.Type, iface *types.Interface, spi generatorSPI) error {
    53  	if !shouldAdd(t, spi.iface()) {
    54  		return nil
    55  	}
    56  	/*
    57  		func VisitAST(in AST) (bool, error) {
    58  			if in == nil {
    59  				return false, nil
    60  			}
    61  			switch a := inA.(type) {
    62  			case *SubImpl:
    63  				return VisitSubImpl(a, b)
    64  			default:
    65  				return false, nil
    66  			}
    67  		}
    68  	*/
    69  	stmts := []jen.Code{
    70  		jen.If(jen.Id("node == nil").Block(returnTrue())),
    71  	}
    72  
    73  	var cases []jen.Code
    74  	_ = spi.findImplementations(iface, func(t types.Type) error {
    75  		if _, ok := t.Underlying().(*types.Interface); ok {
    76  			return nil
    77  		}
    78  		typeString := types.TypeString(t, noQualifier)
    79  		funcName := rewriteName + printableTypeName(t)
    80  		spi.addType(t)
    81  		caseBlock := jen.Case(jen.Id(typeString)).Block(
    82  			jen.Return(jen.Id("a").Dot(funcName).Call(jen.Id("parent, node, replacer"))),
    83  		)
    84  		cases = append(cases, caseBlock)
    85  		return nil
    86  	})
    87  
    88  	cases = append(cases,
    89  		jen.Default().Block(
    90  			jen.Comment("this should never happen"),
    91  			returnTrue(),
    92  		))
    93  
    94  	stmts = append(stmts, jen.Switch(jen.Id("node := node.(type)").Block(
    95  		cases...,
    96  	)))
    97  
    98  	r.rewriteFunc(t, stmts)
    99  	return nil
   100  }
   101  
   102  func (r *rewriteGen) structMethod(t types.Type, strct *types.Struct, spi generatorSPI) error {
   103  	if !shouldAdd(t, spi.iface()) {
   104  		return nil
   105  	}
   106  	fields := r.rewriteAllStructFields(t, strct, spi, true)
   107  
   108  	stmts := []jen.Code{executePre()}
   109  	stmts = append(stmts, fields...)
   110  	stmts = append(stmts, executePost(len(fields) > 0))
   111  	stmts = append(stmts, returnTrue())
   112  
   113  	r.rewriteFunc(t, stmts)
   114  
   115  	return nil
   116  }
   117  
   118  func (r *rewriteGen) ptrToStructMethod(t types.Type, strct *types.Struct, spi generatorSPI) error {
   119  	if !shouldAdd(t, spi.iface()) {
   120  		return nil
   121  	}
   122  
   123  	/*
   124  		if node == nil { return nil }
   125  	*/
   126  	stmts := []jen.Code{jen.If(jen.Id("node == nil").Block(returnTrue()))}
   127  
   128  	/*
   129  		if !pre(&cur) {
   130  			return nil
   131  		}
   132  	*/
   133  	stmts = append(stmts, executePre())
   134  	fields := r.rewriteAllStructFields(t, strct, spi, false)
   135  	stmts = append(stmts, fields...)
   136  	stmts = append(stmts, executePost(len(fields) > 0))
   137  	stmts = append(stmts, returnTrue())
   138  
   139  	r.rewriteFunc(t, stmts)
   140  
   141  	return nil
   142  }
   143  
   144  func (r *rewriteGen) ptrToBasicMethod(t types.Type, _ *types.Basic, spi generatorSPI) error {
   145  	if !shouldAdd(t, spi.iface()) {
   146  		return nil
   147  	}
   148  
   149  	/*
   150  	 */
   151  
   152  	stmts := []jen.Code{
   153  		jen.Comment("ptrToBasicMethod"),
   154  	}
   155  	r.rewriteFunc(t, stmts)
   156  
   157  	return nil
   158  }
   159  
   160  func (r *rewriteGen) sliceMethod(t types.Type, slice *types.Slice, spi generatorSPI) error {
   161  	if !shouldAdd(t, spi.iface()) {
   162  		return nil
   163  	}
   164  
   165  	/*
   166  		if node == nil {
   167  				return nil
   168  			}
   169  			cur := Cursor{
   170  				node:     node,
   171  				parent:   parent,
   172  				replacer: replacer,
   173  			}
   174  			if !pre(&cur) {
   175  				return nil
   176  			}
   177  	*/
   178  	stmts := []jen.Code{
   179  		jen.If(jen.Id("node == nil").Block(returnTrue())),
   180  	}
   181  
   182  	typeString := types.TypeString(t, noQualifier)
   183  
   184  	preStmts := setupCursor()
   185  	preStmts = append(preStmts,
   186  		jen.Id("kontinue").Op(":=").Id("!a.pre(&a.cur)"),
   187  		jen.If(jen.Id("a.cur.revisit").Block(
   188  			jen.Id("node").Op("=").Id("a.cur.node.("+typeString+")"),
   189  			jen.Id("a.cur.revisit").Op("=").False(),
   190  			jen.Return(jen.Id("a.rewrite"+typeString+"(parent, node, replacer)")),
   191  		)),
   192  		jen.If(jen.Id("kontinue").Block(jen.Return(jen.True()))),
   193  	)
   194  
   195  	stmts = append(stmts, jen.If(jen.Id("a.pre!= nil").Block(preStmts...)))
   196  
   197  	haveChildren := false
   198  	if shouldAdd(slice.Elem(), spi.iface()) {
   199  		/*
   200  			for i, el := range node {
   201  						if err := rewriteRefOfLeaf(node, el, func(newNode, parent AST) {
   202  							parent.(LeafSlice)[i] = newNode.(*Leaf)
   203  						}, pre, post); err != nil {
   204  							return err
   205  						}
   206  					}
   207  		*/
   208  		haveChildren = true
   209  		stmts = append(stmts,
   210  			jen.For(jen.Id("x, el").Op(":=").Id("range node")).
   211  				Block(r.rewriteChildSlice(t, slice.Elem(), "notUsed", jen.Id("el"), jen.Index(jen.Id("idx")), false)))
   212  	}
   213  
   214  	stmts = append(stmts, executePost(haveChildren))
   215  	stmts = append(stmts, returnTrue())
   216  
   217  	r.rewriteFunc(t, stmts)
   218  	return nil
   219  }
   220  
   221  func setupCursor() []jen.Code {
   222  	return []jen.Code{
   223  		jen.Id("a.cur.replacer = replacer"),
   224  		jen.Id("a.cur.parent = parent"),
   225  		jen.Id("a.cur.node = node"),
   226  	}
   227  }
   228  func executePre() jen.Code {
   229  	curStmts := setupCursor()
   230  	curStmts = append(curStmts, jen.If(jen.Id("!a.pre(&a.cur)")).Block(returnTrue()))
   231  	return jen.If(jen.Id("a.pre!= nil").Block(curStmts...))
   232  }
   233  
   234  func executePost(seenChildren bool) jen.Code {
   235  	var curStmts []jen.Code
   236  	if seenChildren {
   237  		// if we have visited children, we have to write to the cursor fields
   238  		curStmts = setupCursor()
   239  	} else {
   240  		curStmts = append(curStmts,
   241  			jen.If(jen.Id("a.pre == nil")).Block(setupCursor()...))
   242  	}
   243  
   244  	curStmts = append(curStmts, jen.If(jen.Id("!a.post(&a.cur)")).Block(returnFalse()))
   245  
   246  	return jen.If(jen.Id("a.post != nil")).Block(curStmts...)
   247  }
   248  
   249  func (r *rewriteGen) basicMethod(t types.Type, _ *types.Basic, spi generatorSPI) error {
   250  	if !shouldAdd(t, spi.iface()) {
   251  		return nil
   252  	}
   253  
   254  	stmts := []jen.Code{executePre(), executePost(false), returnTrue()}
   255  	r.rewriteFunc(t, stmts)
   256  	return nil
   257  }
   258  
   259  func (r *rewriteGen) rewriteFunc(t types.Type, stmts []jen.Code) {
   260  
   261  	/*
   262  		func (a *application) rewriteNodeType(parent AST, node NodeType, replacer replacerFunc) {
   263  	*/
   264  
   265  	typeString := types.TypeString(t, noQualifier)
   266  	funcName := fmt.Sprintf("%s%s", rewriteName, printableTypeName(t))
   267  	code := jen.Func().Params(
   268  		jen.Id("a").Op("*").Id("application"),
   269  	).Id(funcName).Params(
   270  		jen.Id(fmt.Sprintf("parent %s, node %s, replacer replacerFunc", r.ifaceName, typeString)),
   271  	).Bool().Block(stmts...)
   272  
   273  	r.file.Add(code)
   274  }
   275  
   276  func (r *rewriteGen) rewriteAllStructFields(t types.Type, strct *types.Struct, spi generatorSPI, fail bool) []jen.Code {
   277  	/*
   278  		if errF := rewriteAST(node, node.ASTType, func(newNode, parent AST) {
   279  			err = vterrors.New(vtrpcpb.Code_INTERNAL, "[BUG] tried to replace '%s' on '%s'")
   280  		}, pre, post); errF != nil {
   281  			return errF
   282  		}
   283  
   284  	*/
   285  	var output []jen.Code
   286  	for i := 0; i < strct.NumFields(); i++ {
   287  		field := strct.Field(i)
   288  		if types.Implements(field.Type(), spi.iface()) {
   289  			spi.addType(field.Type())
   290  			output = append(output, r.rewriteChild(t, field.Type(), field.Name(), jen.Id("node").Dot(field.Name()), jen.Dot(field.Name()), fail))
   291  			continue
   292  		}
   293  		slice, isSlice := field.Type().(*types.Slice)
   294  		if isSlice && types.Implements(slice.Elem(), spi.iface()) {
   295  			spi.addType(slice.Elem())
   296  			id := jen.Id("x")
   297  			if fail {
   298  				id = jen.Id("_")
   299  			}
   300  			output = append(output,
   301  				jen.For(jen.List(id, jen.Id("el")).Op(":=").Id("range node."+field.Name())).
   302  					Block(r.rewriteChildSlice(t, slice.Elem(), field.Name(), jen.Id("el"), jen.Dot(field.Name()).Index(jen.Id("idx")), fail)))
   303  		}
   304  	}
   305  	return output
   306  }
   307  
   308  func failReplacer(t types.Type, f string) *jen.Statement {
   309  	typeString := types.TypeString(t, noQualifier)
   310  	return jen.Panic(jen.Lit(fmt.Sprintf("[BUG] tried to replace '%s' on '%s'", f, typeString)))
   311  }
   312  
   313  func (r *rewriteGen) rewriteChild(t, field types.Type, fieldName string, param jen.Code, replace jen.Code, fail bool) jen.Code {
   314  	/*
   315  		if errF := rewriteAST(node, node.ASTType, func(newNode, parent AST) {
   316  			parent.(*RefContainer).ASTType = newNode.(AST)
   317  		}, pre, post); errF != nil {
   318  			return errF
   319  		}
   320  
   321  		if errF := rewriteAST(node, el, func(newNode, parent AST) {
   322  			parent.(*RefSliceContainer).ASTElements[i] = newNode.(AST)
   323  		}, pre, post); errF != nil {
   324  			return errF
   325  		}
   326  
   327  	*/
   328  	funcName := rewriteName + printableTypeName(field)
   329  	var replaceOrFail *jen.Statement
   330  	if fail {
   331  		replaceOrFail = failReplacer(t, fieldName)
   332  	} else {
   333  		replaceOrFail = jen.Id("parent").
   334  			Assert(jen.Id(types.TypeString(t, noQualifier))).
   335  			Add(replace).
   336  			Op("=").
   337  			Id("newNode").Assert(jen.Id(types.TypeString(field, noQualifier)))
   338  
   339  	}
   340  	funcBlock := jen.Func().Call(jen.Id("newNode, parent").Id(r.ifaceName)).
   341  		Block(replaceOrFail)
   342  
   343  	rewriteField := jen.If(
   344  		jen.Op("!").Id("a").Dot(funcName).Call(
   345  			jen.Id("node"),
   346  			param,
   347  			funcBlock).Block(returnFalse()))
   348  
   349  	return rewriteField
   350  }
   351  
   352  func (r *rewriteGen) rewriteChildSlice(t, field types.Type, fieldName string, param jen.Code, replace jen.Code, fail bool) jen.Code {
   353  	/*
   354  				if errF := a.rewriteAST(node, el, func(idx int) replacerFunc {
   355  				return func(newNode, parent AST) {
   356  					parent.(InterfaceSlice)[idx] = newNode.(AST)
   357  				}
   358  			}(i)); errF != nil {
   359  				return errF
   360  			}
   361  
   362  			if errF := a.rewriteAST(node, el, func(newNode, parent AST) {
   363  		return errr...
   364  		}); errF != nil {
   365  				return errF
   366  			}
   367  
   368  	*/
   369  
   370  	funcName := rewriteName + printableTypeName(field)
   371  	var funcBlock jen.Code
   372  	replacerFuncDef := jen.Func().Call(jen.Id("newNode, parent").Id(r.ifaceName))
   373  	if fail {
   374  		funcBlock = replacerFuncDef.Block(failReplacer(t, fieldName))
   375  	} else {
   376  		funcBlock = jen.Func().Call(jen.Id("idx int")).Id("replacerFunc").
   377  			Block(jen.Return(replacerFuncDef.Block(
   378  				jen.Id("parent").Assert(jen.Id(types.TypeString(t, noQualifier))).Add(replace).Op("=").Id("newNode").Assert(jen.Id(types.TypeString(field, noQualifier)))),
   379  			)).Call(jen.Id("x"))
   380  	}
   381  
   382  	rewriteField := jen.If(
   383  		jen.Op("!").Id("a").Dot(funcName).Call(
   384  			jen.Id("node"),
   385  			param,
   386  			funcBlock).Block(returnFalse()))
   387  
   388  	return rewriteField
   389  }
   390  
   391  var noQualifier = func(p *types.Package) string {
   392  	return ""
   393  }
   394  
   395  func returnTrue() jen.Code {
   396  	return jen.Return(jen.True())
   397  }
   398  
   399  func returnFalse() jen.Code {
   400  	return jen.Return(jen.False())
   401  }