vitess.io/vitess@v0.16.2/go/tools/asthelpergen/copy_on_rewrite_gen.go (about)

     1  /*
     2  Copyright 2023 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package asthelpergen
    18  
    19  import (
    20  	"go/types"
    21  
    22  	"github.com/dave/jennifer/jen"
    23  )
    24  
    25  type cowGen struct {
    26  	file     *jen.File
    27  	baseType string
    28  }
    29  
    30  var _ generator = (*cowGen)(nil)
    31  
    32  func newCOWGen(pkgname string, nt *types.Named) *cowGen {
    33  	file := jen.NewFile(pkgname)
    34  	file.HeaderComment(licenseFileHeader)
    35  	file.HeaderComment("Code generated by ASTHelperGen. DO NOT EDIT.")
    36  
    37  	return &cowGen{
    38  		file:     file,
    39  		baseType: nt.Obj().Id(),
    40  	}
    41  }
    42  
    43  func (c *cowGen) addFunc(code *jen.Statement) {
    44  	c.file.Add(code)
    45  }
    46  
    47  func (c *cowGen) genFile() (string, *jen.File) {
    48  	return "ast_copy_on_rewrite.go", c.file
    49  }
    50  
    51  const cowName = "copyOnRewrite"
    52  
    53  // readValueOfType produces code to read the expression of type `t`, and adds the type to the todo-list
    54  func (c *cowGen) readValueOfType(t types.Type, expr jen.Code, spi generatorSPI) jen.Code {
    55  	switch t.Underlying().(type) {
    56  	case *types.Interface:
    57  		if types.TypeString(t, noQualifier) == "any" {
    58  			// these fields have to be taken care of manually
    59  			return expr
    60  		}
    61  	}
    62  	spi.addType(t)
    63  	return jen.Id("c").Dot(cowName + printableTypeName(t)).Call(expr)
    64  }
    65  
    66  func (c *cowGen) sliceMethod(t types.Type, slice *types.Slice, spi generatorSPI) error {
    67  	if !types.Implements(t, spi.iface()) {
    68  		return nil
    69  	}
    70  
    71  	typeString := types.TypeString(t, noQualifier)
    72  
    73  	changedVarName := "changed"
    74  	fieldVar := "res"
    75  	elemTyp := types.TypeString(slice.Elem(), noQualifier)
    76  
    77  	name := printableTypeName(t)
    78  	funcName := cowName + name
    79  	var visitElements *jen.Statement
    80  
    81  	if types.Implements(slice.Elem(), spi.iface()) {
    82  		visitElements = ifPreNotNilOrReturnsTrue().Block(
    83  			jen.Id(fieldVar).Op(":=").Id("make").Params(jen.Id(typeString), jen.Id("len").Params(jen.Id("n"))), // _Foo := make([]Typ, len(n))
    84  			jen.For(jen.List(jen.Id("x"), jen.Id("el")).Op(":=").Id("range n")).Block(
    85  				c.visitFieldOrElement("this", "change", slice.Elem(), jen.Id("el"), spi),
    86  				// jen.Id(fieldVar).Index(jen.Id("x")).Op("=").Id("this").Op(".").Params(jen.Id(types.TypeString(elemTyp, noQualifier))),
    87  				jen.Id(fieldVar).Index(jen.Id("x")).Op("=").Id("this").Op(".").Params(jen.Id(elemTyp)),
    88  				jen.If(jen.Id("change")).Block(
    89  					jen.Id(changedVarName).Op("=").True(),
    90  				),
    91  			),
    92  			jen.If(jen.Id("changed")).Block(
    93  				jen.Id("out").Op("=").Id("res"),
    94  			),
    95  		)
    96  	} else {
    97  		visitElements = jen.If(jen.Id("c.pre != nil")).Block(
    98  			jen.Id("c.pre(n, parent)"),
    99  		)
   100  	}
   101  
   102  	block := c.funcDecl(funcName, typeString).Block(
   103  		ifNilReturnNilAndFalse("n"),
   104  		jen.Id("out").Op("=").Id("n"),
   105  		visitElements,
   106  		ifPostNotNilVisit("out"),
   107  		jen.Return(),
   108  	)
   109  	c.addFunc(block)
   110  	return nil
   111  }
   112  
   113  func (c *cowGen) basicMethod(t types.Type, basic *types.Basic, spi generatorSPI) error {
   114  	if !types.Implements(t, spi.iface()) {
   115  		return nil
   116  	}
   117  
   118  	typeString := types.TypeString(t, noQualifier)
   119  	typeName := printableTypeName(t)
   120  
   121  	var stmts []jen.Code
   122  	stmts = append(stmts,
   123  		jen.If(jen.Id("c").Dot("cursor").Dot("stop")).Block(jen.Return(jen.Id("n"), jen.False())),
   124  		ifNotNil("c.pre", jen.Id("c.pre").Params(jen.Id("n"), jen.Id("parent"))),
   125  		ifNotNil("c.post", jen.List(jen.Id("out"), jen.Id("changed")).Op("=").Id("c.postVisit").Params(jen.Id("n"), jen.Id("parent"), jen.Id("changed"))).
   126  			Else().Block(jen.Id("out = n")),
   127  		jen.Return(),
   128  	)
   129  	funcName := cowName + typeName
   130  	funcDecl := c.funcDecl(funcName, typeString).Block(stmts...)
   131  	c.addFunc(funcDecl)
   132  	return nil
   133  }
   134  
   135  func (c *cowGen) copySliceElement(t types.Type, elType types.Type, spi generatorSPI) jen.Code {
   136  	if !isNamed(t) && isBasic(elType) {
   137  		//	copy(res, n)
   138  		return jen.Id("copy").Call(jen.Id("res"), jen.Id("n"))
   139  	}
   140  
   141  	// for i := range n {
   142  	//  res[i] = CloneAST(x)
   143  	// }
   144  	spi.addType(elType)
   145  
   146  	return jen.For(jen.List(jen.Id("i"), jen.Id("x"))).Op(":=").Range().Id("n").Block(
   147  		jen.Id("res").Index(jen.Id("i")).Op("=").Add(c.readValueOfType(elType, jen.Id("x"), spi)),
   148  	)
   149  }
   150  
   151  func ifNotNil(id string, stmts ...jen.Code) *jen.Statement {
   152  	return jen.If(jen.Id(id).Op("!=").Nil()).Block(stmts...)
   153  }
   154  
   155  func ifNilReturnNilAndFalse(id string) *jen.Statement {
   156  	return jen.If(jen.Id(id).Op("==").Nil().Op("||").Id("c").Dot("cursor").Dot("stop")).Block(jen.Return(jen.Id("n"), jen.False()))
   157  }
   158  
   159  func ifPreNotNilOrReturnsTrue() *jen.Statement {
   160  	//	if c.pre == nil || c.pre(n, parent) {
   161  	return jen.If(
   162  		jen.Id("c").Dot("pre").Op("==").Nil().Op("||").Id("c").Dot("pre").Params(
   163  			jen.Id("n"),
   164  			jen.Id("parent"),
   165  		))
   166  
   167  }
   168  
   169  func (c *cowGen) interfaceMethod(t types.Type, iface *types.Interface, spi generatorSPI) error {
   170  	if !types.Implements(t, spi.iface()) {
   171  		return nil
   172  	}
   173  
   174  	// func (c cow) cowAST(in AST) (AST, bool) {
   175  	//	if in == nil {
   176  	//		return nil, false
   177  	// 	}
   178  	//
   179  	//	if c.old == in {
   180  	//		return c.new, true
   181  	//	}
   182  	//	switch in := in.(type) {
   183  	// 	case *RefContainer:
   184  	//			return c.CowRefOfRefContainer(in)
   185  	// 	}
   186  	//	// this should never happen
   187  	//	return nil
   188  	// }
   189  
   190  	typeString := types.TypeString(t, noQualifier)
   191  	typeName := printableTypeName(t)
   192  
   193  	stmts := []jen.Code{ifNilReturnNilAndFalse("n")}
   194  
   195  	var cases []jen.Code
   196  	_ = findImplementations(spi.scope(), iface, func(t types.Type) error {
   197  		if _, ok := t.Underlying().(*types.Interface); ok {
   198  			return nil
   199  		}
   200  		spi.addType(t)
   201  		typeString := types.TypeString(t, noQualifier)
   202  
   203  		// case Type: return CloneType(in)
   204  		block := jen.Case(jen.Id(typeString)).Block(jen.Return(c.readValueOfType(t, jen.List(jen.Id("n"), jen.Id("parent")), spi)))
   205  		cases = append(cases, block)
   206  
   207  		return nil
   208  	})
   209  
   210  	cases = append(cases,
   211  		jen.Default().Block(
   212  			jen.Comment("this should never happen"),
   213  			jen.Return(jen.Nil(), jen.False()),
   214  		))
   215  
   216  	//	switch n := node.(type) {
   217  	stmts = append(stmts, jen.Switch(jen.Id("n").Op(":=").Id("n").Assert(jen.Id("type")).Block(
   218  		cases...,
   219  	)))
   220  
   221  	funcName := cowName + typeName
   222  	funcDecl := c.funcDecl(funcName, typeString).Block(stmts...)
   223  	c.addFunc(funcDecl)
   224  	return nil
   225  }
   226  
   227  func (c *cowGen) ptrToBasicMethod(t types.Type, _ *types.Basic, spi generatorSPI) error {
   228  	if !types.Implements(t, spi.iface()) {
   229  		return nil
   230  	}
   231  
   232  	ptr := t.Underlying().(*types.Pointer)
   233  	return c.ptrToOtherMethod(t, ptr, spi)
   234  }
   235  
   236  func (c *cowGen) ptrToOtherMethod(t types.Type, ptr *types.Pointer, spi generatorSPI) error {
   237  	if !types.Implements(t, spi.iface()) {
   238  		return nil
   239  	}
   240  
   241  	receiveType := types.TypeString(t, noQualifier)
   242  
   243  	funcName := cowName + printableTypeName(t)
   244  	c.addFunc(c.funcDecl(funcName, receiveType).Block(
   245  		jen.Comment("apan was here"),
   246  		jen.Return(jen.Id("n"), jen.False()),
   247  	))
   248  	return nil
   249  }
   250  
   251  // func (c cow) COWRefOfType(n *Type) (*Type, bool)
   252  func (c *cowGen) funcDecl(funcName, typeName string) *jen.Statement {
   253  	return jen.Func().Params(jen.Id("c").Id("*cow")).Id(funcName).Call(jen.List(jen.Id("n").Id(typeName), jen.Id("parent").Id(c.baseType))).Params(jen.Id("out").Id(c.baseType), jen.Id("changed").Id("bool"))
   254  }
   255  
   256  func (c *cowGen) visitFieldOrElement(varName, changedVarName string, typ types.Type, el *jen.Statement, spi generatorSPI) *jen.Statement {
   257  	// _Field, changedField := c.COWType(n.<Field>, n)
   258  	return jen.List(jen.Id(varName), jen.Id(changedVarName)).Op(":=").Add(c.readValueOfType(typ, jen.List(el, jen.Id("n")), spi))
   259  }
   260  
   261  func (c *cowGen) structMethod(t types.Type, strct *types.Struct, spi generatorSPI) error {
   262  	if !types.Implements(t, spi.iface()) {
   263  		return nil
   264  	}
   265  
   266  	c.visitStruct(t, strct, spi, nil, false)
   267  	return nil
   268  }
   269  
   270  func (c *cowGen) ptrToStructMethod(t types.Type, strct *types.Struct, spi generatorSPI) error {
   271  	if !types.Implements(t, spi.iface()) {
   272  		return nil
   273  	}
   274  	start := ifNilReturnNilAndFalse("n")
   275  
   276  	c.visitStruct(t, strct, spi, start, true)
   277  	return nil
   278  }
   279  
   280  func (c *cowGen) visitStruct(t types.Type, strct *types.Struct, spi generatorSPI, start *jen.Statement, ref bool) {
   281  	receiveType := types.TypeString(t, noQualifier)
   282  	funcName := cowName + printableTypeName(t)
   283  
   284  	funcDeclaration := c.funcDecl(funcName, receiveType)
   285  
   286  	var fields []jen.Code
   287  	out := "out"
   288  	changed := "res"
   289  	var fieldSetters []jen.Code
   290  	kopy := jen.Id(changed).Op(":=")
   291  	if ref {
   292  		fieldSetters = append(fieldSetters, kopy.Op("*").Id("n")) // changed := *n
   293  	} else {
   294  		fieldSetters = append(fieldSetters, kopy.Id("n")) // changed := n
   295  	}
   296  	var changedVariables []string
   297  	for i := 0; i < strct.NumFields(); i++ {
   298  		field := strct.Field(i).Name()
   299  		typ := strct.Field(i).Type()
   300  		changedVarName := "changed" + field
   301  
   302  		fieldType := types.TypeString(typ, noQualifier)
   303  		fieldVar := "_" + field
   304  		if types.Implements(typ, spi.iface()) {
   305  			fields = append(fields, c.visitFieldOrElement(fieldVar, changedVarName, typ, jen.Id("n").Dot(field), spi))
   306  			changedVariables = append(changedVariables, changedVarName)
   307  			fieldSetters = append(fieldSetters, jen.List(jen.Id(changed).Dot(field), jen.Op("_")).Op("=").Id(fieldVar).Op(".").Params(jen.Id(fieldType)))
   308  		} else {
   309  			// _Foo := make([]*Type, len(n.Foo))
   310  			// var changedFoo bool
   311  			// for x, el := range n.Foo {
   312  			// 	c, changed := c.COWSliceOfRefOfType(el, n)
   313  			// 	if changed {
   314  			// 		changedFoo = true
   315  			// 	}
   316  			// 	_Foo[i] = c.(*Type)
   317  			// }
   318  
   319  			slice, isSlice := typ.(*types.Slice)
   320  			if isSlice && types.Implements(slice.Elem(), spi.iface()) {
   321  				elemTyp := slice.Elem()
   322  				spi.addType(elemTyp)
   323  				x := jen.Id("x")
   324  				el := jen.Id("el")
   325  				// 	changed := jen.Id("changed")
   326  				fields = append(fields,
   327  					jen.Var().Id(changedVarName).Bool(), // var changedFoo bool
   328  					jen.Id(fieldVar).Op(":=").Id("make").Params(jen.Id(fieldType), jen.Id("len").Params(jen.Id("n").Dot(field))), // _Foo := make([]Typ, len(n.Foo))
   329  					jen.For(jen.List(x, el).Op(":=").Id("range n").Dot(field)).Block(
   330  						c.visitFieldOrElement("this", "changed", elemTyp, jen.Id("el"), spi),
   331  						jen.Id(fieldVar).Index(jen.Id("x")).Op("=").Id("this").Op(".").Params(jen.Id(types.TypeString(elemTyp, noQualifier))),
   332  						jen.If(jen.Id("changed")).Block(
   333  							jen.Id(changedVarName).Op("=").True(),
   334  						),
   335  					),
   336  				)
   337  				changedVariables = append(changedVariables, changedVarName)
   338  				fieldSetters = append(fieldSetters, jen.Id(changed).Dot(field).Op("=").Id(fieldVar))
   339  			}
   340  		}
   341  	}
   342  
   343  	var cond *jen.Statement
   344  	for _, variable := range changedVariables {
   345  		if cond == nil {
   346  			cond = jen.Id(variable)
   347  		} else {
   348  			cond = cond.Op("||").Add(jen.Id(variable))
   349  		}
   350  
   351  	}
   352  
   353  	fieldSetters = append(fieldSetters,
   354  		jen.Id(out).Op("=").Op("&").Id(changed),
   355  		ifNotNil("c.cloned", jen.Id("c.cloned").Params(jen.Id("n, out"))),
   356  		jen.Id("changed").Op("=").True(),
   357  	)
   358  	ifChanged := jen.If(cond).Block(fieldSetters...)
   359  
   360  	var stmts []jen.Code
   361  	if start != nil {
   362  		stmts = append(stmts, start)
   363  	}
   364  
   365  	// handle all fields with CloneAble types
   366  	var visitChildren []jen.Code
   367  	visitChildren = append(visitChildren, fields...)
   368  	if len(fieldSetters) > 4 /*we add three statements always*/ {
   369  		visitChildren = append(visitChildren, ifChanged)
   370  	}
   371  
   372  	children := ifPreNotNilOrReturnsTrue().Block(visitChildren...)
   373  	stmts = append(stmts,
   374  		jen.Id(out).Op("=").Id("n"),
   375  		children,
   376  	)
   377  
   378  	stmts = append(
   379  		stmts,
   380  		ifPostNotNilVisit(out),
   381  		jen.Return(),
   382  	)
   383  
   384  	c.addFunc(funcDeclaration.Block(stmts...))
   385  }
   386  
   387  func ifPostNotNilVisit(out string) *jen.Statement {
   388  	return ifNotNil("c.post", jen.List(jen.Id(out), jen.Id("changed")).Op("=").Id("c").Dot("postVisit").Params(jen.Id(out), jen.Id("parent"), jen.Id("changed")))
   389  }