vitess.io/vitess@v0.16.2/go/tools/asthelpergen/equals_gen.go (about) 1 /* 2 Copyright 2021 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package asthelpergen 18 19 import ( 20 "fmt" 21 "go/types" 22 "strings" 23 24 "github.com/dave/jennifer/jen" 25 ) 26 27 const Comparator = "Comparator" 28 29 type EqualsOptions struct { 30 AllowCustom []string 31 } 32 33 type equalsGen struct { 34 file *jen.File 35 comparators map[string]types.Type 36 } 37 38 var _ generator = (*equalsGen)(nil) 39 40 func newEqualsGen(pkgname string, options *EqualsOptions) *equalsGen { 41 file := jen.NewFile(pkgname) 42 file.HeaderComment(licenseFileHeader) 43 file.HeaderComment("Code generated by ASTHelperGen. DO NOT EDIT.") 44 45 customComparators := make(map[string]types.Type, len(options.AllowCustom)) 46 for _, tt := range options.AllowCustom { 47 customComparators[tt] = nil 48 } 49 50 return &equalsGen{ 51 file: file, 52 comparators: customComparators, 53 } 54 } 55 56 func (e *equalsGen) addFunc(name string, code *jen.Statement) { 57 e.file.Add(jen.Comment(fmt.Sprintf("%s does deep equals between the two objects.", name))) 58 e.file.Add(code) 59 } 60 61 func (e *equalsGen) customComparatorField(t types.Type) string { 62 return printableTypeName(t) + "_" 63 } 64 65 func (e *equalsGen) genFile() (string, *jen.File) { 66 e.file.Type().Id(Comparator).StructFunc(func(g *jen.Group) { 67 for tname, t := range e.comparators { 68 if t == nil { 69 continue 70 } 71 method := e.customComparatorField(t) 72 g.Add(jen.Id(method).Func().Call(jen.List(jen.Id("a"), jen.Id("b")).Id(tname)).Bool()) 73 } 74 }) 75 76 return "ast_equals.go", e.file 77 } 78 79 func (e *equalsGen) interfaceMethod(t types.Type, iface *types.Interface, spi generatorSPI) error { 80 /* 81 func (cmp *Comparator) AST(inA, inB AST) bool { 82 if inA == inB { 83 return true 84 } 85 if inA == nil || inB8 == nil { 86 return false 87 } 88 switch a := inA.(type) { 89 case *SubImpl: 90 b, ok := inB.(*SubImpl) 91 if !ok { 92 return false 93 } 94 return cmp.SubImpl(a, b) 95 } 96 return false 97 } 98 */ 99 stmts := []jen.Code{ 100 jen.If(jen.Id("inA == nil").Op("&&").Id("inB == nil")).Block(jen.Return(jen.True())), 101 jen.If(jen.Id("inA == nil").Op("||").Id("inB == nil")).Block(jen.Return(jen.False())), 102 } 103 104 var cases []jen.Code 105 _ = spi.findImplementations(iface, func(t types.Type) error { 106 if _, ok := t.Underlying().(*types.Interface); ok { 107 return nil 108 } 109 typeString := types.TypeString(t, noQualifier) 110 caseBlock := jen.Case(jen.Id(typeString)).Block( 111 jen.Id("b, ok := inB.").Call(jen.Id(typeString)), 112 jen.If(jen.Id("!ok")).Block(jen.Return(jen.False())), 113 jen.Return(compareValueType(t, jen.Id("a"), jen.Id("b"), true, spi)), 114 ) 115 cases = append(cases, caseBlock) 116 return nil 117 }) 118 119 cases = append(cases, 120 jen.Default().Block( 121 jen.Comment("this should never happen"), 122 jen.Return(jen.False()), 123 )) 124 125 stmts = append(stmts, jen.Switch(jen.Id("a := inA.(type)").Block( 126 cases..., 127 ))) 128 129 funcDecl, funcName := e.declareFunc(t, "inA", "inB") 130 e.addFunc(funcName, funcDecl.Block(stmts...)) 131 132 return nil 133 } 134 135 func compareValueType(t types.Type, a, b *jen.Statement, eq bool, spi generatorSPI) *jen.Statement { 136 switch t.Underlying().(type) { 137 case *types.Basic: 138 if eq { 139 return a.Op("==").Add(b) 140 } 141 return a.Op("!=").Add(b) 142 } 143 spi.addType(t) 144 fcall := jen.Id("cmp").Dot(printableTypeName(t)).Call(a, b) 145 if !eq { 146 return jen.Op("!").Add(fcall) 147 } 148 return fcall 149 } 150 151 func (e *equalsGen) structMethod(t types.Type, strct *types.Struct, spi generatorSPI) error { 152 /* 153 func EqualsRefOfRefContainer(inA RefContainer, inB RefContainer, f ASTComparison) bool { 154 return EqualsRefOfLeaf(inA.ASTImplementationType, inB.ASTImplementationType, f) && 155 EqualsAST(inA.ASTType, inB.ASTType, f) && inA.NotASTType == inB.NotASTType 156 } 157 */ 158 159 funcDecl, funcName := e.declareFunc(t, "a", "b") 160 e.addFunc(funcName, funcDecl.Block(jen.Return(compareAllStructFields(strct, spi)))) 161 162 return nil 163 } 164 165 func compareAllStructFields(strct *types.Struct, spi generatorSPI) jen.Code { 166 var basicsPred []*jen.Statement 167 var others []*jen.Statement 168 for i := 0; i < strct.NumFields(); i++ { 169 field := strct.Field(i) 170 if field.Type().Underlying().String() == "any" || strings.HasPrefix(field.Name(), "_") { 171 // we can safely ignore this, we do not want ast to contain `any` types. 172 continue 173 } 174 fieldA := jen.Id("a").Dot(field.Name()) 175 fieldB := jen.Id("b").Dot(field.Name()) 176 pred := compareValueType(field.Type(), fieldA, fieldB, true, spi) 177 if _, ok := field.Type().(*types.Basic); ok { 178 basicsPred = append(basicsPred, pred) 179 continue 180 } 181 others = append(others, pred) 182 } 183 184 var ret *jen.Statement 185 for _, pred := range basicsPred { 186 if ret == nil { 187 ret = pred 188 } else { 189 ret = ret.Op("&&").Line().Add(pred) 190 } 191 } 192 193 for _, pred := range others { 194 if ret == nil { 195 ret = pred 196 } else { 197 ret = ret.Op("&&").Line().Add(pred) 198 } 199 } 200 201 if ret == nil { 202 return jen.True() 203 } 204 return ret 205 } 206 207 func (e *equalsGen) ptrToStructMethod(t types.Type, strct *types.Struct, spi generatorSPI) error { 208 /* 209 func EqualsRefOfType(a, b *Type, f ASTComparison) *Type { 210 if a == b { 211 return true 212 } 213 if a == nil || b == nil { 214 return false 215 } 216 217 // only if it is a *ColName 218 if f != nil { 219 return f.ColNames(a, b) 220 } 221 222 return compareAllStructFields 223 } 224 */ 225 // func EqualsRefOfType(a,b *Type) *Type 226 funcDeclaration, funcName := e.declareFunc(t, "a", "b") 227 stmts := []jen.Code{ 228 jen.If(jen.Id("a == b")).Block(jen.Return(jen.True())), 229 jen.If(jen.Id("a == nil").Op("||").Id("b == nil")).Block(jen.Return(jen.False())), 230 } 231 232 typeString := types.TypeString(t, noQualifier) 233 234 if _, ok := e.comparators[typeString]; ok { 235 e.comparators[typeString] = t 236 237 method := e.customComparatorField(t) 238 stmts = append(stmts, 239 jen.If(jen.Id("cmp").Dot(method).Op("!=").Nil()).Block( 240 jen.Return(jen.Id("cmp").Dot(method).Call(jen.Id("a"), jen.Id("b"))), 241 )) 242 } 243 244 stmts = append(stmts, jen.Return(compareAllStructFields(strct, spi))) 245 246 e.addFunc(funcName, funcDeclaration.Block(stmts...)) 247 return nil 248 } 249 250 func (e *equalsGen) ptrToBasicMethod(t types.Type, _ *types.Basic, spi generatorSPI) error { 251 /* 252 func EqualsRefOfBool(a, b *bool, f ASTComparison) bool { 253 if a == b { 254 return true 255 } 256 if a == nil || b == nil { 257 return false 258 } 259 return *a == *b 260 } 261 */ 262 funcDeclaration, funcName := e.declareFunc(t, "a", "b") 263 stmts := []jen.Code{ 264 jen.If(jen.Id("a == b")).Block(jen.Return(jen.True())), 265 jen.If(jen.Id("a == nil").Op("||").Id("b == nil")).Block(jen.Return(jen.False())), 266 jen.Return(jen.Id("*a == *b")), 267 } 268 e.addFunc(funcName, funcDeclaration.Block(stmts...)) 269 return nil 270 } 271 272 func (e *equalsGen) declareFunc(t types.Type, aArg, bArg string) (*jen.Statement, string) { 273 typeString := types.TypeString(t, noQualifier) 274 funcName := printableTypeName(t) 275 276 // func EqualsFunNameS(a, b <T>, f ASTComparison) bool 277 return jen.Func().Params(jen.Id("cmp").Op("*").Id(Comparator)).Id(funcName).Call(jen.Id(aArg), jen.Id(bArg).Id(typeString)).Bool(), funcName 278 } 279 280 func (e *equalsGen) sliceMethod(t types.Type, slice *types.Slice, spi generatorSPI) error { 281 /* 282 func EqualsSliceOfRefOfLeaf(a, b []*Leaf) bool { 283 if len(a) != len(b) { 284 return false 285 } 286 for i := 0; i < len(a); i++ { 287 if !EqualsRefOfLeaf(a[i], b[i]) { 288 return false 289 } 290 } 291 return false 292 } 293 */ 294 295 stmts := []jen.Code{jen.If(jen.Id("len(a) != len(b)")).Block(jen.Return(jen.False())), 296 jen.For(jen.Id("i := 0; i < len(a); i++")).Block( 297 jen.If(compareValueType(slice.Elem(), jen.Id("a[i]"), jen.Id("b[i]"), false, spi)).Block(jen.Return(jen.False()))), 298 jen.Return(jen.True()), 299 } 300 301 funcDecl, funcName := e.declareFunc(t, "a", "b") 302 e.addFunc(funcName, funcDecl.Block(stmts...)) 303 return nil 304 } 305 306 func (e *equalsGen) basicMethod(types.Type, *types.Basic, generatorSPI) error { 307 return nil 308 }