vitess.io/vitess@v0.16.2/go/tools/asthelpergen/visit_gen.go (about)

     1  /*
     2  Copyright 2021 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package asthelpergen
    18  
    19  import (
    20  	"go/types"
    21  
    22  	"github.com/dave/jennifer/jen"
    23  )
    24  
    25  const visitName = "Visit"
    26  
    27  type visitGen struct {
    28  	file *jen.File
    29  }
    30  
    31  var _ generator = (*visitGen)(nil)
    32  
    33  func newVisitGen(pkgname string) *visitGen {
    34  	file := jen.NewFile(pkgname)
    35  	file.HeaderComment(licenseFileHeader)
    36  	file.HeaderComment("Code generated by ASTHelperGen. DO NOT EDIT.")
    37  
    38  	return &visitGen{
    39  		file: file,
    40  	}
    41  }
    42  
    43  func (v *visitGen) genFile() (string, *jen.File) {
    44  	return "ast_visit.go", v.file
    45  }
    46  
    47  func shouldAdd(t types.Type, i *types.Interface) bool {
    48  	return types.Implements(t, i)
    49  }
    50  
    51  func (v *visitGen) interfaceMethod(t types.Type, iface *types.Interface, spi generatorSPI) error {
    52  	if !shouldAdd(t, spi.iface()) {
    53  		return nil
    54  	}
    55  	/*
    56  		func VisitAST(in AST) (bool, error) {
    57  			if in == nil {
    58  				return false, nil
    59  			}
    60  			switch a := inA.(type) {
    61  			case *SubImpl:
    62  				return VisitSubImpl(a, b)
    63  			default:
    64  				return false, nil
    65  			}
    66  		}
    67  	*/
    68  	stmts := []jen.Code{
    69  		jen.If(jen.Id("in == nil").Block(returnNil())),
    70  	}
    71  
    72  	var cases []jen.Code
    73  	_ = spi.findImplementations(iface, func(t types.Type) error {
    74  		if _, ok := t.Underlying().(*types.Interface); ok {
    75  			return nil
    76  		}
    77  		typeString := types.TypeString(t, noQualifier)
    78  		funcName := visitName + printableTypeName(t)
    79  		spi.addType(t)
    80  		caseBlock := jen.Case(jen.Id(typeString)).Block(
    81  			jen.Return(jen.Id(funcName).Call(jen.Id("in"), jen.Id("f"))),
    82  		)
    83  		cases = append(cases, caseBlock)
    84  		return nil
    85  	})
    86  
    87  	cases = append(cases,
    88  		jen.Default().Block(
    89  			jen.Comment("this should never happen"),
    90  			returnNil(),
    91  		))
    92  
    93  	stmts = append(stmts, jen.Switch(jen.Id("in := in.(type)").Block(
    94  		cases...,
    95  	)))
    96  
    97  	v.visitFunc(t, stmts)
    98  	return nil
    99  }
   100  
   101  func returnNil() jen.Code {
   102  	return jen.Return(jen.Nil())
   103  }
   104  
   105  func (v *visitGen) structMethod(t types.Type, strct *types.Struct, spi generatorSPI) error {
   106  	if !shouldAdd(t, spi.iface()) {
   107  		return nil
   108  	}
   109  
   110  	/*
   111  		func VisitRefOfRefContainer(in *RefContainer, f func(node AST) (kontinue bool, err error)) (bool, error) {
   112  			if cont, err := f(in); err != nil || !cont {
   113  				return false, err
   114  			}
   115  			if k, err := VisitRefOfLeaf(in.ASTImplementationType, f); err != nil || !k {
   116  				return false, err
   117  			}
   118  			if k, err := VisitAST(in.ASTType, f); err != nil || !k {
   119  				return false, err
   120  			}
   121  			return true, nil
   122  		}
   123  	*/
   124  
   125  	stmts := visitAllStructFields(strct, spi)
   126  	v.visitFunc(t, stmts)
   127  
   128  	return nil
   129  }
   130  
   131  func (v *visitGen) ptrToStructMethod(t types.Type, strct *types.Struct, spi generatorSPI) error {
   132  	if !shouldAdd(t, spi.iface()) {
   133  		return nil
   134  	}
   135  
   136  	/*
   137  		func VisitRefOfRefContainer(in *RefContainer, f func(node AST) (kontinue bool, err error)) (bool, error) {
   138  			if in == nil {
   139  				return true, nil
   140  			}
   141  			if cont, err := f(in); err != nil || !cont {
   142  				return false, err
   143  			}
   144  			if k, err := VisitRefOfLeaf(in.ASTImplementationType, f); err != nil || !k {
   145  				return false, err
   146  			}
   147  			if k, err := VisitAST(in.ASTType, f); err != nil || !k {
   148  				return false, err
   149  			}
   150  			return true, nil
   151  		}
   152  	*/
   153  
   154  	stmts := []jen.Code{
   155  		jen.If(jen.Id("in == nil").Block(returnNil())),
   156  	}
   157  	stmts = append(stmts, visitAllStructFields(strct, spi)...)
   158  	v.visitFunc(t, stmts)
   159  
   160  	return nil
   161  }
   162  
   163  func (v *visitGen) ptrToBasicMethod(t types.Type, _ *types.Basic, spi generatorSPI) error {
   164  	if !shouldAdd(t, spi.iface()) {
   165  		return nil
   166  	}
   167  
   168  	stmts := []jen.Code{
   169  		jen.Comment("ptrToBasicMethod"),
   170  	}
   171  
   172  	v.visitFunc(t, stmts)
   173  
   174  	return nil
   175  }
   176  
   177  func (v *visitGen) sliceMethod(t types.Type, slice *types.Slice, spi generatorSPI) error {
   178  	if !shouldAdd(t, spi.iface()) {
   179  		return nil
   180  	}
   181  
   182  	if !shouldAdd(slice.Elem(), spi.iface()) {
   183  		return v.visitNoChildren(t, spi)
   184  	}
   185  
   186  	stmts := []jen.Code{
   187  		jen.If(jen.Id("in == nil").Block(returnNil())),
   188  		visitIn(),
   189  		jen.For(jen.Id("_, el := range in")).Block(
   190  			visitChild(slice.Elem(), jen.Id("el")),
   191  		),
   192  		returnNil(),
   193  	}
   194  
   195  	v.visitFunc(t, stmts)
   196  
   197  	return nil
   198  }
   199  
   200  func (v *visitGen) basicMethod(t types.Type, basic *types.Basic, spi generatorSPI) error {
   201  	if !shouldAdd(t, spi.iface()) {
   202  		return nil
   203  	}
   204  
   205  	return v.visitNoChildren(t, spi)
   206  }
   207  
   208  func (v *visitGen) visitNoChildren(t types.Type, spi generatorSPI) error {
   209  	stmts := []jen.Code{
   210  		jen.Id("_, err := f(in)"),
   211  		jen.Return(jen.Err()),
   212  	}
   213  
   214  	v.visitFunc(t, stmts)
   215  
   216  	return nil
   217  }
   218  
   219  func visitAllStructFields(strct *types.Struct, spi generatorSPI) []jen.Code {
   220  	output := []jen.Code{
   221  		visitIn(),
   222  	}
   223  	for i := 0; i < strct.NumFields(); i++ {
   224  		field := strct.Field(i)
   225  		if types.Implements(field.Type(), spi.iface()) {
   226  			spi.addType(field.Type())
   227  			visitField := visitChild(field.Type(), jen.Id("in").Dot(field.Name()))
   228  			output = append(output, visitField)
   229  			continue
   230  		}
   231  		slice, isSlice := field.Type().(*types.Slice)
   232  		if isSlice && types.Implements(slice.Elem(), spi.iface()) {
   233  			spi.addType(slice.Elem())
   234  			output = append(output, jen.For(jen.Id("_, el := range in."+field.Name())).Block(
   235  				visitChild(slice.Elem(), jen.Id("el")),
   236  			))
   237  		}
   238  	}
   239  	output = append(output, returnNil())
   240  	return output
   241  }
   242  
   243  func visitChild(t types.Type, id jen.Code) *jen.Statement {
   244  	funcName := visitName + printableTypeName(t)
   245  	visitField := jen.If(
   246  		jen.Id("err := ").Id(funcName).Call(id, jen.Id("f")),
   247  		jen.Id("err != nil "),
   248  	).Block(jen.Return(jen.Err()))
   249  	return visitField
   250  }
   251  
   252  func visitIn() *jen.Statement {
   253  	return jen.If(
   254  		jen.Id("cont, err := ").Id("f").Call(jen.Id("in")),
   255  		jen.Id("err != nil || !cont"),
   256  	).Block(jen.Return(jen.Err()))
   257  }
   258  
   259  func (v *visitGen) visitFunc(t types.Type, stmts []jen.Code) {
   260  	typeString := types.TypeString(t, noQualifier)
   261  	funcName := visitName + printableTypeName(t)
   262  	v.file.Add(jen.Func().Id(funcName).Call(jen.Id("in").Id(typeString), jen.Id("f Visit")).Error().Block(stmts...))
   263  }