vitess.io/vitess@v0.16.2/go/tools/asthelpergen/equals_gen.go (about)

     1  /*
     2  Copyright 2021 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package asthelpergen
    18  
    19  import (
    20  	"fmt"
    21  	"go/types"
    22  	"strings"
    23  
    24  	"github.com/dave/jennifer/jen"
    25  )
    26  
    27  const Comparator = "Comparator"
    28  
    29  type EqualsOptions struct {
    30  	AllowCustom []string
    31  }
    32  
    33  type equalsGen struct {
    34  	file        *jen.File
    35  	comparators map[string]types.Type
    36  }
    37  
    38  var _ generator = (*equalsGen)(nil)
    39  
    40  func newEqualsGen(pkgname string, options *EqualsOptions) *equalsGen {
    41  	file := jen.NewFile(pkgname)
    42  	file.HeaderComment(licenseFileHeader)
    43  	file.HeaderComment("Code generated by ASTHelperGen. DO NOT EDIT.")
    44  
    45  	customComparators := make(map[string]types.Type, len(options.AllowCustom))
    46  	for _, tt := range options.AllowCustom {
    47  		customComparators[tt] = nil
    48  	}
    49  
    50  	return &equalsGen{
    51  		file:        file,
    52  		comparators: customComparators,
    53  	}
    54  }
    55  
    56  func (e *equalsGen) addFunc(name string, code *jen.Statement) {
    57  	e.file.Add(jen.Comment(fmt.Sprintf("%s does deep equals between the two objects.", name)))
    58  	e.file.Add(code)
    59  }
    60  
    61  func (e *equalsGen) customComparatorField(t types.Type) string {
    62  	return printableTypeName(t) + "_"
    63  }
    64  
    65  func (e *equalsGen) genFile() (string, *jen.File) {
    66  	e.file.Type().Id(Comparator).StructFunc(func(g *jen.Group) {
    67  		for tname, t := range e.comparators {
    68  			if t == nil {
    69  				continue
    70  			}
    71  			method := e.customComparatorField(t)
    72  			g.Add(jen.Id(method).Func().Call(jen.List(jen.Id("a"), jen.Id("b")).Id(tname)).Bool())
    73  		}
    74  	})
    75  
    76  	return "ast_equals.go", e.file
    77  }
    78  
    79  func (e *equalsGen) interfaceMethod(t types.Type, iface *types.Interface, spi generatorSPI) error {
    80  	/*
    81  		func (cmp *Comparator) AST(inA, inB AST) bool {
    82  			if inA == inB {
    83  				return true
    84  			}
    85  			if inA == nil || inB8 == nil {
    86  				return false
    87  			}
    88  			switch a := inA.(type) {
    89  			case *SubImpl:
    90  				b, ok := inB.(*SubImpl)
    91  				if !ok {
    92  					return false
    93  				}
    94  				return cmp.SubImpl(a, b)
    95  			}
    96  			return false
    97  		}
    98  	*/
    99  	stmts := []jen.Code{
   100  		jen.If(jen.Id("inA == nil").Op("&&").Id("inB == nil")).Block(jen.Return(jen.True())),
   101  		jen.If(jen.Id("inA == nil").Op("||").Id("inB == nil")).Block(jen.Return(jen.False())),
   102  	}
   103  
   104  	var cases []jen.Code
   105  	_ = spi.findImplementations(iface, func(t types.Type) error {
   106  		if _, ok := t.Underlying().(*types.Interface); ok {
   107  			return nil
   108  		}
   109  		typeString := types.TypeString(t, noQualifier)
   110  		caseBlock := jen.Case(jen.Id(typeString)).Block(
   111  			jen.Id("b, ok := inB.").Call(jen.Id(typeString)),
   112  			jen.If(jen.Id("!ok")).Block(jen.Return(jen.False())),
   113  			jen.Return(compareValueType(t, jen.Id("a"), jen.Id("b"), true, spi)),
   114  		)
   115  		cases = append(cases, caseBlock)
   116  		return nil
   117  	})
   118  
   119  	cases = append(cases,
   120  		jen.Default().Block(
   121  			jen.Comment("this should never happen"),
   122  			jen.Return(jen.False()),
   123  		))
   124  
   125  	stmts = append(stmts, jen.Switch(jen.Id("a := inA.(type)").Block(
   126  		cases...,
   127  	)))
   128  
   129  	funcDecl, funcName := e.declareFunc(t, "inA", "inB")
   130  	e.addFunc(funcName, funcDecl.Block(stmts...))
   131  
   132  	return nil
   133  }
   134  
   135  func compareValueType(t types.Type, a, b *jen.Statement, eq bool, spi generatorSPI) *jen.Statement {
   136  	switch t.Underlying().(type) {
   137  	case *types.Basic:
   138  		if eq {
   139  			return a.Op("==").Add(b)
   140  		}
   141  		return a.Op("!=").Add(b)
   142  	}
   143  	spi.addType(t)
   144  	fcall := jen.Id("cmp").Dot(printableTypeName(t)).Call(a, b)
   145  	if !eq {
   146  		return jen.Op("!").Add(fcall)
   147  	}
   148  	return fcall
   149  }
   150  
   151  func (e *equalsGen) structMethod(t types.Type, strct *types.Struct, spi generatorSPI) error {
   152  	/*
   153  		func EqualsRefOfRefContainer(inA RefContainer, inB RefContainer, f ASTComparison) bool {
   154  			return EqualsRefOfLeaf(inA.ASTImplementationType, inB.ASTImplementationType, f) &&
   155  				EqualsAST(inA.ASTType, inB.ASTType, f) && inA.NotASTType == inB.NotASTType
   156  		}
   157  	*/
   158  
   159  	funcDecl, funcName := e.declareFunc(t, "a", "b")
   160  	e.addFunc(funcName, funcDecl.Block(jen.Return(compareAllStructFields(strct, spi))))
   161  
   162  	return nil
   163  }
   164  
   165  func compareAllStructFields(strct *types.Struct, spi generatorSPI) jen.Code {
   166  	var basicsPred []*jen.Statement
   167  	var others []*jen.Statement
   168  	for i := 0; i < strct.NumFields(); i++ {
   169  		field := strct.Field(i)
   170  		if field.Type().Underlying().String() == "any" || strings.HasPrefix(field.Name(), "_") {
   171  			// we can safely ignore this, we do not want ast to contain `any` types.
   172  			continue
   173  		}
   174  		fieldA := jen.Id("a").Dot(field.Name())
   175  		fieldB := jen.Id("b").Dot(field.Name())
   176  		pred := compareValueType(field.Type(), fieldA, fieldB, true, spi)
   177  		if _, ok := field.Type().(*types.Basic); ok {
   178  			basicsPred = append(basicsPred, pred)
   179  			continue
   180  		}
   181  		others = append(others, pred)
   182  	}
   183  
   184  	var ret *jen.Statement
   185  	for _, pred := range basicsPred {
   186  		if ret == nil {
   187  			ret = pred
   188  		} else {
   189  			ret = ret.Op("&&").Line().Add(pred)
   190  		}
   191  	}
   192  
   193  	for _, pred := range others {
   194  		if ret == nil {
   195  			ret = pred
   196  		} else {
   197  			ret = ret.Op("&&").Line().Add(pred)
   198  		}
   199  	}
   200  
   201  	if ret == nil {
   202  		return jen.True()
   203  	}
   204  	return ret
   205  }
   206  
   207  func (e *equalsGen) ptrToStructMethod(t types.Type, strct *types.Struct, spi generatorSPI) error {
   208  	/*
   209  		func EqualsRefOfType(a, b *Type, f ASTComparison) *Type {
   210  			if a == b {
   211  				return true
   212  			}
   213  			if a == nil || b == nil {
   214  				return false
   215  			}
   216  
   217  			// only if it is a *ColName
   218  			if f != nil {
   219  				return f.ColNames(a, b)
   220  			}
   221  
   222  			return compareAllStructFields
   223  		}
   224  	*/
   225  	// func EqualsRefOfType(a,b  *Type) *Type
   226  	funcDeclaration, funcName := e.declareFunc(t, "a", "b")
   227  	stmts := []jen.Code{
   228  		jen.If(jen.Id("a == b")).Block(jen.Return(jen.True())),
   229  		jen.If(jen.Id("a == nil").Op("||").Id("b == nil")).Block(jen.Return(jen.False())),
   230  	}
   231  
   232  	typeString := types.TypeString(t, noQualifier)
   233  
   234  	if _, ok := e.comparators[typeString]; ok {
   235  		e.comparators[typeString] = t
   236  
   237  		method := e.customComparatorField(t)
   238  		stmts = append(stmts,
   239  			jen.If(jen.Id("cmp").Dot(method).Op("!=").Nil()).Block(
   240  				jen.Return(jen.Id("cmp").Dot(method).Call(jen.Id("a"), jen.Id("b"))),
   241  			))
   242  	}
   243  
   244  	stmts = append(stmts, jen.Return(compareAllStructFields(strct, spi)))
   245  
   246  	e.addFunc(funcName, funcDeclaration.Block(stmts...))
   247  	return nil
   248  }
   249  
   250  func (e *equalsGen) ptrToBasicMethod(t types.Type, _ *types.Basic, spi generatorSPI) error {
   251  	/*
   252  		func EqualsRefOfBool(a, b *bool, f ASTComparison) bool {
   253  			if a == b {
   254  				return true
   255  			}
   256  			if a == nil || b == nil {
   257  				return false
   258  			}
   259  			return *a == *b
   260  		}
   261  	*/
   262  	funcDeclaration, funcName := e.declareFunc(t, "a", "b")
   263  	stmts := []jen.Code{
   264  		jen.If(jen.Id("a == b")).Block(jen.Return(jen.True())),
   265  		jen.If(jen.Id("a == nil").Op("||").Id("b == nil")).Block(jen.Return(jen.False())),
   266  		jen.Return(jen.Id("*a == *b")),
   267  	}
   268  	e.addFunc(funcName, funcDeclaration.Block(stmts...))
   269  	return nil
   270  }
   271  
   272  func (e *equalsGen) declareFunc(t types.Type, aArg, bArg string) (*jen.Statement, string) {
   273  	typeString := types.TypeString(t, noQualifier)
   274  	funcName := printableTypeName(t)
   275  
   276  	// func EqualsFunNameS(a, b <T>, f ASTComparison) bool
   277  	return jen.Func().Params(jen.Id("cmp").Op("*").Id(Comparator)).Id(funcName).Call(jen.Id(aArg), jen.Id(bArg).Id(typeString)).Bool(), funcName
   278  }
   279  
   280  func (e *equalsGen) sliceMethod(t types.Type, slice *types.Slice, spi generatorSPI) error {
   281  	/*
   282  		func EqualsSliceOfRefOfLeaf(a, b []*Leaf) bool {
   283  			if len(a) != len(b) {
   284  				return false
   285  			}
   286  			for i := 0; i < len(a); i++ {
   287  				if !EqualsRefOfLeaf(a[i], b[i]) {
   288  					return false
   289  				}
   290  			}
   291  			return false
   292  		}
   293  	*/
   294  
   295  	stmts := []jen.Code{jen.If(jen.Id("len(a) != len(b)")).Block(jen.Return(jen.False())),
   296  		jen.For(jen.Id("i := 0; i < len(a); i++")).Block(
   297  			jen.If(compareValueType(slice.Elem(), jen.Id("a[i]"), jen.Id("b[i]"), false, spi)).Block(jen.Return(jen.False()))),
   298  		jen.Return(jen.True()),
   299  	}
   300  
   301  	funcDecl, funcName := e.declareFunc(t, "a", "b")
   302  	e.addFunc(funcName, funcDecl.Block(stmts...))
   303  	return nil
   304  }
   305  
   306  func (e *equalsGen) basicMethod(types.Type, *types.Basic, generatorSPI) error {
   307  	return nil
   308  }