vitess.io/vitess@v0.16.2/go/tools/asthelpergen/clone_gen.go (about)

     1  /*
     2  Copyright 2021 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package asthelpergen
    18  
    19  import (
    20  	"fmt"
    21  	"go/types"
    22  	"log"
    23  	"strings"
    24  
    25  	"github.com/dave/jennifer/jen"
    26  	"golang.org/x/exp/slices"
    27  )
    28  
    29  type CloneOptions struct {
    30  	Exclude []string
    31  }
    32  
    33  // cloneGen creates the deep clone methods for the AST. It works by discovering the types that it needs to support,
    34  // starting from a root interface type. While creating the clone method for this root interface, more types that need
    35  // to be cloned are discovered. This continues type by type until all necessary types have been traversed.
    36  type cloneGen struct {
    37  	exclude []string
    38  	file    *jen.File
    39  }
    40  
    41  var _ generator = (*cloneGen)(nil)
    42  
    43  func newCloneGen(pkgname string, options *CloneOptions) *cloneGen {
    44  	file := jen.NewFile(pkgname)
    45  	file.HeaderComment(licenseFileHeader)
    46  	file.HeaderComment("Code generated by ASTHelperGen. DO NOT EDIT.")
    47  
    48  	return &cloneGen{
    49  		exclude: options.Exclude,
    50  		file:    file,
    51  	}
    52  }
    53  
    54  func (c *cloneGen) addFunc(name string, code *jen.Statement) {
    55  	c.file.Add(jen.Comment(fmt.Sprintf("%s creates a deep clone of the input.", name)))
    56  	c.file.Add(code)
    57  }
    58  
    59  func (c *cloneGen) genFile() (string, *jen.File) {
    60  	return "ast_clone.go", c.file
    61  }
    62  
    63  const cloneName = "Clone"
    64  
    65  // readValueOfType produces code to read the expression of type `t`, and adds the type to the todo-list
    66  func (c *cloneGen) readValueOfType(t types.Type, expr jen.Code, spi generatorSPI) jen.Code {
    67  	switch t.Underlying().(type) {
    68  	case *types.Basic:
    69  		return expr
    70  	case *types.Interface:
    71  		if types.TypeString(t, noQualifier) == "any" {
    72  			// these fields have to be taken care of manually
    73  			return expr
    74  		}
    75  	}
    76  	spi.addType(t)
    77  	return jen.Id(cloneName + printableTypeName(t)).Call(expr)
    78  }
    79  
    80  func (c *cloneGen) structMethod(t types.Type, _ *types.Struct, spi generatorSPI) error {
    81  	typeString := types.TypeString(t, noQualifier)
    82  	funcName := cloneName + printableTypeName(t)
    83  	c.addFunc(funcName,
    84  		jen.Func().Id(funcName).Call(jen.Id("n").Id(typeString)).Id(typeString).Block(
    85  			jen.Return(jen.Op("*").Add(c.readValueOfType(types.NewPointer(t), jen.Op("&").Id("n"), spi))),
    86  		))
    87  	return nil
    88  }
    89  
    90  func (c *cloneGen) sliceMethod(t types.Type, slice *types.Slice, spi generatorSPI) error {
    91  	typeString := types.TypeString(t, noQualifier)
    92  	name := printableTypeName(t)
    93  	funcName := cloneName + name
    94  
    95  	c.addFunc(funcName,
    96  		// func (n Bytes) Clone() Bytes {
    97  		jen.Func().Id(funcName).Call(jen.Id("n").Id(typeString)).Id(typeString).Block(
    98  			// if n == nil { return nil }
    99  			ifNilReturnNil("n"),
   100  			//	res := make(Bytes, len(n))
   101  			jen.Id("res").Op(":=").Id("make").Call(jen.Id(typeString), jen.Id("len").Call(jen.Id("n"))),
   102  			c.copySliceElement(t, slice.Elem(), spi),
   103  			//	return res
   104  			jen.Return(jen.Id("res")),
   105  		))
   106  	return nil
   107  }
   108  
   109  func (c *cloneGen) basicMethod(t types.Type, basic *types.Basic, spi generatorSPI) error {
   110  	return nil
   111  }
   112  
   113  func (c *cloneGen) copySliceElement(t types.Type, elType types.Type, spi generatorSPI) jen.Code {
   114  	if !isNamed(t) && isBasic(elType) {
   115  		//	copy(res, n)
   116  		return jen.Id("copy").Call(jen.Id("res"), jen.Id("n"))
   117  	}
   118  
   119  	// for i := range n {
   120  	//  res[i] = CloneAST(x)
   121  	// }
   122  	spi.addType(elType)
   123  
   124  	return jen.For(jen.List(jen.Id("i"), jen.Id("x"))).Op(":=").Range().Id("n").Block(
   125  		jen.Id("res").Index(jen.Id("i")).Op("=").Add(c.readValueOfType(elType, jen.Id("x"), spi)),
   126  	)
   127  }
   128  
   129  func (c *cloneGen) interfaceMethod(t types.Type, iface *types.Interface, spi generatorSPI) error {
   130  
   131  	// func CloneAST(in AST) AST {
   132  	//	if in == nil {
   133  	//	return nil
   134  	// }
   135  	//	switch in := in.(type) {
   136  	// case *RefContainer:
   137  	//	return in.CloneRefOfRefContainer()
   138  	// }
   139  	//	// this should never happen
   140  	//	return nil
   141  	// }
   142  
   143  	typeString := types.TypeString(t, noQualifier)
   144  	typeName := printableTypeName(t)
   145  
   146  	stmts := []jen.Code{ifNilReturnNil("in")}
   147  
   148  	var cases []jen.Code
   149  	_ = findImplementations(spi.scope(), iface, func(t types.Type) error {
   150  		typeString := types.TypeString(t, noQualifier)
   151  
   152  		// case Type: return CloneType(in)
   153  		block := jen.Case(jen.Id(typeString)).Block(jen.Return(c.readValueOfType(t, jen.Id("in"), spi)))
   154  		switch t := t.(type) {
   155  		case *types.Pointer:
   156  			_, isIface := t.Elem().(*types.Interface)
   157  			if !isIface {
   158  				cases = append(cases, block)
   159  			}
   160  
   161  		case *types.Named:
   162  			_, isIface := t.Underlying().(*types.Interface)
   163  			if !isIface {
   164  				cases = append(cases, block)
   165  			}
   166  
   167  		default:
   168  			log.Fatalf("unexpected type encountered: %s", typeString)
   169  		}
   170  
   171  		return nil
   172  	})
   173  
   174  	cases = append(cases,
   175  		jen.Default().Block(
   176  			jen.Comment("this should never happen"),
   177  			jen.Return(jen.Nil()),
   178  		))
   179  
   180  	//	switch n := node.(type) {
   181  	stmts = append(stmts, jen.Switch(jen.Id("in").Op(":=").Id("in").Assert(jen.Id("type")).Block(
   182  		cases...,
   183  	)))
   184  
   185  	funcName := cloneName + typeName
   186  	funcDecl := jen.Func().Id(funcName).Call(jen.Id("in").Id(typeString)).Id(typeString).Block(stmts...)
   187  	c.addFunc(funcName, funcDecl)
   188  	return nil
   189  }
   190  
   191  func (c *cloneGen) ptrToBasicMethod(t types.Type, _ *types.Basic, spi generatorSPI) error {
   192  	ptr := t.Underlying().(*types.Pointer)
   193  	return c.ptrToOtherMethod(t, ptr, spi)
   194  }
   195  
   196  func (c *cloneGen) ptrToOtherMethod(t types.Type, ptr *types.Pointer, spi generatorSPI) error {
   197  	receiveType := types.TypeString(t, noQualifier)
   198  
   199  	funcName := cloneName + printableTypeName(t)
   200  	c.addFunc(funcName,
   201  		jen.Func().Id(funcName).Call(jen.Id("n").Id(receiveType)).Id(receiveType).Block(
   202  			ifNilReturnNil("n"),
   203  			jen.Id("out").Op(":=").Add(c.readValueOfType(ptr.Elem(), jen.Op("*").Id("n"), spi)),
   204  			jen.Return(jen.Op("&").Id("out")),
   205  		))
   206  	return nil
   207  }
   208  
   209  func ifNilReturnNil(id string) *jen.Statement {
   210  	return jen.If(jen.Id(id).Op("==").Nil()).Block(jen.Return(jen.Nil()))
   211  }
   212  
   213  func isNamed(t types.Type) bool {
   214  	_, x := t.(*types.Named)
   215  	return x
   216  }
   217  
   218  func isBasic(t types.Type) bool {
   219  	_, x := t.Underlying().(*types.Basic)
   220  	return x
   221  }
   222  
   223  func (c *cloneGen) ptrToStructMethod(t types.Type, strct *types.Struct, spi generatorSPI) error {
   224  	receiveType := types.TypeString(t, noQualifier)
   225  	funcName := cloneName + printableTypeName(t)
   226  
   227  	// func CloneRefOfType(n *Type) *Type
   228  	funcDeclaration := jen.Func().Id(funcName).Call(jen.Id("n").Id(receiveType)).Id(receiveType)
   229  
   230  	if slices.Contains(c.exclude, receiveType) {
   231  		c.addFunc(funcName, funcDeclaration.Block(
   232  			jen.Return(jen.Id("n")),
   233  		))
   234  		return nil
   235  	}
   236  
   237  	var fields []jen.Code
   238  	for i := 0; i < strct.NumFields(); i++ {
   239  		field := strct.Field(i)
   240  		if isBasic(field.Type()) || strings.HasPrefix(field.Name(), "_") {
   241  			continue
   242  		}
   243  		// out.Field = CloneType(n.Field)
   244  		fields = append(fields,
   245  			jen.Id("out").Dot(field.Name()).Op("=").Add(c.readValueOfType(field.Type(), jen.Id("n").Dot(field.Name()), spi)))
   246  	}
   247  
   248  	stmts := []jen.Code{
   249  		// if n == nil { return nil }
   250  		ifNilReturnNil("n"),
   251  		// 	out := *n
   252  		jen.Id("out").Op(":=").Op("*").Id("n"),
   253  	}
   254  
   255  	// handle all fields with CloneAble types
   256  	stmts = append(stmts, fields...)
   257  
   258  	stmts = append(stmts,
   259  		// return &out
   260  		jen.Return(jen.Op("&").Id("out")),
   261  	)
   262  
   263  	c.addFunc(funcName, funcDeclaration.Block(stmts...))
   264  	return nil
   265  }