github.com/gmemcc/yaegi@v0.12.1-0.20221128122509-aa99124c5d16/extract/extract.go (about)

     1  /*
     2  Package extract generates wrappers of package exported symbols.
     3  */
     4  package extract
     5  
     6  import (
     7  	"bufio"
     8  	"bytes"
     9  	"errors"
    10  	"fmt"
    11  	"github.com/gmemcc/yaegi/internal/cmd/extract/srcimporter"
    12  	"github.com/gmemcc/yaegi/internal/cmd/extract/types"
    13  	"go/build"
    14  	"go/constant"
    15  	"go/format"
    16  	"go/token"
    17  	"io"
    18  	"math/big"
    19  	"os"
    20  	"path"
    21  	"path/filepath"
    22  	"regexp"
    23  	"runtime"
    24  	"strconv"
    25  	"strings"
    26  	"text/template"
    27  )
    28  
    29  const model = `// Code generated by 'yaegi extract {{.ImportPath}}'. DO NOT EDIT.
    30  
    31  {{.License}}
    32  
    33  {{if .BuildTags}}// +build {{.BuildTags}}{{end}}
    34  
    35  package {{.Dest}}
    36  
    37  import (
    38  {{- range $key, $value := .Imports }}
    39  	{{- if $value}}
    40  	"{{$key}}"
    41  	{{- end}}
    42  {{- end}}
    43  	"{{.ImportPath}}"
    44  	"reflect"
    45  )
    46  
    47  func init() {
    48  	Symbols["{{.PkgName}}"] = map[string]reflect.Value{
    49  		{{- if .Val}}
    50  		// function, constant and variable definitions
    51  		{{range $key, $value := .Val -}}
    52  			{{- if $value.Addr -}}
    53  				"{{$key}}": reflect.ValueOf(&{{$value.Name}}).Elem(),
    54  			{{else -}}
    55  				"{{$key}}": reflect.ValueOf({{$value.Name}}),
    56  			{{end -}}
    57  		{{end}}
    58  
    59  		{{- end}}
    60  		{{- if .Typ}}
    61  		// type definitions
    62  		{{range $key, $value := .Typ -}}
    63  			"{{$key}}": reflect.ValueOf((*{{$value}})(nil)),
    64  		{{end}}
    65  
    66  		{{- end}}
    67  		{{- if .Wrap}}
    68  		// interface wrapper definitions
    69  		{{range $key, $value := .Wrap -}}
    70  			"_{{$key}}": reflect.ValueOf((*{{$value.Name}})(nil)),
    71  		{{end}}
    72  		{{- end}}
    73  	}
    74  }
    75  {{range $key, $value := .Wrap -}}
    76  	// {{$value.Name}} is an interface wrapper for {{$key}} type
    77  	type {{$value.Name}} struct {
    78  		IValue interface{}
    79  		{{range $m := $value.Method -}}
    80  		W{{$m.Name}} func{{$m.Param}} {{$m.Result}}
    81  		{{end}}
    82  	}
    83  	{{range $m := $value.Method -}}
    84  		func (W {{$value.Name}}) {{$m.Name}}{{$m.Param}} {{$m.Result}} {
    85  			{{- if eq $m.Name "String"}}
    86  			if W.WString == nil {
    87  				return ""
    88  			}
    89  			{{end -}}
    90  			{{$m.Ret}} W.W{{$m.Name}}{{$m.Arg}}
    91  		}
    92  	{{end}}
    93  {{end}}
    94  `
    95  
    96  // Val stores the value name and addressable status of symbols.
    97  type Val struct {
    98  	Name string // "package.name"
    99  	Addr bool   // true if symbol is a Var
   100  }
   101  
   102  // Method stores information for generating interface wrapper method.
   103  type Method struct {
   104  	Name, Param, Result, Arg, Ret string
   105  }
   106  
   107  // Wrap stores information for generating interface wrapper.
   108  type Wrap struct {
   109  	Name   string
   110  	Method []Method
   111  }
   112  
   113  // restricted map defines symbols for which a special implementation is provided.
   114  var restricted = map[string]bool{
   115  	"osExit":        true,
   116  	"osFindProcess": true,
   117  	"logFatal":      true,
   118  	"logFatalf":     true,
   119  	"logFatalln":    true,
   120  	"logLogger":     true,
   121  	"logNew":        true,
   122  }
   123  
   124  func matchList(name string, list []string) (match bool, err error) {
   125  	for _, re := range list {
   126  		match, err = regexp.MatchString(re, name)
   127  		if err != nil || match {
   128  			return
   129  		}
   130  	}
   131  	return
   132  }
   133  
   134  func (e *Extractor) genContent(importPath string, p *types.Package) ([]byte, error) {
   135  	prefix := "_" + importPath + "_"
   136  	prefix = strings.NewReplacer("/", "_", "-", "_", ".", "_").Replace(prefix)
   137  
   138  	typ := map[string]string{}
   139  	val := map[string]Val{}
   140  	wrap := map[string]Wrap{}
   141  	imports := map[string]bool{}
   142  	sc := p.Scope()
   143  
   144  	for _, pkg := range p.Imports() {
   145  		imports[pkg.Path()] = false
   146  	}
   147  	qualify := func(pkg *types.Package) string {
   148  		if pkg.Path() != importPath {
   149  			imports[pkg.Path()] = true
   150  		}
   151  		return pkg.Name()
   152  	}
   153  
   154  	for _, name := range sc.Names() {
   155  		o := sc.Lookup(name)
   156  		if !o.Exported() {
   157  			continue
   158  		}
   159  
   160  		if len(e.Include) > 0 {
   161  			match, err := matchList(name, e.Include)
   162  			if err != nil {
   163  				return nil, err
   164  			}
   165  			if !match {
   166  				// Explicitly defined include expressions force non matching symbols to be skipped.
   167  				continue
   168  			}
   169  		}
   170  
   171  		match, err := matchList(name, e.Exclude)
   172  		if err != nil {
   173  			return nil, err
   174  		}
   175  		if match {
   176  			continue
   177  		}
   178  
   179  		pname := p.Name() + "." + name
   180  		if rname := p.Name() + name; restricted[rname] {
   181  			// Restricted symbol, locally provided by stdlib wrapper.
   182  			pname = rname
   183  		}
   184  
   185  		switch o := o.(type) {
   186  		case *types.Const:
   187  			if b, ok := o.Type().(*types.Basic); ok && (b.Info()&types.IsUntyped) != 0 {
   188  				// Convert untyped constant to right type to avoid overflow.
   189  				val[name] = Val{fixConst(pname, o.Val(), imports), false}
   190  			} else {
   191  				val[name] = Val{pname, false}
   192  			}
   193  		case *types.Func:
   194  			val[name] = Val{pname, false}
   195  		case *types.Var:
   196  			val[name] = Val{pname, true}
   197  		case *types.TypeName:
   198  			typ[name] = pname
   199  			if t, ok := o.Type().Underlying().(*types.Interface); ok {
   200  				var methods []Method
   201  				for i := 0; i < t.NumMethods(); i++ {
   202  					f := t.Method(i)
   203  					if !f.Exported() {
   204  						continue
   205  					}
   206  
   207  					sign := f.Type().(*types.Signature)
   208  					args := make([]string, sign.Params().Len())
   209  					params := make([]string, len(args))
   210  					for j := range args {
   211  						v := sign.Params().At(j)
   212  						if args[j] = v.Name(); args[j] == "" {
   213  							args[j] = fmt.Sprintf("a%d", j)
   214  						}
   215  						// process interface method variadic parameter
   216  						if sign.Variadic() && j == len(args)-1 { // check is last arg
   217  							// only replace the first "[]" to "..."
   218  							at := types.TypeString(v.Type(), qualify)[2:]
   219  							params[j] = args[j] + " ..." + at
   220  							args[j] += "..."
   221  						} else {
   222  							params[j] = args[j] + " " + types.TypeString(v.Type(), qualify)
   223  						}
   224  					}
   225  					arg := "(" + strings.Join(args, ", ") + ")"
   226  					param := "(" + strings.Join(params, ", ") + ")"
   227  
   228  					results := make([]string, sign.Results().Len())
   229  					for j := range results {
   230  						v := sign.Results().At(j)
   231  						results[j] = v.Name() + " " + types.TypeString(v.Type(), qualify)
   232  					}
   233  					result := "(" + strings.Join(results, ", ") + ")"
   234  
   235  					ret := ""
   236  					if sign.Results().Len() > 0 {
   237  						ret = "return"
   238  					}
   239  
   240  					methods = append(methods, Method{f.Name(), param, result, arg, ret})
   241  				}
   242  				wrap[name] = Wrap{prefix + name, methods}
   243  			}
   244  		}
   245  	}
   246  
   247  	// Generate buildTags with Go version only for stdlib packages.
   248  	// Third party packages do not depend on Go compiler version by default.
   249  	var buildTags string
   250  	if isInStdlib(importPath) {
   251  		var err error
   252  		buildTags, err = genBuildTags()
   253  		if err != nil {
   254  			return nil, err
   255  		}
   256  	}
   257  
   258  	base := template.New("extract")
   259  	parse, err := base.Parse(model)
   260  	if err != nil {
   261  		return nil, fmt.Errorf("template parsing error: %w", err)
   262  	}
   263  
   264  	if importPath == "log/syslog" {
   265  		buildTags += ",!windows,!nacl,!plan9"
   266  	}
   267  
   268  	if importPath == "syscall" {
   269  		// As per https://golang.org/cmd/go/#hdr-Build_constraints,
   270  		// using GOOS=android also matches tags and files for GOOS=linux,
   271  		// so exclude it explicitly to avoid collisions (issue #843).
   272  		// Also using GOOS=illumos matches tags and files for GOOS=solaris.
   273  		switch os.Getenv("GOOS") {
   274  		case "android":
   275  			buildTags += ",!linux"
   276  		case "illumos":
   277  			buildTags += ",!solaris"
   278  		}
   279  	}
   280  
   281  	for _, t := range e.Tag {
   282  		if len(t) != 0 {
   283  			buildTags += "," + t
   284  		}
   285  	}
   286  	if len(buildTags) != 0 && buildTags[0] == ',' {
   287  		buildTags = buildTags[1:]
   288  	}
   289  
   290  	b := new(bytes.Buffer)
   291  	data := map[string]interface{}{
   292  		"Dest":       e.Dest,
   293  		"Imports":    imports,
   294  		"ImportPath": importPath,
   295  		"PkgName":    path.Join(importPath, p.Name()),
   296  		"Val":        val,
   297  		"Typ":        typ,
   298  		"Wrap":       wrap,
   299  		"BuildTags":  buildTags,
   300  		"License":    e.License,
   301  	}
   302  	err = parse.Execute(b, data)
   303  	if err != nil {
   304  		return nil, fmt.Errorf("template error: %w", err)
   305  	}
   306  
   307  	// gofmt
   308  	source, err := format.Source(b.Bytes())
   309  	if err != nil {
   310  		return nil, fmt.Errorf("failed to format source: %w: %s", err, b.Bytes())
   311  	}
   312  	return source, nil
   313  }
   314  
   315  // fixConst checks untyped constant value, converting it if necessary to avoid overflow.
   316  func fixConst(name string, val constant.Value, imports map[string]bool) string {
   317  	var (
   318  		tok string
   319  		str string
   320  	)
   321  	switch val.Kind() {
   322  	case constant.String:
   323  		tok = "STRING"
   324  		str = val.ExactString()
   325  	case constant.Int:
   326  		tok = "INT"
   327  		str = val.ExactString()
   328  	case constant.Float:
   329  		v := constant.Val(val) // v is *big.Rat or *big.Float
   330  		f, ok := v.(*big.Float)
   331  		if !ok {
   332  			f = new(big.Float).SetRat(v.(*big.Rat))
   333  		}
   334  
   335  		tok = "FLOAT"
   336  		str = f.Text('g', int(f.Prec()))
   337  	case constant.Complex:
   338  		// TODO: not sure how to parse this case
   339  		fallthrough
   340  	default:
   341  		return name
   342  	}
   343  
   344  	imports["go/constant"] = true
   345  	imports["go/token"] = true
   346  
   347  	return fmt.Sprintf("constant.MakeFromLiteral(%q, token.%s, 0)", str, tok)
   348  }
   349  
   350  // Extractor creates a package with all the symbols from a dependency package.
   351  type Extractor struct {
   352  	Dest    string   // The name of the created package.
   353  	License string   // License text to be included in the created package, optional.
   354  	Exclude []string // Comma separated list of regexp matching symbols to exclude.
   355  	Include []string // Comma separated list of regexp matching symbols to include.
   356  	Tag     []string // Comma separated of build tags to be added to the created package.
   357  }
   358  
   359  // importPath checks whether pkgIdent is an existing directory relative to
   360  // e.WorkingDir. If yes, it returns the actual import path of the Go package
   361  // located in the directory. If it is definitely a relative path, but it does not
   362  // exist, an error is returned. Otherwise, it is assumed to be an import path, and
   363  // pkgIdent is returned.
   364  func (e *Extractor) importPath(pkgIdent, importPath string) (string, error) {
   365  	wd, err := os.Getwd()
   366  	if err != nil {
   367  		return "", err
   368  	}
   369  
   370  	dirPath := filepath.Join(wd, pkgIdent)
   371  	_, err = os.Stat(dirPath)
   372  	if err != nil && !os.IsNotExist(err) {
   373  		return "", err
   374  	}
   375  	if err != nil {
   376  		if len(pkgIdent) > 0 && pkgIdent[0] == '.' {
   377  			// pkgIdent is definitely a relative path, not a package name, and it does not exist
   378  			return "", err
   379  		}
   380  		// pkgIdent might be a valid stdlib package name. So we leave that responsibility to the caller now.
   381  		return pkgIdent, nil
   382  	}
   383  
   384  	// local import
   385  	if importPath != "" {
   386  		return importPath, nil
   387  	}
   388  
   389  	modPath := filepath.Join(dirPath, "go.mod")
   390  	_, err = os.Stat(modPath)
   391  	if os.IsNotExist(err) {
   392  		return "", errors.New("no go.mod found, and no import path specified")
   393  	}
   394  	if err != nil {
   395  		return "", err
   396  	}
   397  	f, err := os.Open(modPath)
   398  	if err != nil {
   399  		return "", err
   400  	}
   401  	defer func() {
   402  		_ = f.Close()
   403  	}()
   404  	sc := bufio.NewScanner(f)
   405  	var l string
   406  	for sc.Scan() {
   407  		l = sc.Text()
   408  		break
   409  	}
   410  	if sc.Err() != nil {
   411  		return "", err
   412  	}
   413  	parts := strings.Fields(l)
   414  	if len(parts) < 2 {
   415  		return "", errors.New(`invalid first line syntax in go.mod`)
   416  	}
   417  	if parts[0] != "module" {
   418  		return "", errors.New(`invalid first line in go.mod, no "module" found`)
   419  	}
   420  
   421  	return parts[1], nil
   422  }
   423  
   424  // Extract writes to rw a Go package with all the symbols found at pkgIdent.
   425  // pkgIdent can be an import path, or a local path, relative to e.WorkingDir. In
   426  // the latter case, Extract returns the actual import path of the package found at
   427  // pkgIdent, otherwise it just returns pkgIdent.
   428  // If pkgIdent is an import path, it is looked up in GOPATH. Vendoring is not
   429  // supported yet, and the behavior is only defined for GO111MODULE=off.
   430  func (e *Extractor) Extract(pkgIdent, importPath string, rw io.Writer) (string, error) {
   431  	ipp, err := e.importPath(pkgIdent, importPath)
   432  	if err != nil {
   433  		return "", err
   434  	}
   435  	fset := token.NewFileSet()
   436  	importer := srcimporter.New(&build.Default, fset, make(map[string]*types.Package))
   437  	pkg, err := importer.Import(pkgIdent, 0)
   438  	if err != nil {
   439  		return "", err
   440  	}
   441  
   442  	content, err := e.genContent(ipp, pkg)
   443  	if err != nil {
   444  		return "", err
   445  	}
   446  
   447  	if _, err := rw.Write(content); err != nil {
   448  		return "", err
   449  	}
   450  
   451  	return ipp, nil
   452  }
   453  
   454  // GetMinor returns the minor part of the version number.
   455  func GetMinor(part string) string {
   456  	minor := part
   457  	index := strings.Index(minor, "beta")
   458  	if index < 0 {
   459  		index = strings.Index(minor, "rc")
   460  	}
   461  	if index > 0 {
   462  		minor = minor[:index]
   463  	}
   464  
   465  	return minor
   466  }
   467  
   468  const defaultMinorVersion = 17
   469  
   470  func genBuildTags() (string, error) {
   471  	version := runtime.Version()
   472  	if strings.HasPrefix(version, "devel") {
   473  		return "", fmt.Errorf("extracting only supported with stable releases of Go, not %v", version)
   474  	}
   475  	parts := strings.Split(version, ".")
   476  
   477  	minorRaw := GetMinor(parts[1])
   478  
   479  	currentGoVersion := parts[0] + "." + minorRaw
   480  
   481  	minor, err := strconv.Atoi(minorRaw)
   482  	if err != nil {
   483  		return "", fmt.Errorf("failed to parse version: %w", err)
   484  	}
   485  
   486  	// Only append an upper bound if we are not on the latest go
   487  	if minor >= defaultMinorVersion {
   488  		return currentGoVersion, nil
   489  	}
   490  
   491  	nextGoVersion := parts[0] + "." + strconv.Itoa(minor+1)
   492  
   493  	return currentGoVersion + ",!" + nextGoVersion, nil
   494  }
   495  
   496  func isInStdlib(path string) bool { return !strings.Contains(path, ".") }