golang.org/x/sys@v0.20.1-0.20240517151509-673e0f94c16d/windows/mkwinsyscall/mkwinsyscall.go (about) 1 // Copyright 2013 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 /* 6 mkwinsyscall generates windows system call bodies 7 8 It parses all files specified on command line containing function 9 prototypes (like syscall_windows.go) and prints system call bodies 10 to standard output. 11 12 The prototypes are marked by lines beginning with "//sys" and read 13 like func declarations if //sys is replaced by func, but: 14 15 - The parameter lists must give a name for each argument. This 16 includes return parameters. 17 18 - The parameter lists must give a type for each argument: 19 the (x, y, z int) shorthand is not allowed. 20 21 - If the return parameter is an error number, it must be named err. 22 23 - If go func name needs to be different from its winapi dll name, 24 the winapi name could be specified at the end, after "=" sign, like 25 //sys LoadLibrary(libname string) (handle uint32, err error) = LoadLibraryA 26 27 - Each function that returns err needs to supply a condition, that 28 return value of winapi will be tested against to detect failure. 29 This would set err to windows "last-error", otherwise it will be nil. 30 The value can be provided at end of //sys declaration, like 31 //sys LoadLibrary(libname string) (handle uint32, err error) [failretval==-1] = LoadLibraryA 32 and is [failretval==0] by default. 33 34 - If the function name ends in a "?", then the function not existing is non- 35 fatal, and an error will be returned instead of panicking. 36 37 Usage: 38 39 mkwinsyscall [flags] [path ...] 40 41 The flags are: 42 43 -output 44 Specify output file name (outputs to console if blank). 45 -trace 46 Generate print statement after every syscall. 47 */ 48 package main 49 50 import ( 51 "bufio" 52 "bytes" 53 "errors" 54 "flag" 55 "fmt" 56 "go/format" 57 "go/parser" 58 "go/token" 59 "io" 60 "log" 61 "os" 62 "path/filepath" 63 "runtime" 64 "sort" 65 "strconv" 66 "strings" 67 "text/template" 68 ) 69 70 var ( 71 filename = flag.String("output", "", "output file name (standard output if omitted)") 72 printTraceFlag = flag.Bool("trace", false, "generate print statement after every syscall") 73 systemDLL = flag.Bool("systemdll", true, "whether all DLLs should be loaded from the Windows system directory") 74 ) 75 76 func trim(s string) string { 77 return strings.Trim(s, " \t") 78 } 79 80 var packageName string 81 82 func packagename() string { 83 return packageName 84 } 85 86 func windowsdot() string { 87 if packageName == "windows" { 88 return "" 89 } 90 return "windows." 91 } 92 93 func syscalldot() string { 94 if packageName == "syscall" { 95 return "" 96 } 97 return "syscall." 98 } 99 100 // Param is function parameter 101 type Param struct { 102 Name string 103 Type string 104 fn *Fn 105 tmpVarIdx int 106 } 107 108 // tmpVar returns temp variable name that will be used to represent p during syscall. 109 func (p *Param) tmpVar() string { 110 if p.tmpVarIdx < 0 { 111 p.tmpVarIdx = p.fn.curTmpVarIdx 112 p.fn.curTmpVarIdx++ 113 } 114 return fmt.Sprintf("_p%d", p.tmpVarIdx) 115 } 116 117 // BoolTmpVarCode returns source code for bool temp variable. 118 func (p *Param) BoolTmpVarCode() string { 119 const code = `var %[1]s uint32 120 if %[2]s { 121 %[1]s = 1 122 }` 123 return fmt.Sprintf(code, p.tmpVar(), p.Name) 124 } 125 126 // BoolPointerTmpVarCode returns source code for bool temp variable. 127 func (p *Param) BoolPointerTmpVarCode() string { 128 const code = `var %[1]s uint32 129 if *%[2]s { 130 %[1]s = 1 131 }` 132 return fmt.Sprintf(code, p.tmpVar(), p.Name) 133 } 134 135 // SliceTmpVarCode returns source code for slice temp variable. 136 func (p *Param) SliceTmpVarCode() string { 137 const code = `var %s *%s 138 if len(%s) > 0 { 139 %s = &%s[0] 140 }` 141 tmp := p.tmpVar() 142 return fmt.Sprintf(code, tmp, p.Type[2:], p.Name, tmp, p.Name) 143 } 144 145 // StringTmpVarCode returns source code for string temp variable. 146 func (p *Param) StringTmpVarCode() string { 147 errvar := p.fn.Rets.ErrorVarName() 148 if errvar == "" { 149 errvar = "_" 150 } 151 tmp := p.tmpVar() 152 const code = `var %s %s 153 %s, %s = %s(%s)` 154 s := fmt.Sprintf(code, tmp, p.fn.StrconvType(), tmp, errvar, p.fn.StrconvFunc(), p.Name) 155 if errvar == "-" { 156 return s 157 } 158 const morecode = ` 159 if %s != nil { 160 return 161 }` 162 return s + fmt.Sprintf(morecode, errvar) 163 } 164 165 // TmpVarCode returns source code for temp variable. 166 func (p *Param) TmpVarCode() string { 167 switch { 168 case p.Type == "bool": 169 return p.BoolTmpVarCode() 170 case p.Type == "*bool": 171 return p.BoolPointerTmpVarCode() 172 case strings.HasPrefix(p.Type, "[]"): 173 return p.SliceTmpVarCode() 174 default: 175 return "" 176 } 177 } 178 179 // TmpVarReadbackCode returns source code for reading back the temp variable into the original variable. 180 func (p *Param) TmpVarReadbackCode() string { 181 switch { 182 case p.Type == "*bool": 183 return fmt.Sprintf("*%s = %s != 0", p.Name, p.tmpVar()) 184 default: 185 return "" 186 } 187 } 188 189 // TmpVarHelperCode returns source code for helper's temp variable. 190 func (p *Param) TmpVarHelperCode() string { 191 if p.Type != "string" { 192 return "" 193 } 194 return p.StringTmpVarCode() 195 } 196 197 // SyscallArgList returns source code fragments representing p parameter 198 // in syscall. Slices are translated into 2 syscall parameters: pointer to 199 // the first element and length. 200 func (p *Param) SyscallArgList() []string { 201 t := p.HelperType() 202 var s string 203 switch { 204 case t == "*bool": 205 s = fmt.Sprintf("unsafe.Pointer(&%s)", p.tmpVar()) 206 case t[0] == '*': 207 s = fmt.Sprintf("unsafe.Pointer(%s)", p.Name) 208 case t == "bool": 209 s = p.tmpVar() 210 case strings.HasPrefix(t, "[]"): 211 return []string{ 212 fmt.Sprintf("uintptr(unsafe.Pointer(%s))", p.tmpVar()), 213 fmt.Sprintf("uintptr(len(%s))", p.Name), 214 } 215 default: 216 s = p.Name 217 } 218 return []string{fmt.Sprintf("uintptr(%s)", s)} 219 } 220 221 // IsError determines if p parameter is used to return error. 222 func (p *Param) IsError() bool { 223 return p.Name == "err" && p.Type == "error" 224 } 225 226 // HelperType returns type of parameter p used in helper function. 227 func (p *Param) HelperType() string { 228 if p.Type == "string" { 229 return p.fn.StrconvType() 230 } 231 return p.Type 232 } 233 234 // join concatenates parameters ps into a string with sep separator. 235 // Each parameter is converted into string by applying fn to it 236 // before conversion. 237 func join(ps []*Param, fn func(*Param) string, sep string) string { 238 if len(ps) == 0 { 239 return "" 240 } 241 a := make([]string, 0) 242 for _, p := range ps { 243 a = append(a, fn(p)) 244 } 245 return strings.Join(a, sep) 246 } 247 248 // Rets describes function return parameters. 249 type Rets struct { 250 Name string 251 Type string 252 ReturnsError bool 253 FailCond string 254 fnMaybeAbsent bool 255 } 256 257 // ErrorVarName returns error variable name for r. 258 func (r *Rets) ErrorVarName() string { 259 if r.ReturnsError { 260 return "err" 261 } 262 if r.Type == "error" { 263 return r.Name 264 } 265 return "" 266 } 267 268 // ToParams converts r into slice of *Param. 269 func (r *Rets) ToParams() []*Param { 270 ps := make([]*Param, 0) 271 if len(r.Name) > 0 { 272 ps = append(ps, &Param{Name: r.Name, Type: r.Type}) 273 } 274 if r.ReturnsError { 275 ps = append(ps, &Param{Name: "err", Type: "error"}) 276 } 277 return ps 278 } 279 280 // List returns source code of syscall return parameters. 281 func (r *Rets) List() string { 282 s := join(r.ToParams(), func(p *Param) string { return p.Name + " " + p.Type }, ", ") 283 if len(s) > 0 { 284 s = "(" + s + ")" 285 } else if r.fnMaybeAbsent { 286 s = "(err error)" 287 } 288 return s 289 } 290 291 // PrintList returns source code of trace printing part correspondent 292 // to syscall return values. 293 func (r *Rets) PrintList() string { 294 return join(r.ToParams(), func(p *Param) string { return fmt.Sprintf(`"%s=", %s, `, p.Name, p.Name) }, `", ", `) 295 } 296 297 // SetReturnValuesCode returns source code that accepts syscall return values. 298 func (r *Rets) SetReturnValuesCode() string { 299 if r.Name == "" && !r.ReturnsError { 300 return "" 301 } 302 retvar := "r0" 303 if r.Name == "" { 304 retvar = "r1" 305 } 306 errvar := "_" 307 if r.ReturnsError { 308 errvar = "e1" 309 } 310 return fmt.Sprintf("%s, _, %s := ", retvar, errvar) 311 } 312 313 func (r *Rets) useLongHandleErrorCode(retvar string) string { 314 const code = `if %s { 315 err = errnoErr(e1) 316 }` 317 cond := retvar + " == 0" 318 if r.FailCond != "" { 319 cond = strings.Replace(r.FailCond, "failretval", retvar, 1) 320 } 321 return fmt.Sprintf(code, cond) 322 } 323 324 // SetErrorCode returns source code that sets return parameters. 325 func (r *Rets) SetErrorCode() string { 326 const code = `if r0 != 0 { 327 %s = %sErrno(r0) 328 }` 329 const ntstatus = `if r0 != 0 { 330 ntstatus = %sNTStatus(r0) 331 }` 332 if r.Name == "" && !r.ReturnsError { 333 return "" 334 } 335 if r.Name == "" { 336 return r.useLongHandleErrorCode("r1") 337 } 338 if r.Type == "error" && r.Name == "ntstatus" { 339 return fmt.Sprintf(ntstatus, windowsdot()) 340 } 341 if r.Type == "error" { 342 return fmt.Sprintf(code, r.Name, syscalldot()) 343 } 344 s := "" 345 switch { 346 case r.Type[0] == '*': 347 s = fmt.Sprintf("%s = (%s)(unsafe.Pointer(r0))", r.Name, r.Type) 348 case r.Type == "bool": 349 s = fmt.Sprintf("%s = r0 != 0", r.Name) 350 default: 351 s = fmt.Sprintf("%s = %s(r0)", r.Name, r.Type) 352 } 353 if !r.ReturnsError { 354 return s 355 } 356 return s + "\n\t" + r.useLongHandleErrorCode(r.Name) 357 } 358 359 // Fn describes syscall function. 360 type Fn struct { 361 Name string 362 Params []*Param 363 Rets *Rets 364 PrintTrace bool 365 dllname string 366 dllfuncname string 367 src string 368 // TODO: get rid of this field and just use parameter index instead 369 curTmpVarIdx int // insure tmp variables have uniq names 370 } 371 372 // extractParams parses s to extract function parameters. 373 func extractParams(s string, f *Fn) ([]*Param, error) { 374 s = trim(s) 375 if s == "" { 376 return nil, nil 377 } 378 a := strings.Split(s, ",") 379 ps := make([]*Param, len(a)) 380 for i := range ps { 381 s2 := trim(a[i]) 382 b := strings.Split(s2, " ") 383 if len(b) != 2 { 384 b = strings.Split(s2, "\t") 385 if len(b) != 2 { 386 return nil, errors.New("Could not extract function parameter from \"" + s2 + "\"") 387 } 388 } 389 ps[i] = &Param{ 390 Name: trim(b[0]), 391 Type: trim(b[1]), 392 fn: f, 393 tmpVarIdx: -1, 394 } 395 } 396 return ps, nil 397 } 398 399 // extractSection extracts text out of string s starting after start 400 // and ending just before end. found return value will indicate success, 401 // and prefix, body and suffix will contain correspondent parts of string s. 402 func extractSection(s string, start, end rune) (prefix, body, suffix string, found bool) { 403 s = trim(s) 404 if strings.HasPrefix(s, string(start)) { 405 // no prefix 406 body = s[1:] 407 } else { 408 a := strings.SplitN(s, string(start), 2) 409 if len(a) != 2 { 410 return "", "", s, false 411 } 412 prefix = a[0] 413 body = a[1] 414 } 415 a := strings.SplitN(body, string(end), 2) 416 if len(a) != 2 { 417 return "", "", "", false 418 } 419 return prefix, a[0], a[1], true 420 } 421 422 // newFn parses string s and return created function Fn. 423 func newFn(s string) (*Fn, error) { 424 s = trim(s) 425 f := &Fn{ 426 Rets: &Rets{}, 427 src: s, 428 PrintTrace: *printTraceFlag, 429 } 430 // function name and args 431 prefix, body, s, found := extractSection(s, '(', ')') 432 if !found || prefix == "" { 433 return nil, errors.New("Could not extract function name and parameters from \"" + f.src + "\"") 434 } 435 f.Name = prefix 436 var err error 437 f.Params, err = extractParams(body, f) 438 if err != nil { 439 return nil, err 440 } 441 // return values 442 _, body, s, found = extractSection(s, '(', ')') 443 if found { 444 r, err := extractParams(body, f) 445 if err != nil { 446 return nil, err 447 } 448 switch len(r) { 449 case 0: 450 case 1: 451 if r[0].IsError() { 452 f.Rets.ReturnsError = true 453 } else { 454 f.Rets.Name = r[0].Name 455 f.Rets.Type = r[0].Type 456 } 457 case 2: 458 if !r[1].IsError() { 459 return nil, errors.New("Only last windows error is allowed as second return value in \"" + f.src + "\"") 460 } 461 f.Rets.ReturnsError = true 462 f.Rets.Name = r[0].Name 463 f.Rets.Type = r[0].Type 464 default: 465 return nil, errors.New("Too many return values in \"" + f.src + "\"") 466 } 467 } 468 // fail condition 469 _, body, s, found = extractSection(s, '[', ']') 470 if found { 471 f.Rets.FailCond = body 472 } 473 // dll and dll function names 474 s = trim(s) 475 if s == "" { 476 return f, nil 477 } 478 if !strings.HasPrefix(s, "=") { 479 return nil, errors.New("Could not extract dll name from \"" + f.src + "\"") 480 } 481 s = trim(s[1:]) 482 if i := strings.LastIndex(s, "."); i >= 0 { 483 f.dllname = s[:i] 484 f.dllfuncname = s[i+1:] 485 } else { 486 f.dllfuncname = s 487 } 488 if f.dllfuncname == "" { 489 return nil, fmt.Errorf("function name is not specified in %q", s) 490 } 491 if n := f.dllfuncname; strings.HasSuffix(n, "?") { 492 f.dllfuncname = n[:len(n)-1] 493 f.Rets.fnMaybeAbsent = true 494 } 495 return f, nil 496 } 497 498 // DLLName returns DLL name for function f. 499 func (f *Fn) DLLName() string { 500 if f.dllname == "" { 501 return "kernel32" 502 } 503 return f.dllname 504 } 505 506 // DLLVar returns a valid Go identifier that represents DLLName. 507 func (f *Fn) DLLVar() string { 508 id := strings.Map(func(r rune) rune { 509 switch r { 510 case '.', '-': 511 return '_' 512 default: 513 return r 514 } 515 }, f.DLLName()) 516 if !token.IsIdentifier(id) { 517 panic(fmt.Errorf("could not create Go identifier for DLLName %q", f.DLLName())) 518 } 519 return id 520 } 521 522 // DLLFuncName returns DLL function name for function f. 523 func (f *Fn) DLLFuncName() string { 524 if f.dllfuncname == "" { 525 return f.Name 526 } 527 return f.dllfuncname 528 } 529 530 // ParamList returns source code for function f parameters. 531 func (f *Fn) ParamList() string { 532 return join(f.Params, func(p *Param) string { return p.Name + " " + p.Type }, ", ") 533 } 534 535 // HelperParamList returns source code for helper function f parameters. 536 func (f *Fn) HelperParamList() string { 537 return join(f.Params, func(p *Param) string { return p.Name + " " + p.HelperType() }, ", ") 538 } 539 540 // ParamPrintList returns source code of trace printing part correspondent 541 // to syscall input parameters. 542 func (f *Fn) ParamPrintList() string { 543 return join(f.Params, func(p *Param) string { return fmt.Sprintf(`"%s=", %s, `, p.Name, p.Name) }, `", ", `) 544 } 545 546 // ParamCount return number of syscall parameters for function f. 547 func (f *Fn) ParamCount() int { 548 n := 0 549 for _, p := range f.Params { 550 n += len(p.SyscallArgList()) 551 } 552 return n 553 } 554 555 // SyscallParamCount determines which version of Syscall/Syscall6/Syscall9/... 556 // to use. It returns parameter count for correspondent SyscallX function. 557 func (f *Fn) SyscallParamCount() int { 558 n := f.ParamCount() 559 switch { 560 case n <= 3: 561 return 3 562 case n <= 6: 563 return 6 564 case n <= 9: 565 return 9 566 case n <= 12: 567 return 12 568 case n <= 15: 569 return 15 570 case n <= 42: // current SyscallN limit 571 return n 572 default: 573 panic("too many arguments to system call") 574 } 575 } 576 577 // Syscall determines which SyscallX function to use for function f. 578 func (f *Fn) Syscall() string { 579 c := f.SyscallParamCount() 580 if c == 3 { 581 return syscalldot() + "Syscall" 582 } 583 if c > 15 { 584 return syscalldot() + "SyscallN" 585 } 586 return syscalldot() + "Syscall" + strconv.Itoa(c) 587 } 588 589 // SyscallParamList returns source code for SyscallX parameters for function f. 590 func (f *Fn) SyscallParamList() string { 591 a := make([]string, 0) 592 for _, p := range f.Params { 593 a = append(a, p.SyscallArgList()...) 594 } 595 for len(a) < f.SyscallParamCount() { 596 a = append(a, "0") 597 } 598 return strings.Join(a, ", ") 599 } 600 601 // HelperCallParamList returns source code of call into function f helper. 602 func (f *Fn) HelperCallParamList() string { 603 a := make([]string, 0, len(f.Params)) 604 for _, p := range f.Params { 605 s := p.Name 606 if p.Type == "string" { 607 s = p.tmpVar() 608 } 609 a = append(a, s) 610 } 611 return strings.Join(a, ", ") 612 } 613 614 // MaybeAbsent returns source code for handling functions that are possibly unavailable. 615 func (p *Fn) MaybeAbsent() string { 616 if !p.Rets.fnMaybeAbsent { 617 return "" 618 } 619 const code = `%[1]s = proc%[2]s.Find() 620 if %[1]s != nil { 621 return 622 }` 623 errorVar := p.Rets.ErrorVarName() 624 if errorVar == "" { 625 errorVar = "err" 626 } 627 return fmt.Sprintf(code, errorVar, p.DLLFuncName()) 628 } 629 630 // IsUTF16 is true, if f is W (utf16) function. It is false 631 // for all A (ascii) functions. 632 func (f *Fn) IsUTF16() bool { 633 s := f.DLLFuncName() 634 return s[len(s)-1] == 'W' 635 } 636 637 // StrconvFunc returns name of Go string to OS string function for f. 638 func (f *Fn) StrconvFunc() string { 639 if f.IsUTF16() { 640 return syscalldot() + "UTF16PtrFromString" 641 } 642 return syscalldot() + "BytePtrFromString" 643 } 644 645 // StrconvType returns Go type name used for OS string for f. 646 func (f *Fn) StrconvType() string { 647 if f.IsUTF16() { 648 return "*uint16" 649 } 650 return "*byte" 651 } 652 653 // HasStringParam is true, if f has at least one string parameter. 654 // Otherwise it is false. 655 func (f *Fn) HasStringParam() bool { 656 for _, p := range f.Params { 657 if p.Type == "string" { 658 return true 659 } 660 } 661 return false 662 } 663 664 // HelperName returns name of function f helper. 665 func (f *Fn) HelperName() string { 666 if !f.HasStringParam() { 667 return f.Name 668 } 669 return "_" + f.Name 670 } 671 672 // DLL is a DLL's filename and a string that is valid in a Go identifier that should be used when 673 // naming a variable that refers to the DLL. 674 type DLL struct { 675 Name string 676 Var string 677 } 678 679 // Source files and functions. 680 type Source struct { 681 Funcs []*Fn 682 DLLFuncNames []*Fn 683 Files []string 684 StdLibImports []string 685 ExternalImports []string 686 } 687 688 func (src *Source) Import(pkg string) { 689 src.StdLibImports = append(src.StdLibImports, pkg) 690 sort.Strings(src.StdLibImports) 691 } 692 693 func (src *Source) ExternalImport(pkg string) { 694 src.ExternalImports = append(src.ExternalImports, pkg) 695 sort.Strings(src.ExternalImports) 696 } 697 698 // ParseFiles parses files listed in fs and extracts all syscall 699 // functions listed in sys comments. It returns source files 700 // and functions collection *Source if successful. 701 func ParseFiles(fs []string) (*Source, error) { 702 src := &Source{ 703 Funcs: make([]*Fn, 0), 704 Files: make([]string, 0), 705 StdLibImports: []string{ 706 "unsafe", 707 }, 708 ExternalImports: make([]string, 0), 709 } 710 for _, file := range fs { 711 if err := src.ParseFile(file); err != nil { 712 return nil, err 713 } 714 } 715 src.DLLFuncNames = make([]*Fn, 0, len(src.Funcs)) 716 uniq := make(map[string]bool, len(src.Funcs)) 717 for _, fn := range src.Funcs { 718 name := fn.DLLFuncName() 719 if !uniq[name] { 720 src.DLLFuncNames = append(src.DLLFuncNames, fn) 721 uniq[name] = true 722 } 723 } 724 return src, nil 725 } 726 727 // DLLs return dll names for a source set src. 728 func (src *Source) DLLs() []DLL { 729 uniq := make(map[string]bool) 730 r := make([]DLL, 0) 731 for _, f := range src.Funcs { 732 id := f.DLLVar() 733 if _, found := uniq[id]; !found { 734 uniq[id] = true 735 r = append(r, DLL{f.DLLName(), id}) 736 } 737 } 738 sort.Slice(r, func(i, j int) bool { 739 return r[i].Var < r[j].Var 740 }) 741 return r 742 } 743 744 // ParseFile adds additional file path to a source set src. 745 func (src *Source) ParseFile(path string) error { 746 file, err := os.Open(path) 747 if err != nil { 748 return err 749 } 750 defer file.Close() 751 752 s := bufio.NewScanner(file) 753 for s.Scan() { 754 t := trim(s.Text()) 755 if len(t) < 7 { 756 continue 757 } 758 if !strings.HasPrefix(t, "//sys") { 759 continue 760 } 761 t = t[5:] 762 if !(t[0] == ' ' || t[0] == '\t') { 763 continue 764 } 765 f, err := newFn(t[1:]) 766 if err != nil { 767 return err 768 } 769 src.Funcs = append(src.Funcs, f) 770 } 771 if err := s.Err(); err != nil { 772 return err 773 } 774 src.Files = append(src.Files, path) 775 sort.Slice(src.Funcs, func(i, j int) bool { 776 fi, fj := src.Funcs[i], src.Funcs[j] 777 if fi.DLLName() == fj.DLLName() { 778 return fi.DLLFuncName() < fj.DLLFuncName() 779 } 780 return fi.DLLName() < fj.DLLName() 781 }) 782 783 // get package name 784 fset := token.NewFileSet() 785 _, err = file.Seek(0, 0) 786 if err != nil { 787 return err 788 } 789 pkg, err := parser.ParseFile(fset, "", file, parser.PackageClauseOnly) 790 if err != nil { 791 return err 792 } 793 packageName = pkg.Name.Name 794 795 return nil 796 } 797 798 // IsStdRepo reports whether src is part of standard library. 799 func (src *Source) IsStdRepo() (bool, error) { 800 if len(src.Files) == 0 { 801 return false, errors.New("no input files provided") 802 } 803 abspath, err := filepath.Abs(src.Files[0]) 804 if err != nil { 805 return false, err 806 } 807 goroot := runtime.GOROOT() 808 if runtime.GOOS == "windows" { 809 abspath = strings.ToLower(abspath) 810 goroot = strings.ToLower(goroot) 811 } 812 sep := string(os.PathSeparator) 813 if !strings.HasSuffix(goroot, sep) { 814 goroot += sep 815 } 816 return strings.HasPrefix(abspath, goroot), nil 817 } 818 819 // Generate output source file from a source set src. 820 func (src *Source) Generate(w io.Writer) error { 821 const ( 822 pkgStd = iota // any package in std library 823 pkgXSysWindows // x/sys/windows package 824 pkgOther 825 ) 826 isStdRepo, err := src.IsStdRepo() 827 if err != nil { 828 return err 829 } 830 var pkgtype int 831 switch { 832 case isStdRepo: 833 pkgtype = pkgStd 834 case packageName == "windows": 835 // TODO: this needs better logic than just using package name 836 pkgtype = pkgXSysWindows 837 default: 838 pkgtype = pkgOther 839 } 840 if *systemDLL { 841 switch pkgtype { 842 case pkgStd: 843 src.Import("internal/syscall/windows/sysdll") 844 case pkgXSysWindows: 845 default: 846 src.ExternalImport("golang.org/x/sys/windows") 847 } 848 } 849 if packageName != "syscall" { 850 src.Import("syscall") 851 } 852 funcMap := template.FuncMap{ 853 "packagename": packagename, 854 "syscalldot": syscalldot, 855 "newlazydll": func(dll string) string { 856 arg := "\"" + dll + ".dll\"" 857 if !*systemDLL { 858 return syscalldot() + "NewLazyDLL(" + arg + ")" 859 } 860 switch pkgtype { 861 case pkgStd: 862 return syscalldot() + "NewLazyDLL(sysdll.Add(" + arg + "))" 863 case pkgXSysWindows: 864 return "NewLazySystemDLL(" + arg + ")" 865 default: 866 return "windows.NewLazySystemDLL(" + arg + ")" 867 } 868 }, 869 } 870 t := template.Must(template.New("main").Funcs(funcMap).Parse(srcTemplate)) 871 err = t.Execute(w, src) 872 if err != nil { 873 return errors.New("Failed to execute template: " + err.Error()) 874 } 875 return nil 876 } 877 878 func writeTempSourceFile(data []byte) (string, error) { 879 f, err := os.CreateTemp("", "mkwinsyscall-generated-*.go") 880 if err != nil { 881 return "", err 882 } 883 _, err = f.Write(data) 884 if closeErr := f.Close(); err == nil { 885 err = closeErr 886 } 887 if err != nil { 888 os.Remove(f.Name()) // best effort 889 return "", err 890 } 891 return f.Name(), nil 892 } 893 894 func usage() { 895 fmt.Fprintf(os.Stderr, "usage: mkwinsyscall [flags] [path ...]\n") 896 flag.PrintDefaults() 897 os.Exit(1) 898 } 899 900 func main() { 901 flag.Usage = usage 902 flag.Parse() 903 if len(flag.Args()) <= 0 { 904 fmt.Fprintf(os.Stderr, "no files to parse provided\n") 905 usage() 906 } 907 908 src, err := ParseFiles(flag.Args()) 909 if err != nil { 910 log.Fatal(err) 911 } 912 913 var buf bytes.Buffer 914 if err := src.Generate(&buf); err != nil { 915 log.Fatal(err) 916 } 917 918 data, err := format.Source(buf.Bytes()) 919 if err != nil { 920 log.Printf("failed to format source: %v", err) 921 f, err := writeTempSourceFile(buf.Bytes()) 922 if err != nil { 923 log.Fatalf("failed to write unformatted source to file: %v", err) 924 } 925 log.Fatalf("for diagnosis, wrote unformatted source to %v", f) 926 } 927 if *filename == "" { 928 _, err = os.Stdout.Write(data) 929 } else { 930 err = os.WriteFile(*filename, data, 0644) 931 } 932 if err != nil { 933 log.Fatal(err) 934 } 935 } 936 937 // TODO: use println instead to print in the following template 938 const srcTemplate = ` 939 940 {{define "main"}}// Code generated by 'go generate'; DO NOT EDIT. 941 942 package {{packagename}} 943 944 import ( 945 {{range .StdLibImports}}"{{.}}" 946 {{end}} 947 948 {{range .ExternalImports}}"{{.}}" 949 {{end}} 950 ) 951 952 var _ unsafe.Pointer 953 954 // Do the interface allocations only once for common 955 // Errno values. 956 const ( 957 errnoERROR_IO_PENDING = 997 958 ) 959 960 var ( 961 errERROR_IO_PENDING error = {{syscalldot}}Errno(errnoERROR_IO_PENDING) 962 errERROR_EINVAL error = {{syscalldot}}EINVAL 963 ) 964 965 // errnoErr returns common boxed Errno values, to prevent 966 // allocations at runtime. 967 func errnoErr(e {{syscalldot}}Errno) error { 968 switch e { 969 case 0: 970 return errERROR_EINVAL 971 case errnoERROR_IO_PENDING: 972 return errERROR_IO_PENDING 973 } 974 // TODO: add more here, after collecting data on the common 975 // error values see on Windows. (perhaps when running 976 // all.bat?) 977 return e 978 } 979 980 var ( 981 {{template "dlls" .}} 982 {{template "funcnames" .}}) 983 {{range .Funcs}}{{if .HasStringParam}}{{template "helperbody" .}}{{end}}{{template "funcbody" .}}{{end}} 984 {{end}} 985 986 {{/* help functions */}} 987 988 {{define "dlls"}}{{range .DLLs}} mod{{.Var}} = {{newlazydll .Name}} 989 {{end}}{{end}} 990 991 {{define "funcnames"}}{{range .DLLFuncNames}} proc{{.DLLFuncName}} = mod{{.DLLVar}}.NewProc("{{.DLLFuncName}}") 992 {{end}}{{end}} 993 994 {{define "helperbody"}} 995 func {{.Name}}({{.ParamList}}) {{template "results" .}}{ 996 {{template "helpertmpvars" .}} return {{.HelperName}}({{.HelperCallParamList}}) 997 } 998 {{end}} 999 1000 {{define "funcbody"}} 1001 func {{.HelperName}}({{.HelperParamList}}) {{template "results" .}}{ 1002 {{template "maybeabsent" .}} {{template "tmpvars" .}} {{template "syscall" .}} {{template "tmpvarsreadback" .}} 1003 {{template "seterror" .}}{{template "printtrace" .}} return 1004 } 1005 {{end}} 1006 1007 {{define "helpertmpvars"}}{{range .Params}}{{if .TmpVarHelperCode}} {{.TmpVarHelperCode}} 1008 {{end}}{{end}}{{end}} 1009 1010 {{define "maybeabsent"}}{{if .MaybeAbsent}}{{.MaybeAbsent}} 1011 {{end}}{{end}} 1012 1013 {{define "tmpvars"}}{{range .Params}}{{if .TmpVarCode}} {{.TmpVarCode}} 1014 {{end}}{{end}}{{end}} 1015 1016 {{define "results"}}{{if .Rets.List}}{{.Rets.List}} {{end}}{{end}} 1017 1018 {{define "syscall"}}{{.Rets.SetReturnValuesCode}}{{.Syscall}}(proc{{.DLLFuncName}}.Addr(),{{if le .ParamCount 15}} {{.ParamCount}},{{end}} {{.SyscallParamList}}){{end}} 1019 1020 {{define "tmpvarsreadback"}}{{range .Params}}{{if .TmpVarReadbackCode}} 1021 {{.TmpVarReadbackCode}}{{end}}{{end}}{{end}} 1022 1023 {{define "seterror"}}{{if .Rets.SetErrorCode}} {{.Rets.SetErrorCode}} 1024 {{end}}{{end}} 1025 1026 {{define "printtrace"}}{{if .PrintTrace}} print("SYSCALL: {{.Name}}(", {{.ParamPrintList}}") (", {{.Rets.PrintList}}")\n") 1027 {{end}}{{end}} 1028 1029 `