vitess.io/vitess@v0.16.2/go/tools/asthelpergen/rewrite_gen.go (about) 1 /* 2 Copyright 2021 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package asthelpergen 18 19 import ( 20 "fmt" 21 "go/types" 22 23 "github.com/dave/jennifer/jen" 24 ) 25 26 const ( 27 rewriteName = "rewrite" 28 ) 29 30 type rewriteGen struct { 31 ifaceName string 32 file *jen.File 33 } 34 35 var _ generator = (*rewriteGen)(nil) 36 37 func newRewriterGen(pkgname string, ifaceName string) *rewriteGen { 38 file := jen.NewFile(pkgname) 39 file.HeaderComment(licenseFileHeader) 40 file.HeaderComment("Code generated by ASTHelperGen. DO NOT EDIT.") 41 42 return &rewriteGen{ 43 ifaceName: ifaceName, 44 file: file, 45 } 46 } 47 48 func (r *rewriteGen) genFile() (string, *jen.File) { 49 return "ast_rewrite.go", r.file 50 } 51 52 func (r *rewriteGen) interfaceMethod(t types.Type, iface *types.Interface, spi generatorSPI) error { 53 if !shouldAdd(t, spi.iface()) { 54 return nil 55 } 56 /* 57 func VisitAST(in AST) (bool, error) { 58 if in == nil { 59 return false, nil 60 } 61 switch a := inA.(type) { 62 case *SubImpl: 63 return VisitSubImpl(a, b) 64 default: 65 return false, nil 66 } 67 } 68 */ 69 stmts := []jen.Code{ 70 jen.If(jen.Id("node == nil").Block(returnTrue())), 71 } 72 73 var cases []jen.Code 74 _ = spi.findImplementations(iface, func(t types.Type) error { 75 if _, ok := t.Underlying().(*types.Interface); ok { 76 return nil 77 } 78 typeString := types.TypeString(t, noQualifier) 79 funcName := rewriteName + printableTypeName(t) 80 spi.addType(t) 81 caseBlock := jen.Case(jen.Id(typeString)).Block( 82 jen.Return(jen.Id("a").Dot(funcName).Call(jen.Id("parent, node, replacer"))), 83 ) 84 cases = append(cases, caseBlock) 85 return nil 86 }) 87 88 cases = append(cases, 89 jen.Default().Block( 90 jen.Comment("this should never happen"), 91 returnTrue(), 92 )) 93 94 stmts = append(stmts, jen.Switch(jen.Id("node := node.(type)").Block( 95 cases..., 96 ))) 97 98 r.rewriteFunc(t, stmts) 99 return nil 100 } 101 102 func (r *rewriteGen) structMethod(t types.Type, strct *types.Struct, spi generatorSPI) error { 103 if !shouldAdd(t, spi.iface()) { 104 return nil 105 } 106 fields := r.rewriteAllStructFields(t, strct, spi, true) 107 108 stmts := []jen.Code{executePre()} 109 stmts = append(stmts, fields...) 110 stmts = append(stmts, executePost(len(fields) > 0)) 111 stmts = append(stmts, returnTrue()) 112 113 r.rewriteFunc(t, stmts) 114 115 return nil 116 } 117 118 func (r *rewriteGen) ptrToStructMethod(t types.Type, strct *types.Struct, spi generatorSPI) error { 119 if !shouldAdd(t, spi.iface()) { 120 return nil 121 } 122 123 /* 124 if node == nil { return nil } 125 */ 126 stmts := []jen.Code{jen.If(jen.Id("node == nil").Block(returnTrue()))} 127 128 /* 129 if !pre(&cur) { 130 return nil 131 } 132 */ 133 stmts = append(stmts, executePre()) 134 fields := r.rewriteAllStructFields(t, strct, spi, false) 135 stmts = append(stmts, fields...) 136 stmts = append(stmts, executePost(len(fields) > 0)) 137 stmts = append(stmts, returnTrue()) 138 139 r.rewriteFunc(t, stmts) 140 141 return nil 142 } 143 144 func (r *rewriteGen) ptrToBasicMethod(t types.Type, _ *types.Basic, spi generatorSPI) error { 145 if !shouldAdd(t, spi.iface()) { 146 return nil 147 } 148 149 /* 150 */ 151 152 stmts := []jen.Code{ 153 jen.Comment("ptrToBasicMethod"), 154 } 155 r.rewriteFunc(t, stmts) 156 157 return nil 158 } 159 160 func (r *rewriteGen) sliceMethod(t types.Type, slice *types.Slice, spi generatorSPI) error { 161 if !shouldAdd(t, spi.iface()) { 162 return nil 163 } 164 165 /* 166 if node == nil { 167 return nil 168 } 169 cur := Cursor{ 170 node: node, 171 parent: parent, 172 replacer: replacer, 173 } 174 if !pre(&cur) { 175 return nil 176 } 177 */ 178 stmts := []jen.Code{ 179 jen.If(jen.Id("node == nil").Block(returnTrue())), 180 } 181 182 typeString := types.TypeString(t, noQualifier) 183 184 preStmts := setupCursor() 185 preStmts = append(preStmts, 186 jen.Id("kontinue").Op(":=").Id("!a.pre(&a.cur)"), 187 jen.If(jen.Id("a.cur.revisit").Block( 188 jen.Id("node").Op("=").Id("a.cur.node.("+typeString+")"), 189 jen.Id("a.cur.revisit").Op("=").False(), 190 jen.Return(jen.Id("a.rewrite"+typeString+"(parent, node, replacer)")), 191 )), 192 jen.If(jen.Id("kontinue").Block(jen.Return(jen.True()))), 193 ) 194 195 stmts = append(stmts, jen.If(jen.Id("a.pre!= nil").Block(preStmts...))) 196 197 haveChildren := false 198 if shouldAdd(slice.Elem(), spi.iface()) { 199 /* 200 for i, el := range node { 201 if err := rewriteRefOfLeaf(node, el, func(newNode, parent AST) { 202 parent.(LeafSlice)[i] = newNode.(*Leaf) 203 }, pre, post); err != nil { 204 return err 205 } 206 } 207 */ 208 haveChildren = true 209 stmts = append(stmts, 210 jen.For(jen.Id("x, el").Op(":=").Id("range node")). 211 Block(r.rewriteChildSlice(t, slice.Elem(), "notUsed", jen.Id("el"), jen.Index(jen.Id("idx")), false))) 212 } 213 214 stmts = append(stmts, executePost(haveChildren)) 215 stmts = append(stmts, returnTrue()) 216 217 r.rewriteFunc(t, stmts) 218 return nil 219 } 220 221 func setupCursor() []jen.Code { 222 return []jen.Code{ 223 jen.Id("a.cur.replacer = replacer"), 224 jen.Id("a.cur.parent = parent"), 225 jen.Id("a.cur.node = node"), 226 } 227 } 228 func executePre() jen.Code { 229 curStmts := setupCursor() 230 curStmts = append(curStmts, jen.If(jen.Id("!a.pre(&a.cur)")).Block(returnTrue())) 231 return jen.If(jen.Id("a.pre!= nil").Block(curStmts...)) 232 } 233 234 func executePost(seenChildren bool) jen.Code { 235 var curStmts []jen.Code 236 if seenChildren { 237 // if we have visited children, we have to write to the cursor fields 238 curStmts = setupCursor() 239 } else { 240 curStmts = append(curStmts, 241 jen.If(jen.Id("a.pre == nil")).Block(setupCursor()...)) 242 } 243 244 curStmts = append(curStmts, jen.If(jen.Id("!a.post(&a.cur)")).Block(returnFalse())) 245 246 return jen.If(jen.Id("a.post != nil")).Block(curStmts...) 247 } 248 249 func (r *rewriteGen) basicMethod(t types.Type, _ *types.Basic, spi generatorSPI) error { 250 if !shouldAdd(t, spi.iface()) { 251 return nil 252 } 253 254 stmts := []jen.Code{executePre(), executePost(false), returnTrue()} 255 r.rewriteFunc(t, stmts) 256 return nil 257 } 258 259 func (r *rewriteGen) rewriteFunc(t types.Type, stmts []jen.Code) { 260 261 /* 262 func (a *application) rewriteNodeType(parent AST, node NodeType, replacer replacerFunc) { 263 */ 264 265 typeString := types.TypeString(t, noQualifier) 266 funcName := fmt.Sprintf("%s%s", rewriteName, printableTypeName(t)) 267 code := jen.Func().Params( 268 jen.Id("a").Op("*").Id("application"), 269 ).Id(funcName).Params( 270 jen.Id(fmt.Sprintf("parent %s, node %s, replacer replacerFunc", r.ifaceName, typeString)), 271 ).Bool().Block(stmts...) 272 273 r.file.Add(code) 274 } 275 276 func (r *rewriteGen) rewriteAllStructFields(t types.Type, strct *types.Struct, spi generatorSPI, fail bool) []jen.Code { 277 /* 278 if errF := rewriteAST(node, node.ASTType, func(newNode, parent AST) { 279 err = vterrors.New(vtrpcpb.Code_INTERNAL, "[BUG] tried to replace '%s' on '%s'") 280 }, pre, post); errF != nil { 281 return errF 282 } 283 284 */ 285 var output []jen.Code 286 for i := 0; i < strct.NumFields(); i++ { 287 field := strct.Field(i) 288 if types.Implements(field.Type(), spi.iface()) { 289 spi.addType(field.Type()) 290 output = append(output, r.rewriteChild(t, field.Type(), field.Name(), jen.Id("node").Dot(field.Name()), jen.Dot(field.Name()), fail)) 291 continue 292 } 293 slice, isSlice := field.Type().(*types.Slice) 294 if isSlice && types.Implements(slice.Elem(), spi.iface()) { 295 spi.addType(slice.Elem()) 296 id := jen.Id("x") 297 if fail { 298 id = jen.Id("_") 299 } 300 output = append(output, 301 jen.For(jen.List(id, jen.Id("el")).Op(":=").Id("range node."+field.Name())). 302 Block(r.rewriteChildSlice(t, slice.Elem(), field.Name(), jen.Id("el"), jen.Dot(field.Name()).Index(jen.Id("idx")), fail))) 303 } 304 } 305 return output 306 } 307 308 func failReplacer(t types.Type, f string) *jen.Statement { 309 typeString := types.TypeString(t, noQualifier) 310 return jen.Panic(jen.Lit(fmt.Sprintf("[BUG] tried to replace '%s' on '%s'", f, typeString))) 311 } 312 313 func (r *rewriteGen) rewriteChild(t, field types.Type, fieldName string, param jen.Code, replace jen.Code, fail bool) jen.Code { 314 /* 315 if errF := rewriteAST(node, node.ASTType, func(newNode, parent AST) { 316 parent.(*RefContainer).ASTType = newNode.(AST) 317 }, pre, post); errF != nil { 318 return errF 319 } 320 321 if errF := rewriteAST(node, el, func(newNode, parent AST) { 322 parent.(*RefSliceContainer).ASTElements[i] = newNode.(AST) 323 }, pre, post); errF != nil { 324 return errF 325 } 326 327 */ 328 funcName := rewriteName + printableTypeName(field) 329 var replaceOrFail *jen.Statement 330 if fail { 331 replaceOrFail = failReplacer(t, fieldName) 332 } else { 333 replaceOrFail = jen.Id("parent"). 334 Assert(jen.Id(types.TypeString(t, noQualifier))). 335 Add(replace). 336 Op("="). 337 Id("newNode").Assert(jen.Id(types.TypeString(field, noQualifier))) 338 339 } 340 funcBlock := jen.Func().Call(jen.Id("newNode, parent").Id(r.ifaceName)). 341 Block(replaceOrFail) 342 343 rewriteField := jen.If( 344 jen.Op("!").Id("a").Dot(funcName).Call( 345 jen.Id("node"), 346 param, 347 funcBlock).Block(returnFalse())) 348 349 return rewriteField 350 } 351 352 func (r *rewriteGen) rewriteChildSlice(t, field types.Type, fieldName string, param jen.Code, replace jen.Code, fail bool) jen.Code { 353 /* 354 if errF := a.rewriteAST(node, el, func(idx int) replacerFunc { 355 return func(newNode, parent AST) { 356 parent.(InterfaceSlice)[idx] = newNode.(AST) 357 } 358 }(i)); errF != nil { 359 return errF 360 } 361 362 if errF := a.rewriteAST(node, el, func(newNode, parent AST) { 363 return errr... 364 }); errF != nil { 365 return errF 366 } 367 368 */ 369 370 funcName := rewriteName + printableTypeName(field) 371 var funcBlock jen.Code 372 replacerFuncDef := jen.Func().Call(jen.Id("newNode, parent").Id(r.ifaceName)) 373 if fail { 374 funcBlock = replacerFuncDef.Block(failReplacer(t, fieldName)) 375 } else { 376 funcBlock = jen.Func().Call(jen.Id("idx int")).Id("replacerFunc"). 377 Block(jen.Return(replacerFuncDef.Block( 378 jen.Id("parent").Assert(jen.Id(types.TypeString(t, noQualifier))).Add(replace).Op("=").Id("newNode").Assert(jen.Id(types.TypeString(field, noQualifier)))), 379 )).Call(jen.Id("x")) 380 } 381 382 rewriteField := jen.If( 383 jen.Op("!").Id("a").Dot(funcName).Call( 384 jen.Id("node"), 385 param, 386 funcBlock).Block(returnFalse())) 387 388 return rewriteField 389 } 390 391 var noQualifier = func(p *types.Package) string { 392 return "" 393 } 394 395 func returnTrue() jen.Code { 396 return jen.Return(jen.True()) 397 } 398 399 func returnFalse() jen.Code { 400 return jen.Return(jen.False()) 401 }