golang.org/x/sys@v0.9.0/windows/mkwinsyscall/mkwinsyscall.go (about)

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  /*
     6  mkwinsyscall generates windows system call bodies
     7  
     8  It parses all files specified on command line containing function
     9  prototypes (like syscall_windows.go) and prints system call bodies
    10  to standard output.
    11  
    12  The prototypes are marked by lines beginning with "//sys" and read
    13  like func declarations if //sys is replaced by func, but:
    14  
    15    - The parameter lists must give a name for each argument. This
    16      includes return parameters.
    17  
    18    - The parameter lists must give a type for each argument:
    19      the (x, y, z int) shorthand is not allowed.
    20  
    21    - If the return parameter is an error number, it must be named err.
    22  
    23    - If go func name needs to be different from its winapi dll name,
    24      the winapi name could be specified at the end, after "=" sign, like
    25      //sys LoadLibrary(libname string) (handle uint32, err error) = LoadLibraryA
    26  
    27    - Each function that returns err needs to supply a condition, that
    28      return value of winapi will be tested against to detect failure.
    29      This would set err to windows "last-error", otherwise it will be nil.
    30      The value can be provided at end of //sys declaration, like
    31      //sys LoadLibrary(libname string) (handle uint32, err error) [failretval==-1] = LoadLibraryA
    32      and is [failretval==0] by default.
    33  
    34    - If the function name ends in a "?", then the function not existing is non-
    35      fatal, and an error will be returned instead of panicking.
    36  
    37  Usage:
    38  
    39  	mkwinsyscall [flags] [path ...]
    40  
    41  The flags are:
    42  
    43  	-output
    44  		Specify output file name (outputs to console if blank).
    45  	-trace
    46  		Generate print statement after every syscall.
    47  */
    48  package main
    49  
    50  import (
    51  	"bufio"
    52  	"bytes"
    53  	"errors"
    54  	"flag"
    55  	"fmt"
    56  	"go/format"
    57  	"go/parser"
    58  	"go/token"
    59  	"io"
    60  	"io/ioutil"
    61  	"log"
    62  	"os"
    63  	"path/filepath"
    64  	"runtime"
    65  	"sort"
    66  	"strconv"
    67  	"strings"
    68  	"text/template"
    69  )
    70  
    71  var (
    72  	filename       = flag.String("output", "", "output file name (standard output if omitted)")
    73  	printTraceFlag = flag.Bool("trace", false, "generate print statement after every syscall")
    74  	systemDLL      = flag.Bool("systemdll", true, "whether all DLLs should be loaded from the Windows system directory")
    75  )
    76  
    77  func trim(s string) string {
    78  	return strings.Trim(s, " \t")
    79  }
    80  
    81  var packageName string
    82  
    83  func packagename() string {
    84  	return packageName
    85  }
    86  
    87  func windowsdot() string {
    88  	if packageName == "windows" {
    89  		return ""
    90  	}
    91  	return "windows."
    92  }
    93  
    94  func syscalldot() string {
    95  	if packageName == "syscall" {
    96  		return ""
    97  	}
    98  	return "syscall."
    99  }
   100  
   101  // Param is function parameter
   102  type Param struct {
   103  	Name      string
   104  	Type      string
   105  	fn        *Fn
   106  	tmpVarIdx int
   107  }
   108  
   109  // tmpVar returns temp variable name that will be used to represent p during syscall.
   110  func (p *Param) tmpVar() string {
   111  	if p.tmpVarIdx < 0 {
   112  		p.tmpVarIdx = p.fn.curTmpVarIdx
   113  		p.fn.curTmpVarIdx++
   114  	}
   115  	return fmt.Sprintf("_p%d", p.tmpVarIdx)
   116  }
   117  
   118  // BoolTmpVarCode returns source code for bool temp variable.
   119  func (p *Param) BoolTmpVarCode() string {
   120  	const code = `var %[1]s uint32
   121  	if %[2]s {
   122  		%[1]s = 1
   123  	}`
   124  	return fmt.Sprintf(code, p.tmpVar(), p.Name)
   125  }
   126  
   127  // BoolPointerTmpVarCode returns source code for bool temp variable.
   128  func (p *Param) BoolPointerTmpVarCode() string {
   129  	const code = `var %[1]s uint32
   130  	if *%[2]s {
   131  		%[1]s = 1
   132  	}`
   133  	return fmt.Sprintf(code, p.tmpVar(), p.Name)
   134  }
   135  
   136  // SliceTmpVarCode returns source code for slice temp variable.
   137  func (p *Param) SliceTmpVarCode() string {
   138  	const code = `var %s *%s
   139  	if len(%s) > 0 {
   140  		%s = &%s[0]
   141  	}`
   142  	tmp := p.tmpVar()
   143  	return fmt.Sprintf(code, tmp, p.Type[2:], p.Name, tmp, p.Name)
   144  }
   145  
   146  // StringTmpVarCode returns source code for string temp variable.
   147  func (p *Param) StringTmpVarCode() string {
   148  	errvar := p.fn.Rets.ErrorVarName()
   149  	if errvar == "" {
   150  		errvar = "_"
   151  	}
   152  	tmp := p.tmpVar()
   153  	const code = `var %s %s
   154  	%s, %s = %s(%s)`
   155  	s := fmt.Sprintf(code, tmp, p.fn.StrconvType(), tmp, errvar, p.fn.StrconvFunc(), p.Name)
   156  	if errvar == "-" {
   157  		return s
   158  	}
   159  	const morecode = `
   160  	if %s != nil {
   161  		return
   162  	}`
   163  	return s + fmt.Sprintf(morecode, errvar)
   164  }
   165  
   166  // TmpVarCode returns source code for temp variable.
   167  func (p *Param) TmpVarCode() string {
   168  	switch {
   169  	case p.Type == "bool":
   170  		return p.BoolTmpVarCode()
   171  	case p.Type == "*bool":
   172  		return p.BoolPointerTmpVarCode()
   173  	case strings.HasPrefix(p.Type, "[]"):
   174  		return p.SliceTmpVarCode()
   175  	default:
   176  		return ""
   177  	}
   178  }
   179  
   180  // TmpVarReadbackCode returns source code for reading back the temp variable into the original variable.
   181  func (p *Param) TmpVarReadbackCode() string {
   182  	switch {
   183  	case p.Type == "*bool":
   184  		return fmt.Sprintf("*%s = %s != 0", p.Name, p.tmpVar())
   185  	default:
   186  		return ""
   187  	}
   188  }
   189  
   190  // TmpVarHelperCode returns source code for helper's temp variable.
   191  func (p *Param) TmpVarHelperCode() string {
   192  	if p.Type != "string" {
   193  		return ""
   194  	}
   195  	return p.StringTmpVarCode()
   196  }
   197  
   198  // SyscallArgList returns source code fragments representing p parameter
   199  // in syscall. Slices are translated into 2 syscall parameters: pointer to
   200  // the first element and length.
   201  func (p *Param) SyscallArgList() []string {
   202  	t := p.HelperType()
   203  	var s string
   204  	switch {
   205  	case t == "*bool":
   206  		s = fmt.Sprintf("unsafe.Pointer(&%s)", p.tmpVar())
   207  	case t[0] == '*':
   208  		s = fmt.Sprintf("unsafe.Pointer(%s)", p.Name)
   209  	case t == "bool":
   210  		s = p.tmpVar()
   211  	case strings.HasPrefix(t, "[]"):
   212  		return []string{
   213  			fmt.Sprintf("uintptr(unsafe.Pointer(%s))", p.tmpVar()),
   214  			fmt.Sprintf("uintptr(len(%s))", p.Name),
   215  		}
   216  	default:
   217  		s = p.Name
   218  	}
   219  	return []string{fmt.Sprintf("uintptr(%s)", s)}
   220  }
   221  
   222  // IsError determines if p parameter is used to return error.
   223  func (p *Param) IsError() bool {
   224  	return p.Name == "err" && p.Type == "error"
   225  }
   226  
   227  // HelperType returns type of parameter p used in helper function.
   228  func (p *Param) HelperType() string {
   229  	if p.Type == "string" {
   230  		return p.fn.StrconvType()
   231  	}
   232  	return p.Type
   233  }
   234  
   235  // join concatenates parameters ps into a string with sep separator.
   236  // Each parameter is converted into string by applying fn to it
   237  // before conversion.
   238  func join(ps []*Param, fn func(*Param) string, sep string) string {
   239  	if len(ps) == 0 {
   240  		return ""
   241  	}
   242  	a := make([]string, 0)
   243  	for _, p := range ps {
   244  		a = append(a, fn(p))
   245  	}
   246  	return strings.Join(a, sep)
   247  }
   248  
   249  // Rets describes function return parameters.
   250  type Rets struct {
   251  	Name          string
   252  	Type          string
   253  	ReturnsError  bool
   254  	FailCond      string
   255  	fnMaybeAbsent bool
   256  }
   257  
   258  // ErrorVarName returns error variable name for r.
   259  func (r *Rets) ErrorVarName() string {
   260  	if r.ReturnsError {
   261  		return "err"
   262  	}
   263  	if r.Type == "error" {
   264  		return r.Name
   265  	}
   266  	return ""
   267  }
   268  
   269  // ToParams converts r into slice of *Param.
   270  func (r *Rets) ToParams() []*Param {
   271  	ps := make([]*Param, 0)
   272  	if len(r.Name) > 0 {
   273  		ps = append(ps, &Param{Name: r.Name, Type: r.Type})
   274  	}
   275  	if r.ReturnsError {
   276  		ps = append(ps, &Param{Name: "err", Type: "error"})
   277  	}
   278  	return ps
   279  }
   280  
   281  // List returns source code of syscall return parameters.
   282  func (r *Rets) List() string {
   283  	s := join(r.ToParams(), func(p *Param) string { return p.Name + " " + p.Type }, ", ")
   284  	if len(s) > 0 {
   285  		s = "(" + s + ")"
   286  	} else if r.fnMaybeAbsent {
   287  		s = "(err error)"
   288  	}
   289  	return s
   290  }
   291  
   292  // PrintList returns source code of trace printing part correspondent
   293  // to syscall return values.
   294  func (r *Rets) PrintList() string {
   295  	return join(r.ToParams(), func(p *Param) string { return fmt.Sprintf(`"%s=", %s, `, p.Name, p.Name) }, `", ", `)
   296  }
   297  
   298  // SetReturnValuesCode returns source code that accepts syscall return values.
   299  func (r *Rets) SetReturnValuesCode() string {
   300  	if r.Name == "" && !r.ReturnsError {
   301  		return ""
   302  	}
   303  	retvar := "r0"
   304  	if r.Name == "" {
   305  		retvar = "r1"
   306  	}
   307  	errvar := "_"
   308  	if r.ReturnsError {
   309  		errvar = "e1"
   310  	}
   311  	return fmt.Sprintf("%s, _, %s := ", retvar, errvar)
   312  }
   313  
   314  func (r *Rets) useLongHandleErrorCode(retvar string) string {
   315  	const code = `if %s {
   316  		err = errnoErr(e1)
   317  	}`
   318  	cond := retvar + " == 0"
   319  	if r.FailCond != "" {
   320  		cond = strings.Replace(r.FailCond, "failretval", retvar, 1)
   321  	}
   322  	return fmt.Sprintf(code, cond)
   323  }
   324  
   325  // SetErrorCode returns source code that sets return parameters.
   326  func (r *Rets) SetErrorCode() string {
   327  	const code = `if r0 != 0 {
   328  		%s = %sErrno(r0)
   329  	}`
   330  	const ntstatus = `if r0 != 0 {
   331  		ntstatus = %sNTStatus(r0)
   332  	}`
   333  	if r.Name == "" && !r.ReturnsError {
   334  		return ""
   335  	}
   336  	if r.Name == "" {
   337  		return r.useLongHandleErrorCode("r1")
   338  	}
   339  	if r.Type == "error" && r.Name == "ntstatus" {
   340  		return fmt.Sprintf(ntstatus, windowsdot())
   341  	}
   342  	if r.Type == "error" {
   343  		return fmt.Sprintf(code, r.Name, syscalldot())
   344  	}
   345  	s := ""
   346  	switch {
   347  	case r.Type[0] == '*':
   348  		s = fmt.Sprintf("%s = (%s)(unsafe.Pointer(r0))", r.Name, r.Type)
   349  	case r.Type == "bool":
   350  		s = fmt.Sprintf("%s = r0 != 0", r.Name)
   351  	default:
   352  		s = fmt.Sprintf("%s = %s(r0)", r.Name, r.Type)
   353  	}
   354  	if !r.ReturnsError {
   355  		return s
   356  	}
   357  	return s + "\n\t" + r.useLongHandleErrorCode(r.Name)
   358  }
   359  
   360  // Fn describes syscall function.
   361  type Fn struct {
   362  	Name        string
   363  	Params      []*Param
   364  	Rets        *Rets
   365  	PrintTrace  bool
   366  	dllname     string
   367  	dllfuncname string
   368  	src         string
   369  	// TODO: get rid of this field and just use parameter index instead
   370  	curTmpVarIdx int // insure tmp variables have uniq names
   371  }
   372  
   373  // extractParams parses s to extract function parameters.
   374  func extractParams(s string, f *Fn) ([]*Param, error) {
   375  	s = trim(s)
   376  	if s == "" {
   377  		return nil, nil
   378  	}
   379  	a := strings.Split(s, ",")
   380  	ps := make([]*Param, len(a))
   381  	for i := range ps {
   382  		s2 := trim(a[i])
   383  		b := strings.Split(s2, " ")
   384  		if len(b) != 2 {
   385  			b = strings.Split(s2, "\t")
   386  			if len(b) != 2 {
   387  				return nil, errors.New("Could not extract function parameter from \"" + s2 + "\"")
   388  			}
   389  		}
   390  		ps[i] = &Param{
   391  			Name:      trim(b[0]),
   392  			Type:      trim(b[1]),
   393  			fn:        f,
   394  			tmpVarIdx: -1,
   395  		}
   396  	}
   397  	return ps, nil
   398  }
   399  
   400  // extractSection extracts text out of string s starting after start
   401  // and ending just before end. found return value will indicate success,
   402  // and prefix, body and suffix will contain correspondent parts of string s.
   403  func extractSection(s string, start, end rune) (prefix, body, suffix string, found bool) {
   404  	s = trim(s)
   405  	if strings.HasPrefix(s, string(start)) {
   406  		// no prefix
   407  		body = s[1:]
   408  	} else {
   409  		a := strings.SplitN(s, string(start), 2)
   410  		if len(a) != 2 {
   411  			return "", "", s, false
   412  		}
   413  		prefix = a[0]
   414  		body = a[1]
   415  	}
   416  	a := strings.SplitN(body, string(end), 2)
   417  	if len(a) != 2 {
   418  		return "", "", "", false
   419  	}
   420  	return prefix, a[0], a[1], true
   421  }
   422  
   423  // newFn parses string s and return created function Fn.
   424  func newFn(s string) (*Fn, error) {
   425  	s = trim(s)
   426  	f := &Fn{
   427  		Rets:       &Rets{},
   428  		src:        s,
   429  		PrintTrace: *printTraceFlag,
   430  	}
   431  	// function name and args
   432  	prefix, body, s, found := extractSection(s, '(', ')')
   433  	if !found || prefix == "" {
   434  		return nil, errors.New("Could not extract function name and parameters from \"" + f.src + "\"")
   435  	}
   436  	f.Name = prefix
   437  	var err error
   438  	f.Params, err = extractParams(body, f)
   439  	if err != nil {
   440  		return nil, err
   441  	}
   442  	// return values
   443  	_, body, s, found = extractSection(s, '(', ')')
   444  	if found {
   445  		r, err := extractParams(body, f)
   446  		if err != nil {
   447  			return nil, err
   448  		}
   449  		switch len(r) {
   450  		case 0:
   451  		case 1:
   452  			if r[0].IsError() {
   453  				f.Rets.ReturnsError = true
   454  			} else {
   455  				f.Rets.Name = r[0].Name
   456  				f.Rets.Type = r[0].Type
   457  			}
   458  		case 2:
   459  			if !r[1].IsError() {
   460  				return nil, errors.New("Only last windows error is allowed as second return value in \"" + f.src + "\"")
   461  			}
   462  			f.Rets.ReturnsError = true
   463  			f.Rets.Name = r[0].Name
   464  			f.Rets.Type = r[0].Type
   465  		default:
   466  			return nil, errors.New("Too many return values in \"" + f.src + "\"")
   467  		}
   468  	}
   469  	// fail condition
   470  	_, body, s, found = extractSection(s, '[', ']')
   471  	if found {
   472  		f.Rets.FailCond = body
   473  	}
   474  	// dll and dll function names
   475  	s = trim(s)
   476  	if s == "" {
   477  		return f, nil
   478  	}
   479  	if !strings.HasPrefix(s, "=") {
   480  		return nil, errors.New("Could not extract dll name from \"" + f.src + "\"")
   481  	}
   482  	s = trim(s[1:])
   483  	if i := strings.LastIndex(s, "."); i >= 0 {
   484  		f.dllname = s[:i]
   485  		f.dllfuncname = s[i+1:]
   486  	} else {
   487  		f.dllfuncname = s
   488  	}
   489  	if f.dllfuncname == "" {
   490  		return nil, fmt.Errorf("function name is not specified in %q", s)
   491  	}
   492  	if n := f.dllfuncname; strings.HasSuffix(n, "?") {
   493  		f.dllfuncname = n[:len(n)-1]
   494  		f.Rets.fnMaybeAbsent = true
   495  	}
   496  	return f, nil
   497  }
   498  
   499  // DLLName returns DLL name for function f.
   500  func (f *Fn) DLLName() string {
   501  	if f.dllname == "" {
   502  		return "kernel32"
   503  	}
   504  	return f.dllname
   505  }
   506  
   507  // DLLVar returns a valid Go identifier that represents DLLName.
   508  func (f *Fn) DLLVar() string {
   509  	id := strings.Map(func(r rune) rune {
   510  		switch r {
   511  		case '.', '-':
   512  			return '_'
   513  		default:
   514  			return r
   515  		}
   516  	}, f.DLLName())
   517  	if !token.IsIdentifier(id) {
   518  		panic(fmt.Errorf("could not create Go identifier for DLLName %q", f.DLLName()))
   519  	}
   520  	return id
   521  }
   522  
   523  // DLLFuncName returns DLL function name for function f.
   524  func (f *Fn) DLLFuncName() string {
   525  	if f.dllfuncname == "" {
   526  		return f.Name
   527  	}
   528  	return f.dllfuncname
   529  }
   530  
   531  // ParamList returns source code for function f parameters.
   532  func (f *Fn) ParamList() string {
   533  	return join(f.Params, func(p *Param) string { return p.Name + " " + p.Type }, ", ")
   534  }
   535  
   536  // HelperParamList returns source code for helper function f parameters.
   537  func (f *Fn) HelperParamList() string {
   538  	return join(f.Params, func(p *Param) string { return p.Name + " " + p.HelperType() }, ", ")
   539  }
   540  
   541  // ParamPrintList returns source code of trace printing part correspondent
   542  // to syscall input parameters.
   543  func (f *Fn) ParamPrintList() string {
   544  	return join(f.Params, func(p *Param) string { return fmt.Sprintf(`"%s=", %s, `, p.Name, p.Name) }, `", ", `)
   545  }
   546  
   547  // ParamCount return number of syscall parameters for function f.
   548  func (f *Fn) ParamCount() int {
   549  	n := 0
   550  	for _, p := range f.Params {
   551  		n += len(p.SyscallArgList())
   552  	}
   553  	return n
   554  }
   555  
   556  // SyscallParamCount determines which version of Syscall/Syscall6/Syscall9/...
   557  // to use. It returns parameter count for correspondent SyscallX function.
   558  func (f *Fn) SyscallParamCount() int {
   559  	n := f.ParamCount()
   560  	switch {
   561  	case n <= 3:
   562  		return 3
   563  	case n <= 6:
   564  		return 6
   565  	case n <= 9:
   566  		return 9
   567  	case n <= 12:
   568  		return 12
   569  	case n <= 15:
   570  		return 15
   571  	default:
   572  		panic("too many arguments to system call")
   573  	}
   574  }
   575  
   576  // Syscall determines which SyscallX function to use for function f.
   577  func (f *Fn) Syscall() string {
   578  	c := f.SyscallParamCount()
   579  	if c == 3 {
   580  		return syscalldot() + "Syscall"
   581  	}
   582  	return syscalldot() + "Syscall" + strconv.Itoa(c)
   583  }
   584  
   585  // SyscallParamList returns source code for SyscallX parameters for function f.
   586  func (f *Fn) SyscallParamList() string {
   587  	a := make([]string, 0)
   588  	for _, p := range f.Params {
   589  		a = append(a, p.SyscallArgList()...)
   590  	}
   591  	for len(a) < f.SyscallParamCount() {
   592  		a = append(a, "0")
   593  	}
   594  	return strings.Join(a, ", ")
   595  }
   596  
   597  // HelperCallParamList returns source code of call into function f helper.
   598  func (f *Fn) HelperCallParamList() string {
   599  	a := make([]string, 0, len(f.Params))
   600  	for _, p := range f.Params {
   601  		s := p.Name
   602  		if p.Type == "string" {
   603  			s = p.tmpVar()
   604  		}
   605  		a = append(a, s)
   606  	}
   607  	return strings.Join(a, ", ")
   608  }
   609  
   610  // MaybeAbsent returns source code for handling functions that are possibly unavailable.
   611  func (p *Fn) MaybeAbsent() string {
   612  	if !p.Rets.fnMaybeAbsent {
   613  		return ""
   614  	}
   615  	const code = `%[1]s = proc%[2]s.Find()
   616  	if %[1]s != nil {
   617  		return
   618  	}`
   619  	errorVar := p.Rets.ErrorVarName()
   620  	if errorVar == "" {
   621  		errorVar = "err"
   622  	}
   623  	return fmt.Sprintf(code, errorVar, p.DLLFuncName())
   624  }
   625  
   626  // IsUTF16 is true, if f is W (utf16) function. It is false
   627  // for all A (ascii) functions.
   628  func (f *Fn) IsUTF16() bool {
   629  	s := f.DLLFuncName()
   630  	return s[len(s)-1] == 'W'
   631  }
   632  
   633  // StrconvFunc returns name of Go string to OS string function for f.
   634  func (f *Fn) StrconvFunc() string {
   635  	if f.IsUTF16() {
   636  		return syscalldot() + "UTF16PtrFromString"
   637  	}
   638  	return syscalldot() + "BytePtrFromString"
   639  }
   640  
   641  // StrconvType returns Go type name used for OS string for f.
   642  func (f *Fn) StrconvType() string {
   643  	if f.IsUTF16() {
   644  		return "*uint16"
   645  	}
   646  	return "*byte"
   647  }
   648  
   649  // HasStringParam is true, if f has at least one string parameter.
   650  // Otherwise it is false.
   651  func (f *Fn) HasStringParam() bool {
   652  	for _, p := range f.Params {
   653  		if p.Type == "string" {
   654  			return true
   655  		}
   656  	}
   657  	return false
   658  }
   659  
   660  // HelperName returns name of function f helper.
   661  func (f *Fn) HelperName() string {
   662  	if !f.HasStringParam() {
   663  		return f.Name
   664  	}
   665  	return "_" + f.Name
   666  }
   667  
   668  // DLL is a DLL's filename and a string that is valid in a Go identifier that should be used when
   669  // naming a variable that refers to the DLL.
   670  type DLL struct {
   671  	Name string
   672  	Var  string
   673  }
   674  
   675  // Source files and functions.
   676  type Source struct {
   677  	Funcs           []*Fn
   678  	DLLFuncNames    []*Fn
   679  	Files           []string
   680  	StdLibImports   []string
   681  	ExternalImports []string
   682  }
   683  
   684  func (src *Source) Import(pkg string) {
   685  	src.StdLibImports = append(src.StdLibImports, pkg)
   686  	sort.Strings(src.StdLibImports)
   687  }
   688  
   689  func (src *Source) ExternalImport(pkg string) {
   690  	src.ExternalImports = append(src.ExternalImports, pkg)
   691  	sort.Strings(src.ExternalImports)
   692  }
   693  
   694  // ParseFiles parses files listed in fs and extracts all syscall
   695  // functions listed in sys comments. It returns source files
   696  // and functions collection *Source if successful.
   697  func ParseFiles(fs []string) (*Source, error) {
   698  	src := &Source{
   699  		Funcs: make([]*Fn, 0),
   700  		Files: make([]string, 0),
   701  		StdLibImports: []string{
   702  			"unsafe",
   703  		},
   704  		ExternalImports: make([]string, 0),
   705  	}
   706  	for _, file := range fs {
   707  		if err := src.ParseFile(file); err != nil {
   708  			return nil, err
   709  		}
   710  	}
   711  	src.DLLFuncNames = make([]*Fn, 0, len(src.Funcs))
   712  	uniq := make(map[string]bool, len(src.Funcs))
   713  	for _, fn := range src.Funcs {
   714  		name := fn.DLLFuncName()
   715  		if !uniq[name] {
   716  			src.DLLFuncNames = append(src.DLLFuncNames, fn)
   717  			uniq[name] = true
   718  		}
   719  	}
   720  	return src, nil
   721  }
   722  
   723  // DLLs return dll names for a source set src.
   724  func (src *Source) DLLs() []DLL {
   725  	uniq := make(map[string]bool)
   726  	r := make([]DLL, 0)
   727  	for _, f := range src.Funcs {
   728  		id := f.DLLVar()
   729  		if _, found := uniq[id]; !found {
   730  			uniq[id] = true
   731  			r = append(r, DLL{f.DLLName(), id})
   732  		}
   733  	}
   734  	sort.Slice(r, func(i, j int) bool {
   735  		return r[i].Var < r[j].Var
   736  	})
   737  	return r
   738  }
   739  
   740  // ParseFile adds additional file path to a source set src.
   741  func (src *Source) ParseFile(path string) error {
   742  	file, err := os.Open(path)
   743  	if err != nil {
   744  		return err
   745  	}
   746  	defer file.Close()
   747  
   748  	s := bufio.NewScanner(file)
   749  	for s.Scan() {
   750  		t := trim(s.Text())
   751  		if len(t) < 7 {
   752  			continue
   753  		}
   754  		if !strings.HasPrefix(t, "//sys") {
   755  			continue
   756  		}
   757  		t = t[5:]
   758  		if !(t[0] == ' ' || t[0] == '\t') {
   759  			continue
   760  		}
   761  		f, err := newFn(t[1:])
   762  		if err != nil {
   763  			return err
   764  		}
   765  		src.Funcs = append(src.Funcs, f)
   766  	}
   767  	if err := s.Err(); err != nil {
   768  		return err
   769  	}
   770  	src.Files = append(src.Files, path)
   771  	sort.Slice(src.Funcs, func(i, j int) bool {
   772  		fi, fj := src.Funcs[i], src.Funcs[j]
   773  		if fi.DLLName() == fj.DLLName() {
   774  			return fi.DLLFuncName() < fj.DLLFuncName()
   775  		}
   776  		return fi.DLLName() < fj.DLLName()
   777  	})
   778  
   779  	// get package name
   780  	fset := token.NewFileSet()
   781  	_, err = file.Seek(0, 0)
   782  	if err != nil {
   783  		return err
   784  	}
   785  	pkg, err := parser.ParseFile(fset, "", file, parser.PackageClauseOnly)
   786  	if err != nil {
   787  		return err
   788  	}
   789  	packageName = pkg.Name.Name
   790  
   791  	return nil
   792  }
   793  
   794  // IsStdRepo reports whether src is part of standard library.
   795  func (src *Source) IsStdRepo() (bool, error) {
   796  	if len(src.Files) == 0 {
   797  		return false, errors.New("no input files provided")
   798  	}
   799  	abspath, err := filepath.Abs(src.Files[0])
   800  	if err != nil {
   801  		return false, err
   802  	}
   803  	goroot := runtime.GOROOT()
   804  	if runtime.GOOS == "windows" {
   805  		abspath = strings.ToLower(abspath)
   806  		goroot = strings.ToLower(goroot)
   807  	}
   808  	sep := string(os.PathSeparator)
   809  	if !strings.HasSuffix(goroot, sep) {
   810  		goroot += sep
   811  	}
   812  	return strings.HasPrefix(abspath, goroot), nil
   813  }
   814  
   815  // Generate output source file from a source set src.
   816  func (src *Source) Generate(w io.Writer) error {
   817  	const (
   818  		pkgStd         = iota // any package in std library
   819  		pkgXSysWindows        // x/sys/windows package
   820  		pkgOther
   821  	)
   822  	isStdRepo, err := src.IsStdRepo()
   823  	if err != nil {
   824  		return err
   825  	}
   826  	var pkgtype int
   827  	switch {
   828  	case isStdRepo:
   829  		pkgtype = pkgStd
   830  	case packageName == "windows":
   831  		// TODO: this needs better logic than just using package name
   832  		pkgtype = pkgXSysWindows
   833  	default:
   834  		pkgtype = pkgOther
   835  	}
   836  	if *systemDLL {
   837  		switch pkgtype {
   838  		case pkgStd:
   839  			src.Import("internal/syscall/windows/sysdll")
   840  		case pkgXSysWindows:
   841  		default:
   842  			src.ExternalImport("golang.org/x/sys/windows")
   843  		}
   844  	}
   845  	if packageName != "syscall" {
   846  		src.Import("syscall")
   847  	}
   848  	funcMap := template.FuncMap{
   849  		"packagename": packagename,
   850  		"syscalldot":  syscalldot,
   851  		"newlazydll": func(dll string) string {
   852  			arg := "\"" + dll + ".dll\""
   853  			if !*systemDLL {
   854  				return syscalldot() + "NewLazyDLL(" + arg + ")"
   855  			}
   856  			switch pkgtype {
   857  			case pkgStd:
   858  				return syscalldot() + "NewLazyDLL(sysdll.Add(" + arg + "))"
   859  			case pkgXSysWindows:
   860  				return "NewLazySystemDLL(" + arg + ")"
   861  			default:
   862  				return "windows.NewLazySystemDLL(" + arg + ")"
   863  			}
   864  		},
   865  	}
   866  	t := template.Must(template.New("main").Funcs(funcMap).Parse(srcTemplate))
   867  	err = t.Execute(w, src)
   868  	if err != nil {
   869  		return errors.New("Failed to execute template: " + err.Error())
   870  	}
   871  	return nil
   872  }
   873  
   874  func writeTempSourceFile(data []byte) (string, error) {
   875  	f, err := os.CreateTemp("", "mkwinsyscall-generated-*.go")
   876  	if err != nil {
   877  		return "", err
   878  	}
   879  	_, err = f.Write(data)
   880  	if closeErr := f.Close(); err == nil {
   881  		err = closeErr
   882  	}
   883  	if err != nil {
   884  		os.Remove(f.Name()) // best effort
   885  		return "", err
   886  	}
   887  	return f.Name(), nil
   888  }
   889  
   890  func usage() {
   891  	fmt.Fprintf(os.Stderr, "usage: mkwinsyscall [flags] [path ...]\n")
   892  	flag.PrintDefaults()
   893  	os.Exit(1)
   894  }
   895  
   896  func main() {
   897  	flag.Usage = usage
   898  	flag.Parse()
   899  	if len(flag.Args()) <= 0 {
   900  		fmt.Fprintf(os.Stderr, "no files to parse provided\n")
   901  		usage()
   902  	}
   903  
   904  	src, err := ParseFiles(flag.Args())
   905  	if err != nil {
   906  		log.Fatal(err)
   907  	}
   908  
   909  	var buf bytes.Buffer
   910  	if err := src.Generate(&buf); err != nil {
   911  		log.Fatal(err)
   912  	}
   913  
   914  	data, err := format.Source(buf.Bytes())
   915  	if err != nil {
   916  		log.Printf("failed to format source: %v", err)
   917  		f, err := writeTempSourceFile(buf.Bytes())
   918  		if err != nil {
   919  			log.Fatalf("failed to write unformatted source to file: %v", err)
   920  		}
   921  		log.Fatalf("for diagnosis, wrote unformatted source to %v", f)
   922  	}
   923  	if *filename == "" {
   924  		_, err = os.Stdout.Write(data)
   925  	} else {
   926  		err = ioutil.WriteFile(*filename, data, 0644)
   927  	}
   928  	if err != nil {
   929  		log.Fatal(err)
   930  	}
   931  }
   932  
   933  // TODO: use println instead to print in the following template
   934  const srcTemplate = `
   935  
   936  {{define "main"}}// Code generated by 'go generate'; DO NOT EDIT.
   937  
   938  package {{packagename}}
   939  
   940  import (
   941  {{range .StdLibImports}}"{{.}}"
   942  {{end}}
   943  
   944  {{range .ExternalImports}}"{{.}}"
   945  {{end}}
   946  )
   947  
   948  var _ unsafe.Pointer
   949  
   950  // Do the interface allocations only once for common
   951  // Errno values.
   952  const (
   953  	errnoERROR_IO_PENDING = 997
   954  )
   955  
   956  var (
   957  	errERROR_IO_PENDING error = {{syscalldot}}Errno(errnoERROR_IO_PENDING)
   958  	errERROR_EINVAL error     = {{syscalldot}}EINVAL
   959  )
   960  
   961  // errnoErr returns common boxed Errno values, to prevent
   962  // allocations at runtime.
   963  func errnoErr(e {{syscalldot}}Errno) error {
   964  	switch e {
   965  	case 0:
   966  		return errERROR_EINVAL
   967  	case errnoERROR_IO_PENDING:
   968  		return errERROR_IO_PENDING
   969  	}
   970  	// TODO: add more here, after collecting data on the common
   971  	// error values see on Windows. (perhaps when running
   972  	// all.bat?)
   973  	return e
   974  }
   975  
   976  var (
   977  {{template "dlls" .}}
   978  {{template "funcnames" .}})
   979  {{range .Funcs}}{{if .HasStringParam}}{{template "helperbody" .}}{{end}}{{template "funcbody" .}}{{end}}
   980  {{end}}
   981  
   982  {{/* help functions */}}
   983  
   984  {{define "dlls"}}{{range .DLLs}}	mod{{.Var}} = {{newlazydll .Name}}
   985  {{end}}{{end}}
   986  
   987  {{define "funcnames"}}{{range .DLLFuncNames}}	proc{{.DLLFuncName}} = mod{{.DLLVar}}.NewProc("{{.DLLFuncName}}")
   988  {{end}}{{end}}
   989  
   990  {{define "helperbody"}}
   991  func {{.Name}}({{.ParamList}}) {{template "results" .}}{
   992  {{template "helpertmpvars" .}}	return {{.HelperName}}({{.HelperCallParamList}})
   993  }
   994  {{end}}
   995  
   996  {{define "funcbody"}}
   997  func {{.HelperName}}({{.HelperParamList}}) {{template "results" .}}{
   998  {{template "maybeabsent" .}}	{{template "tmpvars" .}}	{{template "syscall" .}}	{{template "tmpvarsreadback" .}}
   999  {{template "seterror" .}}{{template "printtrace" .}}	return
  1000  }
  1001  {{end}}
  1002  
  1003  {{define "helpertmpvars"}}{{range .Params}}{{if .TmpVarHelperCode}}	{{.TmpVarHelperCode}}
  1004  {{end}}{{end}}{{end}}
  1005  
  1006  {{define "maybeabsent"}}{{if .MaybeAbsent}}{{.MaybeAbsent}}
  1007  {{end}}{{end}}
  1008  
  1009  {{define "tmpvars"}}{{range .Params}}{{if .TmpVarCode}}	{{.TmpVarCode}}
  1010  {{end}}{{end}}{{end}}
  1011  
  1012  {{define "results"}}{{if .Rets.List}}{{.Rets.List}} {{end}}{{end}}
  1013  
  1014  {{define "syscall"}}{{.Rets.SetReturnValuesCode}}{{.Syscall}}(proc{{.DLLFuncName}}.Addr(), {{.ParamCount}}, {{.SyscallParamList}}){{end}}
  1015  
  1016  {{define "tmpvarsreadback"}}{{range .Params}}{{if .TmpVarReadbackCode}}
  1017  {{.TmpVarReadbackCode}}{{end}}{{end}}{{end}}
  1018  
  1019  {{define "seterror"}}{{if .Rets.SetErrorCode}}	{{.Rets.SetErrorCode}}
  1020  {{end}}{{end}}
  1021  
  1022  {{define "printtrace"}}{{if .PrintTrace}}	print("SYSCALL: {{.Name}}(", {{.ParamPrintList}}") (", {{.Rets.PrintList}}")\n")
  1023  {{end}}{{end}}
  1024  
  1025  `