vitess.io/vitess@v0.16.2/go/tools/asthelpergen/copy_on_rewrite_gen.go (about) 1 /* 2 Copyright 2023 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package asthelpergen 18 19 import ( 20 "go/types" 21 22 "github.com/dave/jennifer/jen" 23 ) 24 25 type cowGen struct { 26 file *jen.File 27 baseType string 28 } 29 30 var _ generator = (*cowGen)(nil) 31 32 func newCOWGen(pkgname string, nt *types.Named) *cowGen { 33 file := jen.NewFile(pkgname) 34 file.HeaderComment(licenseFileHeader) 35 file.HeaderComment("Code generated by ASTHelperGen. DO NOT EDIT.") 36 37 return &cowGen{ 38 file: file, 39 baseType: nt.Obj().Id(), 40 } 41 } 42 43 func (c *cowGen) addFunc(code *jen.Statement) { 44 c.file.Add(code) 45 } 46 47 func (c *cowGen) genFile() (string, *jen.File) { 48 return "ast_copy_on_rewrite.go", c.file 49 } 50 51 const cowName = "copyOnRewrite" 52 53 // readValueOfType produces code to read the expression of type `t`, and adds the type to the todo-list 54 func (c *cowGen) readValueOfType(t types.Type, expr jen.Code, spi generatorSPI) jen.Code { 55 switch t.Underlying().(type) { 56 case *types.Interface: 57 if types.TypeString(t, noQualifier) == "any" { 58 // these fields have to be taken care of manually 59 return expr 60 } 61 } 62 spi.addType(t) 63 return jen.Id("c").Dot(cowName + printableTypeName(t)).Call(expr) 64 } 65 66 func (c *cowGen) sliceMethod(t types.Type, slice *types.Slice, spi generatorSPI) error { 67 if !types.Implements(t, spi.iface()) { 68 return nil 69 } 70 71 typeString := types.TypeString(t, noQualifier) 72 73 changedVarName := "changed" 74 fieldVar := "res" 75 elemTyp := types.TypeString(slice.Elem(), noQualifier) 76 77 name := printableTypeName(t) 78 funcName := cowName + name 79 var visitElements *jen.Statement 80 81 if types.Implements(slice.Elem(), spi.iface()) { 82 visitElements = ifPreNotNilOrReturnsTrue().Block( 83 jen.Id(fieldVar).Op(":=").Id("make").Params(jen.Id(typeString), jen.Id("len").Params(jen.Id("n"))), // _Foo := make([]Typ, len(n)) 84 jen.For(jen.List(jen.Id("x"), jen.Id("el")).Op(":=").Id("range n")).Block( 85 c.visitFieldOrElement("this", "change", slice.Elem(), jen.Id("el"), spi), 86 // jen.Id(fieldVar).Index(jen.Id("x")).Op("=").Id("this").Op(".").Params(jen.Id(types.TypeString(elemTyp, noQualifier))), 87 jen.Id(fieldVar).Index(jen.Id("x")).Op("=").Id("this").Op(".").Params(jen.Id(elemTyp)), 88 jen.If(jen.Id("change")).Block( 89 jen.Id(changedVarName).Op("=").True(), 90 ), 91 ), 92 jen.If(jen.Id("changed")).Block( 93 jen.Id("out").Op("=").Id("res"), 94 ), 95 ) 96 } else { 97 visitElements = jen.If(jen.Id("c.pre != nil")).Block( 98 jen.Id("c.pre(n, parent)"), 99 ) 100 } 101 102 block := c.funcDecl(funcName, typeString).Block( 103 ifNilReturnNilAndFalse("n"), 104 jen.Id("out").Op("=").Id("n"), 105 visitElements, 106 ifPostNotNilVisit("out"), 107 jen.Return(), 108 ) 109 c.addFunc(block) 110 return nil 111 } 112 113 func (c *cowGen) basicMethod(t types.Type, basic *types.Basic, spi generatorSPI) error { 114 if !types.Implements(t, spi.iface()) { 115 return nil 116 } 117 118 typeString := types.TypeString(t, noQualifier) 119 typeName := printableTypeName(t) 120 121 var stmts []jen.Code 122 stmts = append(stmts, 123 jen.If(jen.Id("c").Dot("cursor").Dot("stop")).Block(jen.Return(jen.Id("n"), jen.False())), 124 ifNotNil("c.pre", jen.Id("c.pre").Params(jen.Id("n"), jen.Id("parent"))), 125 ifNotNil("c.post", jen.List(jen.Id("out"), jen.Id("changed")).Op("=").Id("c.postVisit").Params(jen.Id("n"), jen.Id("parent"), jen.Id("changed"))). 126 Else().Block(jen.Id("out = n")), 127 jen.Return(), 128 ) 129 funcName := cowName + typeName 130 funcDecl := c.funcDecl(funcName, typeString).Block(stmts...) 131 c.addFunc(funcDecl) 132 return nil 133 } 134 135 func (c *cowGen) copySliceElement(t types.Type, elType types.Type, spi generatorSPI) jen.Code { 136 if !isNamed(t) && isBasic(elType) { 137 // copy(res, n) 138 return jen.Id("copy").Call(jen.Id("res"), jen.Id("n")) 139 } 140 141 // for i := range n { 142 // res[i] = CloneAST(x) 143 // } 144 spi.addType(elType) 145 146 return jen.For(jen.List(jen.Id("i"), jen.Id("x"))).Op(":=").Range().Id("n").Block( 147 jen.Id("res").Index(jen.Id("i")).Op("=").Add(c.readValueOfType(elType, jen.Id("x"), spi)), 148 ) 149 } 150 151 func ifNotNil(id string, stmts ...jen.Code) *jen.Statement { 152 return jen.If(jen.Id(id).Op("!=").Nil()).Block(stmts...) 153 } 154 155 func ifNilReturnNilAndFalse(id string) *jen.Statement { 156 return jen.If(jen.Id(id).Op("==").Nil().Op("||").Id("c").Dot("cursor").Dot("stop")).Block(jen.Return(jen.Id("n"), jen.False())) 157 } 158 159 func ifPreNotNilOrReturnsTrue() *jen.Statement { 160 // if c.pre == nil || c.pre(n, parent) { 161 return jen.If( 162 jen.Id("c").Dot("pre").Op("==").Nil().Op("||").Id("c").Dot("pre").Params( 163 jen.Id("n"), 164 jen.Id("parent"), 165 )) 166 167 } 168 169 func (c *cowGen) interfaceMethod(t types.Type, iface *types.Interface, spi generatorSPI) error { 170 if !types.Implements(t, spi.iface()) { 171 return nil 172 } 173 174 // func (c cow) cowAST(in AST) (AST, bool) { 175 // if in == nil { 176 // return nil, false 177 // } 178 // 179 // if c.old == in { 180 // return c.new, true 181 // } 182 // switch in := in.(type) { 183 // case *RefContainer: 184 // return c.CowRefOfRefContainer(in) 185 // } 186 // // this should never happen 187 // return nil 188 // } 189 190 typeString := types.TypeString(t, noQualifier) 191 typeName := printableTypeName(t) 192 193 stmts := []jen.Code{ifNilReturnNilAndFalse("n")} 194 195 var cases []jen.Code 196 _ = findImplementations(spi.scope(), iface, func(t types.Type) error { 197 if _, ok := t.Underlying().(*types.Interface); ok { 198 return nil 199 } 200 spi.addType(t) 201 typeString := types.TypeString(t, noQualifier) 202 203 // case Type: return CloneType(in) 204 block := jen.Case(jen.Id(typeString)).Block(jen.Return(c.readValueOfType(t, jen.List(jen.Id("n"), jen.Id("parent")), spi))) 205 cases = append(cases, block) 206 207 return nil 208 }) 209 210 cases = append(cases, 211 jen.Default().Block( 212 jen.Comment("this should never happen"), 213 jen.Return(jen.Nil(), jen.False()), 214 )) 215 216 // switch n := node.(type) { 217 stmts = append(stmts, jen.Switch(jen.Id("n").Op(":=").Id("n").Assert(jen.Id("type")).Block( 218 cases..., 219 ))) 220 221 funcName := cowName + typeName 222 funcDecl := c.funcDecl(funcName, typeString).Block(stmts...) 223 c.addFunc(funcDecl) 224 return nil 225 } 226 227 func (c *cowGen) ptrToBasicMethod(t types.Type, _ *types.Basic, spi generatorSPI) error { 228 if !types.Implements(t, spi.iface()) { 229 return nil 230 } 231 232 ptr := t.Underlying().(*types.Pointer) 233 return c.ptrToOtherMethod(t, ptr, spi) 234 } 235 236 func (c *cowGen) ptrToOtherMethod(t types.Type, ptr *types.Pointer, spi generatorSPI) error { 237 if !types.Implements(t, spi.iface()) { 238 return nil 239 } 240 241 receiveType := types.TypeString(t, noQualifier) 242 243 funcName := cowName + printableTypeName(t) 244 c.addFunc(c.funcDecl(funcName, receiveType).Block( 245 jen.Comment("apan was here"), 246 jen.Return(jen.Id("n"), jen.False()), 247 )) 248 return nil 249 } 250 251 // func (c cow) COWRefOfType(n *Type) (*Type, bool) 252 func (c *cowGen) funcDecl(funcName, typeName string) *jen.Statement { 253 return jen.Func().Params(jen.Id("c").Id("*cow")).Id(funcName).Call(jen.List(jen.Id("n").Id(typeName), jen.Id("parent").Id(c.baseType))).Params(jen.Id("out").Id(c.baseType), jen.Id("changed").Id("bool")) 254 } 255 256 func (c *cowGen) visitFieldOrElement(varName, changedVarName string, typ types.Type, el *jen.Statement, spi generatorSPI) *jen.Statement { 257 // _Field, changedField := c.COWType(n.<Field>, n) 258 return jen.List(jen.Id(varName), jen.Id(changedVarName)).Op(":=").Add(c.readValueOfType(typ, jen.List(el, jen.Id("n")), spi)) 259 } 260 261 func (c *cowGen) structMethod(t types.Type, strct *types.Struct, spi generatorSPI) error { 262 if !types.Implements(t, spi.iface()) { 263 return nil 264 } 265 266 c.visitStruct(t, strct, spi, nil, false) 267 return nil 268 } 269 270 func (c *cowGen) ptrToStructMethod(t types.Type, strct *types.Struct, spi generatorSPI) error { 271 if !types.Implements(t, spi.iface()) { 272 return nil 273 } 274 start := ifNilReturnNilAndFalse("n") 275 276 c.visitStruct(t, strct, spi, start, true) 277 return nil 278 } 279 280 func (c *cowGen) visitStruct(t types.Type, strct *types.Struct, spi generatorSPI, start *jen.Statement, ref bool) { 281 receiveType := types.TypeString(t, noQualifier) 282 funcName := cowName + printableTypeName(t) 283 284 funcDeclaration := c.funcDecl(funcName, receiveType) 285 286 var fields []jen.Code 287 out := "out" 288 changed := "res" 289 var fieldSetters []jen.Code 290 kopy := jen.Id(changed).Op(":=") 291 if ref { 292 fieldSetters = append(fieldSetters, kopy.Op("*").Id("n")) // changed := *n 293 } else { 294 fieldSetters = append(fieldSetters, kopy.Id("n")) // changed := n 295 } 296 var changedVariables []string 297 for i := 0; i < strct.NumFields(); i++ { 298 field := strct.Field(i).Name() 299 typ := strct.Field(i).Type() 300 changedVarName := "changed" + field 301 302 fieldType := types.TypeString(typ, noQualifier) 303 fieldVar := "_" + field 304 if types.Implements(typ, spi.iface()) { 305 fields = append(fields, c.visitFieldOrElement(fieldVar, changedVarName, typ, jen.Id("n").Dot(field), spi)) 306 changedVariables = append(changedVariables, changedVarName) 307 fieldSetters = append(fieldSetters, jen.List(jen.Id(changed).Dot(field), jen.Op("_")).Op("=").Id(fieldVar).Op(".").Params(jen.Id(fieldType))) 308 } else { 309 // _Foo := make([]*Type, len(n.Foo)) 310 // var changedFoo bool 311 // for x, el := range n.Foo { 312 // c, changed := c.COWSliceOfRefOfType(el, n) 313 // if changed { 314 // changedFoo = true 315 // } 316 // _Foo[i] = c.(*Type) 317 // } 318 319 slice, isSlice := typ.(*types.Slice) 320 if isSlice && types.Implements(slice.Elem(), spi.iface()) { 321 elemTyp := slice.Elem() 322 spi.addType(elemTyp) 323 x := jen.Id("x") 324 el := jen.Id("el") 325 // changed := jen.Id("changed") 326 fields = append(fields, 327 jen.Var().Id(changedVarName).Bool(), // var changedFoo bool 328 jen.Id(fieldVar).Op(":=").Id("make").Params(jen.Id(fieldType), jen.Id("len").Params(jen.Id("n").Dot(field))), // _Foo := make([]Typ, len(n.Foo)) 329 jen.For(jen.List(x, el).Op(":=").Id("range n").Dot(field)).Block( 330 c.visitFieldOrElement("this", "changed", elemTyp, jen.Id("el"), spi), 331 jen.Id(fieldVar).Index(jen.Id("x")).Op("=").Id("this").Op(".").Params(jen.Id(types.TypeString(elemTyp, noQualifier))), 332 jen.If(jen.Id("changed")).Block( 333 jen.Id(changedVarName).Op("=").True(), 334 ), 335 ), 336 ) 337 changedVariables = append(changedVariables, changedVarName) 338 fieldSetters = append(fieldSetters, jen.Id(changed).Dot(field).Op("=").Id(fieldVar)) 339 } 340 } 341 } 342 343 var cond *jen.Statement 344 for _, variable := range changedVariables { 345 if cond == nil { 346 cond = jen.Id(variable) 347 } else { 348 cond = cond.Op("||").Add(jen.Id(variable)) 349 } 350 351 } 352 353 fieldSetters = append(fieldSetters, 354 jen.Id(out).Op("=").Op("&").Id(changed), 355 ifNotNil("c.cloned", jen.Id("c.cloned").Params(jen.Id("n, out"))), 356 jen.Id("changed").Op("=").True(), 357 ) 358 ifChanged := jen.If(cond).Block(fieldSetters...) 359 360 var stmts []jen.Code 361 if start != nil { 362 stmts = append(stmts, start) 363 } 364 365 // handle all fields with CloneAble types 366 var visitChildren []jen.Code 367 visitChildren = append(visitChildren, fields...) 368 if len(fieldSetters) > 4 /*we add three statements always*/ { 369 visitChildren = append(visitChildren, ifChanged) 370 } 371 372 children := ifPreNotNilOrReturnsTrue().Block(visitChildren...) 373 stmts = append(stmts, 374 jen.Id(out).Op("=").Id("n"), 375 children, 376 ) 377 378 stmts = append( 379 stmts, 380 ifPostNotNilVisit(out), 381 jen.Return(), 382 ) 383 384 c.addFunc(funcDeclaration.Block(stmts...)) 385 } 386 387 func ifPostNotNilVisit(out string) *jen.Statement { 388 return ifNotNil("c.post", jen.List(jen.Id(out), jen.Id("changed")).Op("=").Id("c").Dot("postVisit").Params(jen.Id(out), jen.Id("parent"), jen.Id("changed"))) 389 }