vitess.io/vitess@v0.16.2/go/tools/asthelpergen/visit_gen.go (about) 1 /* 2 Copyright 2021 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package asthelpergen 18 19 import ( 20 "go/types" 21 22 "github.com/dave/jennifer/jen" 23 ) 24 25 const visitName = "Visit" 26 27 type visitGen struct { 28 file *jen.File 29 } 30 31 var _ generator = (*visitGen)(nil) 32 33 func newVisitGen(pkgname string) *visitGen { 34 file := jen.NewFile(pkgname) 35 file.HeaderComment(licenseFileHeader) 36 file.HeaderComment("Code generated by ASTHelperGen. DO NOT EDIT.") 37 38 return &visitGen{ 39 file: file, 40 } 41 } 42 43 func (v *visitGen) genFile() (string, *jen.File) { 44 return "ast_visit.go", v.file 45 } 46 47 func shouldAdd(t types.Type, i *types.Interface) bool { 48 return types.Implements(t, i) 49 } 50 51 func (v *visitGen) interfaceMethod(t types.Type, iface *types.Interface, spi generatorSPI) error { 52 if !shouldAdd(t, spi.iface()) { 53 return nil 54 } 55 /* 56 func VisitAST(in AST) (bool, error) { 57 if in == nil { 58 return false, nil 59 } 60 switch a := inA.(type) { 61 case *SubImpl: 62 return VisitSubImpl(a, b) 63 default: 64 return false, nil 65 } 66 } 67 */ 68 stmts := []jen.Code{ 69 jen.If(jen.Id("in == nil").Block(returnNil())), 70 } 71 72 var cases []jen.Code 73 _ = spi.findImplementations(iface, func(t types.Type) error { 74 if _, ok := t.Underlying().(*types.Interface); ok { 75 return nil 76 } 77 typeString := types.TypeString(t, noQualifier) 78 funcName := visitName + printableTypeName(t) 79 spi.addType(t) 80 caseBlock := jen.Case(jen.Id(typeString)).Block( 81 jen.Return(jen.Id(funcName).Call(jen.Id("in"), jen.Id("f"))), 82 ) 83 cases = append(cases, caseBlock) 84 return nil 85 }) 86 87 cases = append(cases, 88 jen.Default().Block( 89 jen.Comment("this should never happen"), 90 returnNil(), 91 )) 92 93 stmts = append(stmts, jen.Switch(jen.Id("in := in.(type)").Block( 94 cases..., 95 ))) 96 97 v.visitFunc(t, stmts) 98 return nil 99 } 100 101 func returnNil() jen.Code { 102 return jen.Return(jen.Nil()) 103 } 104 105 func (v *visitGen) structMethod(t types.Type, strct *types.Struct, spi generatorSPI) error { 106 if !shouldAdd(t, spi.iface()) { 107 return nil 108 } 109 110 /* 111 func VisitRefOfRefContainer(in *RefContainer, f func(node AST) (kontinue bool, err error)) (bool, error) { 112 if cont, err := f(in); err != nil || !cont { 113 return false, err 114 } 115 if k, err := VisitRefOfLeaf(in.ASTImplementationType, f); err != nil || !k { 116 return false, err 117 } 118 if k, err := VisitAST(in.ASTType, f); err != nil || !k { 119 return false, err 120 } 121 return true, nil 122 } 123 */ 124 125 stmts := visitAllStructFields(strct, spi) 126 v.visitFunc(t, stmts) 127 128 return nil 129 } 130 131 func (v *visitGen) ptrToStructMethod(t types.Type, strct *types.Struct, spi generatorSPI) error { 132 if !shouldAdd(t, spi.iface()) { 133 return nil 134 } 135 136 /* 137 func VisitRefOfRefContainer(in *RefContainer, f func(node AST) (kontinue bool, err error)) (bool, error) { 138 if in == nil { 139 return true, nil 140 } 141 if cont, err := f(in); err != nil || !cont { 142 return false, err 143 } 144 if k, err := VisitRefOfLeaf(in.ASTImplementationType, f); err != nil || !k { 145 return false, err 146 } 147 if k, err := VisitAST(in.ASTType, f); err != nil || !k { 148 return false, err 149 } 150 return true, nil 151 } 152 */ 153 154 stmts := []jen.Code{ 155 jen.If(jen.Id("in == nil").Block(returnNil())), 156 } 157 stmts = append(stmts, visitAllStructFields(strct, spi)...) 158 v.visitFunc(t, stmts) 159 160 return nil 161 } 162 163 func (v *visitGen) ptrToBasicMethod(t types.Type, _ *types.Basic, spi generatorSPI) error { 164 if !shouldAdd(t, spi.iface()) { 165 return nil 166 } 167 168 stmts := []jen.Code{ 169 jen.Comment("ptrToBasicMethod"), 170 } 171 172 v.visitFunc(t, stmts) 173 174 return nil 175 } 176 177 func (v *visitGen) sliceMethod(t types.Type, slice *types.Slice, spi generatorSPI) error { 178 if !shouldAdd(t, spi.iface()) { 179 return nil 180 } 181 182 if !shouldAdd(slice.Elem(), spi.iface()) { 183 return v.visitNoChildren(t, spi) 184 } 185 186 stmts := []jen.Code{ 187 jen.If(jen.Id("in == nil").Block(returnNil())), 188 visitIn(), 189 jen.For(jen.Id("_, el := range in")).Block( 190 visitChild(slice.Elem(), jen.Id("el")), 191 ), 192 returnNil(), 193 } 194 195 v.visitFunc(t, stmts) 196 197 return nil 198 } 199 200 func (v *visitGen) basicMethod(t types.Type, basic *types.Basic, spi generatorSPI) error { 201 if !shouldAdd(t, spi.iface()) { 202 return nil 203 } 204 205 return v.visitNoChildren(t, spi) 206 } 207 208 func (v *visitGen) visitNoChildren(t types.Type, spi generatorSPI) error { 209 stmts := []jen.Code{ 210 jen.Id("_, err := f(in)"), 211 jen.Return(jen.Err()), 212 } 213 214 v.visitFunc(t, stmts) 215 216 return nil 217 } 218 219 func visitAllStructFields(strct *types.Struct, spi generatorSPI) []jen.Code { 220 output := []jen.Code{ 221 visitIn(), 222 } 223 for i := 0; i < strct.NumFields(); i++ { 224 field := strct.Field(i) 225 if types.Implements(field.Type(), spi.iface()) { 226 spi.addType(field.Type()) 227 visitField := visitChild(field.Type(), jen.Id("in").Dot(field.Name())) 228 output = append(output, visitField) 229 continue 230 } 231 slice, isSlice := field.Type().(*types.Slice) 232 if isSlice && types.Implements(slice.Elem(), spi.iface()) { 233 spi.addType(slice.Elem()) 234 output = append(output, jen.For(jen.Id("_, el := range in."+field.Name())).Block( 235 visitChild(slice.Elem(), jen.Id("el")), 236 )) 237 } 238 } 239 output = append(output, returnNil()) 240 return output 241 } 242 243 func visitChild(t types.Type, id jen.Code) *jen.Statement { 244 funcName := visitName + printableTypeName(t) 245 visitField := jen.If( 246 jen.Id("err := ").Id(funcName).Call(id, jen.Id("f")), 247 jen.Id("err != nil "), 248 ).Block(jen.Return(jen.Err())) 249 return visitField 250 } 251 252 func visitIn() *jen.Statement { 253 return jen.If( 254 jen.Id("cont, err := ").Id("f").Call(jen.Id("in")), 255 jen.Id("err != nil || !cont"), 256 ).Block(jen.Return(jen.Err())) 257 } 258 259 func (v *visitGen) visitFunc(t types.Type, stmts []jen.Code) { 260 typeString := types.TypeString(t, noQualifier) 261 funcName := visitName + printableTypeName(t) 262 v.file.Add(jen.Func().Id(funcName).Call(jen.Id("in").Id(typeString), jen.Id("f Visit")).Error().Block(stmts...)) 263 }