github.com/switchupcb/yaegi@v0.10.2/extract/extract.go (about) 1 /* 2 Package extract generates wrappers of package exported symbols. 3 */ 4 package extract 5 6 import ( 7 "bufio" 8 "bytes" 9 "errors" 10 "fmt" 11 "go/constant" 12 "go/format" 13 "go/importer" 14 "go/token" 15 "go/types" 16 "io" 17 "math/big" 18 "os" 19 "path" 20 "path/filepath" 21 "regexp" 22 "runtime" 23 "strconv" 24 "strings" 25 "text/template" 26 ) 27 28 const model = `// Code generated by 'yaegi extract {{.ImportPath}}'. DO NOT EDIT. 29 30 {{.License}} 31 32 {{if .BuildTags}}// +build {{.BuildTags}}{{end}} 33 34 package {{.Dest}} 35 36 import ( 37 {{- range $key, $value := .Imports }} 38 {{- if $value}} 39 "{{$key}}" 40 {{- end}} 41 {{- end}} 42 "{{.ImportPath}}" 43 "reflect" 44 ) 45 46 func init() { 47 Symbols["{{.PkgName}}"] = map[string]reflect.Value{ 48 {{- if .Val}} 49 // function, constant and variable definitions 50 {{range $key, $value := .Val -}} 51 {{- if $value.Addr -}} 52 "{{$key}}": reflect.ValueOf(&{{$value.Name}}).Elem(), 53 {{else -}} 54 "{{$key}}": reflect.ValueOf({{$value.Name}}), 55 {{end -}} 56 {{end}} 57 58 {{- end}} 59 {{- if .Typ}} 60 // type definitions 61 {{range $key, $value := .Typ -}} 62 "{{$key}}": reflect.ValueOf((*{{$value}})(nil)), 63 {{end}} 64 65 {{- end}} 66 {{- if .Wrap}} 67 // interface wrapper definitions 68 {{range $key, $value := .Wrap -}} 69 "_{{$key}}": reflect.ValueOf((*{{$value.Name}})(nil)), 70 {{end}} 71 {{- end}} 72 } 73 } 74 {{range $key, $value := .Wrap -}} 75 // {{$value.Name}} is an interface wrapper for {{$key}} type 76 type {{$value.Name}} struct { 77 IValue interface{} 78 {{range $m := $value.Method -}} 79 W{{$m.Name}} func{{$m.Param}} {{$m.Result}} 80 {{end}} 81 } 82 {{range $m := $value.Method -}} 83 func (W {{$value.Name}}) {{$m.Name}}{{$m.Param}} {{$m.Result}} { 84 {{- if eq $m.Name "String"}} 85 if W.WString == nil { 86 return "" 87 } 88 {{end -}} 89 {{$m.Ret}} W.W{{$m.Name}}{{$m.Arg}} 90 } 91 {{end}} 92 {{end}} 93 ` 94 95 // Val stores the value name and addressable status of symbols. 96 type Val struct { 97 Name string // "package.name" 98 Addr bool // true if symbol is a Var 99 } 100 101 // Method stores information for generating interface wrapper method. 102 type Method struct { 103 Name, Param, Result, Arg, Ret string 104 } 105 106 // Wrap stores information for generating interface wrapper. 107 type Wrap struct { 108 Name string 109 Method []Method 110 } 111 112 // restricted map defines symbols for which a special implementation is provided. 113 var restricted = map[string]bool{ 114 "osExit": true, 115 "osFindProcess": true, 116 "logFatal": true, 117 "logFatalf": true, 118 "logFatalln": true, 119 "logLogger": true, 120 "logNew": true, 121 } 122 123 func matchList(name string, list []string) (match bool, err error) { 124 for _, re := range list { 125 match, err = regexp.MatchString(re, name) 126 if err != nil || match { 127 return 128 } 129 } 130 return 131 } 132 133 func (e *Extractor) genContent(importPath string, p *types.Package) ([]byte, error) { 134 prefix := "_" + importPath + "_" 135 prefix = strings.NewReplacer("/", "_", "-", "_", ".", "_").Replace(prefix) 136 137 typ := map[string]string{} 138 val := map[string]Val{} 139 wrap := map[string]Wrap{} 140 imports := map[string]bool{} 141 sc := p.Scope() 142 143 for _, pkg := range p.Imports() { 144 imports[pkg.Path()] = false 145 } 146 qualify := func(pkg *types.Package) string { 147 if pkg.Path() != importPath { 148 imports[pkg.Path()] = true 149 } 150 return pkg.Name() 151 } 152 153 for _, name := range sc.Names() { 154 o := sc.Lookup(name) 155 if !o.Exported() { 156 continue 157 } 158 159 if len(e.Include) > 0 { 160 match, err := matchList(name, e.Include) 161 if err != nil { 162 return nil, err 163 } 164 if !match { 165 // Explicitly defined include expressions force non matching symbols to be skipped. 166 continue 167 } 168 } 169 170 match, err := matchList(name, e.Exclude) 171 if err != nil { 172 return nil, err 173 } 174 if match { 175 continue 176 } 177 178 pname := p.Name() + "." + name 179 if rname := p.Name() + name; restricted[rname] { 180 // Restricted symbol, locally provided by stdlib wrapper. 181 pname = rname 182 } 183 184 switch o := o.(type) { 185 case *types.Const: 186 if b, ok := o.Type().(*types.Basic); ok && (b.Info()&types.IsUntyped) != 0 { 187 // Convert untyped constant to right type to avoid overflow. 188 val[name] = Val{fixConst(pname, o.Val(), imports), false} 189 } else { 190 val[name] = Val{pname, false} 191 } 192 case *types.Func: 193 val[name] = Val{pname, false} 194 case *types.Var: 195 val[name] = Val{pname, true} 196 case *types.TypeName: 197 typ[name] = pname 198 if t, ok := o.Type().Underlying().(*types.Interface); ok { 199 var methods []Method 200 for i := 0; i < t.NumMethods(); i++ { 201 f := t.Method(i) 202 if !f.Exported() { 203 continue 204 } 205 206 sign := f.Type().(*types.Signature) 207 args := make([]string, sign.Params().Len()) 208 params := make([]string, len(args)) 209 for j := range args { 210 v := sign.Params().At(j) 211 if args[j] = v.Name(); args[j] == "" { 212 args[j] = fmt.Sprintf("a%d", j) 213 } 214 // process interface method variadic parameter 215 if sign.Variadic() && j == len(args)-1 { // check is last arg 216 // only replace the first "[]" to "..." 217 at := types.TypeString(v.Type(), qualify)[2:] 218 params[j] = args[j] + " ..." + at 219 args[j] += "..." 220 } else { 221 params[j] = args[j] + " " + types.TypeString(v.Type(), qualify) 222 } 223 } 224 arg := "(" + strings.Join(args, ", ") + ")" 225 param := "(" + strings.Join(params, ", ") + ")" 226 227 results := make([]string, sign.Results().Len()) 228 for j := range results { 229 v := sign.Results().At(j) 230 results[j] = v.Name() + " " + types.TypeString(v.Type(), qualify) 231 } 232 result := "(" + strings.Join(results, ", ") + ")" 233 234 ret := "" 235 if sign.Results().Len() > 0 { 236 ret = "return" 237 } 238 239 methods = append(methods, Method{f.Name(), param, result, arg, ret}) 240 } 241 wrap[name] = Wrap{prefix + name, methods} 242 } 243 } 244 } 245 246 // Generate buildTags with Go version only for stdlib packages. 247 // Third party packages do not depend on Go compiler version by default. 248 var buildTags string 249 if isInStdlib(importPath) { 250 var err error 251 buildTags, err = genBuildTags() 252 if err != nil { 253 return nil, err 254 } 255 } 256 257 base := template.New("extract") 258 parse, err := base.Parse(model) 259 if err != nil { 260 return nil, fmt.Errorf("template parsing error: %w", err) 261 } 262 263 if importPath == "log/syslog" { 264 buildTags += ",!windows,!nacl,!plan9" 265 } 266 267 if importPath == "syscall" { 268 // As per https://golang.org/cmd/go/#hdr-Build_constraints, 269 // using GOOS=android also matches tags and files for GOOS=linux, 270 // so exclude it explicitly to avoid collisions (issue #843). 271 // Also using GOOS=illumos matches tags and files for GOOS=solaris. 272 switch os.Getenv("GOOS") { 273 case "android": 274 buildTags += ",!linux" 275 case "illumos": 276 buildTags += ",!solaris" 277 } 278 } 279 280 for _, t := range e.Tag { 281 if len(t) != 0 { 282 buildTags += "," + t 283 } 284 } 285 if len(buildTags) != 0 && buildTags[0] == ',' { 286 buildTags = buildTags[1:] 287 } 288 289 b := new(bytes.Buffer) 290 data := map[string]interface{}{ 291 "Dest": e.Dest, 292 "Imports": imports, 293 "ImportPath": importPath, 294 "PkgName": path.Join(importPath, p.Name()), 295 "Val": val, 296 "Typ": typ, 297 "Wrap": wrap, 298 "BuildTags": buildTags, 299 "License": e.License, 300 } 301 err = parse.Execute(b, data) 302 if err != nil { 303 return nil, fmt.Errorf("template error: %w", err) 304 } 305 306 // gofmt 307 source, err := format.Source(b.Bytes()) 308 if err != nil { 309 return nil, fmt.Errorf("failed to format source: %w: %s", err, b.Bytes()) 310 } 311 return source, nil 312 } 313 314 // fixConst checks untyped constant value, converting it if necessary to avoid overflow. 315 func fixConst(name string, val constant.Value, imports map[string]bool) string { 316 var ( 317 tok string 318 str string 319 ) 320 switch val.Kind() { 321 case constant.String: 322 tok = "STRING" 323 str = val.ExactString() 324 case constant.Int: 325 tok = "INT" 326 str = val.ExactString() 327 case constant.Float: 328 v := constant.Val(val) // v is *big.Rat or *big.Float 329 f, ok := v.(*big.Float) 330 if !ok { 331 f = new(big.Float).SetRat(v.(*big.Rat)) 332 } 333 334 tok = "FLOAT" 335 str = f.Text('g', int(f.Prec())) 336 case constant.Complex: 337 // TODO: not sure how to parse this case 338 fallthrough 339 default: 340 return name 341 } 342 343 imports["go/constant"] = true 344 imports["go/token"] = true 345 346 return fmt.Sprintf("constant.MakeFromLiteral(%q, token.%s, 0)", str, tok) 347 } 348 349 // Extractor creates a package with all the symbols from a dependency package. 350 type Extractor struct { 351 Dest string // The name of the created package. 352 License string // License text to be included in the created package, optional. 353 Exclude []string // Comma separated list of regexp matching symbols to exclude. 354 Include []string // Comma separated list of regexp matching symbols to include. 355 Tag []string // Comma separated of build tags to be added to the created package. 356 } 357 358 // importPath checks whether pkgIdent is an existing directory relative to 359 // e.WorkingDir. If yes, it returns the actual import path of the Go package 360 // located in the directory. If it is definitely a relative path, but it does not 361 // exist, an error is returned. Otherwise, it is assumed to be an import path, and 362 // pkgIdent is returned. 363 func (e *Extractor) importPath(pkgIdent, importPath string) (string, error) { 364 wd, err := os.Getwd() 365 if err != nil { 366 return "", err 367 } 368 369 dirPath := filepath.Join(wd, pkgIdent) 370 _, err = os.Stat(dirPath) 371 if err != nil && !os.IsNotExist(err) { 372 return "", err 373 } 374 if err != nil { 375 if len(pkgIdent) > 0 && pkgIdent[0] == '.' { 376 // pkgIdent is definitely a relative path, not a package name, and it does not exist 377 return "", err 378 } 379 // pkgIdent might be a valid stdlib package name. So we leave that responsibility to the caller now. 380 return pkgIdent, nil 381 } 382 383 // local import 384 if importPath != "" { 385 return importPath, nil 386 } 387 388 modPath := filepath.Join(dirPath, "go.mod") 389 _, err = os.Stat(modPath) 390 if os.IsNotExist(err) { 391 return "", errors.New("no go.mod found, and no import path specified") 392 } 393 if err != nil { 394 return "", err 395 } 396 f, err := os.Open(modPath) 397 if err != nil { 398 return "", err 399 } 400 defer func() { 401 _ = f.Close() 402 }() 403 sc := bufio.NewScanner(f) 404 var l string 405 for sc.Scan() { 406 l = sc.Text() 407 break 408 } 409 if sc.Err() != nil { 410 return "", err 411 } 412 parts := strings.Fields(l) 413 if len(parts) < 2 { 414 return "", errors.New(`invalid first line syntax in go.mod`) 415 } 416 if parts[0] != "module" { 417 return "", errors.New(`invalid first line in go.mod, no "module" found`) 418 } 419 420 return parts[1], nil 421 } 422 423 // Extract writes to rw a Go package with all the symbols found at pkgIdent. 424 // pkgIdent can be an import path, or a local path, relative to e.WorkingDir. In 425 // the latter case, Extract returns the actual import path of the package found at 426 // pkgIdent, otherwise it just returns pkgIdent. 427 // If pkgIdent is an import path, it is looked up in GOPATH. Vendoring is not 428 // supported yet, and the behavior is only defined for GO111MODULE=off. 429 func (e *Extractor) Extract(pkgIdent, importPath string, rw io.Writer) (string, error) { 430 ipp, err := e.importPath(pkgIdent, importPath) 431 if err != nil { 432 return "", err 433 } 434 435 pkg, err := importer.ForCompiler(token.NewFileSet(), "source", nil).Import(pkgIdent) 436 if err != nil { 437 return "", err 438 } 439 440 content, err := e.genContent(ipp, pkg) 441 if err != nil { 442 return "", err 443 } 444 445 if _, err := rw.Write(content); err != nil { 446 return "", err 447 } 448 449 return ipp, nil 450 } 451 452 // GetMinor returns the minor part of the version number. 453 func GetMinor(part string) string { 454 minor := part 455 index := strings.Index(minor, "beta") 456 if index < 0 { 457 index = strings.Index(minor, "rc") 458 } 459 if index > 0 { 460 minor = minor[:index] 461 } 462 463 return minor 464 } 465 466 const defaultMinorVersion = 17 467 468 func genBuildTags() (string, error) { 469 version := runtime.Version() 470 if strings.HasPrefix(version, "devel") { 471 return "", fmt.Errorf("extracting only supported with stable releases of Go, not %v", version) 472 } 473 parts := strings.Split(version, ".") 474 475 minorRaw := GetMinor(parts[1]) 476 477 currentGoVersion := parts[0] + "." + minorRaw 478 479 minor, err := strconv.Atoi(minorRaw) 480 if err != nil { 481 return "", fmt.Errorf("failed to parse version: %w", err) 482 } 483 484 // Only append an upper bound if we are not on the latest go 485 if minor >= defaultMinorVersion { 486 return currentGoVersion, nil 487 } 488 489 nextGoVersion := parts[0] + "." + strconv.Itoa(minor+1) 490 491 return currentGoVersion + ",!" + nextGoVersion, nil 492 } 493 494 func isInStdlib(path string) bool { return !strings.Contains(path, ".") }