github.com/jhump/golang-x-tools@v0.0.0-20220218190644-4958d6d39439/internal/lsp/source/stub.go (about)

     1  // Copyright 2022 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package source
     6  
     7  import (
     8  	"bytes"
     9  	"context"
    10  	"fmt"
    11  	"go/ast"
    12  	"go/format"
    13  	"go/parser"
    14  	"go/token"
    15  	"go/types"
    16  	"strings"
    17  
    18  	"github.com/jhump/golang-x-tools/go/analysis"
    19  	"github.com/jhump/golang-x-tools/go/ast/astutil"
    20  	"github.com/jhump/golang-x-tools/internal/lsp/analysis/stubmethods"
    21  	"github.com/jhump/golang-x-tools/internal/lsp/protocol"
    22  	"github.com/jhump/golang-x-tools/internal/span"
    23  )
    24  
    25  func stubSuggestedFixFunc(ctx context.Context, snapshot Snapshot, fh VersionedFileHandle, rng protocol.Range) (*analysis.SuggestedFix, error) {
    26  	pkg, pgf, err := GetParsedFile(ctx, snapshot, fh, NarrowestPackage)
    27  	if err != nil {
    28  		return nil, fmt.Errorf("GetParsedFile: %w", err)
    29  	}
    30  	nodes, pos, err := getStubNodes(pgf, rng)
    31  	if err != nil {
    32  		return nil, fmt.Errorf("getNodes: %w", err)
    33  	}
    34  	si := stubmethods.GetStubInfo(pkg.GetTypesInfo(), nodes, pkg.GetTypes(), pos)
    35  	if si == nil {
    36  		return nil, fmt.Errorf("nil interface request")
    37  	}
    38  	parsedConcreteFile, concreteFH, err := getStubFile(ctx, si.Concrete.Obj(), snapshot)
    39  	if err != nil {
    40  		return nil, fmt.Errorf("getFile(concrete): %w", err)
    41  	}
    42  	var (
    43  		methodsSrc  []byte
    44  		stubImports []*stubImport // additional imports needed for method stubs
    45  	)
    46  	if si.Interface.Pkg() == nil && si.Interface.Name() == "error" && si.Interface.Parent() == types.Universe {
    47  		methodsSrc = stubErr(ctx, parsedConcreteFile.File, si, snapshot)
    48  	} else {
    49  		methodsSrc, stubImports, err = stubMethods(ctx, parsedConcreteFile.File, si, snapshot)
    50  	}
    51  	if err != nil {
    52  		return nil, fmt.Errorf("stubMethods: %w", err)
    53  	}
    54  	nodes, _ = astutil.PathEnclosingInterval(parsedConcreteFile.File, si.Concrete.Obj().Pos(), si.Concrete.Obj().Pos())
    55  	concreteSrc, err := concreteFH.Read()
    56  	if err != nil {
    57  		return nil, fmt.Errorf("error reading concrete file source: %w", err)
    58  	}
    59  	insertPos := snapshot.FileSet().Position(nodes[1].End()).Offset
    60  	if insertPos >= len(concreteSrc) {
    61  		return nil, fmt.Errorf("insertion position is past the end of the file")
    62  	}
    63  	var buf bytes.Buffer
    64  	buf.Write(concreteSrc[:insertPos])
    65  	buf.WriteByte('\n')
    66  	buf.Write(methodsSrc)
    67  	buf.Write(concreteSrc[insertPos:])
    68  	fset := token.NewFileSet()
    69  	newF, err := parser.ParseFile(fset, parsedConcreteFile.File.Name.Name, buf.Bytes(), parser.ParseComments)
    70  	if err != nil {
    71  		return nil, fmt.Errorf("could not reparse file: %w", err)
    72  	}
    73  	for _, imp := range stubImports {
    74  		astutil.AddNamedImport(fset, newF, imp.Name, imp.Path)
    75  	}
    76  	var source bytes.Buffer
    77  	err = format.Node(&source, fset, newF)
    78  	if err != nil {
    79  		return nil, fmt.Errorf("format.Node: %w", err)
    80  	}
    81  	diffEdits, err := snapshot.View().Options().ComputeEdits(parsedConcreteFile.URI, string(parsedConcreteFile.Src), source.String())
    82  	if err != nil {
    83  		return nil, err
    84  	}
    85  	var edits []analysis.TextEdit
    86  	for _, edit := range diffEdits {
    87  		rng, err := edit.Span.Range(parsedConcreteFile.Mapper.Converter)
    88  		if err != nil {
    89  			return nil, err
    90  		}
    91  		edits = append(edits, analysis.TextEdit{
    92  			Pos:     rng.Start,
    93  			End:     rng.End,
    94  			NewText: []byte(edit.NewText),
    95  		})
    96  	}
    97  	return &analysis.SuggestedFix{
    98  		TextEdits: edits,
    99  	}, nil
   100  }
   101  
   102  // stubMethods returns the Go code of all methods
   103  // that implement the given interface
   104  func stubMethods(ctx context.Context, concreteFile *ast.File, si *stubmethods.StubInfo, snapshot Snapshot) ([]byte, []*stubImport, error) {
   105  	ifacePkg, err := deducePkgFromTypes(ctx, snapshot, si.Interface)
   106  	if err != nil {
   107  		return nil, nil, err
   108  	}
   109  	si.Concrete.Obj().Type()
   110  	concMS := types.NewMethodSet(types.NewPointer(si.Concrete.Obj().Type()))
   111  	missing, err := missingMethods(ctx, snapshot, concMS, si.Concrete.Obj().Pkg(), si.Interface, ifacePkg, map[string]struct{}{})
   112  	if err != nil {
   113  		return nil, nil, fmt.Errorf("missingMethods: %w", err)
   114  	}
   115  	if len(missing) == 0 {
   116  		return nil, nil, fmt.Errorf("no missing methods found")
   117  	}
   118  	var (
   119  		stubImports   []*stubImport
   120  		methodsBuffer bytes.Buffer
   121  	)
   122  	for _, mi := range missing {
   123  		for _, m := range mi.missing {
   124  			// TODO(marwan-at-work): this should share the same logic with source.FormatVarType
   125  			// as it also accounts for type aliases.
   126  			sig := types.TypeString(m.Type(), stubmethods.RelativeToFiles(si.Concrete.Obj().Pkg(), concreteFile, mi.file, func(name, path string) {
   127  				for _, imp := range stubImports {
   128  					if imp.Name == name && imp.Path == path {
   129  						return
   130  					}
   131  				}
   132  				stubImports = append(stubImports, &stubImport{name, path})
   133  			}))
   134  			_, err = methodsBuffer.Write(printStubMethod(methodData{
   135  				Method:    m.Name(),
   136  				Concrete:  getStubReceiver(si),
   137  				Interface: deduceIfaceName(si.Concrete.Obj().Pkg(), si.Interface.Pkg(), si.Concrete.Obj(), si.Interface),
   138  				Signature: strings.TrimPrefix(sig, "func"),
   139  			}))
   140  			if err != nil {
   141  				return nil, nil, fmt.Errorf("error printing method: %w", err)
   142  			}
   143  			methodsBuffer.WriteRune('\n')
   144  		}
   145  	}
   146  	return methodsBuffer.Bytes(), stubImports, nil
   147  }
   148  
   149  // stubErr reurns the Go code implementation
   150  // of an error interface relevant to the
   151  // concrete type
   152  func stubErr(ctx context.Context, concreteFile *ast.File, si *stubmethods.StubInfo, snapshot Snapshot) []byte {
   153  	return printStubMethod(methodData{
   154  		Method:    "Error",
   155  		Interface: "error",
   156  		Concrete:  getStubReceiver(si),
   157  		Signature: "() string",
   158  	})
   159  }
   160  
   161  // getStubReceiver returns the concrete type's name as a method receiver.
   162  // TODO(marwan-at-work): add type parameters to the receiver when the concrete type
   163  // is a generic one.
   164  func getStubReceiver(si *stubmethods.StubInfo) string {
   165  	concrete := si.Concrete.Obj().Name()
   166  	if si.Pointer {
   167  		concrete = "*" + concrete
   168  	}
   169  	return concrete
   170  }
   171  
   172  type methodData struct {
   173  	Method    string
   174  	Interface string
   175  	Concrete  string
   176  	Signature string
   177  }
   178  
   179  // printStubMethod takes methodData and returns Go code that represents the given method such as:
   180  // 	// {{ .Method }} implements {{ .Interface }}
   181  // 	func ({{ .Concrete }}) {{ .Method }}{{ .Signature }} {
   182  // 		panic("unimplemented")
   183  // 	}
   184  func printStubMethod(md methodData) []byte {
   185  	var b bytes.Buffer
   186  	fmt.Fprintf(&b, "// %s implements %s\n", md.Method, md.Interface)
   187  	fmt.Fprintf(&b, "func (%s) %s%s {\n\t", md.Concrete, md.Method, md.Signature)
   188  	fmt.Fprintln(&b, `panic("unimplemented")`)
   189  	fmt.Fprintln(&b, "}")
   190  	return b.Bytes()
   191  }
   192  
   193  func deducePkgFromTypes(ctx context.Context, snapshot Snapshot, ifaceObj types.Object) (Package, error) {
   194  	pkgs, err := snapshot.KnownPackages(ctx)
   195  	if err != nil {
   196  		return nil, err
   197  	}
   198  	for _, p := range pkgs {
   199  		if p.PkgPath() == ifaceObj.Pkg().Path() {
   200  			return p, nil
   201  		}
   202  	}
   203  	return nil, fmt.Errorf("pkg %q not found", ifaceObj.Pkg().Path())
   204  }
   205  
   206  func deduceIfaceName(concretePkg, ifacePkg *types.Package, concreteObj, ifaceObj types.Object) string {
   207  	if concretePkg.Path() == ifacePkg.Path() {
   208  		return ifaceObj.Name()
   209  	}
   210  	return fmt.Sprintf("%s.%s", ifacePkg.Name(), ifaceObj.Name())
   211  }
   212  
   213  func getStubNodes(pgf *ParsedGoFile, pRng protocol.Range) ([]ast.Node, token.Pos, error) {
   214  	spn, err := pgf.Mapper.RangeSpan(pRng)
   215  	if err != nil {
   216  		return nil, 0, err
   217  	}
   218  	rng, err := spn.Range(pgf.Mapper.Converter)
   219  	if err != nil {
   220  		return nil, 0, err
   221  	}
   222  	nodes, _ := astutil.PathEnclosingInterval(pgf.File, rng.Start, rng.End)
   223  	return nodes, rng.Start, nil
   224  }
   225  
   226  /*
   227  missingMethods takes a concrete type and returns any missing methods for the given interface as well as
   228  any missing interface that might have been embedded to its parent. For example:
   229  
   230  type I interface {
   231  	io.Writer
   232  	Hello()
   233  }
   234  returns []*missingInterface{
   235  	{
   236  		iface: *types.Interface (io.Writer),
   237  		file: *ast.File: io.go,
   238  		missing []*types.Func{Write},
   239  	},
   240  	{
   241  		iface: *types.Interface (I),
   242  		file: *ast.File: myfile.go,
   243  		missing: []*types.Func{Hello}
   244  	},
   245  }
   246  */
   247  func missingMethods(ctx context.Context, snapshot Snapshot, concMS *types.MethodSet, concPkg *types.Package, ifaceObj types.Object, ifacePkg Package, visited map[string]struct{}) ([]*missingInterface, error) {
   248  	iface, ok := ifaceObj.Type().Underlying().(*types.Interface)
   249  	if !ok {
   250  		return nil, fmt.Errorf("expected %v to be an interface but got %T", iface, ifaceObj.Type().Underlying())
   251  	}
   252  	missing := []*missingInterface{}
   253  	for i := 0; i < iface.NumEmbeddeds(); i++ {
   254  		eiface := iface.Embedded(i).Obj()
   255  		depPkg := ifacePkg
   256  		if eiface.Pkg().Path() != ifacePkg.PkgPath() {
   257  			var err error
   258  			depPkg, err = ifacePkg.GetImport(eiface.Pkg().Path())
   259  			if err != nil {
   260  				return nil, err
   261  			}
   262  		}
   263  		em, err := missingMethods(ctx, snapshot, concMS, concPkg, eiface, depPkg, visited)
   264  		if err != nil {
   265  			return nil, err
   266  		}
   267  		missing = append(missing, em...)
   268  	}
   269  	parsedFile, _, err := getStubFile(ctx, ifaceObj, snapshot)
   270  	if err != nil {
   271  		return nil, fmt.Errorf("error getting iface file: %w", err)
   272  	}
   273  	mi := &missingInterface{
   274  		pkg:   ifacePkg,
   275  		iface: iface,
   276  		file:  parsedFile.File,
   277  	}
   278  	if mi.file == nil {
   279  		return nil, fmt.Errorf("could not find ast.File for %v", ifaceObj.Name())
   280  	}
   281  	for i := 0; i < iface.NumExplicitMethods(); i++ {
   282  		method := iface.ExplicitMethod(i)
   283  		// if the concrete type does not have the interface method
   284  		if concMS.Lookup(concPkg, method.Name()) == nil {
   285  			if _, ok := visited[method.Name()]; !ok {
   286  				mi.missing = append(mi.missing, method)
   287  				visited[method.Name()] = struct{}{}
   288  			}
   289  		}
   290  		if sel := concMS.Lookup(concPkg, method.Name()); sel != nil {
   291  			implSig := sel.Type().(*types.Signature)
   292  			ifaceSig := method.Type().(*types.Signature)
   293  			if !types.Identical(ifaceSig, implSig) {
   294  				return nil, fmt.Errorf("mimsatched %q function signatures:\nhave: %s\nwant: %s", method.Name(), implSig, ifaceSig)
   295  			}
   296  		}
   297  	}
   298  	if len(mi.missing) > 0 {
   299  		missing = append(missing, mi)
   300  	}
   301  	return missing, nil
   302  }
   303  
   304  func getStubFile(ctx context.Context, obj types.Object, snapshot Snapshot) (*ParsedGoFile, VersionedFileHandle, error) {
   305  	objPos := snapshot.FileSet().Position(obj.Pos())
   306  	objFile := span.URIFromPath(objPos.Filename)
   307  	objectFH := snapshot.FindFile(objFile)
   308  	_, goFile, err := GetParsedFile(ctx, snapshot, objectFH, WidestPackage)
   309  	if err != nil {
   310  		return nil, nil, fmt.Errorf("GetParsedFile: %w", err)
   311  	}
   312  	return goFile, objectFH, nil
   313  }
   314  
   315  // missingInterface represents an interface
   316  // that has all or some of its methods missing
   317  // from the destination concrete type
   318  type missingInterface struct {
   319  	iface   *types.Interface
   320  	file    *ast.File
   321  	pkg     Package
   322  	missing []*types.Func
   323  }
   324  
   325  // stubImport represents a newly added import
   326  // statement to the concrete type. If name is not
   327  // empty, then that import is required to have that name.
   328  type stubImport struct{ Name, Path string }