github.com/HaHadaxigua/yaegi@v1.0.1/extract/extract.go (about)

     1  /*
     2  Package extract generates wrappers of package exported symbols.
     3  */
     4  package extract
     5  
     6  import (
     7  	"bufio"
     8  	"bytes"
     9  	"errors"
    10  	"fmt"
    11  	"go/constant"
    12  	"go/format"
    13  	"go/importer"
    14  	"go/token"
    15  	"go/types"
    16  	"io"
    17  	"math/big"
    18  	"os"
    19  	"path"
    20  	"path/filepath"
    21  	"regexp"
    22  	"runtime"
    23  	"strconv"
    24  	"strings"
    25  	"text/template"
    26  )
    27  
    28  const model = `// Code generated by 'yaegi extract {{.ImportPath}}'. DO NOT EDIT.
    29  
    30  {{.License}}
    31  
    32  {{if .BuildTags}}// +build {{.BuildTags}}{{end}}
    33  
    34  package {{.Dest}}
    35  
    36  import (
    37  {{- range $key, $value := .Imports }}
    38  	{{- if $value}}
    39  	"{{$key}}"
    40  	{{- end}}
    41  {{- end}}
    42  	"{{.ImportPath}}"
    43  	"reflect"
    44  )
    45  
    46  func init() {
    47  	Symbols["{{.PkgName}}"] = map[string]reflect.Value{
    48  		{{- if .Val}}
    49  		// function, constant and variable definitions
    50  		{{range $key, $value := .Val -}}
    51  			{{- if $value.Addr -}}
    52  				"{{$key}}": reflect.ValueOf(&{{$value.Name}}).Elem(),
    53  			{{else -}}
    54  				"{{$key}}": reflect.ValueOf({{$value.Name}}),
    55  			{{end -}}
    56  		{{end}}
    57  
    58  		{{- end}}
    59  		{{- if .Typ}}
    60  		// type definitions
    61  		{{range $key, $value := .Typ -}}
    62  			"{{$key}}": reflect.ValueOf((*{{$value}})(nil)),
    63  		{{end}}
    64  
    65  		{{- end}}
    66  		{{- if .Wrap}}
    67  		// interface wrapper definitions
    68  		{{range $key, $value := .Wrap -}}
    69  			"_{{$key}}": reflect.ValueOf((*{{$value.Name}})(nil)),
    70  		{{end}}
    71  		{{- end}}
    72  	}
    73  }
    74  {{range $key, $value := .Wrap -}}
    75  	// {{$value.Name}} is an interface wrapper for {{$key}} type
    76  	type {{$value.Name}} struct {
    77  		IValue interface{}
    78  		{{range $m := $value.Method -}}
    79  		W{{$m.Name}} func{{$m.Param}} {{$m.Result}}
    80  		{{end}}
    81  	}
    82  	{{range $m := $value.Method -}}
    83  		func (W {{$value.Name}}) {{$m.Name}}{{$m.Param}} {{$m.Result}} {
    84  			{{- if eq $m.Name "String"}}
    85  			if W.WString == nil {
    86  				return ""
    87  			}
    88  			{{end -}}
    89  			{{$m.Ret}} W.W{{$m.Name}}{{$m.Arg}}
    90  		}
    91  	{{end}}
    92  {{end}}
    93  `
    94  
    95  // Val stores the value name and addressable status of symbols.
    96  type Val struct {
    97  	Name string // "package.name"
    98  	Addr bool   // true if symbol is a Var
    99  }
   100  
   101  // Method stores information for generating interface wrapper method.
   102  type Method struct {
   103  	Name, Param, Result, Arg, Ret string
   104  }
   105  
   106  // Wrap stores information for generating interface wrapper.
   107  type Wrap struct {
   108  	Name   string
   109  	Method []Method
   110  }
   111  
   112  // restricted map defines symbols for which a special implementation is provided.
   113  var restricted = map[string]bool{
   114  	"osExit":        true,
   115  	"osFindProcess": true,
   116  	"logFatal":      true,
   117  	"logFatalf":     true,
   118  	"logFatalln":    true,
   119  	"logLogger":     true,
   120  	"logNew":        true,
   121  }
   122  
   123  func matchList(name string, list []string) (match bool, err error) {
   124  	for _, re := range list {
   125  		match, err = regexp.MatchString(re, name)
   126  		if err != nil || match {
   127  			return
   128  		}
   129  	}
   130  	return
   131  }
   132  
   133  type PackageStruct struct {
   134  	Typ     map[string]string
   135  	Val     map[string]Val
   136  	Wrap    map[string]Wrap
   137  	Imports map[string]bool
   138  }
   139  
   140  func (e *Extractor) genStructure(importPath string, p *types.Package) (*PackageStruct, error) {
   141  	prefix := "_" + importPath + "_"
   142  	prefix = strings.NewReplacer("/", "_", "-", "_", ".", "_").Replace(prefix)
   143  
   144  	typ := map[string]string{}
   145  	val := map[string]Val{}
   146  	wrap := map[string]Wrap{}
   147  	imports := map[string]bool{}
   148  	sc := p.Scope()
   149  
   150  	for _, pkg := range p.Imports() {
   151  		imports[pkg.Path()] = false
   152  	}
   153  	qualify := func(pkg *types.Package) string {
   154  		if pkg.Path() != importPath {
   155  			imports[pkg.Path()] = true
   156  		}
   157  		return pkg.Name()
   158  	}
   159  
   160  	for _, name := range sc.Names() {
   161  		o := sc.Lookup(name)
   162  		if !o.Exported() {
   163  			continue
   164  		}
   165  
   166  		if len(e.Include) > 0 {
   167  			match, err := matchList(name, e.Include)
   168  			if err != nil {
   169  				return nil, err
   170  			}
   171  			if !match {
   172  				// Explicitly defined include expressions force non matching symbols to be skipped.
   173  				continue
   174  			}
   175  		}
   176  
   177  		match, err := matchList(name, e.Exclude)
   178  		if err != nil {
   179  			return nil, err
   180  		}
   181  		if match {
   182  			continue
   183  		}
   184  
   185  		pname := p.Name() + "." + name
   186  		if rname := p.Name() + name; restricted[rname] {
   187  			// Restricted symbol, locally provided by stdlib wrapper.
   188  			pname = rname
   189  		}
   190  
   191  		switch o := o.(type) {
   192  		case *types.Const:
   193  			if b, ok := o.Type().(*types.Basic); ok && (b.Info()&types.IsUntyped) != 0 {
   194  				// Convert untyped constant to right type to avoid overflow.
   195  				val[name] = Val{fixConst(pname, o.Val(), imports), false}
   196  			} else {
   197  				val[name] = Val{pname, false}
   198  			}
   199  		case *types.Func:
   200  			val[name] = Val{pname, false}
   201  		case *types.Var:
   202  			val[name] = Val{pname, true}
   203  		case *types.TypeName:
   204  			typ[name] = pname
   205  			if t, ok := o.Type().Underlying().(*types.Interface); ok {
   206  				var methods []Method
   207  				for i := 0; i < t.NumMethods(); i++ {
   208  					f := t.Method(i)
   209  					if !f.Exported() {
   210  						continue
   211  					}
   212  
   213  					sign := f.Type().(*types.Signature)
   214  					args := make([]string, sign.Params().Len())
   215  					params := make([]string, len(args))
   216  					for j := range args {
   217  						v := sign.Params().At(j)
   218  						if args[j] = v.Name(); args[j] == "" {
   219  							args[j] = fmt.Sprintf("a%d", j)
   220  						}
   221  						// process interface method variadic parameter
   222  						if sign.Variadic() && j == len(args)-1 { // check is last arg
   223  							// only replace the first "[]" to "..."
   224  							at := types.TypeString(v.Type(), qualify)[2:]
   225  							params[j] = args[j] + " ..." + at
   226  							args[j] += "..."
   227  						} else {
   228  							params[j] = args[j] + " " + types.TypeString(v.Type(), qualify)
   229  						}
   230  					}
   231  					arg := "(" + strings.Join(args, ", ") + ")"
   232  					param := "(" + strings.Join(params, ", ") + ")"
   233  
   234  					results := make([]string, sign.Results().Len())
   235  					for j := range results {
   236  						v := sign.Results().At(j)
   237  						results[j] = v.Name() + " " + types.TypeString(v.Type(), qualify)
   238  					}
   239  					result := "(" + strings.Join(results, ", ") + ")"
   240  
   241  					ret := ""
   242  					if sign.Results().Len() > 0 {
   243  						ret = "return"
   244  					}
   245  
   246  					methods = append(methods, Method{f.Name(), param, result, arg, ret})
   247  				}
   248  				wrap[name] = Wrap{prefix + name, methods}
   249  			}
   250  		}
   251  	}
   252  
   253  	// Generate buildTags with Go version only for stdlib packages.
   254  	// Third party packages do not depend on Go compiler version by default.
   255  	var buildTags string
   256  	if isInStdlib(importPath) {
   257  		var err error
   258  		buildTags, err = genBuildTags()
   259  		if err != nil {
   260  			return nil, err
   261  		}
   262  	}
   263  
   264  	if importPath == "log/syslog" {
   265  		buildTags += ",!windows,!nacl,!plan9"
   266  	}
   267  
   268  	if importPath == "syscall" {
   269  		// As per https://golang.org/cmd/go/#hdr-Build_constraints,
   270  		// using GOOS=android also matches tags and files for GOOS=linux,
   271  		// so exclude it explicitly to avoid collisions (issue #843).
   272  		// Also using GOOS=illumos matches tags and files for GOOS=solaris.
   273  		switch os.Getenv("GOOS") {
   274  		case "android":
   275  			buildTags += ",!linux"
   276  		case "illumos":
   277  			buildTags += ",!solaris"
   278  		}
   279  	}
   280  
   281  	for _, t := range e.Tag {
   282  		if len(t) != 0 {
   283  			buildTags += "," + t
   284  		}
   285  	}
   286  	if len(buildTags) != 0 && buildTags[0] == ',' {
   287  		buildTags = buildTags[1:]
   288  	}
   289  
   290  	return &PackageStruct{
   291  		Typ:     typ,
   292  		Val:     val,
   293  		Wrap:    wrap,
   294  		Imports: imports,
   295  	}, nil
   296  }
   297  
   298  func (e *Extractor) genContent(importPath string, p *types.Package) ([]byte, error) {
   299  	prefix := "_" + importPath + "_"
   300  	prefix = strings.NewReplacer("/", "_", "-", "_", ".", "_").Replace(prefix)
   301  
   302  	typ := map[string]string{}
   303  	val := map[string]Val{}
   304  	wrap := map[string]Wrap{}
   305  	imports := map[string]bool{}
   306  	sc := p.Scope()
   307  
   308  	for _, pkg := range p.Imports() {
   309  		imports[pkg.Path()] = false
   310  	}
   311  	qualify := func(pkg *types.Package) string {
   312  		if pkg.Path() != importPath {
   313  			imports[pkg.Path()] = true
   314  		}
   315  		return pkg.Name()
   316  	}
   317  
   318  	for _, name := range sc.Names() {
   319  		o := sc.Lookup(name)
   320  		if !o.Exported() {
   321  			continue
   322  		}
   323  
   324  		if len(e.Include) > 0 {
   325  			match, err := matchList(name, e.Include)
   326  			if err != nil {
   327  				return nil, err
   328  			}
   329  			if !match {
   330  				// Explicitly defined include expressions force non matching symbols to be skipped.
   331  				continue
   332  			}
   333  		}
   334  
   335  		match, err := matchList(name, e.Exclude)
   336  		if err != nil {
   337  			return nil, err
   338  		}
   339  		if match {
   340  			continue
   341  		}
   342  
   343  		pname := p.Name() + "." + name
   344  		if rname := p.Name() + name; restricted[rname] {
   345  			// Restricted symbol, locally provided by stdlib wrapper.
   346  			pname = rname
   347  		}
   348  
   349  		switch o := o.(type) {
   350  		case *types.Const:
   351  			if b, ok := o.Type().(*types.Basic); ok && (b.Info()&types.IsUntyped) != 0 {
   352  				// Convert untyped constant to right type to avoid overflow.
   353  				val[name] = Val{fixConst(pname, o.Val(), imports), false}
   354  			} else {
   355  				val[name] = Val{pname, false}
   356  			}
   357  		case *types.Func:
   358  			val[name] = Val{pname, false}
   359  		case *types.Var:
   360  			val[name] = Val{pname, true}
   361  		case *types.TypeName:
   362  			typ[name] = pname
   363  			if t, ok := o.Type().Underlying().(*types.Interface); ok {
   364  				var methods []Method
   365  				for i := 0; i < t.NumMethods(); i++ {
   366  					f := t.Method(i)
   367  					if !f.Exported() {
   368  						continue
   369  					}
   370  
   371  					sign := f.Type().(*types.Signature)
   372  					args := make([]string, sign.Params().Len())
   373  					params := make([]string, len(args))
   374  					for j := range args {
   375  						v := sign.Params().At(j)
   376  						if args[j] = v.Name(); args[j] == "" {
   377  							args[j] = fmt.Sprintf("a%d", j)
   378  						}
   379  						// process interface method variadic parameter
   380  						if sign.Variadic() && j == len(args)-1 { // check is last arg
   381  							// only replace the first "[]" to "..."
   382  							at := types.TypeString(v.Type(), qualify)[2:]
   383  							params[j] = args[j] + " ..." + at
   384  							args[j] += "..."
   385  						} else {
   386  							params[j] = args[j] + " " + types.TypeString(v.Type(), qualify)
   387  						}
   388  					}
   389  					arg := "(" + strings.Join(args, ", ") + ")"
   390  					param := "(" + strings.Join(params, ", ") + ")"
   391  
   392  					results := make([]string, sign.Results().Len())
   393  					for j := range results {
   394  						v := sign.Results().At(j)
   395  						results[j] = v.Name() + " " + types.TypeString(v.Type(), qualify)
   396  					}
   397  					result := "(" + strings.Join(results, ", ") + ")"
   398  
   399  					ret := ""
   400  					if sign.Results().Len() > 0 {
   401  						ret = "return"
   402  					}
   403  
   404  					methods = append(methods, Method{f.Name(), param, result, arg, ret})
   405  				}
   406  				wrap[name] = Wrap{prefix + name, methods}
   407  			}
   408  		}
   409  	}
   410  
   411  	// Generate buildTags with Go version only for stdlib packages.
   412  	// Third party packages do not depend on Go compiler version by default.
   413  	var buildTags string
   414  	if isInStdlib(importPath) {
   415  		var err error
   416  		buildTags, err = genBuildTags()
   417  		if err != nil {
   418  			return nil, err
   419  		}
   420  	}
   421  
   422  	base := template.New("extract")
   423  	parse, err := base.Parse(model)
   424  	if err != nil {
   425  		return nil, fmt.Errorf("template parsing error: %w", err)
   426  	}
   427  
   428  	if importPath == "log/syslog" {
   429  		buildTags += ",!windows,!nacl,!plan9"
   430  	}
   431  
   432  	if importPath == "syscall" {
   433  		// As per https://golang.org/cmd/go/#hdr-Build_constraints,
   434  		// using GOOS=android also matches tags and files for GOOS=linux,
   435  		// so exclude it explicitly to avoid collisions (issue #843).
   436  		// Also using GOOS=illumos matches tags and files for GOOS=solaris.
   437  		switch os.Getenv("GOOS") {
   438  		case "android":
   439  			buildTags += ",!linux"
   440  		case "illumos":
   441  			buildTags += ",!solaris"
   442  		}
   443  	}
   444  
   445  	for _, t := range e.Tag {
   446  		if len(t) != 0 {
   447  			buildTags += "," + t
   448  		}
   449  	}
   450  	if len(buildTags) != 0 && buildTags[0] == ',' {
   451  		buildTags = buildTags[1:]
   452  	}
   453  
   454  	b := new(bytes.Buffer)
   455  	data := map[string]interface{}{
   456  		"Dest":       e.Dest,
   457  		"Imports":    imports,
   458  		"ImportPath": importPath,
   459  		"PkgName":    path.Join(importPath, p.Name()),
   460  		"Val":        val,
   461  		"Typ":        typ,
   462  		"Wrap":       wrap,
   463  		"BuildTags":  buildTags,
   464  		"License":    e.License,
   465  	}
   466  	err = parse.Execute(b, data)
   467  	if err != nil {
   468  		return nil, fmt.Errorf("template error: %w", err)
   469  	}
   470  
   471  	// gofmt
   472  	source, err := format.Source(b.Bytes())
   473  	if err != nil {
   474  		return nil, fmt.Errorf("failed to format source: %w: %s", err, b.Bytes())
   475  	}
   476  	return source, nil
   477  }
   478  
   479  // fixConst checks untyped constant value, converting it if necessary to avoid overflow.
   480  func fixConst(name string, val constant.Value, imports map[string]bool) string {
   481  	var (
   482  		tok string
   483  		str string
   484  	)
   485  	switch val.Kind() {
   486  	case constant.String:
   487  		tok = "STRING"
   488  		str = val.ExactString()
   489  	case constant.Int:
   490  		tok = "INT"
   491  		str = val.ExactString()
   492  	case constant.Float:
   493  		v := constant.Val(val) // v is *big.Rat or *big.Float
   494  		f, ok := v.(*big.Float)
   495  		if !ok {
   496  			f = new(big.Float).SetRat(v.(*big.Rat))
   497  		}
   498  
   499  		tok = "FLOAT"
   500  		str = f.Text('g', int(f.Prec()))
   501  	case constant.Complex:
   502  		// TODO: not sure how to parse this case
   503  		fallthrough
   504  	default:
   505  		return name
   506  	}
   507  
   508  	imports["go/constant"] = true
   509  	imports["go/token"] = true
   510  
   511  	return fmt.Sprintf("constant.MakeFromLiteral(%q, token.%s, 0)", str, tok)
   512  }
   513  
   514  // Extractor creates a package with all the symbols from a dependency package.
   515  type Extractor struct {
   516  	Dest    string   // The name of the created package.
   517  	License string   // License text to be included in the created package, optional.
   518  	Exclude []string // Comma separated list of regexp matching symbols to exclude.
   519  	Include []string // Comma separated list of regexp matching symbols to include.
   520  	Tag     []string // Comma separated of build tags to be added to the created package.
   521  }
   522  
   523  // importPath checks whether pkgIdent is an existing directory relative to
   524  // e.WorkingDir. If yes, it returns the actual import path of the Go package
   525  // located in the directory. If it is definitely a relative path, but it does not
   526  // exist, an error is returned. Otherwise, it is assumed to be an import path, and
   527  // pkgIdent is returned.
   528  func (e *Extractor) importPath(pkgIdent, importPath string) (string, error) {
   529  	wd, err := os.Getwd()
   530  	if err != nil {
   531  		return "", err
   532  	}
   533  
   534  	dirPath := filepath.Join(wd, pkgIdent)
   535  	_, err = os.Stat(dirPath)
   536  	if err != nil && !os.IsNotExist(err) {
   537  		return "", err
   538  	}
   539  	if err != nil {
   540  		if len(pkgIdent) > 0 && pkgIdent[0] == '.' {
   541  			// pkgIdent is definitely a relative path, not a package name, and it does not exist
   542  			return "", err
   543  		}
   544  		// pkgIdent might be a valid stdlib package name. So we leave that responsibility to the caller now.
   545  		return pkgIdent, nil
   546  	}
   547  
   548  	// local import
   549  	if importPath != "" {
   550  		return importPath, nil
   551  	}
   552  
   553  	modPath := filepath.Join(dirPath, "go.mod")
   554  	_, err = os.Stat(modPath)
   555  	if os.IsNotExist(err) {
   556  		return "", errors.New("no go.mod found, and no import path specified")
   557  	}
   558  	if err != nil {
   559  		return "", err
   560  	}
   561  	f, err := os.Open(modPath)
   562  	if err != nil {
   563  		return "", err
   564  	}
   565  	defer func() {
   566  		_ = f.Close()
   567  	}()
   568  	sc := bufio.NewScanner(f)
   569  	var l string
   570  	for sc.Scan() {
   571  		l = sc.Text()
   572  		break
   573  	}
   574  	if sc.Err() != nil {
   575  		return "", err
   576  	}
   577  	parts := strings.Fields(l)
   578  	if len(parts) < 2 {
   579  		return "", errors.New(`invalid first line syntax in go.mod`)
   580  	}
   581  	if parts[0] != "module" {
   582  		return "", errors.New(`invalid first line in go.mod, no "module" found`)
   583  	}
   584  
   585  	return parts[1], nil
   586  }
   587  
   588  // Extract writes to rw a Go package with all the symbols found at pkgIdent.
   589  // pkgIdent can be an import path, or a local path, relative to e.WorkingDir. In
   590  // the latter case, Extract returns the actual import path of the package found at
   591  // pkgIdent, otherwise it just returns pkgIdent.
   592  // If pkgIdent is an import path, it is looked up in GOPATH. Vendoring is not
   593  // supported yet, and the behavior is only defined for GO111MODULE=off.
   594  func (e *Extractor) Extract(pkgIdent, importPath string, rw io.Writer) (string, error) {
   595  	ipp, err := e.importPath(pkgIdent, importPath)
   596  	if err != nil {
   597  		return "", err
   598  	}
   599  
   600  	pkg, err := importer.ForCompiler(token.NewFileSet(), "source", nil).Import(pkgIdent)
   601  	if err != nil {
   602  		return "", err
   603  	}
   604  
   605  	content, err := e.genContent(ipp, pkg)
   606  	if err != nil {
   607  		return "", err
   608  	}
   609  
   610  	if _, err := rw.Write(content); err != nil {
   611  		return "", err
   612  	}
   613  
   614  	return ipp, nil
   615  }
   616  
   617  func (e *Extractor) ExtractStruct(pkgIdent, importPath string) (*PackageStruct, error) {
   618  	ipp, err := e.importPath(pkgIdent, importPath)
   619  	if err != nil {
   620  		return nil, err
   621  	}
   622  
   623  	pkg, err := importer.ForCompiler(token.NewFileSet(), "source", nil).Import(pkgIdent)
   624  	if err != nil {
   625  		return nil, err
   626  	}
   627  
   628  	return e.genStructure(ipp, pkg)
   629  }
   630  
   631  // GetMinor returns the minor part of the version number.
   632  func GetMinor(part string) string {
   633  	minor := part
   634  	index := strings.Index(minor, "beta")
   635  	if index < 0 {
   636  		index = strings.Index(minor, "rc")
   637  	}
   638  	if index > 0 {
   639  		minor = minor[:index]
   640  	}
   641  
   642  	return minor
   643  }
   644  
   645  const defaultMinorVersion = 17
   646  
   647  func genBuildTags() (string, error) {
   648  	version := runtime.Version()
   649  	if strings.HasPrefix(version, "devel") {
   650  		return "", fmt.Errorf("extracting only supported with stable releases of Go, not %v", version)
   651  	}
   652  	parts := strings.Split(version, ".")
   653  
   654  	minorRaw := GetMinor(parts[1])
   655  
   656  	currentGoVersion := parts[0] + "." + minorRaw
   657  
   658  	minor, err := strconv.Atoi(minorRaw)
   659  	if err != nil {
   660  		return "", fmt.Errorf("failed to parse version: %w", err)
   661  	}
   662  
   663  	// Only append an upper bound if we are not on the latest go
   664  	if minor >= defaultMinorVersion {
   665  		return currentGoVersion, nil
   666  	}
   667  
   668  	nextGoVersion := parts[0] + "." + strconv.Itoa(minor+1)
   669  
   670  	return currentGoVersion + ",!" + nextGoVersion, nil
   671  }
   672  
   673  func isInStdlib(path string) bool { return !strings.Contains(path, ".") }