github.com/julianthome/gore@v0.0.0-20231109011145-b3a6bbe6fe55/commands.go (about)

     1  package gore
     2  
     3  import (
     4  	"fmt"
     5  	"go/ast"
     6  	"go/build"
     7  	"go/types"
     8  	"os"
     9  	"os/exec"
    10  	"path"
    11  	"path/filepath"
    12  	"runtime"
    13  	"strings"
    14  	"text/tabwriter"
    15  	"time"
    16  	"unicode"
    17  
    18  	"golang.org/x/tools/go/ast/astutil"
    19  	"golang.org/x/tools/go/packages"
    20  )
    21  
    22  type command struct {
    23  	name     commandName
    24  	action   func(*Session, string) error
    25  	complete func(*Session, string) []string
    26  	arg      string
    27  	document string
    28  }
    29  
    30  var commands []command
    31  
    32  func init() {
    33  	commands = []command{
    34  		{
    35  			name:     commandName("i[mport]"),
    36  			action:   actionImport,
    37  			complete: completeImport,
    38  			arg:      "<package>",
    39  			document: "import a package",
    40  		},
    41  		{
    42  			name:     commandName("t[ype]"),
    43  			action:   actionType,
    44  			arg:      "<expr>",
    45  			complete: completeDoc,
    46  			document: "print the type of expression",
    47  		},
    48  		{
    49  			name:     commandName("print"),
    50  			action:   actionPrint,
    51  			document: "print current source",
    52  		},
    53  		{
    54  			name:     commandName("w[rite]"),
    55  			action:   actionWrite,
    56  			complete: nil, // TODO implement
    57  			arg:      "[<file>]",
    58  			document: "write out current source",
    59  		},
    60  		{
    61  			name:     commandName("clear"),
    62  			action:   actionClear,
    63  			document: "clear the codes",
    64  		},
    65  		{
    66  			name:     commandName("d[oc]"),
    67  			action:   actionDoc,
    68  			complete: completeDoc,
    69  			arg:      "<expr or pkg>",
    70  			document: "show documentation",
    71  		},
    72  		{
    73  			name:     commandName("h[elp]"),
    74  			action:   actionHelp,
    75  			document: "show this help",
    76  		},
    77  		{
    78  			name:     commandName("q[uit]"),
    79  			action:   actionQuit,
    80  			document: "quit the session",
    81  		},
    82  	}
    83  }
    84  
    85  func actionImport(s *Session, arg string) error {
    86  	if arg == "" {
    87  		return fmt.Errorf("argument is required")
    88  	}
    89  
    90  	if strings.Contains(arg, " ") {
    91  		for _, v := range strings.Fields(arg) {
    92  			if v == "" {
    93  				continue
    94  			}
    95  			if err := actionImport(s, v); err != nil {
    96  				return err
    97  			}
    98  		}
    99  
   100  		return nil
   101  	}
   102  
   103  	arg = strings.Trim(arg, `"`)
   104  
   105  	// check if the package specified by path is importable
   106  	_, err := packages.Load(
   107  		&packages.Config{
   108  			Dir:        s.tempDir,
   109  			BuildFlags: []string{"-mod=mod"},
   110  		},
   111  		arg,
   112  	)
   113  	if err != nil {
   114  		return err
   115  	}
   116  
   117  	var found bool
   118  	for _, i := range s.file.Imports {
   119  		if strings.Trim(i.Path.Value, `"`) == arg {
   120  			found = true
   121  			break
   122  		}
   123  	}
   124  	if !found {
   125  		astutil.AddNamedImport(s.fset, s.file, "_", arg)
   126  		_, err = s.types.Check("_tmp", s.fset, append(s.extraFiles, s.file), nil)
   127  		if err != nil && strings.Contains(err.Error(), "could not import "+arg) {
   128  			astutil.DeleteNamedImport(s.fset, s.file, "_", arg)
   129  			return fmt.Errorf("could not import %q", arg)
   130  		}
   131  	}
   132  
   133  	return nil
   134  }
   135  
   136  var gorootSrc = filepath.Join(filepath.Clean(runtime.GOROOT()), "src")
   137  
   138  func completeImport(_ *Session, prefix string) []string {
   139  	result := []string{}
   140  	seen := map[string]bool{}
   141  
   142  	p := strings.LastIndexFunc(prefix, unicode.IsSpace) + 1
   143  
   144  	d, fn := path.Split(prefix[p:])
   145  
   146  	// complete candidates from the current module
   147  	if modules, err := goListAll(); err == nil {
   148  		for _, m := range modules {
   149  
   150  			matchPath := func(fn string) bool {
   151  				if len(fn) < 2 {
   152  					return false
   153  				}
   154  				for _, s := range strings.Split(m.Path, "/") {
   155  					if strings.HasPrefix(s, fn) || strings.HasPrefix(strings.TrimPrefix(s, "go-"), fn) {
   156  						return true
   157  					}
   158  				}
   159  				return false
   160  			}
   161  			if strings.HasPrefix(m.Path, prefix[p:]) || d == "" && matchPath(fn) {
   162  				result = append(result, prefix[:p]+m.Path)
   163  				seen[m.Path] = true
   164  				continue
   165  			}
   166  
   167  			if strings.HasPrefix(d, m.Path) {
   168  				dir := filepath.Join(m.Dir, strings.Replace(d, m.Path, "", 1))
   169  				if fi, err := os.Stat(dir); err != nil || !fi.IsDir() {
   170  					continue
   171  				}
   172  				entries, err := os.ReadDir(dir)
   173  				if err != nil {
   174  					continue
   175  				}
   176  				for _, fi := range entries {
   177  					if !fi.IsDir() {
   178  						continue
   179  					}
   180  					name := fi.Name()
   181  					if skipCompleteDir(name) {
   182  						continue
   183  					}
   184  					if strings.HasPrefix(name, fn) {
   185  						r := path.Join(d, name)
   186  						if !seen[r] {
   187  							result = append(result, prefix[:p]+r)
   188  							seen[r] = true
   189  						}
   190  					}
   191  				}
   192  			}
   193  
   194  		}
   195  	}
   196  
   197  	// complete candidates from GOPATH/src/
   198  	for _, srcDir := range build.Default.SrcDirs() {
   199  		dir := filepath.Join(srcDir, d)
   200  
   201  		if fi, err := os.Stat(dir); err != nil || !fi.IsDir() {
   202  			if err != nil && !os.IsNotExist(err) {
   203  				errorf("Stat %s: %s", dir, err)
   204  			}
   205  			continue
   206  		}
   207  
   208  		entries, err := os.ReadDir(dir)
   209  		if err != nil {
   210  			errorf("ReadDir %s: %s", dir, err)
   211  			continue
   212  		}
   213  		for _, fi := range entries {
   214  			if !fi.IsDir() {
   215  				continue
   216  			}
   217  
   218  			name := fi.Name()
   219  			if skipCompleteDir(name) {
   220  				continue
   221  			}
   222  
   223  			if strings.HasPrefix(name, fn) {
   224  				r := path.Join(d, name)
   225  				if srcDir != gorootSrc {
   226  					// append "/" if this directory is not a repository
   227  					// e.g. does not have VCS directory such as .git or .hg
   228  					// TODO: do not append "/" to subdirectories of repos
   229  					var isRepo bool
   230  					for _, vcsDir := range []string{".git", ".hg", ".svn", ".bzr"} {
   231  						_, err := os.Stat(filepath.Join(srcDir, filepath.FromSlash(r), vcsDir))
   232  						if err == nil {
   233  							isRepo = true
   234  							break
   235  						}
   236  					}
   237  					if !isRepo {
   238  						r += "/"
   239  					}
   240  				}
   241  
   242  				if !seen[r] {
   243  					result = append(result, prefix[:p]+r)
   244  					seen[r] = true
   245  				}
   246  			}
   247  		}
   248  	}
   249  
   250  	return result
   251  }
   252  
   253  func skipCompleteDir(dir string) bool {
   254  	return strings.HasPrefix(dir, ".") || strings.HasPrefix(dir, "_") || dir == "testdata"
   255  }
   256  
   257  func completeDoc(s *Session, prefix string) []string {
   258  	pos, cands, err := s.completeCode(prefix, len(prefix), false)
   259  	if err != nil {
   260  		errorf("completeCode: %s", err)
   261  		return nil
   262  	}
   263  
   264  	result := make([]string, 0, len(cands))
   265  	for _, c := range cands {
   266  		result = append(result, prefix[0:pos]+c)
   267  	}
   268  
   269  	return result
   270  }
   271  
   272  func actionPrint(s *Session, _ string) error {
   273  	source, err := s.source(true)
   274  	if err != nil {
   275  		return err
   276  	}
   277  
   278  	fmt.Println(source)
   279  
   280  	return nil
   281  }
   282  
   283  func actionType(s *Session, in string) error {
   284  	if in == "" {
   285  		return fmt.Errorf("argument is required")
   286  	}
   287  
   288  	s.clearQuickFix()
   289  
   290  	s.storeCode()
   291  	defer s.restoreCode()
   292  
   293  	expr, err := s.evalExpr(in)
   294  	if err != nil {
   295  		return err
   296  	}
   297  
   298  	s.typeInfo = types.Info{
   299  		Types:  make(map[ast.Expr]types.TypeAndValue),
   300  		Uses:   make(map[*ast.Ident]types.Object),
   301  		Defs:   make(map[*ast.Ident]types.Object),
   302  		Scopes: make(map[ast.Node]*types.Scope),
   303  	}
   304  	_, err = s.types.Check("_tmp", s.fset, append(s.extraFiles, s.file), &s.typeInfo)
   305  	if err != nil {
   306  		debugf("typecheck error (ignored): %s", err)
   307  	}
   308  
   309  	typ := s.typeInfo.TypeOf(expr)
   310  	if typ == nil {
   311  		return fmt.Errorf("cannot get type: %v", expr)
   312  	}
   313  	if typ, ok := typ.(*types.Basic); ok && typ.Kind() == types.Invalid {
   314  		return fmt.Errorf("cannot get type: %v", expr)
   315  	}
   316  	fmt.Fprintf(s.stdout, "%v\n", typ)
   317  	return nil
   318  }
   319  
   320  func actionWrite(s *Session, filename string) error {
   321  	source, err := s.source(false)
   322  	if err != nil {
   323  		return err
   324  	}
   325  
   326  	if filename == "" {
   327  		filename = fmt.Sprintf("gore_session_%s.go", time.Now().Format("20060102_150405"))
   328  	}
   329  
   330  	err = os.WriteFile(filename, []byte(source), 0o644)
   331  	if err != nil {
   332  		return err
   333  	}
   334  
   335  	infof("Source wrote to %s", filename)
   336  
   337  	return nil
   338  }
   339  
   340  func actionClear(s *Session, _ string) error {
   341  	return s.init()
   342  }
   343  
   344  func actionDoc(s *Session, in string) error {
   345  	if in == "" {
   346  		return fmt.Errorf("argument is required")
   347  	}
   348  
   349  	s.clearQuickFix()
   350  
   351  	s.storeCode()
   352  	defer s.restoreCode()
   353  
   354  	expr, err := s.evalExpr(in)
   355  	if err != nil {
   356  		return err
   357  	}
   358  
   359  	s.typeInfo = types.Info{
   360  		Types:  make(map[ast.Expr]types.TypeAndValue),
   361  		Uses:   make(map[*ast.Ident]types.Object),
   362  		Defs:   make(map[*ast.Ident]types.Object),
   363  		Scopes: make(map[ast.Node]*types.Scope),
   364  	}
   365  	_, err = s.types.Check("_tmp", s.fset, append(s.extraFiles, s.file), &s.typeInfo)
   366  	if err != nil {
   367  		debugf("typecheck error (ignored): %s", err)
   368  	}
   369  
   370  	// :doc patterns:
   371  	// - "json" -> "encoding/json" (package name)
   372  	// - "json.Encoder" -> "encoding/json", "Encoder" (package member)
   373  	// - "json.NewEncoder(nil).Encode" -> "encoding/json", "Decode" (package type member)
   374  	var docObj types.Object
   375  	if sel, ok := expr.(*ast.SelectorExpr); ok {
   376  		// package member, package type member
   377  		docObj = s.typeInfo.ObjectOf(sel.Sel)
   378  	} else if t := s.typeInfo.TypeOf(expr); t != nil && t != types.Typ[types.Invalid] {
   379  		for {
   380  			if pt, ok := t.(*types.Pointer); ok {
   381  				t = pt.Elem()
   382  			} else {
   383  				break
   384  			}
   385  		}
   386  		switch t := t.(type) {
   387  		case *types.Named:
   388  			docObj = t.Obj()
   389  		case *types.Basic:
   390  			// builtin types
   391  			docObj = types.Universe.Lookup(t.Name())
   392  		}
   393  	} else if ident, ok := expr.(*ast.Ident); ok {
   394  		// package name
   395  		mainScope := s.typeInfo.Scopes[s.mainFunc().Type]
   396  		_, docObj = mainScope.LookupParent(ident.Name, ident.NamePos)
   397  	}
   398  
   399  	if docObj == nil {
   400  		return fmt.Errorf("cannot determine the document location")
   401  	}
   402  
   403  	debugf("doc :: obj=%#v", docObj)
   404  
   405  	var pkgPath, objName string
   406  	if pkgName, ok := docObj.(*types.PkgName); ok {
   407  		pkgPath = pkgName.Imported().Path()
   408  	} else {
   409  		if pkg := docObj.Pkg(); pkg != nil {
   410  			pkgPath = pkg.Path()
   411  		} else {
   412  			pkgPath = "builtin"
   413  		}
   414  		objName = docObj.Name()
   415  	}
   416  
   417  	debugf("doc :: %q %q", pkgPath, objName)
   418  
   419  	args := []string{"doc", pkgPath}
   420  	if objName != "" {
   421  		args = append(args, objName)
   422  	}
   423  
   424  	godoc := exec.Command("go", args...)
   425  	godoc.Dir = s.tempDir
   426  	godoc.Env = append(os.Environ(), "GO111MODULE=on")
   427  	ef := newErrFilter(s.stderr)
   428  	godoc.Stderr = ef
   429  	defer ef.Close()
   430  
   431  	// TODO just use PAGER?
   432  	if pagerCmd := os.Getenv("GORE_PAGER"); pagerCmd != "" {
   433  		r, err := godoc.StdoutPipe()
   434  		if err != nil {
   435  			return err
   436  		}
   437  
   438  		pager := exec.Command(pagerCmd)
   439  		pager.Stdin = r
   440  		pager.Stdout = s.stdout
   441  		pager.Stderr = s.stderr
   442  
   443  		err = pager.Start()
   444  		if err != nil {
   445  			return err
   446  		}
   447  
   448  		err = godoc.Run()
   449  		if err != nil {
   450  			return err
   451  		}
   452  
   453  		return pager.Wait()
   454  	}
   455  	godoc.Stdout = s.stdout
   456  	return godoc.Run()
   457  }
   458  
   459  func actionHelp(s *Session, _ string) error {
   460  	w := tabwriter.NewWriter(s.stdout, 0, 8, 4, ' ', 0)
   461  	for _, command := range commands {
   462  		cmd := fmt.Sprintf(":%s", command.name)
   463  		if command.arg != "" {
   464  			cmd = cmd + " " + command.arg
   465  		}
   466  		w.Write([]byte("    " + cmd + "\t" + command.document + "\n"))
   467  	}
   468  	w.Flush()
   469  
   470  	return nil
   471  }
   472  
   473  func actionQuit(_ *Session, _ string) error {
   474  	return ErrQuit
   475  }