vitess.io/vitess@v0.16.2/go/tools/asthelpergen/clone_gen.go (about) 1 /* 2 Copyright 2021 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package asthelpergen 18 19 import ( 20 "fmt" 21 "go/types" 22 "log" 23 "strings" 24 25 "github.com/dave/jennifer/jen" 26 "golang.org/x/exp/slices" 27 ) 28 29 type CloneOptions struct { 30 Exclude []string 31 } 32 33 // cloneGen creates the deep clone methods for the AST. It works by discovering the types that it needs to support, 34 // starting from a root interface type. While creating the clone method for this root interface, more types that need 35 // to be cloned are discovered. This continues type by type until all necessary types have been traversed. 36 type cloneGen struct { 37 exclude []string 38 file *jen.File 39 } 40 41 var _ generator = (*cloneGen)(nil) 42 43 func newCloneGen(pkgname string, options *CloneOptions) *cloneGen { 44 file := jen.NewFile(pkgname) 45 file.HeaderComment(licenseFileHeader) 46 file.HeaderComment("Code generated by ASTHelperGen. DO NOT EDIT.") 47 48 return &cloneGen{ 49 exclude: options.Exclude, 50 file: file, 51 } 52 } 53 54 func (c *cloneGen) addFunc(name string, code *jen.Statement) { 55 c.file.Add(jen.Comment(fmt.Sprintf("%s creates a deep clone of the input.", name))) 56 c.file.Add(code) 57 } 58 59 func (c *cloneGen) genFile() (string, *jen.File) { 60 return "ast_clone.go", c.file 61 } 62 63 const cloneName = "Clone" 64 65 // readValueOfType produces code to read the expression of type `t`, and adds the type to the todo-list 66 func (c *cloneGen) readValueOfType(t types.Type, expr jen.Code, spi generatorSPI) jen.Code { 67 switch t.Underlying().(type) { 68 case *types.Basic: 69 return expr 70 case *types.Interface: 71 if types.TypeString(t, noQualifier) == "any" { 72 // these fields have to be taken care of manually 73 return expr 74 } 75 } 76 spi.addType(t) 77 return jen.Id(cloneName + printableTypeName(t)).Call(expr) 78 } 79 80 func (c *cloneGen) structMethod(t types.Type, _ *types.Struct, spi generatorSPI) error { 81 typeString := types.TypeString(t, noQualifier) 82 funcName := cloneName + printableTypeName(t) 83 c.addFunc(funcName, 84 jen.Func().Id(funcName).Call(jen.Id("n").Id(typeString)).Id(typeString).Block( 85 jen.Return(jen.Op("*").Add(c.readValueOfType(types.NewPointer(t), jen.Op("&").Id("n"), spi))), 86 )) 87 return nil 88 } 89 90 func (c *cloneGen) sliceMethod(t types.Type, slice *types.Slice, spi generatorSPI) error { 91 typeString := types.TypeString(t, noQualifier) 92 name := printableTypeName(t) 93 funcName := cloneName + name 94 95 c.addFunc(funcName, 96 // func (n Bytes) Clone() Bytes { 97 jen.Func().Id(funcName).Call(jen.Id("n").Id(typeString)).Id(typeString).Block( 98 // if n == nil { return nil } 99 ifNilReturnNil("n"), 100 // res := make(Bytes, len(n)) 101 jen.Id("res").Op(":=").Id("make").Call(jen.Id(typeString), jen.Id("len").Call(jen.Id("n"))), 102 c.copySliceElement(t, slice.Elem(), spi), 103 // return res 104 jen.Return(jen.Id("res")), 105 )) 106 return nil 107 } 108 109 func (c *cloneGen) basicMethod(t types.Type, basic *types.Basic, spi generatorSPI) error { 110 return nil 111 } 112 113 func (c *cloneGen) copySliceElement(t types.Type, elType types.Type, spi generatorSPI) jen.Code { 114 if !isNamed(t) && isBasic(elType) { 115 // copy(res, n) 116 return jen.Id("copy").Call(jen.Id("res"), jen.Id("n")) 117 } 118 119 // for i := range n { 120 // res[i] = CloneAST(x) 121 // } 122 spi.addType(elType) 123 124 return jen.For(jen.List(jen.Id("i"), jen.Id("x"))).Op(":=").Range().Id("n").Block( 125 jen.Id("res").Index(jen.Id("i")).Op("=").Add(c.readValueOfType(elType, jen.Id("x"), spi)), 126 ) 127 } 128 129 func (c *cloneGen) interfaceMethod(t types.Type, iface *types.Interface, spi generatorSPI) error { 130 131 // func CloneAST(in AST) AST { 132 // if in == nil { 133 // return nil 134 // } 135 // switch in := in.(type) { 136 // case *RefContainer: 137 // return in.CloneRefOfRefContainer() 138 // } 139 // // this should never happen 140 // return nil 141 // } 142 143 typeString := types.TypeString(t, noQualifier) 144 typeName := printableTypeName(t) 145 146 stmts := []jen.Code{ifNilReturnNil("in")} 147 148 var cases []jen.Code 149 _ = findImplementations(spi.scope(), iface, func(t types.Type) error { 150 typeString := types.TypeString(t, noQualifier) 151 152 // case Type: return CloneType(in) 153 block := jen.Case(jen.Id(typeString)).Block(jen.Return(c.readValueOfType(t, jen.Id("in"), spi))) 154 switch t := t.(type) { 155 case *types.Pointer: 156 _, isIface := t.Elem().(*types.Interface) 157 if !isIface { 158 cases = append(cases, block) 159 } 160 161 case *types.Named: 162 _, isIface := t.Underlying().(*types.Interface) 163 if !isIface { 164 cases = append(cases, block) 165 } 166 167 default: 168 log.Fatalf("unexpected type encountered: %s", typeString) 169 } 170 171 return nil 172 }) 173 174 cases = append(cases, 175 jen.Default().Block( 176 jen.Comment("this should never happen"), 177 jen.Return(jen.Nil()), 178 )) 179 180 // switch n := node.(type) { 181 stmts = append(stmts, jen.Switch(jen.Id("in").Op(":=").Id("in").Assert(jen.Id("type")).Block( 182 cases..., 183 ))) 184 185 funcName := cloneName + typeName 186 funcDecl := jen.Func().Id(funcName).Call(jen.Id("in").Id(typeString)).Id(typeString).Block(stmts...) 187 c.addFunc(funcName, funcDecl) 188 return nil 189 } 190 191 func (c *cloneGen) ptrToBasicMethod(t types.Type, _ *types.Basic, spi generatorSPI) error { 192 ptr := t.Underlying().(*types.Pointer) 193 return c.ptrToOtherMethod(t, ptr, spi) 194 } 195 196 func (c *cloneGen) ptrToOtherMethod(t types.Type, ptr *types.Pointer, spi generatorSPI) error { 197 receiveType := types.TypeString(t, noQualifier) 198 199 funcName := cloneName + printableTypeName(t) 200 c.addFunc(funcName, 201 jen.Func().Id(funcName).Call(jen.Id("n").Id(receiveType)).Id(receiveType).Block( 202 ifNilReturnNil("n"), 203 jen.Id("out").Op(":=").Add(c.readValueOfType(ptr.Elem(), jen.Op("*").Id("n"), spi)), 204 jen.Return(jen.Op("&").Id("out")), 205 )) 206 return nil 207 } 208 209 func ifNilReturnNil(id string) *jen.Statement { 210 return jen.If(jen.Id(id).Op("==").Nil()).Block(jen.Return(jen.Nil())) 211 } 212 213 func isNamed(t types.Type) bool { 214 _, x := t.(*types.Named) 215 return x 216 } 217 218 func isBasic(t types.Type) bool { 219 _, x := t.Underlying().(*types.Basic) 220 return x 221 } 222 223 func (c *cloneGen) ptrToStructMethod(t types.Type, strct *types.Struct, spi generatorSPI) error { 224 receiveType := types.TypeString(t, noQualifier) 225 funcName := cloneName + printableTypeName(t) 226 227 // func CloneRefOfType(n *Type) *Type 228 funcDeclaration := jen.Func().Id(funcName).Call(jen.Id("n").Id(receiveType)).Id(receiveType) 229 230 if slices.Contains(c.exclude, receiveType) { 231 c.addFunc(funcName, funcDeclaration.Block( 232 jen.Return(jen.Id("n")), 233 )) 234 return nil 235 } 236 237 var fields []jen.Code 238 for i := 0; i < strct.NumFields(); i++ { 239 field := strct.Field(i) 240 if isBasic(field.Type()) || strings.HasPrefix(field.Name(), "_") { 241 continue 242 } 243 // out.Field = CloneType(n.Field) 244 fields = append(fields, 245 jen.Id("out").Dot(field.Name()).Op("=").Add(c.readValueOfType(field.Type(), jen.Id("n").Dot(field.Name()), spi))) 246 } 247 248 stmts := []jen.Code{ 249 // if n == nil { return nil } 250 ifNilReturnNil("n"), 251 // out := *n 252 jen.Id("out").Op(":=").Op("*").Id("n"), 253 } 254 255 // handle all fields with CloneAble types 256 stmts = append(stmts, fields...) 257 258 stmts = append(stmts, 259 // return &out 260 jen.Return(jen.Op("&").Id("out")), 261 ) 262 263 c.addFunc(funcName, funcDeclaration.Block(stmts...)) 264 return nil 265 }