github.com/Serizao/go-winio@v0.0.0-20230906082528-f02f7f4ad6e8/tools/mkwinsyscall/mkwinsyscall.go (about) 1 //go:build windows 2 3 // Copyright 2013 The Go Authors. All rights reserved. 4 // Use of this source code is governed by a BSD-style 5 // license that can be found in the LICENSE file. 6 7 package main 8 9 import ( 10 "bufio" 11 "bytes" 12 "errors" 13 "flag" 14 "fmt" 15 "go/format" 16 "go/parser" 17 "go/token" 18 "io" 19 "log" 20 "os" 21 "path/filepath" 22 "runtime" 23 "sort" 24 "strconv" 25 "strings" 26 "text/template" 27 28 "golang.org/x/sys/windows" 29 ) 30 31 const ( 32 pkgSyscall = "syscall" 33 pkgWindows = "windows" 34 35 // common types. 36 37 tBool = "bool" 38 tBoolPtr = "*bool" 39 tError = "error" 40 tString = "string" 41 42 // error variable names. 43 44 varErr = "err" 45 varErrNTStatus = "ntStatus" 46 varErrHR = "hr" 47 ) 48 49 var ( 50 filename = flag.String("output", "", "output file name (standard output if omitted)") 51 printTraceFlag = flag.Bool("trace", false, "generate print statement after every syscall") 52 systemDLL = flag.Bool("systemdll", true, "whether all DLLs should be loaded from the Windows system directory") 53 winio = flag.Bool("winio", false, `import this package ("github.com/Serizao/go-winio")`) 54 utf16 = flag.Bool("utf16", true, "encode string arguments as UTF-16 for syscalls not ending in 'A' or 'W'") 55 sortdecls = flag.Bool("sort", true, "sort DLL and function declarations") 56 ) 57 58 func trim(s string) string { 59 return strings.Trim(s, " \t") 60 } 61 62 func endsIn(s string, c byte) bool { 63 return len(s) >= 1 && s[len(s)-1] == c 64 } 65 66 var packageName string 67 68 func packagename() string { 69 return packageName 70 } 71 72 func windowsdot() string { 73 if packageName == pkgWindows { 74 return "" 75 } 76 return pkgWindows + "." 77 } 78 79 func syscalldot() string { 80 if packageName == pkgSyscall { 81 return "" 82 } 83 return pkgSyscall + "." 84 } 85 86 // Param is function parameter. 87 type Param struct { 88 Name string 89 Type string 90 fn *Fn 91 tmpVarIdx int 92 } 93 94 // tmpVar returns temp variable name that will be used to represent p during syscall. 95 func (p *Param) tmpVar() string { 96 if p.tmpVarIdx < 0 { 97 p.tmpVarIdx = p.fn.curTmpVarIdx 98 p.fn.curTmpVarIdx++ 99 } 100 return fmt.Sprintf("_p%d", p.tmpVarIdx) 101 } 102 103 // BoolTmpVarCode returns source code for bool temp variable. 104 func (p *Param) BoolTmpVarCode() string { 105 const code = `var %[1]s uint32 106 if %[2]s { 107 %[1]s = 1 108 }` 109 return fmt.Sprintf(code, p.tmpVar(), p.Name) 110 } 111 112 // BoolPointerTmpVarCode returns source code for bool temp variable. 113 func (p *Param) BoolPointerTmpVarCode() string { 114 const code = `var %[1]s uint32 115 if *%[2]s { 116 %[1]s = 1 117 }` 118 return fmt.Sprintf(code, p.tmpVar(), p.Name) 119 } 120 121 // SliceTmpVarCode returns source code for slice temp variable. 122 func (p *Param) SliceTmpVarCode() string { 123 const code = `var %s *%s 124 if len(%s) > 0 { 125 %s = &%s[0] 126 }` 127 tmp := p.tmpVar() 128 return fmt.Sprintf(code, tmp, p.Type[2:], p.Name, tmp, p.Name) 129 } 130 131 // StringTmpVarCode returns source code for string temp variable. 132 func (p *Param) StringTmpVarCode() string { 133 errvar := p.fn.Rets.ErrorVarName() 134 if errvar == "" { 135 errvar = "_" 136 } 137 tmp := p.tmpVar() 138 const code = `var %s %s 139 %s, %s = %s(%s)` 140 s := fmt.Sprintf(code, tmp, p.fn.StrconvType(), tmp, errvar, p.fn.StrconvFunc(), p.Name) 141 if errvar == "-" { 142 return s 143 } 144 const morecode = ` 145 if %s != nil { 146 return 147 }` 148 return s + fmt.Sprintf(morecode, errvar) 149 } 150 151 // TmpVarCode returns source code for temp variable. 152 func (p *Param) TmpVarCode() string { 153 switch { 154 case p.Type == tBool: 155 return p.BoolTmpVarCode() 156 case p.Type == tBoolPtr: 157 return p.BoolPointerTmpVarCode() 158 case strings.HasPrefix(p.Type, "[]"): 159 return p.SliceTmpVarCode() 160 default: 161 return "" 162 } 163 } 164 165 // TmpVarReadbackCode returns source code for reading back the temp variable into the original variable. 166 func (p *Param) TmpVarReadbackCode() string { 167 switch { 168 case p.Type == tBoolPtr: 169 return fmt.Sprintf("*%s = %s != 0", p.Name, p.tmpVar()) 170 default: 171 return "" 172 } 173 } 174 175 // TmpVarHelperCode returns source code for helper's temp variable. 176 func (p *Param) TmpVarHelperCode() string { 177 if p.Type != "string" { 178 return "" 179 } 180 return p.StringTmpVarCode() 181 } 182 183 // SyscallArgList returns source code fragments representing p parameter 184 // in syscall. Slices are translated into 2 syscall parameters: pointer to 185 // the first element and length. 186 func (p *Param) SyscallArgList() []string { 187 t := p.HelperType() 188 var s string 189 switch { 190 case t == tBoolPtr: 191 s = fmt.Sprintf("unsafe.Pointer(&%s)", p.tmpVar()) 192 case t[0] == '*': 193 s = fmt.Sprintf("unsafe.Pointer(%s)", p.Name) 194 case t == tBool: 195 s = p.tmpVar() 196 case strings.HasPrefix(t, "[]"): 197 return []string{ 198 fmt.Sprintf("uintptr(unsafe.Pointer(%s))", p.tmpVar()), 199 fmt.Sprintf("uintptr(len(%s))", p.Name), 200 } 201 default: 202 s = p.Name 203 } 204 return []string{fmt.Sprintf("uintptr(%s)", s)} 205 } 206 207 // IsError determines if p parameter is used to return error. 208 func (p *Param) IsError() bool { 209 return p.Name == varErr && p.Type == tError 210 } 211 212 // HelperType returns type of parameter p used in helper function. 213 func (p *Param) HelperType() string { 214 if p.Type == tString { 215 return p.fn.StrconvType() 216 } 217 return p.Type 218 } 219 220 // join concatenates parameters ps into a string with sep separator. 221 // Each parameter is converted into string by applying fn to it 222 // before conversion. 223 func join(ps []*Param, fn func(*Param) string, sep string) string { 224 if len(ps) == 0 { 225 return "" 226 } 227 a := make([]string, 0) 228 for _, p := range ps { 229 a = append(a, fn(p)) 230 } 231 return strings.Join(a, sep) 232 } 233 234 // Rets describes function return parameters. 235 type Rets struct { 236 Name string 237 Type string 238 ReturnsError bool 239 FailCond string 240 fnMaybeAbsent bool 241 } 242 243 // ErrorVarName returns error variable name for r. 244 func (r *Rets) ErrorVarName() string { 245 if r.ReturnsError { 246 return varErr 247 } 248 if r.Type == tError { 249 return r.Name 250 } 251 return "" 252 } 253 254 // ToParams converts r into slice of *Param. 255 func (r *Rets) ToParams() []*Param { 256 ps := make([]*Param, 0) 257 if len(r.Name) > 0 { 258 ps = append(ps, &Param{Name: r.Name, Type: r.Type}) 259 } 260 if r.ReturnsError { 261 ps = append(ps, &Param{Name: varErr, Type: tError}) 262 } 263 return ps 264 } 265 266 // List returns source code of syscall return parameters. 267 func (r *Rets) List() string { 268 s := join(r.ToParams(), func(p *Param) string { return p.Name + " " + p.Type }, ", ") 269 if len(s) > 0 { 270 s = "(" + s + ")" 271 } else if r.fnMaybeAbsent { 272 s = "(err error)" 273 } 274 return s 275 } 276 277 // PrintList returns source code of trace printing part correspondent 278 // to syscall return values. 279 func (r *Rets) PrintList() string { 280 return join(r.ToParams(), func(p *Param) string { return fmt.Sprintf(`"%s=", %s, `, p.Name, p.Name) }, `", ", `) 281 } 282 283 // SetReturnValuesCode returns source code that accepts syscall return values. 284 func (r *Rets) SetReturnValuesCode() string { 285 if r.Name == "" && !r.ReturnsError { 286 return "" 287 } 288 retvar := "r0" 289 if r.Name == "" { 290 retvar = "r1" 291 } 292 errvar := "_" 293 if r.ReturnsError { 294 errvar = "e1" 295 } 296 return fmt.Sprintf("%s, _, %s := ", retvar, errvar) 297 } 298 299 func (r *Rets) useLongHandleErrorCode(retvar string) string { 300 const code = `if %s { 301 err = errnoErr(e1) 302 }` 303 cond := retvar + " == 0" 304 if r.FailCond != "" { 305 cond = strings.Replace(r.FailCond, "failretval", retvar, 1) 306 } 307 return fmt.Sprintf(code, cond) 308 } 309 310 // SetErrorCode returns source code that sets return parameters. 311 func (r *Rets) SetErrorCode() string { 312 const code = `if r0 != 0 { 313 %s = %sErrno(r0) 314 }` 315 const ntStatus = `if r0 != 0 { 316 %s = %sNTStatus(r0) 317 }` 318 const hrCode = `if int32(r0) < 0 { 319 if r0&0x1fff0000 == 0x00070000 { 320 r0 &= 0xffff 321 } 322 %s = %sErrno(r0) 323 }` 324 325 if r.Name == "" && !r.ReturnsError { 326 return "" 327 } 328 if r.Name == "" { 329 return r.useLongHandleErrorCode("r1") 330 } 331 if r.Type == tError { 332 switch r.Name { 333 case varErrNTStatus, strings.ToLower(varErrNTStatus): // allow ntstatus to work 334 return fmt.Sprintf(ntStatus, r.Name, windowsdot()) 335 case varErrHR: 336 return fmt.Sprintf(hrCode, r.Name, syscalldot()) 337 default: 338 return fmt.Sprintf(code, r.Name, syscalldot()) 339 } 340 } 341 342 var s string 343 switch { 344 case r.Type[0] == '*': 345 s = fmt.Sprintf("%s = (%s)(unsafe.Pointer(r0))", r.Name, r.Type) 346 case r.Type == tBool: 347 s = fmt.Sprintf("%s = r0 != 0", r.Name) 348 default: 349 s = fmt.Sprintf("%s = %s(r0)", r.Name, r.Type) 350 } 351 if !r.ReturnsError { 352 return s 353 } 354 return s + "\n\t" + r.useLongHandleErrorCode(r.Name) 355 } 356 357 // Fn describes syscall function. 358 type Fn struct { 359 Name string 360 Params []*Param 361 Rets *Rets 362 PrintTrace bool 363 dllname string 364 dllfuncname string 365 src string 366 // TODO: get rid of this field and just use parameter index instead 367 curTmpVarIdx int // insure tmp variables have uniq names 368 } 369 370 // extractParams parses s to extract function parameters. 371 func extractParams(s string, f *Fn) ([]*Param, error) { 372 s = trim(s) 373 if s == "" { 374 return nil, nil 375 } 376 a := strings.Split(s, ",") 377 ps := make([]*Param, len(a)) 378 for i := range ps { 379 s2 := trim(a[i]) 380 b := strings.Split(s2, " ") 381 if len(b) != 2 { 382 b = strings.Split(s2, "\t") 383 if len(b) != 2 { 384 return nil, errors.New("Could not extract function parameter from \"" + s2 + "\"") 385 } 386 } 387 ps[i] = &Param{ 388 Name: trim(b[0]), 389 Type: trim(b[1]), 390 fn: f, 391 tmpVarIdx: -1, 392 } 393 } 394 return ps, nil 395 } 396 397 // extractSection extracts text out of string s starting after start 398 // and ending just before end. found return value will indicate success, 399 // and prefix, body and suffix will contain correspondent parts of string s. 400 func extractSection(s string, start, end rune) (prefix, body, suffix string, found bool) { 401 s = trim(s) 402 if strings.HasPrefix(s, string(start)) { 403 // no prefix 404 body = s[1:] 405 } else { 406 a := strings.SplitN(s, string(start), 2) 407 if len(a) != 2 { 408 return "", "", s, false 409 } 410 prefix = a[0] 411 body = a[1] 412 } 413 a := strings.SplitN(body, string(end), 2) 414 if len(a) != 2 { 415 return "", "", "", false 416 } 417 return prefix, a[0], a[1], true 418 } 419 420 // newFn parses string s and return created function Fn. 421 func newFn(s string) (*Fn, error) { 422 s = trim(s) 423 f := &Fn{ 424 Rets: &Rets{}, 425 src: s, 426 PrintTrace: *printTraceFlag, 427 } 428 // function name and args 429 prefix, body, s, found := extractSection(s, '(', ')') 430 if !found || prefix == "" { 431 return nil, errors.New("Could not extract function name and parameters from \"" + f.src + "\"") 432 } 433 f.Name = prefix 434 var err error 435 f.Params, err = extractParams(body, f) 436 if err != nil { 437 return nil, err 438 } 439 // return values 440 _, body, s, found = extractSection(s, '(', ')') 441 if found { 442 r, err := extractParams(body, f) 443 if err != nil { 444 return nil, err 445 } 446 switch len(r) { 447 case 0: 448 case 1: 449 if r[0].IsError() { 450 f.Rets.ReturnsError = true 451 } else { 452 f.Rets.Name = r[0].Name 453 f.Rets.Type = r[0].Type 454 } 455 case 2: 456 if !r[1].IsError() { 457 return nil, errors.New("Only last windows error is allowed as second return value in \"" + f.src + "\"") 458 } 459 f.Rets.ReturnsError = true 460 f.Rets.Name = r[0].Name 461 f.Rets.Type = r[0].Type 462 default: 463 return nil, errors.New("Too many return values in \"" + f.src + "\"") 464 } 465 } 466 // fail condition 467 _, body, s, found = extractSection(s, '[', ']') 468 if found { 469 f.Rets.FailCond = body 470 } 471 // dll and dll function names 472 s = trim(s) 473 if s == "" { 474 return f, nil 475 } 476 if !strings.HasPrefix(s, "=") { 477 return nil, errors.New("Could not extract dll name from \"" + f.src + "\"") 478 } 479 s = trim(s[1:]) 480 if i := strings.LastIndex(s, "."); i >= 0 { 481 f.dllname = s[:i] 482 f.dllfuncname = s[i+1:] 483 } else { 484 f.dllfuncname = s 485 } 486 if f.dllfuncname == "" { 487 return nil, fmt.Errorf("function name is not specified in %q", s) 488 } 489 if n := f.dllfuncname; endsIn(n, '?') { 490 f.dllfuncname = n[:len(n)-1] 491 f.Rets.fnMaybeAbsent = true 492 } 493 return f, nil 494 } 495 496 // DLLName returns DLL name for function f. 497 func (f *Fn) DLLName() string { 498 if f.dllname == "" { 499 return "kernel32" 500 } 501 return f.dllname 502 } 503 504 // DLLVar returns a valid Go identifier that represents DLLName. 505 func (f *Fn) DLLVar() string { 506 id := strings.Map(func(r rune) rune { 507 switch r { 508 case '.', '-': 509 return '_' 510 default: 511 return r 512 } 513 }, f.DLLName()) 514 if !token.IsIdentifier(id) { 515 panic(fmt.Errorf("could not create Go identifier for DLLName %q", f.DLLName())) 516 } 517 return id 518 } 519 520 // DLLFuncName returns DLL function name for function f. 521 func (f *Fn) DLLFuncName() string { 522 if f.dllfuncname == "" { 523 return f.Name 524 } 525 return f.dllfuncname 526 } 527 528 // ParamList returns source code for function f parameters. 529 func (f *Fn) ParamList() string { 530 return join(f.Params, func(p *Param) string { return p.Name + " " + p.Type }, ", ") 531 } 532 533 // HelperParamList returns source code for helper function f parameters. 534 func (f *Fn) HelperParamList() string { 535 return join(f.Params, func(p *Param) string { return p.Name + " " + p.HelperType() }, ", ") 536 } 537 538 // ParamPrintList returns source code of trace printing part correspondent 539 // to syscall input parameters. 540 func (f *Fn) ParamPrintList() string { 541 return join(f.Params, func(p *Param) string { return fmt.Sprintf(`"%s=", %s, `, p.Name, p.Name) }, `", ", `) 542 } 543 544 // ParamCount return number of syscall parameters for function f. 545 func (f *Fn) ParamCount() int { 546 n := 0 547 for _, p := range f.Params { 548 n += len(p.SyscallArgList()) 549 } 550 return n 551 } 552 553 // SyscallParamCount determines which version of Syscall/Syscall6/Syscall9/... 554 // to use. It returns parameter count for correspondent SyscallX function. 555 func (f *Fn) SyscallParamCount() int { 556 n := f.ParamCount() 557 switch { 558 case n <= 3: 559 return 3 560 case n <= 6: 561 return 6 562 case n <= 9: 563 return 9 564 case n <= 12: 565 return 12 566 case n <= 15: 567 return 15 568 default: 569 panic("too many arguments to system call") 570 } 571 } 572 573 // Syscall determines which SyscallX function to use for function f. 574 func (f *Fn) Syscall() string { 575 c := f.SyscallParamCount() 576 if c == 3 { 577 return syscalldot() + "Syscall" 578 } 579 return syscalldot() + "Syscall" + strconv.Itoa(c) 580 } 581 582 // SyscallParamList returns source code for SyscallX parameters for function f. 583 func (f *Fn) SyscallParamList() string { 584 a := make([]string, 0) 585 for _, p := range f.Params { 586 a = append(a, p.SyscallArgList()...) 587 } 588 for len(a) < f.SyscallParamCount() { 589 a = append(a, "0") 590 } 591 return strings.Join(a, ", ") 592 } 593 594 // HelperCallParamList returns source code of call into function f helper. 595 func (f *Fn) HelperCallParamList() string { 596 a := make([]string, 0, len(f.Params)) 597 for _, p := range f.Params { 598 s := p.Name 599 if p.Type == tString { 600 s = p.tmpVar() 601 } 602 a = append(a, s) 603 } 604 return strings.Join(a, ", ") 605 } 606 607 // MaybeAbsent returns source code for handling functions that are possibly unavailable. 608 func (f *Fn) MaybeAbsent() string { 609 if !f.Rets.fnMaybeAbsent { 610 return "" 611 } 612 const code = `%[1]s = proc%[2]s.Find() 613 if %[1]s != nil { 614 return 615 }` 616 errorVar := f.Rets.ErrorVarName() 617 if errorVar == "" { 618 errorVar = varErr 619 } 620 return fmt.Sprintf(code, errorVar, f.DLLFuncName()) 621 } 622 623 // IsUTF16 is true, if f is W (UTF-16) function and false for all A (ASCII) functions. 624 // Functions ending in neither will default to UTF-16, unless the `-utf16` flag is set 625 // to `false`. 626 func (f *Fn) IsUTF16() bool { 627 s := f.DLLFuncName() 628 return endsIn(s, 'W') || (*utf16 && !endsIn(s, 'A')) 629 } 630 631 // StrconvFunc returns name of Go string to OS string function for f. 632 func (f *Fn) StrconvFunc() string { 633 if f.IsUTF16() { 634 return syscalldot() + "UTF16PtrFromString" 635 } 636 return syscalldot() + "BytePtrFromString" 637 } 638 639 // StrconvType returns Go type name used for OS string for f. 640 func (f *Fn) StrconvType() string { 641 if f.IsUTF16() { 642 return "*uint16" 643 } 644 return "*byte" 645 } 646 647 // HasStringParam is true, if f has at least one string parameter. 648 // Otherwise it is false. 649 func (f *Fn) HasStringParam() bool { 650 for _, p := range f.Params { 651 if p.Type == tString { 652 return true 653 } 654 } 655 return false 656 } 657 658 // HelperName returns name of function f helper. 659 func (f *Fn) HelperName() string { 660 if !f.HasStringParam() { 661 return f.Name 662 } 663 return "_" + f.Name 664 } 665 666 // DLL is a DLL's filename and a string that is valid in a Go identifier that should be used when 667 // naming a variable that refers to the DLL. 668 type DLL struct { 669 Name string 670 Var string 671 } 672 673 // Source files and functions. 674 type Source struct { 675 Funcs []*Fn 676 DLLFuncNames []*Fn 677 Files []string 678 StdLibImports []string 679 ExternalImports []string 680 } 681 682 func (src *Source) Import(pkg string) { 683 src.StdLibImports = append(src.StdLibImports, pkg) 684 sort.Strings(src.StdLibImports) 685 } 686 687 func (src *Source) ExternalImport(pkg string) { 688 src.ExternalImports = append(src.ExternalImports, pkg) 689 sort.Strings(src.ExternalImports) 690 } 691 692 // ParseFiles parses files listed in fs and extracts all syscall 693 // functions listed in sys comments. It returns source files 694 // and functions collection *Source if successful. 695 func ParseFiles(fs []string) (*Source, error) { 696 src := &Source{ 697 Funcs: make([]*Fn, 0), 698 Files: make([]string, 0), 699 StdLibImports: []string{ 700 "unsafe", 701 }, 702 ExternalImports: make([]string, 0), 703 } 704 for _, file := range fs { 705 if err := src.ParseFile(file); err != nil { 706 return nil, err 707 } 708 } 709 src.DLLFuncNames = make([]*Fn, 0, len(src.Funcs)) 710 uniq := make(map[string]bool, len(src.Funcs)) 711 for _, fn := range src.Funcs { 712 name := fn.DLLFuncName() 713 if !uniq[name] { 714 src.DLLFuncNames = append(src.DLLFuncNames, fn) 715 uniq[name] = true 716 } 717 } 718 return src, nil 719 } 720 721 // DLLs return dll names for a source set src. 722 func (src *Source) DLLs() []DLL { 723 uniq := make(map[string]bool) 724 r := make([]DLL, 0) 725 for _, f := range src.Funcs { 726 id := f.DLLVar() 727 if _, found := uniq[id]; !found { 728 uniq[id] = true 729 r = append(r, DLL{f.DLLName(), id}) 730 } 731 } 732 if *sortdecls { 733 sort.Slice(r, func(i, j int) bool { 734 return r[i].Var < r[j].Var 735 }) 736 } 737 return r 738 } 739 740 // ParseFile adds additional file (or files, if path is a glob pattern) path to a source set src. 741 func (src *Source) ParseFile(path string) error { 742 file, err := os.Open(path) 743 if err == nil { 744 defer file.Close() 745 return src.parseFile(file) 746 } else if !(errors.Is(err, os.ErrNotExist) || errors.Is(err, windows.ERROR_INVALID_NAME)) { 747 return err 748 } 749 750 paths, err := filepath.Glob(path) 751 if err != nil { 752 return err 753 } 754 755 for _, path := range paths { 756 file, err := os.Open(path) 757 if err != nil { 758 return err 759 } 760 err = src.parseFile(file) 761 file.Close() 762 if err != nil { 763 return err 764 } 765 } 766 767 return nil 768 } 769 770 func (src *Source) parseFile(file *os.File) error { 771 s := bufio.NewScanner(file) 772 for s.Scan() { 773 t := trim(s.Text()) 774 if len(t) < 7 { 775 continue 776 } 777 if !strings.HasPrefix(t, "//sys") { 778 continue 779 } 780 t = t[5:] 781 if !(t[0] == ' ' || t[0] == '\t') { 782 continue 783 } 784 f, err := newFn(t[1:]) 785 if err != nil { 786 return err 787 } 788 src.Funcs = append(src.Funcs, f) 789 } 790 if err := s.Err(); err != nil { 791 return err 792 } 793 src.Files = append(src.Files, file.Name()) 794 if *sortdecls { 795 sort.Slice(src.Funcs, func(i, j int) bool { 796 fi, fj := src.Funcs[i], src.Funcs[j] 797 if fi.DLLName() == fj.DLLName() { 798 return fi.DLLFuncName() < fj.DLLFuncName() 799 } 800 return fi.DLLName() < fj.DLLName() 801 }) 802 } 803 804 // get package name 805 fset := token.NewFileSet() 806 _, err := file.Seek(0, 0) 807 if err != nil { 808 return err 809 } 810 pkg, err := parser.ParseFile(fset, "", file, parser.PackageClauseOnly) 811 if err != nil { 812 return err 813 } 814 packageName = pkg.Name.Name 815 816 return nil 817 } 818 819 // IsStdRepo reports whether src is part of standard library. 820 func (src *Source) IsStdRepo() (bool, error) { 821 if len(src.Files) == 0 { 822 return false, errors.New("no input files provided") 823 } 824 abspath, err := filepath.Abs(src.Files[0]) 825 if err != nil { 826 return false, err 827 } 828 goroot := runtime.GOROOT() 829 if runtime.GOOS == "windows" { 830 abspath = strings.ToLower(abspath) 831 goroot = strings.ToLower(goroot) 832 } 833 sep := string(os.PathSeparator) 834 if !strings.HasSuffix(goroot, sep) { 835 goroot += sep 836 } 837 return strings.HasPrefix(abspath, goroot), nil 838 } 839 840 // Generate output source file from a source set src. 841 func (src *Source) Generate(w io.Writer) error { 842 const ( 843 pkgStd = iota // any package in std library 844 pkgXSysWindows // x/sys/windows package 845 pkgOther 846 ) 847 isStdRepo, err := src.IsStdRepo() 848 if err != nil { 849 return err 850 } 851 var pkgtype int 852 switch { 853 case isStdRepo: 854 pkgtype = pkgStd 855 case packageName == "windows": 856 // TODO: this needs better logic than just using package name 857 pkgtype = pkgXSysWindows 858 default: 859 pkgtype = pkgOther 860 } 861 if *systemDLL { 862 switch pkgtype { 863 case pkgStd: 864 src.Import("internal/syscall/windows/sysdll") 865 case pkgXSysWindows: 866 default: 867 src.ExternalImport("golang.org/x/sys/windows") 868 } 869 } 870 if *winio { 871 src.ExternalImport("github.com/Serizao/go-winio") 872 } 873 if packageName != "syscall" { 874 src.Import("syscall") 875 } 876 funcMap := template.FuncMap{ 877 "packagename": packagename, 878 "syscalldot": syscalldot, 879 "newlazydll": func(dll string) string { 880 arg := "\"" + dll + ".dll\"" 881 if !*systemDLL { 882 return syscalldot() + "NewLazyDLL(" + arg + ")" 883 } 884 if strings.HasPrefix(dll, "api_") || strings.HasPrefix(dll, "ext_") { 885 arg = strings.Replace(arg, "_", "-", -1) 886 } 887 switch pkgtype { 888 case pkgStd: 889 return syscalldot() + "NewLazyDLL(sysdll.Add(" + arg + "))" 890 case pkgXSysWindows: 891 return "NewLazySystemDLL(" + arg + ")" 892 default: 893 return "windows.NewLazySystemDLL(" + arg + ")" 894 } 895 }, 896 } 897 t := template.Must(template.New("main").Funcs(funcMap).Parse(srcTemplate)) 898 err = t.Execute(w, src) 899 if err != nil { 900 return errors.New("Failed to execute template: " + err.Error()) 901 } 902 return nil 903 } 904 905 func writeTempSourceFile(data []byte) (string, error) { 906 f, err := os.CreateTemp("", "mkwinsyscall-generated-*.go") 907 if err != nil { 908 return "", err 909 } 910 _, err = f.Write(data) 911 if closeErr := f.Close(); err == nil { 912 err = closeErr 913 } 914 if err != nil { 915 os.Remove(f.Name()) // best effort 916 return "", err 917 } 918 return f.Name(), nil 919 } 920 921 func usage() { 922 fmt.Fprintf(os.Stderr, "usage: mkwinsyscall [flags] [path ...]\n") 923 flag.PrintDefaults() 924 os.Exit(1) 925 } 926 927 func main() { 928 flag.Usage = usage 929 flag.Parse() 930 if len(flag.Args()) <= 0 { 931 fmt.Fprintf(os.Stderr, "no files to parse provided\n") 932 usage() 933 } 934 935 src, err := ParseFiles(flag.Args()) 936 if err != nil { 937 log.Fatal(err) 938 } 939 940 var buf bytes.Buffer 941 if err := src.Generate(&buf); err != nil { 942 log.Fatal(err) 943 } 944 945 data, err := format.Source(buf.Bytes()) 946 if err != nil { 947 log.Printf("failed to format source: %v", err) 948 f, err := writeTempSourceFile(buf.Bytes()) 949 if err != nil { 950 log.Fatalf("failed to write unformatted source to file: %v", err) 951 } 952 log.Fatalf("for diagnosis, wrote unformatted source to %v", f) 953 } 954 if *filename == "" { 955 _, err = os.Stdout.Write(data) 956 } else { 957 //nolint:gosec // G306: code file, no need for wants 0600 958 err = os.WriteFile(*filename, data, 0644) 959 } 960 if err != nil { 961 log.Fatal(err) 962 } 963 } 964 965 // TODO: use println instead to print in the following template 966 967 const srcTemplate = ` 968 {{define "main"}} //go:build windows 969 970 // Code generated by 'go generate' using "github.com/Serizao/go-winio/tools/mkwinsyscall"; DO NOT EDIT. 971 972 package {{packagename}} 973 974 import ( 975 {{range .StdLibImports}}"{{.}}" 976 {{end}} 977 978 {{range .ExternalImports}}"{{.}}" 979 {{end}} 980 ) 981 982 var _ unsafe.Pointer 983 984 // Do the interface allocations only once for common 985 // Errno values. 986 const ( 987 errnoERROR_IO_PENDING = 997 988 ) 989 990 var ( 991 errERROR_IO_PENDING error = {{syscalldot}}Errno(errnoERROR_IO_PENDING) 992 errERROR_EINVAL error = {{syscalldot}}EINVAL 993 ) 994 995 // errnoErr returns common boxed Errno values, to prevent 996 // allocations at runtime. 997 func errnoErr(e {{syscalldot}}Errno) error { 998 switch e { 999 case 0: 1000 return errERROR_EINVAL 1001 case errnoERROR_IO_PENDING: 1002 return errERROR_IO_PENDING 1003 } 1004 // TODO: add more here, after collecting data on the common 1005 // error values see on Windows. (perhaps when running 1006 // all.bat?) 1007 return e 1008 } 1009 1010 var ( 1011 {{template "dlls" .}} 1012 {{template "funcnames" .}}) 1013 {{range .Funcs}}{{if .HasStringParam}}{{template "helperbody" .}}{{end}}{{template "funcbody" .}}{{end}} 1014 {{end}} 1015 1016 {{/* help functions */}} 1017 1018 {{define "dlls"}}{{range .DLLs}} mod{{.Var}} = {{newlazydll .Name}} 1019 {{end}}{{end}} 1020 1021 {{define "funcnames"}}{{range .DLLFuncNames}} proc{{.DLLFuncName}} = mod{{.DLLVar}}.NewProc("{{.DLLFuncName}}") 1022 {{end}}{{end}} 1023 1024 {{define "helperbody"}} 1025 func {{.Name}}({{.ParamList}}) {{template "results" .}}{ 1026 {{template "helpertmpvars" .}} return {{.HelperName}}({{.HelperCallParamList}}) 1027 } 1028 {{end}} 1029 1030 {{define "funcbody"}} 1031 func {{.HelperName}}({{.HelperParamList}}) {{template "results" .}}{ 1032 {{template "maybeabsent" .}} {{template "tmpvars" .}} {{template "syscall" .}} {{template "tmpvarsreadback" .}} 1033 {{template "seterror" .}}{{template "printtrace" .}} return 1034 } 1035 {{end}} 1036 1037 {{define "helpertmpvars"}}{{range .Params}}{{if .TmpVarHelperCode}} {{.TmpVarHelperCode}} 1038 {{end}}{{end}}{{end}} 1039 1040 {{define "maybeabsent"}}{{if .MaybeAbsent}}{{.MaybeAbsent}} 1041 {{end}}{{end}} 1042 1043 {{define "tmpvars"}}{{range .Params}}{{if .TmpVarCode}} {{.TmpVarCode}} 1044 {{end}}{{end}}{{end}} 1045 1046 {{define "results"}}{{if .Rets.List}}{{.Rets.List}} {{end}}{{end}} 1047 1048 {{define "syscall"}}{{.Rets.SetReturnValuesCode}}{{.Syscall}}(proc{{.DLLFuncName}}.Addr(), {{.ParamCount}}, {{.SyscallParamList}}){{end}} 1049 1050 {{define "tmpvarsreadback"}}{{range .Params}}{{if .TmpVarReadbackCode}} 1051 {{.TmpVarReadbackCode}}{{end}}{{end}}{{end}} 1052 1053 {{define "seterror"}}{{if .Rets.SetErrorCode}} {{.Rets.SetErrorCode}} 1054 {{end}}{{end}} 1055 1056 {{define "printtrace"}}{{if .PrintTrace}} print("SYSCALL: {{.Name}}(", {{.ParamPrintList}}") (", {{.Rets.PrintList}}")\n") 1057 {{end}}{{end}} 1058 1059 `