github.com/ader1990/go@v0.0.0-20140630135419-8c24447fa791/src/pkg/syscall/mksyscall_windows.go (about)

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // +build ignore
     6  
     7  /*
     8  mksyscall_windows generates windows system call bodies
     9  
    10  It parses all files specified on command line containing function
    11  prototypes (like syscall_windows.go) and prints system call bodies
    12  to standard output.
    13  
    14  The prototypes are marked by lines beginning with "//sys" and read
    15  like func declarations if //sys is replaced by func, but:
    16  
    17  * The parameter lists must give a name for each argument. This
    18    includes return parameters.
    19  
    20  * The parameter lists must give a type for each argument:
    21    the (x, y, z int) shorthand is not allowed.
    22  
    23  * If the return parameter is an error number, it must be named err.
    24  
    25  * If go func name needs to be different from it's winapi dll name,
    26    the winapi name could be specified at the end, after "=" sign, like
    27    //sys LoadLibrary(libname string) (handle uint32, err error) = LoadLibraryA
    28  
    29  * Each function that returns err needs to supply a condition, that
    30    return value of winapi will be tested against to detect failure.
    31    This would set err to windows "last-error", otherwise it will be nil.
    32    The value can be provided at end of //sys declaration, like
    33    //sys LoadLibrary(libname string) (handle uint32, err error) [failretval==-1] = LoadLibraryA
    34    and is [failretval==0] by default.
    35  
    36  Usage:
    37  	mksyscall_windows [flags] [path ...]
    38  
    39  The flags are:
    40  	-trace
    41  		Generate print statement after every syscall.
    42  */
    43  package main
    44  
    45  import (
    46  	"bufio"
    47  	"errors"
    48  	"flag"
    49  	"fmt"
    50  	"io"
    51  	"log"
    52  	"os"
    53  	"strconv"
    54  	"strings"
    55  	"text/template"
    56  )
    57  
    58  var PrintTraceFlag = flag.Bool("trace", false, "generate print statement after every syscall")
    59  
    60  func trim(s string) string {
    61  	return strings.Trim(s, " \t")
    62  }
    63  
    64  // Param is function parameter
    65  type Param struct {
    66  	Name      string
    67  	Type      string
    68  	fn        *Fn
    69  	tmpVarIdx int
    70  }
    71  
    72  // tmpVar returns temp variable name that will be used to represent p during syscall.
    73  func (p *Param) tmpVar() string {
    74  	if p.tmpVarIdx < 0 {
    75  		p.tmpVarIdx = p.fn.curTmpVarIdx
    76  		p.fn.curTmpVarIdx++
    77  	}
    78  	return fmt.Sprintf("_p%d", p.tmpVarIdx)
    79  }
    80  
    81  // BoolTmpVarCode returns source code for bool temp variable.
    82  func (p *Param) BoolTmpVarCode() string {
    83  	const code = `var %s uint32
    84  	if %s {
    85  		%s = 1
    86  	} else {
    87  		%s = 0
    88  	}`
    89  	tmp := p.tmpVar()
    90  	return fmt.Sprintf(code, tmp, p.Name, tmp, tmp)
    91  }
    92  
    93  // SliceTmpVarCode returns source code for slice temp variable.
    94  func (p *Param) SliceTmpVarCode() string {
    95  	const code = `var %s *%s
    96  	if len(%s) > 0 {
    97  		%s = &%s[0]
    98  	}`
    99  	tmp := p.tmpVar()
   100  	return fmt.Sprintf(code, tmp, p.Type[2:], p.Name, tmp, p.Name)
   101  }
   102  
   103  // StringTmpVarCode returns source code for string temp variable.
   104  func (p *Param) StringTmpVarCode() string {
   105  	errvar := p.fn.Rets.ErrorVarName()
   106  	if errvar == "" {
   107  		errvar = "_"
   108  	}
   109  	tmp := p.tmpVar()
   110  	const code = `var %s %s
   111  	%s, %s = %s(%s)`
   112  	s := fmt.Sprintf(code, tmp, p.fn.StrconvType(), tmp, errvar, p.fn.StrconvFunc(), p.Name)
   113  	if errvar == "-" {
   114  		return s
   115  	}
   116  	const morecode = `
   117  	if %s != nil {
   118  		return
   119  	}`
   120  	return s + fmt.Sprintf(morecode, errvar)
   121  }
   122  
   123  // TmpVarCode returns source code for temp variable.
   124  func (p *Param) TmpVarCode() string {
   125  	switch {
   126  	case p.Type == "string":
   127  		return p.StringTmpVarCode()
   128  	case p.Type == "bool":
   129  		return p.BoolTmpVarCode()
   130  	case strings.HasPrefix(p.Type, "[]"):
   131  		return p.SliceTmpVarCode()
   132  	default:
   133  		return ""
   134  	}
   135  }
   136  
   137  // SyscallArgList returns source code fragments representing p parameter
   138  // in syscall. Slices are translated into 2 syscall parameters: pointer to
   139  // the first element and length.
   140  func (p *Param) SyscallArgList() []string {
   141  	var s string
   142  	switch {
   143  	case p.Type[0] == '*':
   144  		s = fmt.Sprintf("unsafe.Pointer(%s)", p.Name)
   145  	case p.Type == "string":
   146  		s = fmt.Sprintf("unsafe.Pointer(%s)", p.tmpVar())
   147  	case p.Type == "bool":
   148  		s = p.tmpVar()
   149  	case strings.HasPrefix(p.Type, "[]"):
   150  		return []string{
   151  			fmt.Sprintf("uintptr(unsafe.Pointer(%s))", p.tmpVar()),
   152  			fmt.Sprintf("uintptr(len(%s))", p.Name),
   153  		}
   154  	default:
   155  		s = p.Name
   156  	}
   157  	return []string{fmt.Sprintf("uintptr(%s)", s)}
   158  }
   159  
   160  // IsError determines if p parameter is used to return error.
   161  func (p *Param) IsError() bool {
   162  	return p.Name == "err" && p.Type == "error"
   163  }
   164  
   165  // join concatenates parameters ps into a string with sep separator.
   166  // Each parameter is converted into string by applying fn to it
   167  // before conversion.
   168  func join(ps []*Param, fn func(*Param) string, sep string) string {
   169  	if len(ps) == 0 {
   170  		return ""
   171  	}
   172  	a := make([]string, 0)
   173  	for _, p := range ps {
   174  		a = append(a, fn(p))
   175  	}
   176  	return strings.Join(a, sep)
   177  }
   178  
   179  // Rets describes function return parameters.
   180  type Rets struct {
   181  	Name         string
   182  	Type         string
   183  	ReturnsError bool
   184  	FailCond     string
   185  }
   186  
   187  // ErrorVarName returns error variable name for r.
   188  func (r *Rets) ErrorVarName() string {
   189  	if r.ReturnsError {
   190  		return "err"
   191  	}
   192  	if r.Type == "error" {
   193  		return r.Name
   194  	}
   195  	return ""
   196  }
   197  
   198  // ToParams converts r into slice of *Param.
   199  func (r *Rets) ToParams() []*Param {
   200  	ps := make([]*Param, 0)
   201  	if len(r.Name) > 0 {
   202  		ps = append(ps, &Param{Name: r.Name, Type: r.Type})
   203  	}
   204  	if r.ReturnsError {
   205  		ps = append(ps, &Param{Name: "err", Type: "error"})
   206  	}
   207  	return ps
   208  }
   209  
   210  // List returns source code of syscall return parameters.
   211  func (r *Rets) List() string {
   212  	s := join(r.ToParams(), func(p *Param) string { return p.Name + " " + p.Type }, ", ")
   213  	if len(s) > 0 {
   214  		s = "(" + s + ")"
   215  	}
   216  	return s
   217  }
   218  
   219  // PrintList returns source code of trace printing part correspondent
   220  // to syscall return values.
   221  func (r *Rets) PrintList() string {
   222  	return join(r.ToParams(), func(p *Param) string { return fmt.Sprintf(`"%s=", %s, `, p.Name, p.Name) }, `", ", `)
   223  }
   224  
   225  // SetReturnValuesCode returns source code that accepts syscall return values.
   226  func (r *Rets) SetReturnValuesCode() string {
   227  	if r.Name == "" && !r.ReturnsError {
   228  		return ""
   229  	}
   230  	retvar := "r0"
   231  	if r.Name == "" {
   232  		retvar = "r1"
   233  	}
   234  	errvar := "_"
   235  	if r.ReturnsError {
   236  		errvar = "e1"
   237  	}
   238  	return fmt.Sprintf("%s, _, %s := ", retvar, errvar)
   239  }
   240  
   241  func (r *Rets) useLongHandleErrorCode(retvar string) string {
   242  	const code = `if %s {
   243  		if e1 != 0 {
   244  			err = error(e1)
   245  		} else {
   246  			err = EINVAL
   247  		}
   248  	}`
   249  	cond := retvar + " == 0"
   250  	if r.FailCond != "" {
   251  		cond = strings.Replace(r.FailCond, "failretval", retvar, 1)
   252  	}
   253  	return fmt.Sprintf(code, cond)
   254  }
   255  
   256  // SetErrorCode returns source code that sets return parameters.
   257  func (r *Rets) SetErrorCode() string {
   258  	const code = `if r0 != 0 {
   259  		%s = Errno(r0)
   260  	}`
   261  	if r.Name == "" && !r.ReturnsError {
   262  		return ""
   263  	}
   264  	if r.Name == "" {
   265  		return r.useLongHandleErrorCode("r1")
   266  	}
   267  	if r.Type == "error" {
   268  		return fmt.Sprintf(code, r.Name)
   269  	}
   270  	s := ""
   271  	if r.Type[0] == '*' {
   272  		s = fmt.Sprintf("%s = (%s)(unsafe.Pointer(r0))", r.Name, r.Type)
   273  	} else {
   274  		s = fmt.Sprintf("%s = %s(r0)", r.Name, r.Type)
   275  	}
   276  	if !r.ReturnsError {
   277  		return s
   278  	}
   279  	return s + "\n\t" + r.useLongHandleErrorCode(r.Name)
   280  }
   281  
   282  // Fn describes syscall function.
   283  type Fn struct {
   284  	Name        string
   285  	Params      []*Param
   286  	Rets        *Rets
   287  	PrintTrace  bool
   288  	dllname     string
   289  	dllfuncname string
   290  	src         string
   291  	// TODO: get rid of this field and just use parameter index instead
   292  	curTmpVarIdx int // insure tmp variables have uniq names
   293  }
   294  
   295  // extractParams parses s to extract function parameters.
   296  func extractParams(s string, f *Fn) ([]*Param, error) {
   297  	s = trim(s)
   298  	if s == "" {
   299  		return nil, nil
   300  	}
   301  	a := strings.Split(s, ",")
   302  	ps := make([]*Param, len(a))
   303  	for i := range ps {
   304  		s2 := trim(a[i])
   305  		b := strings.Split(s2, " ")
   306  		if len(b) != 2 {
   307  			b = strings.Split(s2, "\t")
   308  			if len(b) != 2 {
   309  				return nil, errors.New("Could not extract function parameter from \"" + s2 + "\"")
   310  			}
   311  		}
   312  		ps[i] = &Param{
   313  			Name:      trim(b[0]),
   314  			Type:      trim(b[1]),
   315  			fn:        f,
   316  			tmpVarIdx: -1,
   317  		}
   318  	}
   319  	return ps, nil
   320  }
   321  
   322  // extractSection extracts text out of string s starting after start
   323  // and ending just before end. found return value will indicate success,
   324  // and prefix, body and suffix will contain correspondent parts of string s.
   325  func extractSection(s string, start, end rune) (prefix, body, suffix string, found bool) {
   326  	s = trim(s)
   327  	if strings.HasPrefix(s, string(start)) {
   328  		// no prefix
   329  		body = s[1:]
   330  	} else {
   331  		a := strings.SplitN(s, string(start), 2)
   332  		if len(a) != 2 {
   333  			return "", "", s, false
   334  		}
   335  		prefix = a[0]
   336  		body = a[1]
   337  	}
   338  	a := strings.SplitN(body, string(end), 2)
   339  	if len(a) != 2 {
   340  		return "", "", "", false
   341  	}
   342  	return prefix, a[0], a[1], true
   343  }
   344  
   345  // newFn parses string s and return created function Fn.
   346  func newFn(s string) (*Fn, error) {
   347  	s = trim(s)
   348  	f := &Fn{
   349  		Rets:       &Rets{},
   350  		src:        s,
   351  		PrintTrace: *PrintTraceFlag,
   352  	}
   353  	// function name and args
   354  	prefix, body, s, found := extractSection(s, '(', ')')
   355  	if !found || prefix == "" {
   356  		return nil, errors.New("Could not extract function name and parameters from \"" + f.src + "\"")
   357  	}
   358  	f.Name = prefix
   359  	var err error
   360  	f.Params, err = extractParams(body, f)
   361  	if err != nil {
   362  		return nil, err
   363  	}
   364  	// return values
   365  	_, body, s, found = extractSection(s, '(', ')')
   366  	if found {
   367  		r, err := extractParams(body, f)
   368  		if err != nil {
   369  			return nil, err
   370  		}
   371  		switch len(r) {
   372  		case 0:
   373  		case 1:
   374  			if r[0].IsError() {
   375  				f.Rets.ReturnsError = true
   376  			} else {
   377  				f.Rets.Name = r[0].Name
   378  				f.Rets.Type = r[0].Type
   379  			}
   380  		case 2:
   381  			if !r[1].IsError() {
   382  				return nil, errors.New("Only last windows error is allowed as second return value in \"" + f.src + "\"")
   383  			}
   384  			f.Rets.ReturnsError = true
   385  			f.Rets.Name = r[0].Name
   386  			f.Rets.Type = r[0].Type
   387  		default:
   388  			return nil, errors.New("Too many return values in \"" + f.src + "\"")
   389  		}
   390  	}
   391  	// fail condition
   392  	_, body, s, found = extractSection(s, '[', ']')
   393  	if found {
   394  		f.Rets.FailCond = body
   395  	}
   396  	// dll and dll function names
   397  	s = trim(s)
   398  	if s == "" {
   399  		return f, nil
   400  	}
   401  	if !strings.HasPrefix(s, "=") {
   402  		return nil, errors.New("Could not extract dll name from \"" + f.src + "\"")
   403  	}
   404  	s = trim(s[1:])
   405  	a := strings.Split(s, ".")
   406  	switch len(a) {
   407  	case 1:
   408  		f.dllfuncname = a[0]
   409  	case 2:
   410  		f.dllname = a[0]
   411  		f.dllfuncname = a[1]
   412  	default:
   413  		return nil, errors.New("Could not extract dll name from \"" + f.src + "\"")
   414  	}
   415  	return f, nil
   416  }
   417  
   418  // DLLName returns DLL name for function f.
   419  func (f *Fn) DLLName() string {
   420  	if f.dllname == "" {
   421  		return "kernel32"
   422  	}
   423  	return f.dllname
   424  }
   425  
   426  // DLLName returns DLL function name for function f.
   427  func (f *Fn) DLLFuncName() string {
   428  	if f.dllfuncname == "" {
   429  		return f.Name
   430  	}
   431  	return f.dllfuncname
   432  }
   433  
   434  // ParamList returns source code for function f parameters.
   435  func (f *Fn) ParamList() string {
   436  	return join(f.Params, func(p *Param) string { return p.Name + " " + p.Type }, ", ")
   437  }
   438  
   439  // ParamPrintList returns source code of trace printing part correspondent
   440  // to syscall input parameters.
   441  func (f *Fn) ParamPrintList() string {
   442  	return join(f.Params, func(p *Param) string { return fmt.Sprintf(`"%s=", %s, `, p.Name, p.Name) }, `", ", `)
   443  }
   444  
   445  // ParamCount return number of syscall parameters for function f.
   446  func (f *Fn) ParamCount() int {
   447  	n := 0
   448  	for _, p := range f.Params {
   449  		n += len(p.SyscallArgList())
   450  	}
   451  	return n
   452  }
   453  
   454  // SyscallParamCount determines which version of Syscall/Syscall6/Syscall9/...
   455  // to use. It returns parameter count for correspondent SyscallX function.
   456  func (f *Fn) SyscallParamCount() int {
   457  	n := f.ParamCount()
   458  	switch {
   459  	case n <= 3:
   460  		return 3
   461  	case n <= 6:
   462  		return 6
   463  	case n <= 9:
   464  		return 9
   465  	case n <= 12:
   466  		return 12
   467  	case n <= 15:
   468  		return 15
   469  	default:
   470  		panic("too many arguments to system call")
   471  	}
   472  }
   473  
   474  // Syscall determines which SyscallX function to use for function f.
   475  func (f *Fn) Syscall() string {
   476  	c := f.SyscallParamCount()
   477  	if c == 3 {
   478  		return "Syscall"
   479  	}
   480  	return "Syscall" + strconv.Itoa(c)
   481  }
   482  
   483  // SyscallParamList returns source code for SyscallX parameters for function f.
   484  func (f *Fn) SyscallParamList() string {
   485  	a := make([]string, 0)
   486  	for _, p := range f.Params {
   487  		a = append(a, p.SyscallArgList()...)
   488  	}
   489  	for len(a) < f.SyscallParamCount() {
   490  		a = append(a, "0")
   491  	}
   492  	return strings.Join(a, ", ")
   493  }
   494  
   495  // IsUTF16 is true, if f is W (utf16) function. It is false
   496  // for all A (ascii) functions.
   497  func (f *Fn) IsUTF16() bool {
   498  	s := f.DLLFuncName()
   499  	return s[len(s)-1] == 'W'
   500  }
   501  
   502  // StrconvFunc returns name of Go string to OS string function for f.
   503  func (f *Fn) StrconvFunc() string {
   504  	if f.IsUTF16() {
   505  		return "UTF16PtrFromString"
   506  	}
   507  	return "BytePtrFromString"
   508  }
   509  
   510  // StrconvType returns Go type name used for OS string for f.
   511  func (f *Fn) StrconvType() string {
   512  	if f.IsUTF16() {
   513  		return "*uint16"
   514  	}
   515  	return "*byte"
   516  }
   517  
   518  // Source files and functions.
   519  type Source struct {
   520  	Funcs []*Fn
   521  	Files []string
   522  }
   523  
   524  // ParseFiles parses files listed in fs and extracts all syscall
   525  // functions listed in  sys comments. It returns source files
   526  // and functions collection *Source if successful.
   527  func ParseFiles(fs []string) (*Source, error) {
   528  	src := &Source{
   529  		Funcs: make([]*Fn, 0),
   530  		Files: make([]string, 0),
   531  	}
   532  	for _, file := range fs {
   533  		if err := src.ParseFile(file); err != nil {
   534  			return nil, err
   535  		}
   536  	}
   537  	return src, nil
   538  }
   539  
   540  // DLLs return dll names for a source set src.
   541  func (src *Source) DLLs() []string {
   542  	uniq := make(map[string]bool)
   543  	r := make([]string, 0)
   544  	for _, f := range src.Funcs {
   545  		name := f.DLLName()
   546  		if _, found := uniq[name]; !found {
   547  			uniq[name] = true
   548  			r = append(r, name)
   549  		}
   550  	}
   551  	return r
   552  }
   553  
   554  // ParseFile adds adition file path to a source set src.
   555  func (src *Source) ParseFile(path string) error {
   556  	file, err := os.Open(path)
   557  	if err != nil {
   558  		return err
   559  	}
   560  	defer file.Close()
   561  
   562  	s := bufio.NewScanner(file)
   563  	for s.Scan() {
   564  		t := trim(s.Text())
   565  		if len(t) < 7 {
   566  			continue
   567  		}
   568  		if !strings.HasPrefix(t, "//sys") {
   569  			continue
   570  		}
   571  		t = t[5:]
   572  		if !(t[0] == ' ' || t[0] == '\t') {
   573  			continue
   574  		}
   575  		f, err := newFn(t[1:])
   576  		if err != nil {
   577  			return err
   578  		}
   579  		src.Funcs = append(src.Funcs, f)
   580  	}
   581  	if err := s.Err(); err != nil {
   582  		return err
   583  	}
   584  	src.Files = append(src.Files, path)
   585  	return nil
   586  }
   587  
   588  // Generate output source file from a source set src.
   589  func (src *Source) Generate(w io.Writer) error {
   590  	t := template.Must(template.New("main").Parse(srcTemplate))
   591  	err := t.Execute(w, src)
   592  	if err != nil {
   593  		return errors.New("Failed to execute template: " + err.Error())
   594  	}
   595  	return nil
   596  }
   597  
   598  func usage() {
   599  	fmt.Fprintf(os.Stderr, "usage: mksyscall_windows [flags] [path ...]\n")
   600  	flag.PrintDefaults()
   601  	os.Exit(1)
   602  }
   603  
   604  func main() {
   605  	flag.Usage = usage
   606  	flag.Parse()
   607  	if len(os.Args) <= 1 {
   608  		fmt.Fprintf(os.Stderr, "no files to parse provided\n")
   609  		usage()
   610  	}
   611  	src, err := ParseFiles(os.Args[1:])
   612  	if err != nil {
   613  		log.Fatal(err)
   614  	}
   615  	if err := src.Generate(os.Stdout); err != nil {
   616  		log.Fatal(err)
   617  	}
   618  }
   619  
   620  // TODO: use println instead to print in the following template
   621  const srcTemplate = `
   622  
   623  {{define "main"}}// go build mksyscall_windows.go && ./mksyscall_windows{{range .Files}} {{.}}{{end}}
   624  // MACHINE GENERATED BY THE COMMAND ABOVE; DO NOT EDIT
   625  
   626  package syscall
   627  
   628  import "unsafe"
   629  
   630  var (
   631  {{template "dlls" .}}
   632  {{template "funcnames" .}})
   633  {{range .Funcs}}{{template "funcbody" .}}{{end}}
   634  {{end}}
   635  
   636  {{/* help functions */}}
   637  
   638  {{define "dlls"}}{{range .DLLs}}	mod{{.}} = NewLazyDLL("{{.}}.dll")
   639  {{end}}{{end}}
   640  
   641  {{define "funcnames"}}{{range .Funcs}}	proc{{.DLLFuncName}} = mod{{.DLLName}}.NewProc("{{.DLLFuncName}}")
   642  {{end}}{{end}}
   643  
   644  {{define "funcbody"}}
   645  func {{.Name}}({{.ParamList}}) {{if .Rets.List}}{{.Rets.List}} {{end}}{
   646  {{template "tmpvars" .}}	{{template "syscall" .}}
   647  {{template "seterror" .}}{{template "printtrace" .}}	return
   648  }
   649  {{end}}
   650  
   651  {{define "tmpvars"}}{{range .Params}}{{if .TmpVarCode}}	{{.TmpVarCode}}
   652  {{end}}{{end}}{{end}}
   653  
   654  {{define "syscall"}}{{.Rets.SetReturnValuesCode}}{{.Syscall}}(proc{{.DLLFuncName}}.Addr(), {{.ParamCount}}, {{.SyscallParamList}}){{end}}
   655  
   656  {{define "seterror"}}{{if .Rets.SetErrorCode}}	{{.Rets.SetErrorCode}}
   657  {{end}}{{end}}
   658  
   659  {{define "printtrace"}}{{if .PrintTrace}}	print("SYSCALL: {{.Name}}(", {{.ParamPrintList}}") (", {{.Rets.PrintList}}")\n")
   660  {{end}}{{end}}
   661  
   662  `