github.com/traefik/yaegi@v0.15.1/extract/extract.go (about) 1 /* 2 Package extract generates wrappers of package exported symbols. 3 */ 4 package extract 5 6 import ( 7 "bufio" 8 "bytes" 9 "errors" 10 "fmt" 11 "go/constant" 12 "go/format" 13 "go/importer" 14 "go/token" 15 "go/types" 16 "io" 17 "math/big" 18 "os" 19 "path" 20 "path/filepath" 21 "regexp" 22 "runtime" 23 "strconv" 24 "strings" 25 "text/template" 26 ) 27 28 const model = `// Code generated by 'yaegi extract {{.ImportPath}}'. DO NOT EDIT. 29 30 {{.License}} 31 32 {{if .BuildTags}}// +build {{.BuildTags}}{{end}} 33 34 package {{.Dest}} 35 36 import ( 37 {{- range $key, $value := .Imports }} 38 {{- if $value}} 39 "{{$key}}" 40 {{- end}} 41 {{- end}} 42 "{{.ImportPath}}" 43 "reflect" 44 ) 45 46 func init() { 47 Symbols["{{.PkgName}}"] = map[string]reflect.Value{ 48 {{- if .Val}} 49 // function, constant and variable definitions 50 {{range $key, $value := .Val -}} 51 {{- if $value.Addr -}} 52 "{{$key}}": reflect.ValueOf(&{{$value.Name}}).Elem(), 53 {{else -}} 54 "{{$key}}": reflect.ValueOf({{$value.Name}}), 55 {{end -}} 56 {{end}} 57 58 {{- end}} 59 {{- if .Typ}} 60 // type definitions 61 {{range $key, $value := .Typ -}} 62 "{{$key}}": reflect.ValueOf((*{{$value}})(nil)), 63 {{end}} 64 65 {{- end}} 66 {{- if .Wrap}} 67 // interface wrapper definitions 68 {{range $key, $value := .Wrap -}} 69 "_{{$key}}": reflect.ValueOf((*{{$value.Name}})(nil)), 70 {{end}} 71 {{- end}} 72 } 73 } 74 {{range $key, $value := .Wrap -}} 75 // {{$value.Name}} is an interface wrapper for {{$key}} type 76 type {{$value.Name}} struct { 77 IValue interface{} 78 {{range $m := $value.Method -}} 79 W{{$m.Name}} func{{$m.Param}} {{$m.Result}} 80 {{end}} 81 } 82 {{range $m := $value.Method -}} 83 func (W {{$value.Name}}) {{$m.Name}}{{$m.Param}} {{$m.Result}} { 84 {{- if eq $m.Name "String"}} 85 if W.WString == nil { 86 return "" 87 } 88 {{end -}} 89 {{$m.Ret}} W.W{{$m.Name}}{{$m.Arg}} 90 } 91 {{end}} 92 {{end}} 93 ` 94 95 // Val stores the value name and addressable status of symbols. 96 type Val struct { 97 Name string // "package.name" 98 Addr bool // true if symbol is a Var 99 } 100 101 // Method stores information for generating interface wrapper method. 102 type Method struct { 103 Name, Param, Result, Arg, Ret string 104 } 105 106 // Wrap stores information for generating interface wrapper. 107 type Wrap struct { 108 Name string 109 Method []Method 110 } 111 112 // restricted map defines symbols for which a special implementation is provided. 113 var restricted = map[string]bool{ 114 "osExit": true, 115 "osFindProcess": true, 116 "logFatal": true, 117 "logFatalf": true, 118 "logFatalln": true, 119 "logLogger": true, 120 "logNew": true, 121 } 122 123 func matchList(name string, list []string) (match bool, err error) { 124 for _, re := range list { 125 match, err = regexp.MatchString(re, name) 126 if err != nil || match { 127 return 128 } 129 } 130 return 131 } 132 133 func (e *Extractor) genContent(importPath string, p *types.Package) ([]byte, error) { 134 prefix := "_" + importPath + "_" 135 prefix = strings.NewReplacer("/", "_", "-", "_", ".", "_").Replace(prefix) 136 137 typ := map[string]string{} 138 val := map[string]Val{} 139 wrap := map[string]Wrap{} 140 imports := map[string]bool{} 141 sc := p.Scope() 142 143 for _, pkg := range p.Imports() { 144 imports[pkg.Path()] = false 145 } 146 qualify := func(pkg *types.Package) string { 147 if pkg.Path() != importPath { 148 imports[pkg.Path()] = true 149 } 150 return pkg.Name() 151 } 152 153 for _, name := range sc.Names() { 154 o := sc.Lookup(name) 155 if !o.Exported() { 156 continue 157 } 158 159 if len(e.Include) > 0 { 160 match, err := matchList(name, e.Include) 161 if err != nil { 162 return nil, err 163 } 164 if !match { 165 // Explicitly defined include expressions force non matching symbols to be skipped. 166 continue 167 } 168 } 169 170 match, err := matchList(name, e.Exclude) 171 if err != nil { 172 return nil, err 173 } 174 if match { 175 continue 176 } 177 178 pname := p.Name() + "." + name 179 if rname := p.Name() + name; restricted[rname] { 180 // Restricted symbol, locally provided by stdlib wrapper. 181 pname = rname 182 } 183 184 switch o := o.(type) { 185 case *types.Const: 186 if b, ok := o.Type().(*types.Basic); ok && (b.Info()&types.IsUntyped) != 0 { 187 // Convert untyped constant to right type to avoid overflow. 188 val[name] = Val{fixConst(pname, o.Val(), imports), false} 189 } else { 190 val[name] = Val{pname, false} 191 } 192 case *types.Func: 193 val[name] = Val{pname, false} 194 case *types.Var: 195 val[name] = Val{pname, true} 196 case *types.TypeName: 197 // Skip type if it is generic. 198 if t, ok := o.Type().(*types.Named); ok && t.TypeParams().Len() > 0 { 199 continue 200 } 201 202 typ[name] = pname 203 if t, ok := o.Type().Underlying().(*types.Interface); ok { 204 var methods []Method 205 for i := 0; i < t.NumMethods(); i++ { 206 f := t.Method(i) 207 if !f.Exported() { 208 continue 209 } 210 211 sign := f.Type().(*types.Signature) 212 args := make([]string, sign.Params().Len()) 213 params := make([]string, len(args)) 214 for j := range args { 215 v := sign.Params().At(j) 216 if args[j] = v.Name(); args[j] == "" { 217 args[j] = fmt.Sprintf("a%d", j) 218 } 219 // process interface method variadic parameter 220 if sign.Variadic() && j == len(args)-1 { // check is last arg 221 // only replace the first "[]" to "..." 222 at := types.TypeString(v.Type(), qualify)[2:] 223 params[j] = args[j] + " ..." + at 224 args[j] += "..." 225 } else { 226 params[j] = args[j] + " " + types.TypeString(v.Type(), qualify) 227 } 228 } 229 arg := "(" + strings.Join(args, ", ") + ")" 230 param := "(" + strings.Join(params, ", ") + ")" 231 232 results := make([]string, sign.Results().Len()) 233 for j := range results { 234 v := sign.Results().At(j) 235 results[j] = v.Name() + " " + types.TypeString(v.Type(), qualify) 236 } 237 result := "(" + strings.Join(results, ", ") + ")" 238 239 ret := "" 240 if sign.Results().Len() > 0 { 241 ret = "return" 242 } 243 244 methods = append(methods, Method{f.Name(), param, result, arg, ret}) 245 } 246 wrap[name] = Wrap{prefix + name, methods} 247 } 248 } 249 } 250 251 // Generate buildTags with Go version only for stdlib packages. 252 // Third party packages do not depend on Go compiler version by default. 253 var buildTags string 254 if isInStdlib(importPath) { 255 var err error 256 buildTags, err = genBuildTags() 257 if err != nil { 258 return nil, err 259 } 260 } 261 262 base := template.New("extract") 263 parse, err := base.Parse(model) 264 if err != nil { 265 return nil, fmt.Errorf("template parsing error: %w", err) 266 } 267 268 if importPath == "log/syslog" { 269 buildTags += ",!windows,!nacl,!plan9" 270 } 271 272 if importPath == "syscall" { 273 // As per https://golang.org/cmd/go/#hdr-Build_constraints, 274 // using GOOS=android also matches tags and files for GOOS=linux, 275 // so exclude it explicitly to avoid collisions (issue #843). 276 // Also using GOOS=illumos matches tags and files for GOOS=solaris. 277 switch os.Getenv("GOOS") { 278 case "android": 279 buildTags += ",!linux" 280 case "illumos": 281 buildTags += ",!solaris" 282 } 283 } 284 285 for _, t := range e.Tag { 286 if len(t) != 0 { 287 buildTags += "," + t 288 } 289 } 290 if len(buildTags) != 0 && buildTags[0] == ',' { 291 buildTags = buildTags[1:] 292 } 293 294 b := new(bytes.Buffer) 295 data := map[string]interface{}{ 296 "Dest": e.Dest, 297 "Imports": imports, 298 "ImportPath": importPath, 299 "PkgName": path.Join(importPath, p.Name()), 300 "Val": val, 301 "Typ": typ, 302 "Wrap": wrap, 303 "BuildTags": buildTags, 304 "License": e.License, 305 } 306 err = parse.Execute(b, data) 307 if err != nil { 308 return nil, fmt.Errorf("template error: %w", err) 309 } 310 311 // gofmt 312 source, err := format.Source(b.Bytes()) 313 if err != nil { 314 return nil, fmt.Errorf("failed to format source: %w: %s", err, b.Bytes()) 315 } 316 return source, nil 317 } 318 319 // fixConst checks untyped constant value, converting it if necessary to avoid overflow. 320 func fixConst(name string, val constant.Value, imports map[string]bool) string { 321 var ( 322 tok string 323 str string 324 ) 325 switch val.Kind() { 326 case constant.String: 327 tok = "STRING" 328 str = val.ExactString() 329 case constant.Int: 330 tok = "INT" 331 str = val.ExactString() 332 case constant.Float: 333 v := constant.Val(val) // v is *big.Rat or *big.Float 334 f, ok := v.(*big.Float) 335 if !ok { 336 f = new(big.Float).SetRat(v.(*big.Rat)) 337 } 338 339 tok = "FLOAT" 340 str = f.Text('g', int(f.Prec())) 341 case constant.Complex: 342 // TODO: not sure how to parse this case 343 fallthrough 344 default: 345 return name 346 } 347 348 imports["go/constant"] = true 349 imports["go/token"] = true 350 351 return fmt.Sprintf("constant.MakeFromLiteral(%q, token.%s, 0)", str, tok) 352 } 353 354 // Extractor creates a package with all the symbols from a dependency package. 355 type Extractor struct { 356 Dest string // The name of the created package. 357 License string // License text to be included in the created package, optional. 358 Exclude []string // Comma separated list of regexp matching symbols to exclude. 359 Include []string // Comma separated list of regexp matching symbols to include. 360 Tag []string // Comma separated of build tags to be added to the created package. 361 } 362 363 // importPath checks whether pkgIdent is an existing directory relative to 364 // e.WorkingDir. If yes, it returns the actual import path of the Go package 365 // located in the directory. If it is definitely a relative path, but it does not 366 // exist, an error is returned. Otherwise, it is assumed to be an import path, and 367 // pkgIdent is returned. 368 func (e *Extractor) importPath(pkgIdent, importPath string) (string, error) { 369 wd, err := os.Getwd() 370 if err != nil { 371 return "", err 372 } 373 374 dirPath := filepath.Join(wd, pkgIdent) 375 _, err = os.Stat(dirPath) 376 if err != nil && !os.IsNotExist(err) { 377 return "", err 378 } 379 if err != nil { 380 if len(pkgIdent) > 0 && pkgIdent[0] == '.' { 381 // pkgIdent is definitely a relative path, not a package name, and it does not exist 382 return "", err 383 } 384 // pkgIdent might be a valid stdlib package name. So we leave that responsibility to the caller now. 385 return pkgIdent, nil 386 } 387 388 // local import 389 if importPath != "" { 390 return importPath, nil 391 } 392 393 modPath := filepath.Join(dirPath, "go.mod") 394 _, err = os.Stat(modPath) 395 if os.IsNotExist(err) { 396 return "", errors.New("no go.mod found, and no import path specified") 397 } 398 if err != nil { 399 return "", err 400 } 401 f, err := os.Open(modPath) 402 if err != nil { 403 return "", err 404 } 405 defer func() { 406 _ = f.Close() 407 }() 408 sc := bufio.NewScanner(f) 409 var l string 410 for sc.Scan() { 411 l = sc.Text() 412 break 413 } 414 if sc.Err() != nil { 415 return "", err 416 } 417 parts := strings.Fields(l) 418 if len(parts) < 2 { 419 return "", errors.New(`invalid first line syntax in go.mod`) 420 } 421 if parts[0] != "module" { 422 return "", errors.New(`invalid first line in go.mod, no "module" found`) 423 } 424 425 return parts[1], nil 426 } 427 428 // Extract writes to rw a Go package with all the symbols found at pkgIdent. 429 // pkgIdent can be an import path, or a local path, relative to e.WorkingDir. In 430 // the latter case, Extract returns the actual import path of the package found at 431 // pkgIdent, otherwise it just returns pkgIdent. 432 // If pkgIdent is an import path, it is looked up in GOPATH. Vendoring is not 433 // supported yet, and the behavior is only defined for GO111MODULE=off. 434 func (e *Extractor) Extract(pkgIdent, importPath string, rw io.Writer) (string, error) { 435 ipp, err := e.importPath(pkgIdent, importPath) 436 if err != nil { 437 return "", err 438 } 439 440 pkg, err := importer.ForCompiler(token.NewFileSet(), "source", nil).Import(pkgIdent) 441 if err != nil { 442 return "", err 443 } 444 445 content, err := e.genContent(ipp, pkg) 446 if err != nil { 447 return "", err 448 } 449 450 if _, err := rw.Write(content); err != nil { 451 return "", err 452 } 453 454 return ipp, nil 455 } 456 457 // GetMinor returns the minor part of the version number. 458 func GetMinor(part string) string { 459 minor := part 460 index := strings.Index(minor, "beta") 461 if index < 0 { 462 index = strings.Index(minor, "rc") 463 } 464 if index > 0 { 465 minor = minor[:index] 466 } 467 468 return minor 469 } 470 471 const defaultMinorVersion = 20 472 473 func genBuildTags() (string, error) { 474 version := runtime.Version() 475 if strings.HasPrefix(version, "devel") { 476 return "", fmt.Errorf("extracting only supported with stable releases of Go, not %v", version) 477 } 478 parts := strings.Split(version, ".") 479 480 minorRaw := GetMinor(parts[1]) 481 482 currentGoVersion := parts[0] + "." + minorRaw 483 484 minor, err := strconv.Atoi(minorRaw) 485 if err != nil { 486 return "", fmt.Errorf("failed to parse version: %w", err) 487 } 488 489 // Only append an upper bound if we are not on the latest go 490 if minor >= defaultMinorVersion { 491 return currentGoVersion, nil 492 } 493 494 nextGoVersion := parts[0] + "." + strconv.Itoa(minor+1) 495 496 return currentGoVersion + ",!" + nextGoVersion, nil 497 } 498 499 func isInStdlib(path string) bool { return !strings.Contains(path, ".") }