github.com/tcnksm/go@v0.0.0-20141208075154-439b32936367/src/runtime/syscall_windows_test.go (about)

     1  // Copyright 2010 The Go Authors.  All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package runtime_test
     6  
     7  import (
     8  	"fmt"
     9  	"io/ioutil"
    10  	"os"
    11  	"os/exec"
    12  	"path/filepath"
    13  	"runtime"
    14  	"strings"
    15  	"syscall"
    16  	"testing"
    17  	"unsafe"
    18  )
    19  
    20  type DLL struct {
    21  	*syscall.DLL
    22  	t *testing.T
    23  }
    24  
    25  func GetDLL(t *testing.T, name string) *DLL {
    26  	d, e := syscall.LoadDLL(name)
    27  	if e != nil {
    28  		t.Fatal(e)
    29  	}
    30  	return &DLL{DLL: d, t: t}
    31  }
    32  
    33  func (d *DLL) Proc(name string) *syscall.Proc {
    34  	p, e := d.FindProc(name)
    35  	if e != nil {
    36  		d.t.Fatal(e)
    37  	}
    38  	return p
    39  }
    40  
    41  func TestStdCall(t *testing.T) {
    42  	type Rect struct {
    43  		left, top, right, bottom int32
    44  	}
    45  	res := Rect{}
    46  	expected := Rect{1, 1, 40, 60}
    47  	a, _, _ := GetDLL(t, "user32.dll").Proc("UnionRect").Call(
    48  		uintptr(unsafe.Pointer(&res)),
    49  		uintptr(unsafe.Pointer(&Rect{10, 1, 14, 60})),
    50  		uintptr(unsafe.Pointer(&Rect{1, 2, 40, 50})))
    51  	if a != 1 || res.left != expected.left ||
    52  		res.top != expected.top ||
    53  		res.right != expected.right ||
    54  		res.bottom != expected.bottom {
    55  		t.Error("stdcall USER32.UnionRect returns", a, "res=", res)
    56  	}
    57  }
    58  
    59  func Test64BitReturnStdCall(t *testing.T) {
    60  
    61  	const (
    62  		VER_BUILDNUMBER      = 0x0000004
    63  		VER_MAJORVERSION     = 0x0000002
    64  		VER_MINORVERSION     = 0x0000001
    65  		VER_PLATFORMID       = 0x0000008
    66  		VER_PRODUCT_TYPE     = 0x0000080
    67  		VER_SERVICEPACKMAJOR = 0x0000020
    68  		VER_SERVICEPACKMINOR = 0x0000010
    69  		VER_SUITENAME        = 0x0000040
    70  
    71  		VER_EQUAL         = 1
    72  		VER_GREATER       = 2
    73  		VER_GREATER_EQUAL = 3
    74  		VER_LESS          = 4
    75  		VER_LESS_EQUAL    = 5
    76  
    77  		ERROR_OLD_WIN_VERSION syscall.Errno = 1150
    78  	)
    79  
    80  	type OSVersionInfoEx struct {
    81  		OSVersionInfoSize uint32
    82  		MajorVersion      uint32
    83  		MinorVersion      uint32
    84  		BuildNumber       uint32
    85  		PlatformId        uint32
    86  		CSDVersion        [128]uint16
    87  		ServicePackMajor  uint16
    88  		ServicePackMinor  uint16
    89  		SuiteMask         uint16
    90  		ProductType       byte
    91  		Reserve           byte
    92  	}
    93  
    94  	d := GetDLL(t, "kernel32.dll")
    95  
    96  	var m1, m2 uintptr
    97  	VerSetConditionMask := d.Proc("VerSetConditionMask")
    98  	m1, m2, _ = VerSetConditionMask.Call(m1, m2, VER_MAJORVERSION, VER_GREATER_EQUAL)
    99  	m1, m2, _ = VerSetConditionMask.Call(m1, m2, VER_MINORVERSION, VER_GREATER_EQUAL)
   100  	m1, m2, _ = VerSetConditionMask.Call(m1, m2, VER_SERVICEPACKMAJOR, VER_GREATER_EQUAL)
   101  	m1, m2, _ = VerSetConditionMask.Call(m1, m2, VER_SERVICEPACKMINOR, VER_GREATER_EQUAL)
   102  
   103  	vi := OSVersionInfoEx{
   104  		MajorVersion:     5,
   105  		MinorVersion:     1,
   106  		ServicePackMajor: 2,
   107  		ServicePackMinor: 0,
   108  	}
   109  	vi.OSVersionInfoSize = uint32(unsafe.Sizeof(vi))
   110  	r, _, e2 := d.Proc("VerifyVersionInfoW").Call(
   111  		uintptr(unsafe.Pointer(&vi)),
   112  		VER_MAJORVERSION|VER_MINORVERSION|VER_SERVICEPACKMAJOR|VER_SERVICEPACKMINOR,
   113  		m1, m2)
   114  	if r == 0 && e2 != ERROR_OLD_WIN_VERSION {
   115  		t.Errorf("VerifyVersionInfo failed: %s", e2)
   116  	}
   117  }
   118  
   119  func TestCDecl(t *testing.T) {
   120  	var buf [50]byte
   121  	fmtp, _ := syscall.BytePtrFromString("%d %d %d")
   122  	a, _, _ := GetDLL(t, "user32.dll").Proc("wsprintfA").Call(
   123  		uintptr(unsafe.Pointer(&buf[0])),
   124  		uintptr(unsafe.Pointer(fmtp)),
   125  		1000, 2000, 3000)
   126  	if string(buf[:a]) != "1000 2000 3000" {
   127  		t.Error("cdecl USER32.wsprintfA returns", a, "buf=", buf[:a])
   128  	}
   129  }
   130  
   131  func TestEnumWindows(t *testing.T) {
   132  	d := GetDLL(t, "user32.dll")
   133  	isWindows := d.Proc("IsWindow")
   134  	counter := 0
   135  	cb := syscall.NewCallback(func(hwnd syscall.Handle, lparam uintptr) uintptr {
   136  		if lparam != 888 {
   137  			t.Error("lparam was not passed to callback")
   138  		}
   139  		b, _, _ := isWindows.Call(uintptr(hwnd))
   140  		if b == 0 {
   141  			t.Error("USER32.IsWindow returns FALSE")
   142  		}
   143  		counter++
   144  		return 1 // continue enumeration
   145  	})
   146  	a, _, _ := d.Proc("EnumWindows").Call(cb, 888)
   147  	if a == 0 {
   148  		t.Error("USER32.EnumWindows returns FALSE")
   149  	}
   150  	if counter == 0 {
   151  		t.Error("Callback has been never called or your have no windows")
   152  	}
   153  }
   154  
   155  func callback(hwnd syscall.Handle, lparam uintptr) uintptr {
   156  	(*(*func())(unsafe.Pointer(&lparam)))()
   157  	return 0 // stop enumeration
   158  }
   159  
   160  // nestedCall calls into Windows, back into Go, and finally to f.
   161  func nestedCall(t *testing.T, f func()) {
   162  	c := syscall.NewCallback(callback)
   163  	d := GetDLL(t, "user32.dll")
   164  	defer d.Release()
   165  	d.Proc("EnumWindows").Call(c, uintptr(*(*unsafe.Pointer)(unsafe.Pointer(&f))))
   166  }
   167  
   168  func TestCallback(t *testing.T) {
   169  	var x = false
   170  	nestedCall(t, func() { x = true })
   171  	if !x {
   172  		t.Fatal("nestedCall did not call func")
   173  	}
   174  }
   175  
   176  func TestCallbackGC(t *testing.T) {
   177  	nestedCall(t, runtime.GC)
   178  }
   179  
   180  func TestCallbackPanicLocked(t *testing.T) {
   181  	runtime.LockOSThread()
   182  	defer runtime.UnlockOSThread()
   183  
   184  	if !runtime.LockedOSThread() {
   185  		t.Fatal("runtime.LockOSThread didn't")
   186  	}
   187  	defer func() {
   188  		s := recover()
   189  		if s == nil {
   190  			t.Fatal("did not panic")
   191  		}
   192  		if s.(string) != "callback panic" {
   193  			t.Fatal("wrong panic:", s)
   194  		}
   195  		if !runtime.LockedOSThread() {
   196  			t.Fatal("lost lock on OS thread after panic")
   197  		}
   198  	}()
   199  	nestedCall(t, func() { panic("callback panic") })
   200  	panic("nestedCall returned")
   201  }
   202  
   203  func TestCallbackPanic(t *testing.T) {
   204  	// Make sure panic during callback unwinds properly.
   205  	if runtime.LockedOSThread() {
   206  		t.Fatal("locked OS thread on entry to TestCallbackPanic")
   207  	}
   208  	defer func() {
   209  		s := recover()
   210  		if s == nil {
   211  			t.Fatal("did not panic")
   212  		}
   213  		if s.(string) != "callback panic" {
   214  			t.Fatal("wrong panic:", s)
   215  		}
   216  		if runtime.LockedOSThread() {
   217  			t.Fatal("locked OS thread on exit from TestCallbackPanic")
   218  		}
   219  	}()
   220  	nestedCall(t, func() { panic("callback panic") })
   221  	panic("nestedCall returned")
   222  }
   223  
   224  func TestCallbackPanicLoop(t *testing.T) {
   225  	// Make sure we don't blow out m->g0 stack.
   226  	for i := 0; i < 100000; i++ {
   227  		TestCallbackPanic(t)
   228  	}
   229  }
   230  
   231  func TestBlockingCallback(t *testing.T) {
   232  	c := make(chan int)
   233  	go func() {
   234  		for i := 0; i < 10; i++ {
   235  			c <- <-c
   236  		}
   237  	}()
   238  	nestedCall(t, func() {
   239  		for i := 0; i < 10; i++ {
   240  			c <- i
   241  			if j := <-c; j != i {
   242  				t.Errorf("out of sync %d != %d", j, i)
   243  			}
   244  		}
   245  	})
   246  }
   247  
   248  func TestCallbackInAnotherThread(t *testing.T) {
   249  	// TODO: test a function which calls back in another thread: QueueUserAPC() or CreateThread()
   250  }
   251  
   252  type cbDLLFunc int // int determines number of callback parameters
   253  
   254  func (f cbDLLFunc) stdcallName() string {
   255  	return fmt.Sprintf("stdcall%d", f)
   256  }
   257  
   258  func (f cbDLLFunc) cdeclName() string {
   259  	return fmt.Sprintf("cdecl%d", f)
   260  }
   261  
   262  func (f cbDLLFunc) buildOne(stdcall bool) string {
   263  	var funcname, attr string
   264  	if stdcall {
   265  		funcname = f.stdcallName()
   266  		attr = "__stdcall"
   267  	} else {
   268  		funcname = f.cdeclName()
   269  		attr = "__cdecl"
   270  	}
   271  	typename := "t" + funcname
   272  	p := make([]string, f)
   273  	for i := range p {
   274  		p[i] = "void*"
   275  	}
   276  	params := strings.Join(p, ",")
   277  	for i := range p {
   278  		p[i] = fmt.Sprintf("%d", i+1)
   279  	}
   280  	args := strings.Join(p, ",")
   281  	return fmt.Sprintf(`
   282  typedef void %s (*%s)(%s);
   283  void %s(%s f, void *n) {
   284  	int i;
   285  	for(i=0;i<(int)n;i++){
   286  		f(%s);
   287  	}
   288  }
   289  	`, attr, typename, params, funcname, typename, args)
   290  }
   291  
   292  func (f cbDLLFunc) build() string {
   293  	return f.buildOne(false) + f.buildOne(true)
   294  }
   295  
   296  var cbFuncs = [...]interface{}{
   297  	2: func(i1, i2 uintptr) uintptr {
   298  		if i1+i2 != 3 {
   299  			panic("bad input")
   300  		}
   301  		return 0
   302  	},
   303  	3: func(i1, i2, i3 uintptr) uintptr {
   304  		if i1+i2+i3 != 6 {
   305  			panic("bad input")
   306  		}
   307  		return 0
   308  	},
   309  	4: func(i1, i2, i3, i4 uintptr) uintptr {
   310  		if i1+i2+i3+i4 != 10 {
   311  			panic("bad input")
   312  		}
   313  		return 0
   314  	},
   315  	5: func(i1, i2, i3, i4, i5 uintptr) uintptr {
   316  		if i1+i2+i3+i4+i5 != 15 {
   317  			panic("bad input")
   318  		}
   319  		return 0
   320  	},
   321  	6: func(i1, i2, i3, i4, i5, i6 uintptr) uintptr {
   322  		if i1+i2+i3+i4+i5+i6 != 21 {
   323  			panic("bad input")
   324  		}
   325  		return 0
   326  	},
   327  	7: func(i1, i2, i3, i4, i5, i6, i7 uintptr) uintptr {
   328  		if i1+i2+i3+i4+i5+i6+i7 != 28 {
   329  			panic("bad input")
   330  		}
   331  		return 0
   332  	},
   333  	8: func(i1, i2, i3, i4, i5, i6, i7, i8 uintptr) uintptr {
   334  		if i1+i2+i3+i4+i5+i6+i7+i8 != 36 {
   335  			panic("bad input")
   336  		}
   337  		return 0
   338  	},
   339  	9: func(i1, i2, i3, i4, i5, i6, i7, i8, i9 uintptr) uintptr {
   340  		if i1+i2+i3+i4+i5+i6+i7+i8+i9 != 45 {
   341  			panic("bad input")
   342  		}
   343  		return 0
   344  	},
   345  }
   346  
   347  type cbDLL struct {
   348  	name      string
   349  	buildArgs func(out, src string) []string
   350  }
   351  
   352  func (d *cbDLL) buildSrc(t *testing.T, path string) {
   353  	f, err := os.Create(path)
   354  	if err != nil {
   355  		t.Fatalf("failed to create source file: %v", err)
   356  	}
   357  	defer f.Close()
   358  
   359  	for i := 2; i < 10; i++ {
   360  		fmt.Fprint(f, cbDLLFunc(i).build())
   361  	}
   362  }
   363  
   364  func (d *cbDLL) build(t *testing.T, dir string) string {
   365  	srcname := d.name + ".c"
   366  	d.buildSrc(t, filepath.Join(dir, srcname))
   367  	outname := d.name + ".dll"
   368  	args := d.buildArgs(outname, srcname)
   369  	cmd := exec.Command(args[0], args[1:]...)
   370  	cmd.Dir = dir
   371  	out, err := cmd.CombinedOutput()
   372  	if err != nil {
   373  		t.Fatalf("failed to build dll: %v - %v", err, string(out))
   374  	}
   375  	return filepath.Join(dir, outname)
   376  }
   377  
   378  var cbDLLs = []cbDLL{
   379  	{
   380  		"test",
   381  		func(out, src string) []string {
   382  			return []string{"gcc", "-shared", "-s", "-o", out, src}
   383  		},
   384  	},
   385  	{
   386  		"testO2",
   387  		func(out, src string) []string {
   388  			return []string{"gcc", "-shared", "-s", "-o", out, "-O2", src}
   389  		},
   390  	},
   391  }
   392  
   393  type cbTest struct {
   394  	n     int     // number of callback parameters
   395  	param uintptr // dll function parameter
   396  }
   397  
   398  func (test *cbTest) run(t *testing.T, dllpath string) {
   399  	dll := syscall.MustLoadDLL(dllpath)
   400  	defer dll.Release()
   401  	cb := cbFuncs[test.n]
   402  	stdcall := syscall.NewCallback(cb)
   403  	f := cbDLLFunc(test.n)
   404  	test.runOne(t, dll, f.stdcallName(), stdcall)
   405  	cdecl := syscall.NewCallbackCDecl(cb)
   406  	test.runOne(t, dll, f.cdeclName(), cdecl)
   407  }
   408  
   409  func (test *cbTest) runOne(t *testing.T, dll *syscall.DLL, proc string, cb uintptr) {
   410  	defer func() {
   411  		if r := recover(); r != nil {
   412  			t.Errorf("dll call %v(..., %d) failed: %v", proc, test.param, r)
   413  		}
   414  	}()
   415  	dll.MustFindProc(proc).Call(cb, test.param)
   416  }
   417  
   418  var cbTests = []cbTest{
   419  	{2, 1},
   420  	{2, 10000},
   421  	{3, 3},
   422  	{4, 5},
   423  	{4, 6},
   424  	{5, 2},
   425  	{6, 7},
   426  	{6, 8},
   427  	{7, 6},
   428  	{8, 1},
   429  	{9, 8},
   430  	{9, 10000},
   431  	{3, 4},
   432  	{5, 3},
   433  	{7, 7},
   434  	{8, 2},
   435  	{9, 9},
   436  }
   437  
   438  func TestStdcallAndCDeclCallbacks(t *testing.T) {
   439  	tmp, err := ioutil.TempDir("", "TestCDeclCallback")
   440  	if err != nil {
   441  		t.Fatal("TempDir failed: ", err)
   442  	}
   443  	defer os.RemoveAll(tmp)
   444  
   445  	for _, dll := range cbDLLs {
   446  		dllPath := dll.build(t, tmp)
   447  		for _, test := range cbTests {
   448  			test.run(t, dllPath)
   449  		}
   450  	}
   451  }
   452  
   453  func TestRegisterClass(t *testing.T) {
   454  	kernel32 := GetDLL(t, "kernel32.dll")
   455  	user32 := GetDLL(t, "user32.dll")
   456  	mh, _, _ := kernel32.Proc("GetModuleHandleW").Call(0)
   457  	cb := syscall.NewCallback(func(hwnd syscall.Handle, msg uint32, wparam, lparam uintptr) (rc uintptr) {
   458  		t.Fatal("callback should never get called")
   459  		return 0
   460  	})
   461  	type Wndclassex struct {
   462  		Size       uint32
   463  		Style      uint32
   464  		WndProc    uintptr
   465  		ClsExtra   int32
   466  		WndExtra   int32
   467  		Instance   syscall.Handle
   468  		Icon       syscall.Handle
   469  		Cursor     syscall.Handle
   470  		Background syscall.Handle
   471  		MenuName   *uint16
   472  		ClassName  *uint16
   473  		IconSm     syscall.Handle
   474  	}
   475  	name := syscall.StringToUTF16Ptr("test_window")
   476  	wc := Wndclassex{
   477  		WndProc:   cb,
   478  		Instance:  syscall.Handle(mh),
   479  		ClassName: name,
   480  	}
   481  	wc.Size = uint32(unsafe.Sizeof(wc))
   482  	a, _, err := user32.Proc("RegisterClassExW").Call(uintptr(unsafe.Pointer(&wc)))
   483  	if a == 0 {
   484  		t.Fatalf("RegisterClassEx failed: %v", err)
   485  	}
   486  	r, _, err := user32.Proc("UnregisterClassW").Call(uintptr(unsafe.Pointer(name)), 0)
   487  	if r == 0 {
   488  		t.Fatalf("UnregisterClass failed: %v", err)
   489  	}
   490  }
   491  
   492  func TestOutputDebugString(t *testing.T) {
   493  	d := GetDLL(t, "kernel32.dll")
   494  	p := syscall.StringToUTF16Ptr("testing OutputDebugString")
   495  	d.Proc("OutputDebugStringW").Call(uintptr(unsafe.Pointer(p)))
   496  }
   497  
   498  func TestRaiseException(t *testing.T) {
   499  	o := executeTest(t, raiseExceptionSource, nil)
   500  	if strings.Contains(o, "RaiseException should not return") {
   501  		t.Fatalf("RaiseException did not crash program: %v", o)
   502  	}
   503  	if !strings.Contains(o, "Exception 0xbad") {
   504  		t.Fatalf("No stack trace: %v", o)
   505  	}
   506  }
   507  
   508  const raiseExceptionSource = `
   509  package main
   510  import "syscall"
   511  func main() {
   512  	const EXCEPTION_NONCONTINUABLE = 1
   513  	mod := syscall.MustLoadDLL("kernel32.dll")
   514  	proc := mod.MustFindProc("RaiseException")
   515  	proc.Call(0xbad, EXCEPTION_NONCONTINUABLE, 0, 0)
   516  	println("RaiseException should not return")
   517  }
   518  `
   519  
   520  func TestZeroDivisionException(t *testing.T) {
   521  	o := executeTest(t, zeroDivisionExceptionSource, nil)
   522  	if !strings.Contains(o, "panic: runtime error: integer divide by zero") {
   523  		t.Fatalf("No stack trace: %v", o)
   524  	}
   525  }
   526  
   527  const zeroDivisionExceptionSource = `
   528  package main
   529  func main() {
   530  	x := 1
   531  	y := 0
   532  	z := x / y
   533  	println(z)
   534  }
   535  `