github.com/petergtz/pegomock@v2.9.1-0.20230424204322-eb0e044013df+incompatible/modelgen/loader/loader.go (about)

     1  package loader
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"go/ast"
     7  	"go/types"
     8  
     9  	"github.com/petergtz/pegomock/model"
    10  	"golang.org/x/tools/go/loader"
    11  )
    12  
    13  func GenerateModel(importPath string, interfaceName string) (*model.Package, error) {
    14  	var conf loader.Config
    15  	conf.Import(importPath)
    16  	program, e := conf.Load()
    17  	if e != nil {
    18  		panic(e)
    19  	}
    20  	info := program.Imported[importPath]
    21  
    22  	for def := range info.Defs {
    23  		if def.Name == interfaceName && def.Obj.Kind == ast.Typ {
    24  			interfacetype, ok := def.Obj.Decl.(*ast.TypeSpec).Type.(*ast.InterfaceType)
    25  			if ok {
    26  				g := &modelGenerator{info: info}
    27  				iface := &model.Interface{
    28  					Name:    interfaceName,
    29  					Methods: g.modelMethodsFrom(interfacetype.Methods),
    30  				}
    31  				return &model.Package{
    32  					Name:       info.Pkg.Name(),
    33  					Interfaces: []*model.Interface{iface},
    34  				}, nil
    35  			}
    36  		}
    37  	}
    38  
    39  	return nil, errors.New("Did not find interface name \"" + interfaceName + "\"")
    40  }
    41  
    42  type modelGenerator struct {
    43  	info *loader.PackageInfo
    44  }
    45  
    46  func (g *modelGenerator) modelMethodsFrom(fields *ast.FieldList) (modelMethods []*model.Method) {
    47  	for _, field := range fields.List {
    48  		switch field.Type.(type) {
    49  		case *ast.FuncType:
    50  			modelMethods = append(modelMethods, g.modelMethodFrom(field))
    51  		case *ast.Ident:
    52  			modelMethods = append(modelMethods, g.modelMethodsFrom(field.Type.(*ast.Ident).Obj.Decl.(*ast.TypeSpec).Type.(*ast.InterfaceType).Methods)...)
    53  		default:
    54  			panic(fmt.Sprintf("Unexpected expression in interface definition. Only methods and embedded interfaces are allowed, but got: %#v", field.Type))
    55  		}
    56  	}
    57  	return
    58  }
    59  
    60  func (g *modelGenerator) modelMethodFrom(astMethod *ast.Field) *model.Method {
    61  	in, out, variadic := g.signatureFrom(astMethod.Type.(*ast.FuncType))
    62  	return &model.Method{Name: astMethod.Names[0].Name, In: in, Variadic: variadic, Out: out}
    63  }
    64  
    65  func (g *modelGenerator) signatureFrom(astFuncType *ast.FuncType) (in, out []*model.Parameter, variadic *model.Parameter) {
    66  	in, variadic = g.inParamsFrom(astFuncType.Params)
    67  	out = g.outParamsFrom(astFuncType.Results)
    68  	return
    69  }
    70  
    71  func (g *modelGenerator) inParamsFrom(params *ast.FieldList) (in []*model.Parameter, variadic *model.Parameter) {
    72  	for _, param := range params.List {
    73  		for _, name := range param.Names {
    74  			if ellipsisType, isEllipsisType := param.Type.(*ast.Ellipsis); isEllipsisType {
    75  				variadic = g.newParam(name.Name, ellipsisType.Elt)
    76  			} else {
    77  				in = append(in, g.newParam(name.Name, param.Type))
    78  			}
    79  		}
    80  		if len(param.Names) == 0 {
    81  			if ellipsisType, isEllipsisType := param.Type.(*ast.Ellipsis); isEllipsisType {
    82  				variadic = g.newParam("", ellipsisType.Elt)
    83  			} else {
    84  				in = append(in, g.newParam("", param.Type))
    85  			}
    86  		}
    87  	}
    88  	return
    89  }
    90  
    91  func (g *modelGenerator) outParamsFrom(results *ast.FieldList) (out []*model.Parameter) {
    92  	if results != nil {
    93  		for _, param := range results.List {
    94  			for _, name := range param.Names {
    95  				out = append(out, g.newParam(name.Name, param.Type))
    96  			}
    97  			if len(param.Names) == 0 {
    98  				out = append(out, g.newParam("", param.Type))
    99  			}
   100  		}
   101  	}
   102  	return
   103  }
   104  
   105  func (g *modelGenerator) newParam(name string, typ ast.Expr) *model.Parameter {
   106  	return &model.Parameter{
   107  		Name: name,
   108  		Type: g.modelTypeFrom(g.info.TypeOf(typ)),
   109  	}
   110  }
   111  
   112  func (g *modelGenerator) modelTypeFrom(typesType types.Type) model.Type {
   113  	switch typedTyp := typesType.(type) {
   114  	case *types.Basic:
   115  		if !predeclared(typedTyp.Kind()) {
   116  			panic(fmt.Sprintf("Unexpected Basic Type %v", typedTyp.Name()))
   117  		}
   118  		return model.PredeclaredType(typedTyp.Name())
   119  	case *types.Pointer:
   120  		return &model.PointerType{
   121  			Type: g.modelTypeFrom(typedTyp.Elem()),
   122  		}
   123  	case *types.Array:
   124  		return &model.ArrayType{
   125  			Len:  int(typedTyp.Len()),
   126  			Type: g.modelTypeFrom(typedTyp.Elem()),
   127  		}
   128  	case *types.Slice:
   129  		return &model.ArrayType{
   130  			Len:  -1,
   131  			Type: g.modelTypeFrom(typedTyp.Elem()),
   132  		}
   133  	case *types.Map:
   134  		return &model.MapType{
   135  			Key:   g.modelTypeFrom(typedTyp.Key()),
   136  			Value: g.modelTypeFrom(typedTyp.Elem()),
   137  		}
   138  	case *types.Chan:
   139  		var dir model.ChanDir
   140  		switch typedTyp.Dir() {
   141  		case types.SendOnly:
   142  			dir = model.SendDir
   143  		case types.RecvOnly:
   144  			dir = model.RecvDir
   145  		default:
   146  			dir = 0
   147  		}
   148  		return &model.ChanType{
   149  			Dir:  dir,
   150  			Type: g.modelTypeFrom(typedTyp.Elem()),
   151  		}
   152  	case *types.Named:
   153  		if typedTyp.Obj().Pkg() == nil {
   154  			return model.PredeclaredType(typedTyp.Obj().Name())
   155  		}
   156  		return &model.NamedType{
   157  			Package: typedTyp.Obj().Pkg().Path(),
   158  			Type:    typedTyp.Obj().Name(),
   159  		}
   160  	case *types.Interface, *types.Struct:
   161  		return model.PredeclaredType(typedTyp.String())
   162  	case *types.Signature:
   163  		in, variadic := g.generateInParamsFrom(typedTyp.Params())
   164  		out := g.generateOutParamsFrom(typedTyp.Results())
   165  		return &model.FuncType{In: in, Out: out, Variadic: variadic}
   166  	default:
   167  		panic(fmt.Sprintf("Unknown types.Type: %v (%T)", typesType, typesType))
   168  	}
   169  }
   170  
   171  func (g *modelGenerator) generateInParamsFrom(params *types.Tuple) (in []*model.Parameter, variadic *model.Parameter) {
   172  	// TODO: variadic
   173  
   174  	for i := 0; i < params.Len(); i++ {
   175  		in = append(in, &model.Parameter{
   176  			Name: params.At(i).Name(),
   177  			Type: g.modelTypeFrom(params.At(i).Type()),
   178  		})
   179  	}
   180  	return
   181  }
   182  
   183  func (g *modelGenerator) generateOutParamsFrom(params *types.Tuple) (out []*model.Parameter) {
   184  	for i := 0; i < params.Len(); i++ {
   185  		out = append(out, &model.Parameter{
   186  			Name: params.At(i).Name(),
   187  			Type: g.modelTypeFrom(params.At(i).Type()),
   188  		})
   189  	}
   190  	return
   191  }
   192  
   193  func predeclared(basicKind types.BasicKind) bool {
   194  	return basicKind >= types.Bool && basicKind <= types.String
   195  }