github.com/traefik/yaegi@v0.15.1/extract/extract.go (about)

     1  /*
     2  Package extract generates wrappers of package exported symbols.
     3  */
     4  package extract
     5  
     6  import (
     7  	"bufio"
     8  	"bytes"
     9  	"errors"
    10  	"fmt"
    11  	"go/constant"
    12  	"go/format"
    13  	"go/importer"
    14  	"go/token"
    15  	"go/types"
    16  	"io"
    17  	"math/big"
    18  	"os"
    19  	"path"
    20  	"path/filepath"
    21  	"regexp"
    22  	"runtime"
    23  	"strconv"
    24  	"strings"
    25  	"text/template"
    26  )
    27  
    28  const model = `// Code generated by 'yaegi extract {{.ImportPath}}'. DO NOT EDIT.
    29  
    30  {{.License}}
    31  
    32  {{if .BuildTags}}// +build {{.BuildTags}}{{end}}
    33  
    34  package {{.Dest}}
    35  
    36  import (
    37  {{- range $key, $value := .Imports }}
    38  	{{- if $value}}
    39  	"{{$key}}"
    40  	{{- end}}
    41  {{- end}}
    42  	"{{.ImportPath}}"
    43  	"reflect"
    44  )
    45  
    46  func init() {
    47  	Symbols["{{.PkgName}}"] = map[string]reflect.Value{
    48  		{{- if .Val}}
    49  		// function, constant and variable definitions
    50  		{{range $key, $value := .Val -}}
    51  			{{- if $value.Addr -}}
    52  				"{{$key}}": reflect.ValueOf(&{{$value.Name}}).Elem(),
    53  			{{else -}}
    54  				"{{$key}}": reflect.ValueOf({{$value.Name}}),
    55  			{{end -}}
    56  		{{end}}
    57  
    58  		{{- end}}
    59  		{{- if .Typ}}
    60  		// type definitions
    61  		{{range $key, $value := .Typ -}}
    62  			"{{$key}}": reflect.ValueOf((*{{$value}})(nil)),
    63  		{{end}}
    64  
    65  		{{- end}}
    66  		{{- if .Wrap}}
    67  		// interface wrapper definitions
    68  		{{range $key, $value := .Wrap -}}
    69  			"_{{$key}}": reflect.ValueOf((*{{$value.Name}})(nil)),
    70  		{{end}}
    71  		{{- end}}
    72  	}
    73  }
    74  {{range $key, $value := .Wrap -}}
    75  	// {{$value.Name}} is an interface wrapper for {{$key}} type
    76  	type {{$value.Name}} struct {
    77  		IValue interface{}
    78  		{{range $m := $value.Method -}}
    79  		W{{$m.Name}} func{{$m.Param}} {{$m.Result}}
    80  		{{end}}
    81  	}
    82  	{{range $m := $value.Method -}}
    83  		func (W {{$value.Name}}) {{$m.Name}}{{$m.Param}} {{$m.Result}} {
    84  			{{- if eq $m.Name "String"}}
    85  			if W.WString == nil {
    86  				return ""
    87  			}
    88  			{{end -}}
    89  			{{$m.Ret}} W.W{{$m.Name}}{{$m.Arg}}
    90  		}
    91  	{{end}}
    92  {{end}}
    93  `
    94  
    95  // Val stores the value name and addressable status of symbols.
    96  type Val struct {
    97  	Name string // "package.name"
    98  	Addr bool   // true if symbol is a Var
    99  }
   100  
   101  // Method stores information for generating interface wrapper method.
   102  type Method struct {
   103  	Name, Param, Result, Arg, Ret string
   104  }
   105  
   106  // Wrap stores information for generating interface wrapper.
   107  type Wrap struct {
   108  	Name   string
   109  	Method []Method
   110  }
   111  
   112  // restricted map defines symbols for which a special implementation is provided.
   113  var restricted = map[string]bool{
   114  	"osExit":        true,
   115  	"osFindProcess": true,
   116  	"logFatal":      true,
   117  	"logFatalf":     true,
   118  	"logFatalln":    true,
   119  	"logLogger":     true,
   120  	"logNew":        true,
   121  }
   122  
   123  func matchList(name string, list []string) (match bool, err error) {
   124  	for _, re := range list {
   125  		match, err = regexp.MatchString(re, name)
   126  		if err != nil || match {
   127  			return
   128  		}
   129  	}
   130  	return
   131  }
   132  
   133  func (e *Extractor) genContent(importPath string, p *types.Package) ([]byte, error) {
   134  	prefix := "_" + importPath + "_"
   135  	prefix = strings.NewReplacer("/", "_", "-", "_", ".", "_").Replace(prefix)
   136  
   137  	typ := map[string]string{}
   138  	val := map[string]Val{}
   139  	wrap := map[string]Wrap{}
   140  	imports := map[string]bool{}
   141  	sc := p.Scope()
   142  
   143  	for _, pkg := range p.Imports() {
   144  		imports[pkg.Path()] = false
   145  	}
   146  	qualify := func(pkg *types.Package) string {
   147  		if pkg.Path() != importPath {
   148  			imports[pkg.Path()] = true
   149  		}
   150  		return pkg.Name()
   151  	}
   152  
   153  	for _, name := range sc.Names() {
   154  		o := sc.Lookup(name)
   155  		if !o.Exported() {
   156  			continue
   157  		}
   158  
   159  		if len(e.Include) > 0 {
   160  			match, err := matchList(name, e.Include)
   161  			if err != nil {
   162  				return nil, err
   163  			}
   164  			if !match {
   165  				// Explicitly defined include expressions force non matching symbols to be skipped.
   166  				continue
   167  			}
   168  		}
   169  
   170  		match, err := matchList(name, e.Exclude)
   171  		if err != nil {
   172  			return nil, err
   173  		}
   174  		if match {
   175  			continue
   176  		}
   177  
   178  		pname := p.Name() + "." + name
   179  		if rname := p.Name() + name; restricted[rname] {
   180  			// Restricted symbol, locally provided by stdlib wrapper.
   181  			pname = rname
   182  		}
   183  
   184  		switch o := o.(type) {
   185  		case *types.Const:
   186  			if b, ok := o.Type().(*types.Basic); ok && (b.Info()&types.IsUntyped) != 0 {
   187  				// Convert untyped constant to right type to avoid overflow.
   188  				val[name] = Val{fixConst(pname, o.Val(), imports), false}
   189  			} else {
   190  				val[name] = Val{pname, false}
   191  			}
   192  		case *types.Func:
   193  			val[name] = Val{pname, false}
   194  		case *types.Var:
   195  			val[name] = Val{pname, true}
   196  		case *types.TypeName:
   197  			// Skip type if it is generic.
   198  			if t, ok := o.Type().(*types.Named); ok && t.TypeParams().Len() > 0 {
   199  				continue
   200  			}
   201  
   202  			typ[name] = pname
   203  			if t, ok := o.Type().Underlying().(*types.Interface); ok {
   204  				var methods []Method
   205  				for i := 0; i < t.NumMethods(); i++ {
   206  					f := t.Method(i)
   207  					if !f.Exported() {
   208  						continue
   209  					}
   210  
   211  					sign := f.Type().(*types.Signature)
   212  					args := make([]string, sign.Params().Len())
   213  					params := make([]string, len(args))
   214  					for j := range args {
   215  						v := sign.Params().At(j)
   216  						if args[j] = v.Name(); args[j] == "" {
   217  							args[j] = fmt.Sprintf("a%d", j)
   218  						}
   219  						// process interface method variadic parameter
   220  						if sign.Variadic() && j == len(args)-1 { // check is last arg
   221  							// only replace the first "[]" to "..."
   222  							at := types.TypeString(v.Type(), qualify)[2:]
   223  							params[j] = args[j] + " ..." + at
   224  							args[j] += "..."
   225  						} else {
   226  							params[j] = args[j] + " " + types.TypeString(v.Type(), qualify)
   227  						}
   228  					}
   229  					arg := "(" + strings.Join(args, ", ") + ")"
   230  					param := "(" + strings.Join(params, ", ") + ")"
   231  
   232  					results := make([]string, sign.Results().Len())
   233  					for j := range results {
   234  						v := sign.Results().At(j)
   235  						results[j] = v.Name() + " " + types.TypeString(v.Type(), qualify)
   236  					}
   237  					result := "(" + strings.Join(results, ", ") + ")"
   238  
   239  					ret := ""
   240  					if sign.Results().Len() > 0 {
   241  						ret = "return"
   242  					}
   243  
   244  					methods = append(methods, Method{f.Name(), param, result, arg, ret})
   245  				}
   246  				wrap[name] = Wrap{prefix + name, methods}
   247  			}
   248  		}
   249  	}
   250  
   251  	// Generate buildTags with Go version only for stdlib packages.
   252  	// Third party packages do not depend on Go compiler version by default.
   253  	var buildTags string
   254  	if isInStdlib(importPath) {
   255  		var err error
   256  		buildTags, err = genBuildTags()
   257  		if err != nil {
   258  			return nil, err
   259  		}
   260  	}
   261  
   262  	base := template.New("extract")
   263  	parse, err := base.Parse(model)
   264  	if err != nil {
   265  		return nil, fmt.Errorf("template parsing error: %w", err)
   266  	}
   267  
   268  	if importPath == "log/syslog" {
   269  		buildTags += ",!windows,!nacl,!plan9"
   270  	}
   271  
   272  	if importPath == "syscall" {
   273  		// As per https://golang.org/cmd/go/#hdr-Build_constraints,
   274  		// using GOOS=android also matches tags and files for GOOS=linux,
   275  		// so exclude it explicitly to avoid collisions (issue #843).
   276  		// Also using GOOS=illumos matches tags and files for GOOS=solaris.
   277  		switch os.Getenv("GOOS") {
   278  		case "android":
   279  			buildTags += ",!linux"
   280  		case "illumos":
   281  			buildTags += ",!solaris"
   282  		}
   283  	}
   284  
   285  	for _, t := range e.Tag {
   286  		if len(t) != 0 {
   287  			buildTags += "," + t
   288  		}
   289  	}
   290  	if len(buildTags) != 0 && buildTags[0] == ',' {
   291  		buildTags = buildTags[1:]
   292  	}
   293  
   294  	b := new(bytes.Buffer)
   295  	data := map[string]interface{}{
   296  		"Dest":       e.Dest,
   297  		"Imports":    imports,
   298  		"ImportPath": importPath,
   299  		"PkgName":    path.Join(importPath, p.Name()),
   300  		"Val":        val,
   301  		"Typ":        typ,
   302  		"Wrap":       wrap,
   303  		"BuildTags":  buildTags,
   304  		"License":    e.License,
   305  	}
   306  	err = parse.Execute(b, data)
   307  	if err != nil {
   308  		return nil, fmt.Errorf("template error: %w", err)
   309  	}
   310  
   311  	// gofmt
   312  	source, err := format.Source(b.Bytes())
   313  	if err != nil {
   314  		return nil, fmt.Errorf("failed to format source: %w: %s", err, b.Bytes())
   315  	}
   316  	return source, nil
   317  }
   318  
   319  // fixConst checks untyped constant value, converting it if necessary to avoid overflow.
   320  func fixConst(name string, val constant.Value, imports map[string]bool) string {
   321  	var (
   322  		tok string
   323  		str string
   324  	)
   325  	switch val.Kind() {
   326  	case constant.String:
   327  		tok = "STRING"
   328  		str = val.ExactString()
   329  	case constant.Int:
   330  		tok = "INT"
   331  		str = val.ExactString()
   332  	case constant.Float:
   333  		v := constant.Val(val) // v is *big.Rat or *big.Float
   334  		f, ok := v.(*big.Float)
   335  		if !ok {
   336  			f = new(big.Float).SetRat(v.(*big.Rat))
   337  		}
   338  
   339  		tok = "FLOAT"
   340  		str = f.Text('g', int(f.Prec()))
   341  	case constant.Complex:
   342  		// TODO: not sure how to parse this case
   343  		fallthrough
   344  	default:
   345  		return name
   346  	}
   347  
   348  	imports["go/constant"] = true
   349  	imports["go/token"] = true
   350  
   351  	return fmt.Sprintf("constant.MakeFromLiteral(%q, token.%s, 0)", str, tok)
   352  }
   353  
   354  // Extractor creates a package with all the symbols from a dependency package.
   355  type Extractor struct {
   356  	Dest    string   // The name of the created package.
   357  	License string   // License text to be included in the created package, optional.
   358  	Exclude []string // Comma separated list of regexp matching symbols to exclude.
   359  	Include []string // Comma separated list of regexp matching symbols to include.
   360  	Tag     []string // Comma separated of build tags to be added to the created package.
   361  }
   362  
   363  // importPath checks whether pkgIdent is an existing directory relative to
   364  // e.WorkingDir. If yes, it returns the actual import path of the Go package
   365  // located in the directory. If it is definitely a relative path, but it does not
   366  // exist, an error is returned. Otherwise, it is assumed to be an import path, and
   367  // pkgIdent is returned.
   368  func (e *Extractor) importPath(pkgIdent, importPath string) (string, error) {
   369  	wd, err := os.Getwd()
   370  	if err != nil {
   371  		return "", err
   372  	}
   373  
   374  	dirPath := filepath.Join(wd, pkgIdent)
   375  	_, err = os.Stat(dirPath)
   376  	if err != nil && !os.IsNotExist(err) {
   377  		return "", err
   378  	}
   379  	if err != nil {
   380  		if len(pkgIdent) > 0 && pkgIdent[0] == '.' {
   381  			// pkgIdent is definitely a relative path, not a package name, and it does not exist
   382  			return "", err
   383  		}
   384  		// pkgIdent might be a valid stdlib package name. So we leave that responsibility to the caller now.
   385  		return pkgIdent, nil
   386  	}
   387  
   388  	// local import
   389  	if importPath != "" {
   390  		return importPath, nil
   391  	}
   392  
   393  	modPath := filepath.Join(dirPath, "go.mod")
   394  	_, err = os.Stat(modPath)
   395  	if os.IsNotExist(err) {
   396  		return "", errors.New("no go.mod found, and no import path specified")
   397  	}
   398  	if err != nil {
   399  		return "", err
   400  	}
   401  	f, err := os.Open(modPath)
   402  	if err != nil {
   403  		return "", err
   404  	}
   405  	defer func() {
   406  		_ = f.Close()
   407  	}()
   408  	sc := bufio.NewScanner(f)
   409  	var l string
   410  	for sc.Scan() {
   411  		l = sc.Text()
   412  		break
   413  	}
   414  	if sc.Err() != nil {
   415  		return "", err
   416  	}
   417  	parts := strings.Fields(l)
   418  	if len(parts) < 2 {
   419  		return "", errors.New(`invalid first line syntax in go.mod`)
   420  	}
   421  	if parts[0] != "module" {
   422  		return "", errors.New(`invalid first line in go.mod, no "module" found`)
   423  	}
   424  
   425  	return parts[1], nil
   426  }
   427  
   428  // Extract writes to rw a Go package with all the symbols found at pkgIdent.
   429  // pkgIdent can be an import path, or a local path, relative to e.WorkingDir. In
   430  // the latter case, Extract returns the actual import path of the package found at
   431  // pkgIdent, otherwise it just returns pkgIdent.
   432  // If pkgIdent is an import path, it is looked up in GOPATH. Vendoring is not
   433  // supported yet, and the behavior is only defined for GO111MODULE=off.
   434  func (e *Extractor) Extract(pkgIdent, importPath string, rw io.Writer) (string, error) {
   435  	ipp, err := e.importPath(pkgIdent, importPath)
   436  	if err != nil {
   437  		return "", err
   438  	}
   439  
   440  	pkg, err := importer.ForCompiler(token.NewFileSet(), "source", nil).Import(pkgIdent)
   441  	if err != nil {
   442  		return "", err
   443  	}
   444  
   445  	content, err := e.genContent(ipp, pkg)
   446  	if err != nil {
   447  		return "", err
   448  	}
   449  
   450  	if _, err := rw.Write(content); err != nil {
   451  		return "", err
   452  	}
   453  
   454  	return ipp, nil
   455  }
   456  
   457  // GetMinor returns the minor part of the version number.
   458  func GetMinor(part string) string {
   459  	minor := part
   460  	index := strings.Index(minor, "beta")
   461  	if index < 0 {
   462  		index = strings.Index(minor, "rc")
   463  	}
   464  	if index > 0 {
   465  		minor = minor[:index]
   466  	}
   467  
   468  	return minor
   469  }
   470  
   471  const defaultMinorVersion = 20
   472  
   473  func genBuildTags() (string, error) {
   474  	version := runtime.Version()
   475  	if strings.HasPrefix(version, "devel") {
   476  		return "", fmt.Errorf("extracting only supported with stable releases of Go, not %v", version)
   477  	}
   478  	parts := strings.Split(version, ".")
   479  
   480  	minorRaw := GetMinor(parts[1])
   481  
   482  	currentGoVersion := parts[0] + "." + minorRaw
   483  
   484  	minor, err := strconv.Atoi(minorRaw)
   485  	if err != nil {
   486  		return "", fmt.Errorf("failed to parse version: %w", err)
   487  	}
   488  
   489  	// Only append an upper bound if we are not on the latest go
   490  	if minor >= defaultMinorVersion {
   491  		return currentGoVersion, nil
   492  	}
   493  
   494  	nextGoVersion := parts[0] + "." + strconv.Itoa(minor+1)
   495  
   496  	return currentGoVersion + ",!" + nextGoVersion, nil
   497  }
   498  
   499  func isInStdlib(path string) bool { return !strings.Contains(path, ".") }