github.com/etecs-ru/go-sys-wineventlog@v0.0.0-20210227233244-4c3abb794018/windows/mkwinsyscall/mkwinsyscall.go (about)

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  /*
     6  mkwinsyscall generates windows system call bodies
     7  
     8  It parses all files specified on command line containing function
     9  prototypes (like syscall_windows.go) and prints system call bodies
    10  to standard output.
    11  
    12  The prototypes are marked by lines beginning with "//sys" and read
    13  like func declarations if //sys is replaced by func, but:
    14  
    15  * The parameter lists must give a name for each argument. This
    16    includes return parameters.
    17  
    18  * The parameter lists must give a type for each argument:
    19    the (x, y, z int) shorthand is not allowed.
    20  
    21  * If the return parameter is an error number, it must be named err.
    22  
    23  * If go func name needs to be different from its winapi dll name,
    24    the winapi name could be specified at the end, after "=" sign, like
    25    //sys LoadLibrary(libname string) (handle uint32, err error) = LoadLibraryA
    26  
    27  * Each function that returns err needs to supply a condition, that
    28    return value of winapi will be tested against to detect failure.
    29    This would set err to windows "last-error", otherwise it will be nil.
    30    The value can be provided at end of //sys declaration, like
    31    //sys LoadLibrary(libname string) (handle uint32, err error) [failretval==-1] = LoadLibraryA
    32    and is [failretval==0] by default.
    33  
    34  * If the function name ends in a "?", then the function not existing is non-
    35    fatal, and an error will be returned instead of panicking.
    36  
    37  Usage:
    38  	mkwinsyscall [flags] [path ...]
    39  
    40  The flags are:
    41  	-output
    42  		Specify output file name (outputs to console if blank).
    43  	-trace
    44  		Generate print statement after every syscall.
    45  */
    46  package main
    47  
    48  import (
    49  	"bufio"
    50  	"bytes"
    51  	"errors"
    52  	"flag"
    53  	"fmt"
    54  	"go/format"
    55  	"go/parser"
    56  	"go/token"
    57  	"io"
    58  	"io/ioutil"
    59  	"log"
    60  	"os"
    61  	"path/filepath"
    62  	"runtime"
    63  	"sort"
    64  	"strconv"
    65  	"strings"
    66  	"text/template"
    67  )
    68  
    69  var (
    70  	filename       = flag.String("output", "", "output file name (standard output if omitted)")
    71  	printTraceFlag = flag.Bool("trace", false, "generate print statement after every syscall")
    72  	systemDLL      = flag.Bool("systemdll", true, "whether all DLLs should be loaded from the Windows system directory")
    73  )
    74  
    75  func trim(s string) string {
    76  	return strings.Trim(s, " \t")
    77  }
    78  
    79  var packageName string
    80  
    81  func packagename() string {
    82  	return packageName
    83  }
    84  
    85  func syscalldot() string {
    86  	if packageName == "syscall" {
    87  		return ""
    88  	}
    89  	return "syscall."
    90  }
    91  
    92  // Param is function parameter
    93  type Param struct {
    94  	Name      string
    95  	Type      string
    96  	fn        *Fn
    97  	tmpVarIdx int
    98  }
    99  
   100  // tmpVar returns temp variable name that will be used to represent p during syscall.
   101  func (p *Param) tmpVar() string {
   102  	if p.tmpVarIdx < 0 {
   103  		p.tmpVarIdx = p.fn.curTmpVarIdx
   104  		p.fn.curTmpVarIdx++
   105  	}
   106  	return fmt.Sprintf("_p%d", p.tmpVarIdx)
   107  }
   108  
   109  // BoolTmpVarCode returns source code for bool temp variable.
   110  func (p *Param) BoolTmpVarCode() string {
   111  	const code = `var %[1]s uint32
   112  	if %[2]s {
   113  		%[1]s = 1
   114  	}`
   115  	return fmt.Sprintf(code, p.tmpVar(), p.Name)
   116  }
   117  
   118  // BoolPointerTmpVarCode returns source code for bool temp variable.
   119  func (p *Param) BoolPointerTmpVarCode() string {
   120  	const code = `var %[1]s uint32
   121  	if *%[2]s {
   122  		%[1]s = 1
   123  	}`
   124  	return fmt.Sprintf(code, p.tmpVar(), p.Name)
   125  }
   126  
   127  // SliceTmpVarCode returns source code for slice temp variable.
   128  func (p *Param) SliceTmpVarCode() string {
   129  	const code = `var %s *%s
   130  	if len(%s) > 0 {
   131  		%s = &%s[0]
   132  	}`
   133  	tmp := p.tmpVar()
   134  	return fmt.Sprintf(code, tmp, p.Type[2:], p.Name, tmp, p.Name)
   135  }
   136  
   137  // StringTmpVarCode returns source code for string temp variable.
   138  func (p *Param) StringTmpVarCode() string {
   139  	errvar := p.fn.Rets.ErrorVarName()
   140  	if errvar == "" {
   141  		errvar = "_"
   142  	}
   143  	tmp := p.tmpVar()
   144  	const code = `var %s %s
   145  	%s, %s = %s(%s)`
   146  	s := fmt.Sprintf(code, tmp, p.fn.StrconvType(), tmp, errvar, p.fn.StrconvFunc(), p.Name)
   147  	if errvar == "-" {
   148  		return s
   149  	}
   150  	const morecode = `
   151  	if %s != nil {
   152  		return
   153  	}`
   154  	return s + fmt.Sprintf(morecode, errvar)
   155  }
   156  
   157  // TmpVarCode returns source code for temp variable.
   158  func (p *Param) TmpVarCode() string {
   159  	switch {
   160  	case p.Type == "bool":
   161  		return p.BoolTmpVarCode()
   162  	case p.Type == "*bool":
   163  		return p.BoolPointerTmpVarCode()
   164  	case strings.HasPrefix(p.Type, "[]"):
   165  		return p.SliceTmpVarCode()
   166  	default:
   167  		return ""
   168  	}
   169  }
   170  
   171  // TmpVarReadbackCode returns source code for reading back the temp variable into the original variable.
   172  func (p *Param) TmpVarReadbackCode() string {
   173  	switch {
   174  	case p.Type == "*bool":
   175  		return fmt.Sprintf("*%s = %s != 0", p.Name, p.tmpVar())
   176  	default:
   177  		return ""
   178  	}
   179  }
   180  
   181  // TmpVarHelperCode returns source code for helper's temp variable.
   182  func (p *Param) TmpVarHelperCode() string {
   183  	if p.Type != "string" {
   184  		return ""
   185  	}
   186  	return p.StringTmpVarCode()
   187  }
   188  
   189  // SyscallArgList returns source code fragments representing p parameter
   190  // in syscall. Slices are translated into 2 syscall parameters: pointer to
   191  // the first element and length.
   192  func (p *Param) SyscallArgList() []string {
   193  	t := p.HelperType()
   194  	var s string
   195  	switch {
   196  	case t == "*bool":
   197  		s = fmt.Sprintf("unsafe.Pointer(&%s)", p.tmpVar())
   198  	case t[0] == '*':
   199  		s = fmt.Sprintf("unsafe.Pointer(%s)", p.Name)
   200  	case t == "bool":
   201  		s = p.tmpVar()
   202  	case strings.HasPrefix(t, "[]"):
   203  		return []string{
   204  			fmt.Sprintf("uintptr(unsafe.Pointer(%s))", p.tmpVar()),
   205  			fmt.Sprintf("uintptr(len(%s))", p.Name),
   206  		}
   207  	default:
   208  		s = p.Name
   209  	}
   210  	return []string{fmt.Sprintf("uintptr(%s)", s)}
   211  }
   212  
   213  // IsError determines if p parameter is used to return error.
   214  func (p *Param) IsError() bool {
   215  	return p.Name == "err" && p.Type == "error"
   216  }
   217  
   218  // HelperType returns type of parameter p used in helper function.
   219  func (p *Param) HelperType() string {
   220  	if p.Type == "string" {
   221  		return p.fn.StrconvType()
   222  	}
   223  	return p.Type
   224  }
   225  
   226  // join concatenates parameters ps into a string with sep separator.
   227  // Each parameter is converted into string by applying fn to it
   228  // before conversion.
   229  func join(ps []*Param, fn func(*Param) string, sep string) string {
   230  	if len(ps) == 0 {
   231  		return ""
   232  	}
   233  	a := make([]string, 0)
   234  	for _, p := range ps {
   235  		a = append(a, fn(p))
   236  	}
   237  	return strings.Join(a, sep)
   238  }
   239  
   240  // Rets describes function return parameters.
   241  type Rets struct {
   242  	Name          string
   243  	Type          string
   244  	ReturnsError  bool
   245  	FailCond      string
   246  	fnMaybeAbsent bool
   247  }
   248  
   249  // ErrorVarName returns error variable name for r.
   250  func (r *Rets) ErrorVarName() string {
   251  	if r.ReturnsError {
   252  		return "err"
   253  	}
   254  	if r.Type == "error" {
   255  		return r.Name
   256  	}
   257  	return ""
   258  }
   259  
   260  // ToParams converts r into slice of *Param.
   261  func (r *Rets) ToParams() []*Param {
   262  	ps := make([]*Param, 0)
   263  	if len(r.Name) > 0 {
   264  		ps = append(ps, &Param{Name: r.Name, Type: r.Type})
   265  	}
   266  	if r.ReturnsError {
   267  		ps = append(ps, &Param{Name: "err", Type: "error"})
   268  	}
   269  	return ps
   270  }
   271  
   272  // List returns source code of syscall return parameters.
   273  func (r *Rets) List() string {
   274  	s := join(r.ToParams(), func(p *Param) string { return p.Name + " " + p.Type }, ", ")
   275  	if len(s) > 0 {
   276  		s = "(" + s + ")"
   277  	} else if r.fnMaybeAbsent {
   278  		s = "(err error)"
   279  	}
   280  	return s
   281  }
   282  
   283  // PrintList returns source code of trace printing part correspondent
   284  // to syscall return values.
   285  func (r *Rets) PrintList() string {
   286  	return join(r.ToParams(), func(p *Param) string { return fmt.Sprintf(`"%s=", %s, `, p.Name, p.Name) }, `", ", `)
   287  }
   288  
   289  // SetReturnValuesCode returns source code that accepts syscall return values.
   290  func (r *Rets) SetReturnValuesCode() string {
   291  	if r.Name == "" && !r.ReturnsError {
   292  		return ""
   293  	}
   294  	retvar := "r0"
   295  	if r.Name == "" {
   296  		retvar = "r1"
   297  	}
   298  	errvar := "_"
   299  	if r.ReturnsError {
   300  		errvar = "e1"
   301  	}
   302  	return fmt.Sprintf("%s, _, %s := ", retvar, errvar)
   303  }
   304  
   305  func (r *Rets) useLongHandleErrorCode(retvar string) string {
   306  	const code = `if %s {
   307  		err = errnoErr(e1)
   308  	}`
   309  	cond := retvar + " == 0"
   310  	if r.FailCond != "" {
   311  		cond = strings.Replace(r.FailCond, "failretval", retvar, 1)
   312  	}
   313  	return fmt.Sprintf(code, cond)
   314  }
   315  
   316  // SetErrorCode returns source code that sets return parameters.
   317  func (r *Rets) SetErrorCode() string {
   318  	const code = `if r0 != 0 {
   319  		%s = %sErrno(r0)
   320  	}`
   321  	if r.Name == "" && !r.ReturnsError {
   322  		return ""
   323  	}
   324  	if r.Name == "" {
   325  		return r.useLongHandleErrorCode("r1")
   326  	}
   327  	if r.Type == "error" {
   328  		return fmt.Sprintf(code, r.Name, syscalldot())
   329  	}
   330  	s := ""
   331  	switch {
   332  	case r.Type[0] == '*':
   333  		s = fmt.Sprintf("%s = (%s)(unsafe.Pointer(r0))", r.Name, r.Type)
   334  	case r.Type == "bool":
   335  		s = fmt.Sprintf("%s = r0 != 0", r.Name)
   336  	default:
   337  		s = fmt.Sprintf("%s = %s(r0)", r.Name, r.Type)
   338  	}
   339  	if !r.ReturnsError {
   340  		return s
   341  	}
   342  	return s + "\n\t" + r.useLongHandleErrorCode(r.Name)
   343  }
   344  
   345  // Fn describes syscall function.
   346  type Fn struct {
   347  	Name        string
   348  	Params      []*Param
   349  	Rets        *Rets
   350  	PrintTrace  bool
   351  	dllname     string
   352  	dllfuncname string
   353  	src         string
   354  	// TODO: get rid of this field and just use parameter index instead
   355  	curTmpVarIdx int // insure tmp variables have uniq names
   356  }
   357  
   358  // extractParams parses s to extract function parameters.
   359  func extractParams(s string, f *Fn) ([]*Param, error) {
   360  	s = trim(s)
   361  	if s == "" {
   362  		return nil, nil
   363  	}
   364  	a := strings.Split(s, ",")
   365  	ps := make([]*Param, len(a))
   366  	for i := range ps {
   367  		s2 := trim(a[i])
   368  		b := strings.Split(s2, " ")
   369  		if len(b) != 2 {
   370  			b = strings.Split(s2, "\t")
   371  			if len(b) != 2 {
   372  				return nil, errors.New("Could not extract function parameter from \"" + s2 + "\"")
   373  			}
   374  		}
   375  		ps[i] = &Param{
   376  			Name:      trim(b[0]),
   377  			Type:      trim(b[1]),
   378  			fn:        f,
   379  			tmpVarIdx: -1,
   380  		}
   381  	}
   382  	return ps, nil
   383  }
   384  
   385  // extractSection extracts text out of string s starting after start
   386  // and ending just before end. found return value will indicate success,
   387  // and prefix, body and suffix will contain correspondent parts of string s.
   388  func extractSection(s string, start, end rune) (prefix, body, suffix string, found bool) {
   389  	s = trim(s)
   390  	if strings.HasPrefix(s, string(start)) {
   391  		// no prefix
   392  		body = s[1:]
   393  	} else {
   394  		a := strings.SplitN(s, string(start), 2)
   395  		if len(a) != 2 {
   396  			return "", "", s, false
   397  		}
   398  		prefix = a[0]
   399  		body = a[1]
   400  	}
   401  	a := strings.SplitN(body, string(end), 2)
   402  	if len(a) != 2 {
   403  		return "", "", "", false
   404  	}
   405  	return prefix, a[0], a[1], true
   406  }
   407  
   408  // newFn parses string s and return created function Fn.
   409  func newFn(s string) (*Fn, error) {
   410  	s = trim(s)
   411  	f := &Fn{
   412  		Rets:       &Rets{},
   413  		src:        s,
   414  		PrintTrace: *printTraceFlag,
   415  	}
   416  	// function name and args
   417  	prefix, body, s, found := extractSection(s, '(', ')')
   418  	if !found || prefix == "" {
   419  		return nil, errors.New("Could not extract function name and parameters from \"" + f.src + "\"")
   420  	}
   421  	f.Name = prefix
   422  	var err error
   423  	f.Params, err = extractParams(body, f)
   424  	if err != nil {
   425  		return nil, err
   426  	}
   427  	// return values
   428  	_, body, s, found = extractSection(s, '(', ')')
   429  	if found {
   430  		r, err := extractParams(body, f)
   431  		if err != nil {
   432  			return nil, err
   433  		}
   434  		switch len(r) {
   435  		case 0:
   436  		case 1:
   437  			if r[0].IsError() {
   438  				f.Rets.ReturnsError = true
   439  			} else {
   440  				f.Rets.Name = r[0].Name
   441  				f.Rets.Type = r[0].Type
   442  			}
   443  		case 2:
   444  			if !r[1].IsError() {
   445  				return nil, errors.New("Only last windows error is allowed as second return value in \"" + f.src + "\"")
   446  			}
   447  			f.Rets.ReturnsError = true
   448  			f.Rets.Name = r[0].Name
   449  			f.Rets.Type = r[0].Type
   450  		default:
   451  			return nil, errors.New("Too many return values in \"" + f.src + "\"")
   452  		}
   453  	}
   454  	// fail condition
   455  	_, body, s, found = extractSection(s, '[', ']')
   456  	if found {
   457  		f.Rets.FailCond = body
   458  	}
   459  	// dll and dll function names
   460  	s = trim(s)
   461  	if s == "" {
   462  		return f, nil
   463  	}
   464  	if !strings.HasPrefix(s, "=") {
   465  		return nil, errors.New("Could not extract dll name from \"" + f.src + "\"")
   466  	}
   467  	s = trim(s[1:])
   468  	a := strings.Split(s, ".")
   469  	switch len(a) {
   470  	case 1:
   471  		f.dllfuncname = a[0]
   472  	case 2:
   473  		f.dllname = a[0]
   474  		f.dllfuncname = a[1]
   475  	default:
   476  		return nil, errors.New("Could not extract dll name from \"" + f.src + "\"")
   477  	}
   478  	if n := f.dllfuncname; strings.HasSuffix(n, "?") {
   479  		f.dllfuncname = n[:len(n)-1]
   480  		f.Rets.fnMaybeAbsent = true
   481  	}
   482  	return f, nil
   483  }
   484  
   485  // DLLName returns DLL name for function f.
   486  func (f *Fn) DLLName() string {
   487  	if f.dllname == "" {
   488  		return "kernel32"
   489  	}
   490  	return f.dllname
   491  }
   492  
   493  // DLLName returns DLL function name for function f.
   494  func (f *Fn) DLLFuncName() string {
   495  	if f.dllfuncname == "" {
   496  		return f.Name
   497  	}
   498  	return f.dllfuncname
   499  }
   500  
   501  // ParamList returns source code for function f parameters.
   502  func (f *Fn) ParamList() string {
   503  	return join(f.Params, func(p *Param) string { return p.Name + " " + p.Type }, ", ")
   504  }
   505  
   506  // HelperParamList returns source code for helper function f parameters.
   507  func (f *Fn) HelperParamList() string {
   508  	return join(f.Params, func(p *Param) string { return p.Name + " " + p.HelperType() }, ", ")
   509  }
   510  
   511  // ParamPrintList returns source code of trace printing part correspondent
   512  // to syscall input parameters.
   513  func (f *Fn) ParamPrintList() string {
   514  	return join(f.Params, func(p *Param) string { return fmt.Sprintf(`"%s=", %s, `, p.Name, p.Name) }, `", ", `)
   515  }
   516  
   517  // ParamCount return number of syscall parameters for function f.
   518  func (f *Fn) ParamCount() int {
   519  	n := 0
   520  	for _, p := range f.Params {
   521  		n += len(p.SyscallArgList())
   522  	}
   523  	return n
   524  }
   525  
   526  // SyscallParamCount determines which version of Syscall/Syscall6/Syscall9/...
   527  // to use. It returns parameter count for correspondent SyscallX function.
   528  func (f *Fn) SyscallParamCount() int {
   529  	n := f.ParamCount()
   530  	switch {
   531  	case n <= 3:
   532  		return 3
   533  	case n <= 6:
   534  		return 6
   535  	case n <= 9:
   536  		return 9
   537  	case n <= 12:
   538  		return 12
   539  	case n <= 15:
   540  		return 15
   541  	default:
   542  		panic("too many arguments to system call")
   543  	}
   544  }
   545  
   546  // Syscall determines which SyscallX function to use for function f.
   547  func (f *Fn) Syscall() string {
   548  	c := f.SyscallParamCount()
   549  	if c == 3 {
   550  		return syscalldot() + "Syscall"
   551  	}
   552  	return syscalldot() + "Syscall" + strconv.Itoa(c)
   553  }
   554  
   555  // SyscallParamList returns source code for SyscallX parameters for function f.
   556  func (f *Fn) SyscallParamList() string {
   557  	a := make([]string, 0)
   558  	for _, p := range f.Params {
   559  		a = append(a, p.SyscallArgList()...)
   560  	}
   561  	for len(a) < f.SyscallParamCount() {
   562  		a = append(a, "0")
   563  	}
   564  	return strings.Join(a, ", ")
   565  }
   566  
   567  // HelperCallParamList returns source code of call into function f helper.
   568  func (f *Fn) HelperCallParamList() string {
   569  	a := make([]string, 0, len(f.Params))
   570  	for _, p := range f.Params {
   571  		s := p.Name
   572  		if p.Type == "string" {
   573  			s = p.tmpVar()
   574  		}
   575  		a = append(a, s)
   576  	}
   577  	return strings.Join(a, ", ")
   578  }
   579  
   580  // MaybeAbsent returns source code for handling functions that are possibly unavailable.
   581  func (p *Fn) MaybeAbsent() string {
   582  	if !p.Rets.fnMaybeAbsent {
   583  		return ""
   584  	}
   585  	const code = `%[1]s = proc%[2]s.Find()
   586  	if %[1]s != nil {
   587  		return
   588  	}`
   589  	errorVar := p.Rets.ErrorVarName()
   590  	if errorVar == "" {
   591  		errorVar = "err"
   592  	}
   593  	return fmt.Sprintf(code, errorVar, p.DLLFuncName())
   594  }
   595  
   596  // IsUTF16 is true, if f is W (utf16) function. It is false
   597  // for all A (ascii) functions.
   598  func (f *Fn) IsUTF16() bool {
   599  	s := f.DLLFuncName()
   600  	return s[len(s)-1] == 'W'
   601  }
   602  
   603  // StrconvFunc returns name of Go string to OS string function for f.
   604  func (f *Fn) StrconvFunc() string {
   605  	if f.IsUTF16() {
   606  		return syscalldot() + "UTF16PtrFromString"
   607  	}
   608  	return syscalldot() + "BytePtrFromString"
   609  }
   610  
   611  // StrconvType returns Go type name used for OS string for f.
   612  func (f *Fn) StrconvType() string {
   613  	if f.IsUTF16() {
   614  		return "*uint16"
   615  	}
   616  	return "*byte"
   617  }
   618  
   619  // HasStringParam is true, if f has at least one string parameter.
   620  // Otherwise it is false.
   621  func (f *Fn) HasStringParam() bool {
   622  	for _, p := range f.Params {
   623  		if p.Type == "string" {
   624  			return true
   625  		}
   626  	}
   627  	return false
   628  }
   629  
   630  // HelperName returns name of function f helper.
   631  func (f *Fn) HelperName() string {
   632  	if !f.HasStringParam() {
   633  		return f.Name
   634  	}
   635  	return "_" + f.Name
   636  }
   637  
   638  // Source files and functions.
   639  type Source struct {
   640  	Funcs           []*Fn
   641  	Files           []string
   642  	StdLibImports   []string
   643  	ExternalImports []string
   644  }
   645  
   646  func (src *Source) Import(pkg string) {
   647  	src.StdLibImports = append(src.StdLibImports, pkg)
   648  	sort.Strings(src.StdLibImports)
   649  }
   650  
   651  func (src *Source) ExternalImport(pkg string) {
   652  	src.ExternalImports = append(src.ExternalImports, pkg)
   653  	sort.Strings(src.ExternalImports)
   654  }
   655  
   656  // ParseFiles parses files listed in fs and extracts all syscall
   657  // functions listed in sys comments. It returns source files
   658  // and functions collection *Source if successful.
   659  func ParseFiles(fs []string) (*Source, error) {
   660  	src := &Source{
   661  		Funcs: make([]*Fn, 0),
   662  		Files: make([]string, 0),
   663  		StdLibImports: []string{
   664  			"unsafe",
   665  		},
   666  		ExternalImports: make([]string, 0),
   667  	}
   668  	for _, file := range fs {
   669  		if err := src.ParseFile(file); err != nil {
   670  			return nil, err
   671  		}
   672  	}
   673  	return src, nil
   674  }
   675  
   676  // DLLs return dll names for a source set src.
   677  func (src *Source) DLLs() []string {
   678  	uniq := make(map[string]bool)
   679  	r := make([]string, 0)
   680  	for _, f := range src.Funcs {
   681  		name := f.DLLName()
   682  		if _, found := uniq[name]; !found {
   683  			uniq[name] = true
   684  			r = append(r, name)
   685  		}
   686  	}
   687  	sort.Strings(r)
   688  	return r
   689  }
   690  
   691  // ParseFile adds additional file path to a source set src.
   692  func (src *Source) ParseFile(path string) error {
   693  	file, err := os.Open(path)
   694  	if err != nil {
   695  		return err
   696  	}
   697  	defer file.Close()
   698  
   699  	s := bufio.NewScanner(file)
   700  	for s.Scan() {
   701  		t := trim(s.Text())
   702  		if len(t) < 7 {
   703  			continue
   704  		}
   705  		if !strings.HasPrefix(t, "//sys") {
   706  			continue
   707  		}
   708  		t = t[5:]
   709  		if !(t[0] == ' ' || t[0] == '\t') {
   710  			continue
   711  		}
   712  		f, err := newFn(t[1:])
   713  		if err != nil {
   714  			return err
   715  		}
   716  		src.Funcs = append(src.Funcs, f)
   717  	}
   718  	if err := s.Err(); err != nil {
   719  		return err
   720  	}
   721  	src.Files = append(src.Files, path)
   722  	sort.Slice(src.Funcs, func(i, j int) bool {
   723  		fi, fj := src.Funcs[i], src.Funcs[j]
   724  		if fi.DLLName() == fj.DLLName() {
   725  			return fi.DLLFuncName() < fj.DLLFuncName()
   726  		}
   727  		return fi.DLLName() < fj.DLLName()
   728  	})
   729  
   730  	// get package name
   731  	fset := token.NewFileSet()
   732  	_, err = file.Seek(0, 0)
   733  	if err != nil {
   734  		return err
   735  	}
   736  	pkg, err := parser.ParseFile(fset, "", file, parser.PackageClauseOnly)
   737  	if err != nil {
   738  		return err
   739  	}
   740  	packageName = pkg.Name.Name
   741  
   742  	return nil
   743  }
   744  
   745  // IsStdRepo reports whether src is part of standard library.
   746  func (src *Source) IsStdRepo() (bool, error) {
   747  	if len(src.Files) == 0 {
   748  		return false, errors.New("no input files provided")
   749  	}
   750  	abspath, err := filepath.Abs(src.Files[0])
   751  	if err != nil {
   752  		return false, err
   753  	}
   754  	goroot := runtime.GOROOT()
   755  	if runtime.GOOS == "windows" {
   756  		abspath = strings.ToLower(abspath)
   757  		goroot = strings.ToLower(goroot)
   758  	}
   759  	sep := string(os.PathSeparator)
   760  	if !strings.HasSuffix(goroot, sep) {
   761  		goroot += sep
   762  	}
   763  	return strings.HasPrefix(abspath, goroot), nil
   764  }
   765  
   766  // Generate output source file from a source set src.
   767  func (src *Source) Generate(w io.Writer) error {
   768  	const (
   769  		pkgStd         = iota // any package in std library
   770  		pkgXSysWindows        // x/sys/windows package
   771  		pkgOther
   772  	)
   773  	isStdRepo, err := src.IsStdRepo()
   774  	if err != nil {
   775  		return err
   776  	}
   777  	var pkgtype int
   778  	switch {
   779  	case isStdRepo:
   780  		pkgtype = pkgStd
   781  	case packageName == "windows":
   782  		// TODO: this needs better logic than just using package name
   783  		pkgtype = pkgXSysWindows
   784  	default:
   785  		pkgtype = pkgOther
   786  	}
   787  	if *systemDLL {
   788  		switch pkgtype {
   789  		case pkgStd:
   790  			src.Import("internal/syscall/windows/sysdll")
   791  		case pkgXSysWindows:
   792  		default:
   793  			src.ExternalImport("golang.org/x/sys/windows")
   794  		}
   795  	}
   796  	if packageName != "syscall" {
   797  		src.Import("syscall")
   798  	}
   799  	funcMap := template.FuncMap{
   800  		"packagename": packagename,
   801  		"syscalldot":  syscalldot,
   802  		"newlazydll": func(dll string) string {
   803  			arg := "\"" + dll + ".dll\""
   804  			if !*systemDLL {
   805  				return syscalldot() + "NewLazyDLL(" + arg + ")"
   806  			}
   807  			switch pkgtype {
   808  			case pkgStd:
   809  				return syscalldot() + "NewLazyDLL(sysdll.Add(" + arg + "))"
   810  			case pkgXSysWindows:
   811  				return "NewLazySystemDLL(" + arg + ")"
   812  			default:
   813  				return "windows.NewLazySystemDLL(" + arg + ")"
   814  			}
   815  		},
   816  	}
   817  	t := template.Must(template.New("main").Funcs(funcMap).Parse(srcTemplate))
   818  	err = t.Execute(w, src)
   819  	if err != nil {
   820  		return errors.New("Failed to execute template: " + err.Error())
   821  	}
   822  	return nil
   823  }
   824  
   825  func usage() {
   826  	fmt.Fprintf(os.Stderr, "usage: mkwinsyscall [flags] [path ...]\n")
   827  	flag.PrintDefaults()
   828  	os.Exit(1)
   829  }
   830  
   831  func main() {
   832  	flag.Usage = usage
   833  	flag.Parse()
   834  	if len(flag.Args()) <= 0 {
   835  		fmt.Fprintf(os.Stderr, "no files to parse provided\n")
   836  		usage()
   837  	}
   838  
   839  	src, err := ParseFiles(flag.Args())
   840  	if err != nil {
   841  		log.Fatal(err)
   842  	}
   843  
   844  	var buf bytes.Buffer
   845  	if err := src.Generate(&buf); err != nil {
   846  		log.Fatal(err)
   847  	}
   848  
   849  	data, err := format.Source(buf.Bytes())
   850  	if err != nil {
   851  		log.Fatal(err)
   852  	}
   853  	if *filename == "" {
   854  		_, err = os.Stdout.Write(data)
   855  	} else {
   856  		err = ioutil.WriteFile(*filename, data, 0644)
   857  	}
   858  	if err != nil {
   859  		log.Fatal(err)
   860  	}
   861  }
   862  
   863  // TODO: use println instead to print in the following template
   864  const srcTemplate = `
   865  
   866  {{define "main"}}// Code generated by 'go generate'; DO NOT EDIT.
   867  
   868  package {{packagename}}
   869  
   870  import (
   871  {{range .StdLibImports}}"{{.}}"
   872  {{end}}
   873  
   874  {{range .ExternalImports}}"{{.}}"
   875  {{end}}
   876  )
   877  
   878  var _ unsafe.Pointer
   879  
   880  // Do the interface allocations only once for common
   881  // Errno values.
   882  const (
   883  	errnoERROR_IO_PENDING = 997
   884  )
   885  
   886  var (
   887  	errERROR_IO_PENDING error = {{syscalldot}}Errno(errnoERROR_IO_PENDING)
   888  	errERROR_EINVAL error     = {{syscalldot}}EINVAL
   889  )
   890  
   891  // errnoErr returns common boxed Errno values, to prevent
   892  // allocations at runtime.
   893  func errnoErr(e {{syscalldot}}Errno) error {
   894  	switch e {
   895  	case 0:
   896  		return errERROR_EINVAL
   897  	case errnoERROR_IO_PENDING:
   898  		return errERROR_IO_PENDING
   899  	}
   900  	// TODO: add more here, after collecting data on the common
   901  	// error values see on Windows. (perhaps when running
   902  	// all.bat?)
   903  	return e
   904  }
   905  
   906  var (
   907  {{template "dlls" .}}
   908  {{template "funcnames" .}})
   909  {{range .Funcs}}{{if .HasStringParam}}{{template "helperbody" .}}{{end}}{{template "funcbody" .}}{{end}}
   910  {{end}}
   911  
   912  {{/* help functions */}}
   913  
   914  {{define "dlls"}}{{range .DLLs}}	mod{{.}} = {{newlazydll .}}
   915  {{end}}{{end}}
   916  
   917  {{define "funcnames"}}{{range .Funcs}}	proc{{.DLLFuncName}} = mod{{.DLLName}}.NewProc("{{.DLLFuncName}}")
   918  {{end}}{{end}}
   919  
   920  {{define "helperbody"}}
   921  func {{.Name}}({{.ParamList}}) {{template "results" .}}{
   922  {{template "helpertmpvars" .}}	return {{.HelperName}}({{.HelperCallParamList}})
   923  }
   924  {{end}}
   925  
   926  {{define "funcbody"}}
   927  func {{.HelperName}}({{.HelperParamList}}) {{template "results" .}}{
   928  {{template "maybeabsent" .}}	{{template "tmpvars" .}}	{{template "syscall" .}}	{{template "tmpvarsreadback" .}}
   929  {{template "seterror" .}}{{template "printtrace" .}}	return
   930  }
   931  {{end}}
   932  
   933  {{define "helpertmpvars"}}{{range .Params}}{{if .TmpVarHelperCode}}	{{.TmpVarHelperCode}}
   934  {{end}}{{end}}{{end}}
   935  
   936  {{define "maybeabsent"}}{{if .MaybeAbsent}}{{.MaybeAbsent}}
   937  {{end}}{{end}}
   938  
   939  {{define "tmpvars"}}{{range .Params}}{{if .TmpVarCode}}	{{.TmpVarCode}}
   940  {{end}}{{end}}{{end}}
   941  
   942  {{define "results"}}{{if .Rets.List}}{{.Rets.List}} {{end}}{{end}}
   943  
   944  {{define "syscall"}}{{.Rets.SetReturnValuesCode}}{{.Syscall}}(proc{{.DLLFuncName}}.Addr(), {{.ParamCount}}, {{.SyscallParamList}}){{end}}
   945  
   946  {{define "tmpvarsreadback"}}{{range .Params}}{{if .TmpVarReadbackCode}}
   947  {{.TmpVarReadbackCode}}{{end}}{{end}}{{end}}
   948  
   949  {{define "seterror"}}{{if .Rets.SetErrorCode}}	{{.Rets.SetErrorCode}}
   950  {{end}}{{end}}
   951  
   952  {{define "printtrace"}}{{if .PrintTrace}}	print("SYSCALL: {{.Name}}(", {{.ParamPrintList}}") (", {{.Rets.PrintList}}")\n")
   953  {{end}}{{end}}
   954  
   955  `