github.com/HaHadaxigua/yaegi@v1.0.1/extract/extract.go (about) 1 /* 2 Package extract generates wrappers of package exported symbols. 3 */ 4 package extract 5 6 import ( 7 "bufio" 8 "bytes" 9 "errors" 10 "fmt" 11 "go/constant" 12 "go/format" 13 "go/importer" 14 "go/token" 15 "go/types" 16 "io" 17 "math/big" 18 "os" 19 "path" 20 "path/filepath" 21 "regexp" 22 "runtime" 23 "strconv" 24 "strings" 25 "text/template" 26 ) 27 28 const model = `// Code generated by 'yaegi extract {{.ImportPath}}'. DO NOT EDIT. 29 30 {{.License}} 31 32 {{if .BuildTags}}// +build {{.BuildTags}}{{end}} 33 34 package {{.Dest}} 35 36 import ( 37 {{- range $key, $value := .Imports }} 38 {{- if $value}} 39 "{{$key}}" 40 {{- end}} 41 {{- end}} 42 "{{.ImportPath}}" 43 "reflect" 44 ) 45 46 func init() { 47 Symbols["{{.PkgName}}"] = map[string]reflect.Value{ 48 {{- if .Val}} 49 // function, constant and variable definitions 50 {{range $key, $value := .Val -}} 51 {{- if $value.Addr -}} 52 "{{$key}}": reflect.ValueOf(&{{$value.Name}}).Elem(), 53 {{else -}} 54 "{{$key}}": reflect.ValueOf({{$value.Name}}), 55 {{end -}} 56 {{end}} 57 58 {{- end}} 59 {{- if .Typ}} 60 // type definitions 61 {{range $key, $value := .Typ -}} 62 "{{$key}}": reflect.ValueOf((*{{$value}})(nil)), 63 {{end}} 64 65 {{- end}} 66 {{- if .Wrap}} 67 // interface wrapper definitions 68 {{range $key, $value := .Wrap -}} 69 "_{{$key}}": reflect.ValueOf((*{{$value.Name}})(nil)), 70 {{end}} 71 {{- end}} 72 } 73 } 74 {{range $key, $value := .Wrap -}} 75 // {{$value.Name}} is an interface wrapper for {{$key}} type 76 type {{$value.Name}} struct { 77 IValue interface{} 78 {{range $m := $value.Method -}} 79 W{{$m.Name}} func{{$m.Param}} {{$m.Result}} 80 {{end}} 81 } 82 {{range $m := $value.Method -}} 83 func (W {{$value.Name}}) {{$m.Name}}{{$m.Param}} {{$m.Result}} { 84 {{- if eq $m.Name "String"}} 85 if W.WString == nil { 86 return "" 87 } 88 {{end -}} 89 {{$m.Ret}} W.W{{$m.Name}}{{$m.Arg}} 90 } 91 {{end}} 92 {{end}} 93 ` 94 95 // Val stores the value name and addressable status of symbols. 96 type Val struct { 97 Name string // "package.name" 98 Addr bool // true if symbol is a Var 99 } 100 101 // Method stores information for generating interface wrapper method. 102 type Method struct { 103 Name, Param, Result, Arg, Ret string 104 } 105 106 // Wrap stores information for generating interface wrapper. 107 type Wrap struct { 108 Name string 109 Method []Method 110 } 111 112 // restricted map defines symbols for which a special implementation is provided. 113 var restricted = map[string]bool{ 114 "osExit": true, 115 "osFindProcess": true, 116 "logFatal": true, 117 "logFatalf": true, 118 "logFatalln": true, 119 "logLogger": true, 120 "logNew": true, 121 } 122 123 func matchList(name string, list []string) (match bool, err error) { 124 for _, re := range list { 125 match, err = regexp.MatchString(re, name) 126 if err != nil || match { 127 return 128 } 129 } 130 return 131 } 132 133 type PackageStruct struct { 134 Typ map[string]string 135 Val map[string]Val 136 Wrap map[string]Wrap 137 Imports map[string]bool 138 } 139 140 func (e *Extractor) genStructure(importPath string, p *types.Package) (*PackageStruct, error) { 141 prefix := "_" + importPath + "_" 142 prefix = strings.NewReplacer("/", "_", "-", "_", ".", "_").Replace(prefix) 143 144 typ := map[string]string{} 145 val := map[string]Val{} 146 wrap := map[string]Wrap{} 147 imports := map[string]bool{} 148 sc := p.Scope() 149 150 for _, pkg := range p.Imports() { 151 imports[pkg.Path()] = false 152 } 153 qualify := func(pkg *types.Package) string { 154 if pkg.Path() != importPath { 155 imports[pkg.Path()] = true 156 } 157 return pkg.Name() 158 } 159 160 for _, name := range sc.Names() { 161 o := sc.Lookup(name) 162 if !o.Exported() { 163 continue 164 } 165 166 if len(e.Include) > 0 { 167 match, err := matchList(name, e.Include) 168 if err != nil { 169 return nil, err 170 } 171 if !match { 172 // Explicitly defined include expressions force non matching symbols to be skipped. 173 continue 174 } 175 } 176 177 match, err := matchList(name, e.Exclude) 178 if err != nil { 179 return nil, err 180 } 181 if match { 182 continue 183 } 184 185 pname := p.Name() + "." + name 186 if rname := p.Name() + name; restricted[rname] { 187 // Restricted symbol, locally provided by stdlib wrapper. 188 pname = rname 189 } 190 191 switch o := o.(type) { 192 case *types.Const: 193 if b, ok := o.Type().(*types.Basic); ok && (b.Info()&types.IsUntyped) != 0 { 194 // Convert untyped constant to right type to avoid overflow. 195 val[name] = Val{fixConst(pname, o.Val(), imports), false} 196 } else { 197 val[name] = Val{pname, false} 198 } 199 case *types.Func: 200 val[name] = Val{pname, false} 201 case *types.Var: 202 val[name] = Val{pname, true} 203 case *types.TypeName: 204 typ[name] = pname 205 if t, ok := o.Type().Underlying().(*types.Interface); ok { 206 var methods []Method 207 for i := 0; i < t.NumMethods(); i++ { 208 f := t.Method(i) 209 if !f.Exported() { 210 continue 211 } 212 213 sign := f.Type().(*types.Signature) 214 args := make([]string, sign.Params().Len()) 215 params := make([]string, len(args)) 216 for j := range args { 217 v := sign.Params().At(j) 218 if args[j] = v.Name(); args[j] == "" { 219 args[j] = fmt.Sprintf("a%d", j) 220 } 221 // process interface method variadic parameter 222 if sign.Variadic() && j == len(args)-1 { // check is last arg 223 // only replace the first "[]" to "..." 224 at := types.TypeString(v.Type(), qualify)[2:] 225 params[j] = args[j] + " ..." + at 226 args[j] += "..." 227 } else { 228 params[j] = args[j] + " " + types.TypeString(v.Type(), qualify) 229 } 230 } 231 arg := "(" + strings.Join(args, ", ") + ")" 232 param := "(" + strings.Join(params, ", ") + ")" 233 234 results := make([]string, sign.Results().Len()) 235 for j := range results { 236 v := sign.Results().At(j) 237 results[j] = v.Name() + " " + types.TypeString(v.Type(), qualify) 238 } 239 result := "(" + strings.Join(results, ", ") + ")" 240 241 ret := "" 242 if sign.Results().Len() > 0 { 243 ret = "return" 244 } 245 246 methods = append(methods, Method{f.Name(), param, result, arg, ret}) 247 } 248 wrap[name] = Wrap{prefix + name, methods} 249 } 250 } 251 } 252 253 // Generate buildTags with Go version only for stdlib packages. 254 // Third party packages do not depend on Go compiler version by default. 255 var buildTags string 256 if isInStdlib(importPath) { 257 var err error 258 buildTags, err = genBuildTags() 259 if err != nil { 260 return nil, err 261 } 262 } 263 264 if importPath == "log/syslog" { 265 buildTags += ",!windows,!nacl,!plan9" 266 } 267 268 if importPath == "syscall" { 269 // As per https://golang.org/cmd/go/#hdr-Build_constraints, 270 // using GOOS=android also matches tags and files for GOOS=linux, 271 // so exclude it explicitly to avoid collisions (issue #843). 272 // Also using GOOS=illumos matches tags and files for GOOS=solaris. 273 switch os.Getenv("GOOS") { 274 case "android": 275 buildTags += ",!linux" 276 case "illumos": 277 buildTags += ",!solaris" 278 } 279 } 280 281 for _, t := range e.Tag { 282 if len(t) != 0 { 283 buildTags += "," + t 284 } 285 } 286 if len(buildTags) != 0 && buildTags[0] == ',' { 287 buildTags = buildTags[1:] 288 } 289 290 return &PackageStruct{ 291 Typ: typ, 292 Val: val, 293 Wrap: wrap, 294 Imports: imports, 295 }, nil 296 } 297 298 func (e *Extractor) genContent(importPath string, p *types.Package) ([]byte, error) { 299 prefix := "_" + importPath + "_" 300 prefix = strings.NewReplacer("/", "_", "-", "_", ".", "_").Replace(prefix) 301 302 typ := map[string]string{} 303 val := map[string]Val{} 304 wrap := map[string]Wrap{} 305 imports := map[string]bool{} 306 sc := p.Scope() 307 308 for _, pkg := range p.Imports() { 309 imports[pkg.Path()] = false 310 } 311 qualify := func(pkg *types.Package) string { 312 if pkg.Path() != importPath { 313 imports[pkg.Path()] = true 314 } 315 return pkg.Name() 316 } 317 318 for _, name := range sc.Names() { 319 o := sc.Lookup(name) 320 if !o.Exported() { 321 continue 322 } 323 324 if len(e.Include) > 0 { 325 match, err := matchList(name, e.Include) 326 if err != nil { 327 return nil, err 328 } 329 if !match { 330 // Explicitly defined include expressions force non matching symbols to be skipped. 331 continue 332 } 333 } 334 335 match, err := matchList(name, e.Exclude) 336 if err != nil { 337 return nil, err 338 } 339 if match { 340 continue 341 } 342 343 pname := p.Name() + "." + name 344 if rname := p.Name() + name; restricted[rname] { 345 // Restricted symbol, locally provided by stdlib wrapper. 346 pname = rname 347 } 348 349 switch o := o.(type) { 350 case *types.Const: 351 if b, ok := o.Type().(*types.Basic); ok && (b.Info()&types.IsUntyped) != 0 { 352 // Convert untyped constant to right type to avoid overflow. 353 val[name] = Val{fixConst(pname, o.Val(), imports), false} 354 } else { 355 val[name] = Val{pname, false} 356 } 357 case *types.Func: 358 val[name] = Val{pname, false} 359 case *types.Var: 360 val[name] = Val{pname, true} 361 case *types.TypeName: 362 typ[name] = pname 363 if t, ok := o.Type().Underlying().(*types.Interface); ok { 364 var methods []Method 365 for i := 0; i < t.NumMethods(); i++ { 366 f := t.Method(i) 367 if !f.Exported() { 368 continue 369 } 370 371 sign := f.Type().(*types.Signature) 372 args := make([]string, sign.Params().Len()) 373 params := make([]string, len(args)) 374 for j := range args { 375 v := sign.Params().At(j) 376 if args[j] = v.Name(); args[j] == "" { 377 args[j] = fmt.Sprintf("a%d", j) 378 } 379 // process interface method variadic parameter 380 if sign.Variadic() && j == len(args)-1 { // check is last arg 381 // only replace the first "[]" to "..." 382 at := types.TypeString(v.Type(), qualify)[2:] 383 params[j] = args[j] + " ..." + at 384 args[j] += "..." 385 } else { 386 params[j] = args[j] + " " + types.TypeString(v.Type(), qualify) 387 } 388 } 389 arg := "(" + strings.Join(args, ", ") + ")" 390 param := "(" + strings.Join(params, ", ") + ")" 391 392 results := make([]string, sign.Results().Len()) 393 for j := range results { 394 v := sign.Results().At(j) 395 results[j] = v.Name() + " " + types.TypeString(v.Type(), qualify) 396 } 397 result := "(" + strings.Join(results, ", ") + ")" 398 399 ret := "" 400 if sign.Results().Len() > 0 { 401 ret = "return" 402 } 403 404 methods = append(methods, Method{f.Name(), param, result, arg, ret}) 405 } 406 wrap[name] = Wrap{prefix + name, methods} 407 } 408 } 409 } 410 411 // Generate buildTags with Go version only for stdlib packages. 412 // Third party packages do not depend on Go compiler version by default. 413 var buildTags string 414 if isInStdlib(importPath) { 415 var err error 416 buildTags, err = genBuildTags() 417 if err != nil { 418 return nil, err 419 } 420 } 421 422 base := template.New("extract") 423 parse, err := base.Parse(model) 424 if err != nil { 425 return nil, fmt.Errorf("template parsing error: %w", err) 426 } 427 428 if importPath == "log/syslog" { 429 buildTags += ",!windows,!nacl,!plan9" 430 } 431 432 if importPath == "syscall" { 433 // As per https://golang.org/cmd/go/#hdr-Build_constraints, 434 // using GOOS=android also matches tags and files for GOOS=linux, 435 // so exclude it explicitly to avoid collisions (issue #843). 436 // Also using GOOS=illumos matches tags and files for GOOS=solaris. 437 switch os.Getenv("GOOS") { 438 case "android": 439 buildTags += ",!linux" 440 case "illumos": 441 buildTags += ",!solaris" 442 } 443 } 444 445 for _, t := range e.Tag { 446 if len(t) != 0 { 447 buildTags += "," + t 448 } 449 } 450 if len(buildTags) != 0 && buildTags[0] == ',' { 451 buildTags = buildTags[1:] 452 } 453 454 b := new(bytes.Buffer) 455 data := map[string]interface{}{ 456 "Dest": e.Dest, 457 "Imports": imports, 458 "ImportPath": importPath, 459 "PkgName": path.Join(importPath, p.Name()), 460 "Val": val, 461 "Typ": typ, 462 "Wrap": wrap, 463 "BuildTags": buildTags, 464 "License": e.License, 465 } 466 err = parse.Execute(b, data) 467 if err != nil { 468 return nil, fmt.Errorf("template error: %w", err) 469 } 470 471 // gofmt 472 source, err := format.Source(b.Bytes()) 473 if err != nil { 474 return nil, fmt.Errorf("failed to format source: %w: %s", err, b.Bytes()) 475 } 476 return source, nil 477 } 478 479 // fixConst checks untyped constant value, converting it if necessary to avoid overflow. 480 func fixConst(name string, val constant.Value, imports map[string]bool) string { 481 var ( 482 tok string 483 str string 484 ) 485 switch val.Kind() { 486 case constant.String: 487 tok = "STRING" 488 str = val.ExactString() 489 case constant.Int: 490 tok = "INT" 491 str = val.ExactString() 492 case constant.Float: 493 v := constant.Val(val) // v is *big.Rat or *big.Float 494 f, ok := v.(*big.Float) 495 if !ok { 496 f = new(big.Float).SetRat(v.(*big.Rat)) 497 } 498 499 tok = "FLOAT" 500 str = f.Text('g', int(f.Prec())) 501 case constant.Complex: 502 // TODO: not sure how to parse this case 503 fallthrough 504 default: 505 return name 506 } 507 508 imports["go/constant"] = true 509 imports["go/token"] = true 510 511 return fmt.Sprintf("constant.MakeFromLiteral(%q, token.%s, 0)", str, tok) 512 } 513 514 // Extractor creates a package with all the symbols from a dependency package. 515 type Extractor struct { 516 Dest string // The name of the created package. 517 License string // License text to be included in the created package, optional. 518 Exclude []string // Comma separated list of regexp matching symbols to exclude. 519 Include []string // Comma separated list of regexp matching symbols to include. 520 Tag []string // Comma separated of build tags to be added to the created package. 521 } 522 523 // importPath checks whether pkgIdent is an existing directory relative to 524 // e.WorkingDir. If yes, it returns the actual import path of the Go package 525 // located in the directory. If it is definitely a relative path, but it does not 526 // exist, an error is returned. Otherwise, it is assumed to be an import path, and 527 // pkgIdent is returned. 528 func (e *Extractor) importPath(pkgIdent, importPath string) (string, error) { 529 wd, err := os.Getwd() 530 if err != nil { 531 return "", err 532 } 533 534 dirPath := filepath.Join(wd, pkgIdent) 535 _, err = os.Stat(dirPath) 536 if err != nil && !os.IsNotExist(err) { 537 return "", err 538 } 539 if err != nil { 540 if len(pkgIdent) > 0 && pkgIdent[0] == '.' { 541 // pkgIdent is definitely a relative path, not a package name, and it does not exist 542 return "", err 543 } 544 // pkgIdent might be a valid stdlib package name. So we leave that responsibility to the caller now. 545 return pkgIdent, nil 546 } 547 548 // local import 549 if importPath != "" { 550 return importPath, nil 551 } 552 553 modPath := filepath.Join(dirPath, "go.mod") 554 _, err = os.Stat(modPath) 555 if os.IsNotExist(err) { 556 return "", errors.New("no go.mod found, and no import path specified") 557 } 558 if err != nil { 559 return "", err 560 } 561 f, err := os.Open(modPath) 562 if err != nil { 563 return "", err 564 } 565 defer func() { 566 _ = f.Close() 567 }() 568 sc := bufio.NewScanner(f) 569 var l string 570 for sc.Scan() { 571 l = sc.Text() 572 break 573 } 574 if sc.Err() != nil { 575 return "", err 576 } 577 parts := strings.Fields(l) 578 if len(parts) < 2 { 579 return "", errors.New(`invalid first line syntax in go.mod`) 580 } 581 if parts[0] != "module" { 582 return "", errors.New(`invalid first line in go.mod, no "module" found`) 583 } 584 585 return parts[1], nil 586 } 587 588 // Extract writes to rw a Go package with all the symbols found at pkgIdent. 589 // pkgIdent can be an import path, or a local path, relative to e.WorkingDir. In 590 // the latter case, Extract returns the actual import path of the package found at 591 // pkgIdent, otherwise it just returns pkgIdent. 592 // If pkgIdent is an import path, it is looked up in GOPATH. Vendoring is not 593 // supported yet, and the behavior is only defined for GO111MODULE=off. 594 func (e *Extractor) Extract(pkgIdent, importPath string, rw io.Writer) (string, error) { 595 ipp, err := e.importPath(pkgIdent, importPath) 596 if err != nil { 597 return "", err 598 } 599 600 pkg, err := importer.ForCompiler(token.NewFileSet(), "source", nil).Import(pkgIdent) 601 if err != nil { 602 return "", err 603 } 604 605 content, err := e.genContent(ipp, pkg) 606 if err != nil { 607 return "", err 608 } 609 610 if _, err := rw.Write(content); err != nil { 611 return "", err 612 } 613 614 return ipp, nil 615 } 616 617 func (e *Extractor) ExtractStruct(pkgIdent, importPath string) (*PackageStruct, error) { 618 ipp, err := e.importPath(pkgIdent, importPath) 619 if err != nil { 620 return nil, err 621 } 622 623 pkg, err := importer.ForCompiler(token.NewFileSet(), "source", nil).Import(pkgIdent) 624 if err != nil { 625 return nil, err 626 } 627 628 return e.genStructure(ipp, pkg) 629 } 630 631 // GetMinor returns the minor part of the version number. 632 func GetMinor(part string) string { 633 minor := part 634 index := strings.Index(minor, "beta") 635 if index < 0 { 636 index = strings.Index(minor, "rc") 637 } 638 if index > 0 { 639 minor = minor[:index] 640 } 641 642 return minor 643 } 644 645 const defaultMinorVersion = 17 646 647 func genBuildTags() (string, error) { 648 version := runtime.Version() 649 if strings.HasPrefix(version, "devel") { 650 return "", fmt.Errorf("extracting only supported with stable releases of Go, not %v", version) 651 } 652 parts := strings.Split(version, ".") 653 654 minorRaw := GetMinor(parts[1]) 655 656 currentGoVersion := parts[0] + "." + minorRaw 657 658 minor, err := strconv.Atoi(minorRaw) 659 if err != nil { 660 return "", fmt.Errorf("failed to parse version: %w", err) 661 } 662 663 // Only append an upper bound if we are not on the latest go 664 if minor >= defaultMinorVersion { 665 return currentGoVersion, nil 666 } 667 668 nextGoVersion := parts[0] + "." + strconv.Itoa(minor+1) 669 670 return currentGoVersion + ",!" + nextGoVersion, nil 671 } 672 673 func isInStdlib(path string) bool { return !strings.Contains(path, ".") }