github.com/powerman/golang-tools@v0.1.11-0.20220410185822-5ad214d8d803/internal/lsp/analysis/stubmethods/stubmethods.go (about)

     1  // Copyright 2022 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package stubmethods
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"go/ast"
    11  	"go/format"
    12  	"go/token"
    13  	"go/types"
    14  	"strconv"
    15  	"strings"
    16  
    17  	"github.com/powerman/golang-tools/go/analysis"
    18  	"github.com/powerman/golang-tools/go/analysis/passes/inspect"
    19  	"github.com/powerman/golang-tools/go/ast/astutil"
    20  	"github.com/powerman/golang-tools/internal/analysisinternal"
    21  	"github.com/powerman/golang-tools/internal/typesinternal"
    22  )
    23  
    24  const Doc = `stub methods analyzer
    25  
    26  This analyzer generates method stubs for concrete types
    27  in order to implement a target interface`
    28  
    29  var Analyzer = &analysis.Analyzer{
    30  	Name:             "stubmethods",
    31  	Doc:              Doc,
    32  	Requires:         []*analysis.Analyzer{inspect.Analyzer},
    33  	Run:              run,
    34  	RunDespiteErrors: true,
    35  }
    36  
    37  func run(pass *analysis.Pass) (interface{}, error) {
    38  	for _, err := range analysisinternal.GetTypeErrors(pass) {
    39  		ifaceErr := strings.Contains(err.Msg, "missing method") || strings.HasPrefix(err.Msg, "cannot convert")
    40  		if !ifaceErr {
    41  			continue
    42  		}
    43  		var file *ast.File
    44  		for _, f := range pass.Files {
    45  			if f.Pos() <= err.Pos && err.Pos < f.End() {
    46  				file = f
    47  				break
    48  			}
    49  		}
    50  		if file == nil {
    51  			continue
    52  		}
    53  		// Get the end position of the error.
    54  		_, _, endPos, ok := typesinternal.ReadGo116ErrorData(err)
    55  		if !ok {
    56  			var buf bytes.Buffer
    57  			if err := format.Node(&buf, pass.Fset, file); err != nil {
    58  				continue
    59  			}
    60  			endPos = analysisinternal.TypeErrorEndPos(pass.Fset, buf.Bytes(), err.Pos)
    61  		}
    62  		path, _ := astutil.PathEnclosingInterval(file, err.Pos, endPos)
    63  		si := GetStubInfo(pass.TypesInfo, path, err.Pos)
    64  		if si == nil {
    65  			continue
    66  		}
    67  		qf := RelativeToFiles(si.Concrete.Obj().Pkg(), file, nil, nil)
    68  		pass.Report(analysis.Diagnostic{
    69  			Pos:     err.Pos,
    70  			End:     endPos,
    71  			Message: fmt.Sprintf("Implement %s", types.TypeString(si.Interface.Type(), qf)),
    72  		})
    73  	}
    74  	return nil, nil
    75  }
    76  
    77  // StubInfo represents a concrete type
    78  // that wants to stub out an interface type
    79  type StubInfo struct {
    80  	// Interface is the interface that the client wants to implement.
    81  	// When the interface is defined, the underlying object will be a TypeName.
    82  	// Note that we keep track of types.Object instead of types.Type in order
    83  	// to keep a reference to the declaring object's package and the ast file
    84  	// in the case where the concrete type file requires a new import that happens to be renamed
    85  	// in the interface file.
    86  	// TODO(marwan-at-work): implement interface literals.
    87  	Interface types.Object
    88  	Concrete  *types.Named
    89  	Pointer   bool
    90  }
    91  
    92  // GetStubInfo determines whether the "missing method error"
    93  // can be used to deduced what the concrete and interface types are.
    94  func GetStubInfo(ti *types.Info, path []ast.Node, pos token.Pos) *StubInfo {
    95  	for _, n := range path {
    96  		switch n := n.(type) {
    97  		case *ast.ValueSpec:
    98  			return fromValueSpec(ti, n, pos)
    99  		case *ast.ReturnStmt:
   100  			// An error here may not indicate a real error the user should know about, but it may.
   101  			// Therefore, it would be best to log it out for debugging/reporting purposes instead of ignoring
   102  			// it. However, event.Log takes a context which is not passed via the analysis package.
   103  			// TODO(marwan-at-work): properly log this error.
   104  			si, _ := fromReturnStmt(ti, pos, path, n)
   105  			return si
   106  		case *ast.AssignStmt:
   107  			return fromAssignStmt(ti, n, pos)
   108  		}
   109  	}
   110  	return nil
   111  }
   112  
   113  // fromReturnStmt analyzes a "return" statement to extract
   114  // a concrete type that is trying to be returned as an interface type.
   115  //
   116  // For example, func() io.Writer { return myType{} }
   117  // would return StubInfo with the interface being io.Writer and the concrete type being myType{}.
   118  func fromReturnStmt(ti *types.Info, pos token.Pos, path []ast.Node, rs *ast.ReturnStmt) (*StubInfo, error) {
   119  	returnIdx := -1
   120  	for i, r := range rs.Results {
   121  		if pos >= r.Pos() && pos <= r.End() {
   122  			returnIdx = i
   123  		}
   124  	}
   125  	if returnIdx == -1 {
   126  		return nil, fmt.Errorf("pos %d not within return statement bounds: [%d-%d]", pos, rs.Pos(), rs.End())
   127  	}
   128  	concObj, pointer := concreteType(rs.Results[returnIdx], ti)
   129  	if concObj == nil || concObj.Obj().Pkg() == nil {
   130  		return nil, nil
   131  	}
   132  	ef := enclosingFunction(path, ti)
   133  	if ef == nil {
   134  		return nil, fmt.Errorf("could not find the enclosing function of the return statement")
   135  	}
   136  	iface := ifaceType(ef.Results.List[returnIdx].Type, ti)
   137  	if iface == nil {
   138  		return nil, nil
   139  	}
   140  	return &StubInfo{
   141  		Concrete:  concObj,
   142  		Pointer:   pointer,
   143  		Interface: iface,
   144  	}, nil
   145  }
   146  
   147  // fromValueSpec returns *StubInfo from a variable declaration such as
   148  // var x io.Writer = &T{}
   149  func fromValueSpec(ti *types.Info, vs *ast.ValueSpec, pos token.Pos) *StubInfo {
   150  	var idx int
   151  	for i, vs := range vs.Values {
   152  		if pos >= vs.Pos() && pos <= vs.End() {
   153  			idx = i
   154  			break
   155  		}
   156  	}
   157  
   158  	valueNode := vs.Values[idx]
   159  	ifaceNode := vs.Type
   160  	callExp, ok := valueNode.(*ast.CallExpr)
   161  	// if the ValueSpec is `var _ = myInterface(...)`
   162  	// as opposed to `var _ myInterface = ...`
   163  	if ifaceNode == nil && ok && len(callExp.Args) == 1 {
   164  		ifaceNode = callExp.Fun
   165  		valueNode = callExp.Args[0]
   166  	}
   167  	concObj, pointer := concreteType(valueNode, ti)
   168  	if concObj == nil || concObj.Obj().Pkg() == nil {
   169  		return nil
   170  	}
   171  	ifaceObj := ifaceType(ifaceNode, ti)
   172  	if ifaceObj == nil {
   173  		return nil
   174  	}
   175  	return &StubInfo{
   176  		Concrete:  concObj,
   177  		Interface: ifaceObj,
   178  		Pointer:   pointer,
   179  	}
   180  }
   181  
   182  // fromAssignStmt returns *StubInfo from a variable re-assignment such as
   183  // var x io.Writer
   184  // x = &T{}
   185  func fromAssignStmt(ti *types.Info, as *ast.AssignStmt, pos token.Pos) *StubInfo {
   186  	idx := -1
   187  	var lhs, rhs ast.Expr
   188  	// Given a re-assignment interface conversion error,
   189  	// the compiler error shows up on the right hand side of the expression.
   190  	// For example, x = &T{} where x is io.Writer highlights the error
   191  	// under "&T{}" and not "x".
   192  	for i, hs := range as.Rhs {
   193  		if pos >= hs.Pos() && pos <= hs.End() {
   194  			idx = i
   195  			break
   196  		}
   197  	}
   198  	if idx == -1 {
   199  		return nil
   200  	}
   201  	// Technically, this should never happen as
   202  	// we would get a "cannot assign N values to M variables"
   203  	// before we get an interface conversion error. Nonetheless,
   204  	// guard against out of range index errors.
   205  	if idx >= len(as.Lhs) {
   206  		return nil
   207  	}
   208  	lhs, rhs = as.Lhs[idx], as.Rhs[idx]
   209  	ifaceObj := ifaceType(lhs, ti)
   210  	if ifaceObj == nil {
   211  		return nil
   212  	}
   213  	concType, pointer := concreteType(rhs, ti)
   214  	if concType == nil || concType.Obj().Pkg() == nil {
   215  		return nil
   216  	}
   217  	return &StubInfo{
   218  		Concrete:  concType,
   219  		Interface: ifaceObj,
   220  		Pointer:   pointer,
   221  	}
   222  }
   223  
   224  // RelativeToFiles returns a types.Qualifier that formats package names
   225  // according to the files where the concrete and interface types are defined.
   226  //
   227  // This is similar to types.RelativeTo except if a file imports the package with a different name,
   228  // then it will use it. And if the file does import the package but it is ignored,
   229  // then it will return the original name. It also prefers package names in ifaceFile in case
   230  // an import is missing from concFile but is present in ifaceFile.
   231  //
   232  // Additionally, if missingImport is not nil, the function will be called whenever the concFile
   233  // is presented with a package that is not imported. This is useful so that as types.TypeString is
   234  // formatting a function signature, it is identifying packages that will need to be imported when
   235  // stubbing an interface.
   236  func RelativeToFiles(concPkg *types.Package, concFile, ifaceFile *ast.File, missingImport func(name, path string)) types.Qualifier {
   237  	return func(other *types.Package) string {
   238  		if other == concPkg {
   239  			return ""
   240  		}
   241  
   242  		// Check if the concrete file already has the given import,
   243  		// if so return the default package name or the renamed import statement.
   244  		for _, imp := range concFile.Imports {
   245  			impPath, _ := strconv.Unquote(imp.Path.Value)
   246  			isIgnored := imp.Name != nil && (imp.Name.Name == "." || imp.Name.Name == "_")
   247  			if impPath == other.Path() && !isIgnored {
   248  				importName := other.Name()
   249  				if imp.Name != nil {
   250  					importName = imp.Name.Name
   251  				}
   252  				return importName
   253  			}
   254  		}
   255  
   256  		// If the concrete file does not have the import, check if the package
   257  		// is renamed in the interface file and prefer that.
   258  		var importName string
   259  		if ifaceFile != nil {
   260  			for _, imp := range ifaceFile.Imports {
   261  				impPath, _ := strconv.Unquote(imp.Path.Value)
   262  				isIgnored := imp.Name != nil && (imp.Name.Name == "." || imp.Name.Name == "_")
   263  				if impPath == other.Path() && !isIgnored {
   264  					if imp.Name != nil && imp.Name.Name != concPkg.Name() {
   265  						importName = imp.Name.Name
   266  					}
   267  					break
   268  				}
   269  			}
   270  		}
   271  
   272  		if missingImport != nil {
   273  			missingImport(importName, other.Path())
   274  		}
   275  
   276  		// Up until this point, importName must stay empty when calling missingImport,
   277  		// otherwise we'd end up with `import time "time"` which doesn't look idiomatic.
   278  		if importName == "" {
   279  			importName = other.Name()
   280  		}
   281  		return importName
   282  	}
   283  }
   284  
   285  // ifaceType will try to extract the types.Object that defines
   286  // the interface given the ast.Expr where the "missing method"
   287  // or "conversion" errors happen.
   288  func ifaceType(n ast.Expr, ti *types.Info) types.Object {
   289  	tv, ok := ti.Types[n]
   290  	if !ok {
   291  		return nil
   292  	}
   293  	typ := tv.Type
   294  	named, ok := typ.(*types.Named)
   295  	if !ok {
   296  		return nil
   297  	}
   298  	_, ok = named.Underlying().(*types.Interface)
   299  	if !ok {
   300  		return nil
   301  	}
   302  	// Interfaces defined in the "builtin" package return nil a Pkg().
   303  	// But they are still real interfaces that we need to make a special case for.
   304  	// Therefore, protect gopls from panicking if a new interface type was added in the future.
   305  	if named.Obj().Pkg() == nil && named.Obj().Name() != "error" {
   306  		return nil
   307  	}
   308  	return named.Obj()
   309  }
   310  
   311  // concreteType tries to extract the *types.Named that defines
   312  // the concrete type given the ast.Expr where the "missing method"
   313  // or "conversion" errors happened. If the concrete type is something
   314  // that cannot have methods defined on it (such as basic types), this
   315  // method will return a nil *types.Named. The second return parameter
   316  // is a boolean that indicates whether the concreteType was defined as a
   317  // pointer or value.
   318  func concreteType(n ast.Expr, ti *types.Info) (*types.Named, bool) {
   319  	tv, ok := ti.Types[n]
   320  	if !ok {
   321  		return nil, false
   322  	}
   323  	typ := tv.Type
   324  	ptr, isPtr := typ.(*types.Pointer)
   325  	if isPtr {
   326  		typ = ptr.Elem()
   327  	}
   328  	named, ok := typ.(*types.Named)
   329  	if !ok {
   330  		return nil, false
   331  	}
   332  	return named, isPtr
   333  }
   334  
   335  // enclosingFunction returns the signature and type of the function
   336  // enclosing the given position.
   337  func enclosingFunction(path []ast.Node, info *types.Info) *ast.FuncType {
   338  	for _, node := range path {
   339  		switch t := node.(type) {
   340  		case *ast.FuncDecl:
   341  			if _, ok := info.Defs[t.Name]; ok {
   342  				return t.Type
   343  			}
   344  		case *ast.FuncLit:
   345  			if _, ok := info.Types[t]; ok {
   346  				return t.Type
   347  			}
   348  		}
   349  	}
   350  	return nil
   351  }