golang.org/x/sys@v0.9.0/windows/mkwinsyscall/mkwinsyscall.go (about) 1 // Copyright 2013 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 /* 6 mkwinsyscall generates windows system call bodies 7 8 It parses all files specified on command line containing function 9 prototypes (like syscall_windows.go) and prints system call bodies 10 to standard output. 11 12 The prototypes are marked by lines beginning with "//sys" and read 13 like func declarations if //sys is replaced by func, but: 14 15 - The parameter lists must give a name for each argument. This 16 includes return parameters. 17 18 - The parameter lists must give a type for each argument: 19 the (x, y, z int) shorthand is not allowed. 20 21 - If the return parameter is an error number, it must be named err. 22 23 - If go func name needs to be different from its winapi dll name, 24 the winapi name could be specified at the end, after "=" sign, like 25 //sys LoadLibrary(libname string) (handle uint32, err error) = LoadLibraryA 26 27 - Each function that returns err needs to supply a condition, that 28 return value of winapi will be tested against to detect failure. 29 This would set err to windows "last-error", otherwise it will be nil. 30 The value can be provided at end of //sys declaration, like 31 //sys LoadLibrary(libname string) (handle uint32, err error) [failretval==-1] = LoadLibraryA 32 and is [failretval==0] by default. 33 34 - If the function name ends in a "?", then the function not existing is non- 35 fatal, and an error will be returned instead of panicking. 36 37 Usage: 38 39 mkwinsyscall [flags] [path ...] 40 41 The flags are: 42 43 -output 44 Specify output file name (outputs to console if blank). 45 -trace 46 Generate print statement after every syscall. 47 */ 48 package main 49 50 import ( 51 "bufio" 52 "bytes" 53 "errors" 54 "flag" 55 "fmt" 56 "go/format" 57 "go/parser" 58 "go/token" 59 "io" 60 "io/ioutil" 61 "log" 62 "os" 63 "path/filepath" 64 "runtime" 65 "sort" 66 "strconv" 67 "strings" 68 "text/template" 69 ) 70 71 var ( 72 filename = flag.String("output", "", "output file name (standard output if omitted)") 73 printTraceFlag = flag.Bool("trace", false, "generate print statement after every syscall") 74 systemDLL = flag.Bool("systemdll", true, "whether all DLLs should be loaded from the Windows system directory") 75 ) 76 77 func trim(s string) string { 78 return strings.Trim(s, " \t") 79 } 80 81 var packageName string 82 83 func packagename() string { 84 return packageName 85 } 86 87 func windowsdot() string { 88 if packageName == "windows" { 89 return "" 90 } 91 return "windows." 92 } 93 94 func syscalldot() string { 95 if packageName == "syscall" { 96 return "" 97 } 98 return "syscall." 99 } 100 101 // Param is function parameter 102 type Param struct { 103 Name string 104 Type string 105 fn *Fn 106 tmpVarIdx int 107 } 108 109 // tmpVar returns temp variable name that will be used to represent p during syscall. 110 func (p *Param) tmpVar() string { 111 if p.tmpVarIdx < 0 { 112 p.tmpVarIdx = p.fn.curTmpVarIdx 113 p.fn.curTmpVarIdx++ 114 } 115 return fmt.Sprintf("_p%d", p.tmpVarIdx) 116 } 117 118 // BoolTmpVarCode returns source code for bool temp variable. 119 func (p *Param) BoolTmpVarCode() string { 120 const code = `var %[1]s uint32 121 if %[2]s { 122 %[1]s = 1 123 }` 124 return fmt.Sprintf(code, p.tmpVar(), p.Name) 125 } 126 127 // BoolPointerTmpVarCode returns source code for bool temp variable. 128 func (p *Param) BoolPointerTmpVarCode() string { 129 const code = `var %[1]s uint32 130 if *%[2]s { 131 %[1]s = 1 132 }` 133 return fmt.Sprintf(code, p.tmpVar(), p.Name) 134 } 135 136 // SliceTmpVarCode returns source code for slice temp variable. 137 func (p *Param) SliceTmpVarCode() string { 138 const code = `var %s *%s 139 if len(%s) > 0 { 140 %s = &%s[0] 141 }` 142 tmp := p.tmpVar() 143 return fmt.Sprintf(code, tmp, p.Type[2:], p.Name, tmp, p.Name) 144 } 145 146 // StringTmpVarCode returns source code for string temp variable. 147 func (p *Param) StringTmpVarCode() string { 148 errvar := p.fn.Rets.ErrorVarName() 149 if errvar == "" { 150 errvar = "_" 151 } 152 tmp := p.tmpVar() 153 const code = `var %s %s 154 %s, %s = %s(%s)` 155 s := fmt.Sprintf(code, tmp, p.fn.StrconvType(), tmp, errvar, p.fn.StrconvFunc(), p.Name) 156 if errvar == "-" { 157 return s 158 } 159 const morecode = ` 160 if %s != nil { 161 return 162 }` 163 return s + fmt.Sprintf(morecode, errvar) 164 } 165 166 // TmpVarCode returns source code for temp variable. 167 func (p *Param) TmpVarCode() string { 168 switch { 169 case p.Type == "bool": 170 return p.BoolTmpVarCode() 171 case p.Type == "*bool": 172 return p.BoolPointerTmpVarCode() 173 case strings.HasPrefix(p.Type, "[]"): 174 return p.SliceTmpVarCode() 175 default: 176 return "" 177 } 178 } 179 180 // TmpVarReadbackCode returns source code for reading back the temp variable into the original variable. 181 func (p *Param) TmpVarReadbackCode() string { 182 switch { 183 case p.Type == "*bool": 184 return fmt.Sprintf("*%s = %s != 0", p.Name, p.tmpVar()) 185 default: 186 return "" 187 } 188 } 189 190 // TmpVarHelperCode returns source code for helper's temp variable. 191 func (p *Param) TmpVarHelperCode() string { 192 if p.Type != "string" { 193 return "" 194 } 195 return p.StringTmpVarCode() 196 } 197 198 // SyscallArgList returns source code fragments representing p parameter 199 // in syscall. Slices are translated into 2 syscall parameters: pointer to 200 // the first element and length. 201 func (p *Param) SyscallArgList() []string { 202 t := p.HelperType() 203 var s string 204 switch { 205 case t == "*bool": 206 s = fmt.Sprintf("unsafe.Pointer(&%s)", p.tmpVar()) 207 case t[0] == '*': 208 s = fmt.Sprintf("unsafe.Pointer(%s)", p.Name) 209 case t == "bool": 210 s = p.tmpVar() 211 case strings.HasPrefix(t, "[]"): 212 return []string{ 213 fmt.Sprintf("uintptr(unsafe.Pointer(%s))", p.tmpVar()), 214 fmt.Sprintf("uintptr(len(%s))", p.Name), 215 } 216 default: 217 s = p.Name 218 } 219 return []string{fmt.Sprintf("uintptr(%s)", s)} 220 } 221 222 // IsError determines if p parameter is used to return error. 223 func (p *Param) IsError() bool { 224 return p.Name == "err" && p.Type == "error" 225 } 226 227 // HelperType returns type of parameter p used in helper function. 228 func (p *Param) HelperType() string { 229 if p.Type == "string" { 230 return p.fn.StrconvType() 231 } 232 return p.Type 233 } 234 235 // join concatenates parameters ps into a string with sep separator. 236 // Each parameter is converted into string by applying fn to it 237 // before conversion. 238 func join(ps []*Param, fn func(*Param) string, sep string) string { 239 if len(ps) == 0 { 240 return "" 241 } 242 a := make([]string, 0) 243 for _, p := range ps { 244 a = append(a, fn(p)) 245 } 246 return strings.Join(a, sep) 247 } 248 249 // Rets describes function return parameters. 250 type Rets struct { 251 Name string 252 Type string 253 ReturnsError bool 254 FailCond string 255 fnMaybeAbsent bool 256 } 257 258 // ErrorVarName returns error variable name for r. 259 func (r *Rets) ErrorVarName() string { 260 if r.ReturnsError { 261 return "err" 262 } 263 if r.Type == "error" { 264 return r.Name 265 } 266 return "" 267 } 268 269 // ToParams converts r into slice of *Param. 270 func (r *Rets) ToParams() []*Param { 271 ps := make([]*Param, 0) 272 if len(r.Name) > 0 { 273 ps = append(ps, &Param{Name: r.Name, Type: r.Type}) 274 } 275 if r.ReturnsError { 276 ps = append(ps, &Param{Name: "err", Type: "error"}) 277 } 278 return ps 279 } 280 281 // List returns source code of syscall return parameters. 282 func (r *Rets) List() string { 283 s := join(r.ToParams(), func(p *Param) string { return p.Name + " " + p.Type }, ", ") 284 if len(s) > 0 { 285 s = "(" + s + ")" 286 } else if r.fnMaybeAbsent { 287 s = "(err error)" 288 } 289 return s 290 } 291 292 // PrintList returns source code of trace printing part correspondent 293 // to syscall return values. 294 func (r *Rets) PrintList() string { 295 return join(r.ToParams(), func(p *Param) string { return fmt.Sprintf(`"%s=", %s, `, p.Name, p.Name) }, `", ", `) 296 } 297 298 // SetReturnValuesCode returns source code that accepts syscall return values. 299 func (r *Rets) SetReturnValuesCode() string { 300 if r.Name == "" && !r.ReturnsError { 301 return "" 302 } 303 retvar := "r0" 304 if r.Name == "" { 305 retvar = "r1" 306 } 307 errvar := "_" 308 if r.ReturnsError { 309 errvar = "e1" 310 } 311 return fmt.Sprintf("%s, _, %s := ", retvar, errvar) 312 } 313 314 func (r *Rets) useLongHandleErrorCode(retvar string) string { 315 const code = `if %s { 316 err = errnoErr(e1) 317 }` 318 cond := retvar + " == 0" 319 if r.FailCond != "" { 320 cond = strings.Replace(r.FailCond, "failretval", retvar, 1) 321 } 322 return fmt.Sprintf(code, cond) 323 } 324 325 // SetErrorCode returns source code that sets return parameters. 326 func (r *Rets) SetErrorCode() string { 327 const code = `if r0 != 0 { 328 %s = %sErrno(r0) 329 }` 330 const ntstatus = `if r0 != 0 { 331 ntstatus = %sNTStatus(r0) 332 }` 333 if r.Name == "" && !r.ReturnsError { 334 return "" 335 } 336 if r.Name == "" { 337 return r.useLongHandleErrorCode("r1") 338 } 339 if r.Type == "error" && r.Name == "ntstatus" { 340 return fmt.Sprintf(ntstatus, windowsdot()) 341 } 342 if r.Type == "error" { 343 return fmt.Sprintf(code, r.Name, syscalldot()) 344 } 345 s := "" 346 switch { 347 case r.Type[0] == '*': 348 s = fmt.Sprintf("%s = (%s)(unsafe.Pointer(r0))", r.Name, r.Type) 349 case r.Type == "bool": 350 s = fmt.Sprintf("%s = r0 != 0", r.Name) 351 default: 352 s = fmt.Sprintf("%s = %s(r0)", r.Name, r.Type) 353 } 354 if !r.ReturnsError { 355 return s 356 } 357 return s + "\n\t" + r.useLongHandleErrorCode(r.Name) 358 } 359 360 // Fn describes syscall function. 361 type Fn struct { 362 Name string 363 Params []*Param 364 Rets *Rets 365 PrintTrace bool 366 dllname string 367 dllfuncname string 368 src string 369 // TODO: get rid of this field and just use parameter index instead 370 curTmpVarIdx int // insure tmp variables have uniq names 371 } 372 373 // extractParams parses s to extract function parameters. 374 func extractParams(s string, f *Fn) ([]*Param, error) { 375 s = trim(s) 376 if s == "" { 377 return nil, nil 378 } 379 a := strings.Split(s, ",") 380 ps := make([]*Param, len(a)) 381 for i := range ps { 382 s2 := trim(a[i]) 383 b := strings.Split(s2, " ") 384 if len(b) != 2 { 385 b = strings.Split(s2, "\t") 386 if len(b) != 2 { 387 return nil, errors.New("Could not extract function parameter from \"" + s2 + "\"") 388 } 389 } 390 ps[i] = &Param{ 391 Name: trim(b[0]), 392 Type: trim(b[1]), 393 fn: f, 394 tmpVarIdx: -1, 395 } 396 } 397 return ps, nil 398 } 399 400 // extractSection extracts text out of string s starting after start 401 // and ending just before end. found return value will indicate success, 402 // and prefix, body and suffix will contain correspondent parts of string s. 403 func extractSection(s string, start, end rune) (prefix, body, suffix string, found bool) { 404 s = trim(s) 405 if strings.HasPrefix(s, string(start)) { 406 // no prefix 407 body = s[1:] 408 } else { 409 a := strings.SplitN(s, string(start), 2) 410 if len(a) != 2 { 411 return "", "", s, false 412 } 413 prefix = a[0] 414 body = a[1] 415 } 416 a := strings.SplitN(body, string(end), 2) 417 if len(a) != 2 { 418 return "", "", "", false 419 } 420 return prefix, a[0], a[1], true 421 } 422 423 // newFn parses string s and return created function Fn. 424 func newFn(s string) (*Fn, error) { 425 s = trim(s) 426 f := &Fn{ 427 Rets: &Rets{}, 428 src: s, 429 PrintTrace: *printTraceFlag, 430 } 431 // function name and args 432 prefix, body, s, found := extractSection(s, '(', ')') 433 if !found || prefix == "" { 434 return nil, errors.New("Could not extract function name and parameters from \"" + f.src + "\"") 435 } 436 f.Name = prefix 437 var err error 438 f.Params, err = extractParams(body, f) 439 if err != nil { 440 return nil, err 441 } 442 // return values 443 _, body, s, found = extractSection(s, '(', ')') 444 if found { 445 r, err := extractParams(body, f) 446 if err != nil { 447 return nil, err 448 } 449 switch len(r) { 450 case 0: 451 case 1: 452 if r[0].IsError() { 453 f.Rets.ReturnsError = true 454 } else { 455 f.Rets.Name = r[0].Name 456 f.Rets.Type = r[0].Type 457 } 458 case 2: 459 if !r[1].IsError() { 460 return nil, errors.New("Only last windows error is allowed as second return value in \"" + f.src + "\"") 461 } 462 f.Rets.ReturnsError = true 463 f.Rets.Name = r[0].Name 464 f.Rets.Type = r[0].Type 465 default: 466 return nil, errors.New("Too many return values in \"" + f.src + "\"") 467 } 468 } 469 // fail condition 470 _, body, s, found = extractSection(s, '[', ']') 471 if found { 472 f.Rets.FailCond = body 473 } 474 // dll and dll function names 475 s = trim(s) 476 if s == "" { 477 return f, nil 478 } 479 if !strings.HasPrefix(s, "=") { 480 return nil, errors.New("Could not extract dll name from \"" + f.src + "\"") 481 } 482 s = trim(s[1:]) 483 if i := strings.LastIndex(s, "."); i >= 0 { 484 f.dllname = s[:i] 485 f.dllfuncname = s[i+1:] 486 } else { 487 f.dllfuncname = s 488 } 489 if f.dllfuncname == "" { 490 return nil, fmt.Errorf("function name is not specified in %q", s) 491 } 492 if n := f.dllfuncname; strings.HasSuffix(n, "?") { 493 f.dllfuncname = n[:len(n)-1] 494 f.Rets.fnMaybeAbsent = true 495 } 496 return f, nil 497 } 498 499 // DLLName returns DLL name for function f. 500 func (f *Fn) DLLName() string { 501 if f.dllname == "" { 502 return "kernel32" 503 } 504 return f.dllname 505 } 506 507 // DLLVar returns a valid Go identifier that represents DLLName. 508 func (f *Fn) DLLVar() string { 509 id := strings.Map(func(r rune) rune { 510 switch r { 511 case '.', '-': 512 return '_' 513 default: 514 return r 515 } 516 }, f.DLLName()) 517 if !token.IsIdentifier(id) { 518 panic(fmt.Errorf("could not create Go identifier for DLLName %q", f.DLLName())) 519 } 520 return id 521 } 522 523 // DLLFuncName returns DLL function name for function f. 524 func (f *Fn) DLLFuncName() string { 525 if f.dllfuncname == "" { 526 return f.Name 527 } 528 return f.dllfuncname 529 } 530 531 // ParamList returns source code for function f parameters. 532 func (f *Fn) ParamList() string { 533 return join(f.Params, func(p *Param) string { return p.Name + " " + p.Type }, ", ") 534 } 535 536 // HelperParamList returns source code for helper function f parameters. 537 func (f *Fn) HelperParamList() string { 538 return join(f.Params, func(p *Param) string { return p.Name + " " + p.HelperType() }, ", ") 539 } 540 541 // ParamPrintList returns source code of trace printing part correspondent 542 // to syscall input parameters. 543 func (f *Fn) ParamPrintList() string { 544 return join(f.Params, func(p *Param) string { return fmt.Sprintf(`"%s=", %s, `, p.Name, p.Name) }, `", ", `) 545 } 546 547 // ParamCount return number of syscall parameters for function f. 548 func (f *Fn) ParamCount() int { 549 n := 0 550 for _, p := range f.Params { 551 n += len(p.SyscallArgList()) 552 } 553 return n 554 } 555 556 // SyscallParamCount determines which version of Syscall/Syscall6/Syscall9/... 557 // to use. It returns parameter count for correspondent SyscallX function. 558 func (f *Fn) SyscallParamCount() int { 559 n := f.ParamCount() 560 switch { 561 case n <= 3: 562 return 3 563 case n <= 6: 564 return 6 565 case n <= 9: 566 return 9 567 case n <= 12: 568 return 12 569 case n <= 15: 570 return 15 571 default: 572 panic("too many arguments to system call") 573 } 574 } 575 576 // Syscall determines which SyscallX function to use for function f. 577 func (f *Fn) Syscall() string { 578 c := f.SyscallParamCount() 579 if c == 3 { 580 return syscalldot() + "Syscall" 581 } 582 return syscalldot() + "Syscall" + strconv.Itoa(c) 583 } 584 585 // SyscallParamList returns source code for SyscallX parameters for function f. 586 func (f *Fn) SyscallParamList() string { 587 a := make([]string, 0) 588 for _, p := range f.Params { 589 a = append(a, p.SyscallArgList()...) 590 } 591 for len(a) < f.SyscallParamCount() { 592 a = append(a, "0") 593 } 594 return strings.Join(a, ", ") 595 } 596 597 // HelperCallParamList returns source code of call into function f helper. 598 func (f *Fn) HelperCallParamList() string { 599 a := make([]string, 0, len(f.Params)) 600 for _, p := range f.Params { 601 s := p.Name 602 if p.Type == "string" { 603 s = p.tmpVar() 604 } 605 a = append(a, s) 606 } 607 return strings.Join(a, ", ") 608 } 609 610 // MaybeAbsent returns source code for handling functions that are possibly unavailable. 611 func (p *Fn) MaybeAbsent() string { 612 if !p.Rets.fnMaybeAbsent { 613 return "" 614 } 615 const code = `%[1]s = proc%[2]s.Find() 616 if %[1]s != nil { 617 return 618 }` 619 errorVar := p.Rets.ErrorVarName() 620 if errorVar == "" { 621 errorVar = "err" 622 } 623 return fmt.Sprintf(code, errorVar, p.DLLFuncName()) 624 } 625 626 // IsUTF16 is true, if f is W (utf16) function. It is false 627 // for all A (ascii) functions. 628 func (f *Fn) IsUTF16() bool { 629 s := f.DLLFuncName() 630 return s[len(s)-1] == 'W' 631 } 632 633 // StrconvFunc returns name of Go string to OS string function for f. 634 func (f *Fn) StrconvFunc() string { 635 if f.IsUTF16() { 636 return syscalldot() + "UTF16PtrFromString" 637 } 638 return syscalldot() + "BytePtrFromString" 639 } 640 641 // StrconvType returns Go type name used for OS string for f. 642 func (f *Fn) StrconvType() string { 643 if f.IsUTF16() { 644 return "*uint16" 645 } 646 return "*byte" 647 } 648 649 // HasStringParam is true, if f has at least one string parameter. 650 // Otherwise it is false. 651 func (f *Fn) HasStringParam() bool { 652 for _, p := range f.Params { 653 if p.Type == "string" { 654 return true 655 } 656 } 657 return false 658 } 659 660 // HelperName returns name of function f helper. 661 func (f *Fn) HelperName() string { 662 if !f.HasStringParam() { 663 return f.Name 664 } 665 return "_" + f.Name 666 } 667 668 // DLL is a DLL's filename and a string that is valid in a Go identifier that should be used when 669 // naming a variable that refers to the DLL. 670 type DLL struct { 671 Name string 672 Var string 673 } 674 675 // Source files and functions. 676 type Source struct { 677 Funcs []*Fn 678 DLLFuncNames []*Fn 679 Files []string 680 StdLibImports []string 681 ExternalImports []string 682 } 683 684 func (src *Source) Import(pkg string) { 685 src.StdLibImports = append(src.StdLibImports, pkg) 686 sort.Strings(src.StdLibImports) 687 } 688 689 func (src *Source) ExternalImport(pkg string) { 690 src.ExternalImports = append(src.ExternalImports, pkg) 691 sort.Strings(src.ExternalImports) 692 } 693 694 // ParseFiles parses files listed in fs and extracts all syscall 695 // functions listed in sys comments. It returns source files 696 // and functions collection *Source if successful. 697 func ParseFiles(fs []string) (*Source, error) { 698 src := &Source{ 699 Funcs: make([]*Fn, 0), 700 Files: make([]string, 0), 701 StdLibImports: []string{ 702 "unsafe", 703 }, 704 ExternalImports: make([]string, 0), 705 } 706 for _, file := range fs { 707 if err := src.ParseFile(file); err != nil { 708 return nil, err 709 } 710 } 711 src.DLLFuncNames = make([]*Fn, 0, len(src.Funcs)) 712 uniq := make(map[string]bool, len(src.Funcs)) 713 for _, fn := range src.Funcs { 714 name := fn.DLLFuncName() 715 if !uniq[name] { 716 src.DLLFuncNames = append(src.DLLFuncNames, fn) 717 uniq[name] = true 718 } 719 } 720 return src, nil 721 } 722 723 // DLLs return dll names for a source set src. 724 func (src *Source) DLLs() []DLL { 725 uniq := make(map[string]bool) 726 r := make([]DLL, 0) 727 for _, f := range src.Funcs { 728 id := f.DLLVar() 729 if _, found := uniq[id]; !found { 730 uniq[id] = true 731 r = append(r, DLL{f.DLLName(), id}) 732 } 733 } 734 sort.Slice(r, func(i, j int) bool { 735 return r[i].Var < r[j].Var 736 }) 737 return r 738 } 739 740 // ParseFile adds additional file path to a source set src. 741 func (src *Source) ParseFile(path string) error { 742 file, err := os.Open(path) 743 if err != nil { 744 return err 745 } 746 defer file.Close() 747 748 s := bufio.NewScanner(file) 749 for s.Scan() { 750 t := trim(s.Text()) 751 if len(t) < 7 { 752 continue 753 } 754 if !strings.HasPrefix(t, "//sys") { 755 continue 756 } 757 t = t[5:] 758 if !(t[0] == ' ' || t[0] == '\t') { 759 continue 760 } 761 f, err := newFn(t[1:]) 762 if err != nil { 763 return err 764 } 765 src.Funcs = append(src.Funcs, f) 766 } 767 if err := s.Err(); err != nil { 768 return err 769 } 770 src.Files = append(src.Files, path) 771 sort.Slice(src.Funcs, func(i, j int) bool { 772 fi, fj := src.Funcs[i], src.Funcs[j] 773 if fi.DLLName() == fj.DLLName() { 774 return fi.DLLFuncName() < fj.DLLFuncName() 775 } 776 return fi.DLLName() < fj.DLLName() 777 }) 778 779 // get package name 780 fset := token.NewFileSet() 781 _, err = file.Seek(0, 0) 782 if err != nil { 783 return err 784 } 785 pkg, err := parser.ParseFile(fset, "", file, parser.PackageClauseOnly) 786 if err != nil { 787 return err 788 } 789 packageName = pkg.Name.Name 790 791 return nil 792 } 793 794 // IsStdRepo reports whether src is part of standard library. 795 func (src *Source) IsStdRepo() (bool, error) { 796 if len(src.Files) == 0 { 797 return false, errors.New("no input files provided") 798 } 799 abspath, err := filepath.Abs(src.Files[0]) 800 if err != nil { 801 return false, err 802 } 803 goroot := runtime.GOROOT() 804 if runtime.GOOS == "windows" { 805 abspath = strings.ToLower(abspath) 806 goroot = strings.ToLower(goroot) 807 } 808 sep := string(os.PathSeparator) 809 if !strings.HasSuffix(goroot, sep) { 810 goroot += sep 811 } 812 return strings.HasPrefix(abspath, goroot), nil 813 } 814 815 // Generate output source file from a source set src. 816 func (src *Source) Generate(w io.Writer) error { 817 const ( 818 pkgStd = iota // any package in std library 819 pkgXSysWindows // x/sys/windows package 820 pkgOther 821 ) 822 isStdRepo, err := src.IsStdRepo() 823 if err != nil { 824 return err 825 } 826 var pkgtype int 827 switch { 828 case isStdRepo: 829 pkgtype = pkgStd 830 case packageName == "windows": 831 // TODO: this needs better logic than just using package name 832 pkgtype = pkgXSysWindows 833 default: 834 pkgtype = pkgOther 835 } 836 if *systemDLL { 837 switch pkgtype { 838 case pkgStd: 839 src.Import("internal/syscall/windows/sysdll") 840 case pkgXSysWindows: 841 default: 842 src.ExternalImport("golang.org/x/sys/windows") 843 } 844 } 845 if packageName != "syscall" { 846 src.Import("syscall") 847 } 848 funcMap := template.FuncMap{ 849 "packagename": packagename, 850 "syscalldot": syscalldot, 851 "newlazydll": func(dll string) string { 852 arg := "\"" + dll + ".dll\"" 853 if !*systemDLL { 854 return syscalldot() + "NewLazyDLL(" + arg + ")" 855 } 856 switch pkgtype { 857 case pkgStd: 858 return syscalldot() + "NewLazyDLL(sysdll.Add(" + arg + "))" 859 case pkgXSysWindows: 860 return "NewLazySystemDLL(" + arg + ")" 861 default: 862 return "windows.NewLazySystemDLL(" + arg + ")" 863 } 864 }, 865 } 866 t := template.Must(template.New("main").Funcs(funcMap).Parse(srcTemplate)) 867 err = t.Execute(w, src) 868 if err != nil { 869 return errors.New("Failed to execute template: " + err.Error()) 870 } 871 return nil 872 } 873 874 func writeTempSourceFile(data []byte) (string, error) { 875 f, err := os.CreateTemp("", "mkwinsyscall-generated-*.go") 876 if err != nil { 877 return "", err 878 } 879 _, err = f.Write(data) 880 if closeErr := f.Close(); err == nil { 881 err = closeErr 882 } 883 if err != nil { 884 os.Remove(f.Name()) // best effort 885 return "", err 886 } 887 return f.Name(), nil 888 } 889 890 func usage() { 891 fmt.Fprintf(os.Stderr, "usage: mkwinsyscall [flags] [path ...]\n") 892 flag.PrintDefaults() 893 os.Exit(1) 894 } 895 896 func main() { 897 flag.Usage = usage 898 flag.Parse() 899 if len(flag.Args()) <= 0 { 900 fmt.Fprintf(os.Stderr, "no files to parse provided\n") 901 usage() 902 } 903 904 src, err := ParseFiles(flag.Args()) 905 if err != nil { 906 log.Fatal(err) 907 } 908 909 var buf bytes.Buffer 910 if err := src.Generate(&buf); err != nil { 911 log.Fatal(err) 912 } 913 914 data, err := format.Source(buf.Bytes()) 915 if err != nil { 916 log.Printf("failed to format source: %v", err) 917 f, err := writeTempSourceFile(buf.Bytes()) 918 if err != nil { 919 log.Fatalf("failed to write unformatted source to file: %v", err) 920 } 921 log.Fatalf("for diagnosis, wrote unformatted source to %v", f) 922 } 923 if *filename == "" { 924 _, err = os.Stdout.Write(data) 925 } else { 926 err = ioutil.WriteFile(*filename, data, 0644) 927 } 928 if err != nil { 929 log.Fatal(err) 930 } 931 } 932 933 // TODO: use println instead to print in the following template 934 const srcTemplate = ` 935 936 {{define "main"}}// Code generated by 'go generate'; DO NOT EDIT. 937 938 package {{packagename}} 939 940 import ( 941 {{range .StdLibImports}}"{{.}}" 942 {{end}} 943 944 {{range .ExternalImports}}"{{.}}" 945 {{end}} 946 ) 947 948 var _ unsafe.Pointer 949 950 // Do the interface allocations only once for common 951 // Errno values. 952 const ( 953 errnoERROR_IO_PENDING = 997 954 ) 955 956 var ( 957 errERROR_IO_PENDING error = {{syscalldot}}Errno(errnoERROR_IO_PENDING) 958 errERROR_EINVAL error = {{syscalldot}}EINVAL 959 ) 960 961 // errnoErr returns common boxed Errno values, to prevent 962 // allocations at runtime. 963 func errnoErr(e {{syscalldot}}Errno) error { 964 switch e { 965 case 0: 966 return errERROR_EINVAL 967 case errnoERROR_IO_PENDING: 968 return errERROR_IO_PENDING 969 } 970 // TODO: add more here, after collecting data on the common 971 // error values see on Windows. (perhaps when running 972 // all.bat?) 973 return e 974 } 975 976 var ( 977 {{template "dlls" .}} 978 {{template "funcnames" .}}) 979 {{range .Funcs}}{{if .HasStringParam}}{{template "helperbody" .}}{{end}}{{template "funcbody" .}}{{end}} 980 {{end}} 981 982 {{/* help functions */}} 983 984 {{define "dlls"}}{{range .DLLs}} mod{{.Var}} = {{newlazydll .Name}} 985 {{end}}{{end}} 986 987 {{define "funcnames"}}{{range .DLLFuncNames}} proc{{.DLLFuncName}} = mod{{.DLLVar}}.NewProc("{{.DLLFuncName}}") 988 {{end}}{{end}} 989 990 {{define "helperbody"}} 991 func {{.Name}}({{.ParamList}}) {{template "results" .}}{ 992 {{template "helpertmpvars" .}} return {{.HelperName}}({{.HelperCallParamList}}) 993 } 994 {{end}} 995 996 {{define "funcbody"}} 997 func {{.HelperName}}({{.HelperParamList}}) {{template "results" .}}{ 998 {{template "maybeabsent" .}} {{template "tmpvars" .}} {{template "syscall" .}} {{template "tmpvarsreadback" .}} 999 {{template "seterror" .}}{{template "printtrace" .}} return 1000 } 1001 {{end}} 1002 1003 {{define "helpertmpvars"}}{{range .Params}}{{if .TmpVarHelperCode}} {{.TmpVarHelperCode}} 1004 {{end}}{{end}}{{end}} 1005 1006 {{define "maybeabsent"}}{{if .MaybeAbsent}}{{.MaybeAbsent}} 1007 {{end}}{{end}} 1008 1009 {{define "tmpvars"}}{{range .Params}}{{if .TmpVarCode}} {{.TmpVarCode}} 1010 {{end}}{{end}}{{end}} 1011 1012 {{define "results"}}{{if .Rets.List}}{{.Rets.List}} {{end}}{{end}} 1013 1014 {{define "syscall"}}{{.Rets.SetReturnValuesCode}}{{.Syscall}}(proc{{.DLLFuncName}}.Addr(), {{.ParamCount}}, {{.SyscallParamList}}){{end}} 1015 1016 {{define "tmpvarsreadback"}}{{range .Params}}{{if .TmpVarReadbackCode}} 1017 {{.TmpVarReadbackCode}}{{end}}{{end}}{{end}} 1018 1019 {{define "seterror"}}{{if .Rets.SetErrorCode}} {{.Rets.SetErrorCode}} 1020 {{end}}{{end}} 1021 1022 {{define "printtrace"}}{{if .PrintTrace}} print("SYSCALL: {{.Name}}(", {{.ParamPrintList}}") (", {{.Rets.PrintList}}")\n") 1023 {{end}}{{end}} 1024 1025 `