github.com/google/grumpy@v0.0.0-20171122020858-3ec87959189c/runtime/module_test.go (about)

     1  // Copyright 2016 Google Inc. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package grumpy
    16  
    17  import (
    18  	"io/ioutil"
    19  	"os"
    20  	"testing"
    21  )
    22  
    23  func TestImportModule(t *testing.T) {
    24  	f := NewRootFrame()
    25  	invalidModule := newObject(ObjectType)
    26  	foo := newTestModule("foo", "foo/__init__.py")
    27  	bar := newTestModule("foo.bar", "foo/bar/__init__.py")
    28  	baz := newTestModule("foo.bar.baz", "foo/bar/baz/__init__.py")
    29  	qux := newTestModule("foo.qux", "foo/qux/__init__.py")
    30  	fooCode := NewCode("<module>", "foo/__init__.py", nil, 0, func(*Frame, []*Object) (*Object, *BaseException) { return None, nil })
    31  	barCode := NewCode("<module>", "foo/bar/__init__.py", nil, 0, func(*Frame, []*Object) (*Object, *BaseException) { return None, nil })
    32  	bazCode := NewCode("<module>", "foo/bar/baz/__init__.py", nil, 0, func(*Frame, []*Object) (*Object, *BaseException) { return None, nil })
    33  	quxCode := NewCode("<module>", "foo/qux/__init__.py", nil, 0, func(*Frame, []*Object) (*Object, *BaseException) { return None, nil })
    34  	raisesCode := NewCode("<module", "raises.py", nil, 0, func(f *Frame, _ []*Object) (*Object, *BaseException) {
    35  		return nil, f.RaiseType(ValueErrorType, "uh oh")
    36  	})
    37  	circularImported := false
    38  	circularCode := NewCode("<module>", "circular.py", nil, 0, func(f *Frame, _ []*Object) (*Object, *BaseException) {
    39  		if circularImported {
    40  			return nil, f.RaiseType(AssertionErrorType, "circular imported recursively")
    41  		}
    42  		circularImported = true
    43  		if _, raised := ImportModule(f, "circular"); raised != nil {
    44  			return nil, raised
    45  		}
    46  		return None, nil
    47  	})
    48  	circularTestModule := newTestModule("circular", "circular.py").ToObject()
    49  	clearCode := NewCode("<module>", "clear.py", nil, 0, func(f *Frame, _ []*Object) (*Object, *BaseException) {
    50  		if _, raised := SysModules.DelItemString(f, "clear"); raised != nil {
    51  			return nil, raised
    52  		}
    53  		return None, nil
    54  	})
    55  	// NOTE: This test progressively evolves sys.modules, checking after
    56  	// each test case that it's populated appropriately.
    57  	oldSysModules := SysModules
    58  	oldModuleRegistry := moduleRegistry
    59  	defer func() {
    60  		SysModules = oldSysModules
    61  		moduleRegistry = oldModuleRegistry
    62  	}()
    63  	SysModules = newStringDict(map[string]*Object{"invalid": invalidModule})
    64  	moduleRegistry = map[string]*Code{
    65  		"foo":         fooCode,
    66  		"foo.bar":     barCode,
    67  		"foo.bar.baz": bazCode,
    68  		"foo.qux":     quxCode,
    69  		"raises":      raisesCode,
    70  		"circular":    circularCode,
    71  		"clear":       clearCode,
    72  	}
    73  	cases := []struct {
    74  		name           string
    75  		want           *Object
    76  		wantExc        *BaseException
    77  		wantSysModules *Dict
    78  	}{
    79  		{
    80  			"noexist",
    81  			nil,
    82  			mustCreateException(ImportErrorType, "noexist"),
    83  			newStringDict(map[string]*Object{"invalid": invalidModule}),
    84  		},
    85  		{
    86  			"invalid",
    87  			NewTuple(invalidModule).ToObject(),
    88  			nil,
    89  			newStringDict(map[string]*Object{"invalid": invalidModule}),
    90  		},
    91  		{
    92  			"raises",
    93  			nil,
    94  			mustCreateException(ValueErrorType, "uh oh"),
    95  			newStringDict(map[string]*Object{"invalid": invalidModule}),
    96  		},
    97  		{
    98  			"foo",
    99  			NewTuple(foo.ToObject()).ToObject(),
   100  			nil,
   101  			newStringDict(map[string]*Object{
   102  				"foo":     foo.ToObject(),
   103  				"invalid": invalidModule,
   104  			}),
   105  		},
   106  		{
   107  			"foo",
   108  			NewTuple(foo.ToObject()).ToObject(),
   109  			nil,
   110  			newStringDict(map[string]*Object{
   111  				"foo":     foo.ToObject(),
   112  				"invalid": invalidModule,
   113  			}),
   114  		},
   115  		{
   116  			"foo.qux",
   117  			NewTuple(foo.ToObject(), qux.ToObject()).ToObject(),
   118  			nil,
   119  			newStringDict(map[string]*Object{
   120  				"foo":     foo.ToObject(),
   121  				"foo.qux": qux.ToObject(),
   122  				"invalid": invalidModule,
   123  			}),
   124  		},
   125  		{
   126  			"foo.bar.baz",
   127  			NewTuple(foo.ToObject(), bar.ToObject(), baz.ToObject()).ToObject(),
   128  			nil,
   129  			newStringDict(map[string]*Object{
   130  				"foo":         foo.ToObject(),
   131  				"foo.bar":     bar.ToObject(),
   132  				"foo.bar.baz": baz.ToObject(),
   133  				"foo.qux":     qux.ToObject(),
   134  				"invalid":     invalidModule,
   135  			}),
   136  		},
   137  		{
   138  			"circular",
   139  			NewTuple(circularTestModule).ToObject(),
   140  			nil,
   141  			newStringDict(map[string]*Object{
   142  				"circular":    circularTestModule,
   143  				"foo":         foo.ToObject(),
   144  				"foo.bar":     bar.ToObject(),
   145  				"foo.bar.baz": baz.ToObject(),
   146  				"foo.qux":     qux.ToObject(),
   147  				"invalid":     invalidModule,
   148  			}),
   149  		},
   150  		{
   151  			"clear",
   152  			nil,
   153  			mustCreateException(ImportErrorType, "Loaded module clear not found in sys.modules"),
   154  			newStringDict(map[string]*Object{
   155  				"circular":    circularTestModule,
   156  				"foo":         foo.ToObject(),
   157  				"foo.bar":     bar.ToObject(),
   158  				"foo.bar.baz": baz.ToObject(),
   159  				"foo.qux":     qux.ToObject(),
   160  				"invalid":     invalidModule,
   161  			}),
   162  		},
   163  	}
   164  	for _, cas := range cases {
   165  		mods, raised := ImportModule(f, cas.name)
   166  		var got *Object
   167  		if raised == nil {
   168  			got = NewTuple(mods...).ToObject()
   169  		}
   170  		switch checkResult(got, cas.want, raised, cas.wantExc) {
   171  		case checkInvokeResultExceptionMismatch:
   172  			t.Errorf("ImportModule(%q) raised %v, want %v", cas.name, raised, cas.wantExc)
   173  		case checkInvokeResultReturnValueMismatch:
   174  			t.Errorf("ImportModule(%q) = %v, want %v", cas.name, got, cas.want)
   175  		}
   176  		ne := mustNotRaise(NE(f, SysModules.ToObject(), cas.wantSysModules.ToObject()))
   177  		b, raised := IsTrue(f, ne)
   178  		if raised != nil {
   179  			panic(raised)
   180  		}
   181  		if b {
   182  			msg := "ImportModule(%q): sys.modules = %v, want %v"
   183  			t.Errorf(msg, cas.name, SysModules, cas.wantSysModules)
   184  		}
   185  	}
   186  }
   187  
   188  func TestModuleGetNameAndFilename(t *testing.T) {
   189  	fun := wrapFuncForTest(func(f *Frame, m *Module) (*Tuple, *BaseException) {
   190  		name, raised := m.GetName(f)
   191  		if raised != nil {
   192  			return nil, raised
   193  		}
   194  		filename, raised := m.GetFilename(f)
   195  		if raised != nil {
   196  			return nil, raised
   197  		}
   198  		return newTestTuple(name, filename), nil
   199  	})
   200  	cases := []invokeTestCase{
   201  		{args: wrapArgs(newModule("foo", "foo.py")), want: newTestTuple("foo", "foo.py").ToObject()},
   202  		{args: Args{mustNotRaise(ModuleType.Call(NewRootFrame(), wrapArgs("foo"), nil))}, wantExc: mustCreateException(SystemErrorType, "module filename missing")},
   203  		{args: wrapArgs(&Module{Object: Object{typ: ModuleType, dict: NewDict()}}), wantExc: mustCreateException(SystemErrorType, "nameless module")},
   204  	}
   205  	for _, cas := range cases {
   206  		if err := runInvokeTestCase(fun, &cas); err != "" {
   207  			t.Error(err)
   208  		}
   209  	}
   210  }
   211  
   212  func TestModuleInit(t *testing.T) {
   213  	fun := wrapFuncForTest(func(f *Frame, args ...*Object) (*Tuple, *BaseException) {
   214  		o, raised := ModuleType.Call(f, args, nil)
   215  		if raised != nil {
   216  			return nil, raised
   217  		}
   218  		name, raised := GetAttr(f, o, internedName, None)
   219  		if raised != nil {
   220  			return nil, raised
   221  		}
   222  		doc, raised := GetAttr(f, o, NewStr("__doc__"), None)
   223  		if raised != nil {
   224  			return nil, raised
   225  		}
   226  		return NewTuple(name, doc), nil
   227  	})
   228  	cases := []invokeTestCase{
   229  		{args: wrapArgs("foo"), want: newTestTuple("foo", None).ToObject()},
   230  		{args: wrapArgs("foo", 123), want: newTestTuple("foo", 123).ToObject()},
   231  		{args: wrapArgs(newObject(ObjectType)), wantExc: mustCreateException(TypeErrorType, `'__init__' requires a 'str' object but received a "object"`)},
   232  		{wantExc: mustCreateException(TypeErrorType, "'__init__' requires 2 arguments")},
   233  	}
   234  	for _, cas := range cases {
   235  		if err := runInvokeTestCase(fun, &cas); err != "" {
   236  			t.Error(err)
   237  		}
   238  	}
   239  }
   240  
   241  func TestModuleStrRepr(t *testing.T) {
   242  	cases := []invokeTestCase{
   243  		{args: wrapArgs(newModule("foo", "<test>")), want: NewStr("<module 'foo' from '<test>'>").ToObject()},
   244  		{args: wrapArgs(newModule("foo.bar.baz", "<test>")), want: NewStr("<module 'foo.bar.baz' from '<test>'>").ToObject()},
   245  		{args: Args{mustNotRaise(ModuleType.Call(NewRootFrame(), wrapArgs("foo"), nil))}, want: NewStr("<module 'foo' (built-in)>").ToObject()},
   246  		{args: wrapArgs(&Module{Object: Object{typ: ModuleType, dict: newTestDict("__file__", "foo.py")}}), want: NewStr("<module '?' from 'foo.py'>").ToObject()},
   247  	}
   248  	for _, cas := range cases {
   249  		if err := runInvokeTestCase(wrapFuncForTest(ToStr), &cas); err != "" {
   250  			t.Error(err)
   251  		}
   252  		if err := runInvokeTestCase(wrapFuncForTest(Repr), &cas); err != "" {
   253  			t.Error(err)
   254  		}
   255  	}
   256  }
   257  
   258  func TestRunMain(t *testing.T) {
   259  	oldSysModules := SysModules
   260  	defer func() {
   261  		SysModules = oldSysModules
   262  	}()
   263  	cases := []struct {
   264  		code       *Code
   265  		wantCode   int
   266  		wantOutput string
   267  	}{
   268  		{NewCode("<test>", "test.py", nil, 0, func(*Frame, []*Object) (*Object, *BaseException) { return None, nil }), 0, ""},
   269  		{NewCode("<test>", "test.py", nil, 0, func(f *Frame, _ []*Object) (*Object, *BaseException) {
   270  			return nil, f.Raise(SystemExitType.ToObject(), None, nil)
   271  		}), 0, ""},
   272  		{NewCode("<test>", "test.py", nil, 0, func(f *Frame, _ []*Object) (*Object, *BaseException) { return nil, f.RaiseType(TypeErrorType, "foo") }), 1, "TypeError: foo\n"},
   273  		{NewCode("<test>", "test.py", nil, 0, func(f *Frame, _ []*Object) (*Object, *BaseException) { return nil, f.RaiseType(SystemExitType, "foo") }), 1, "foo\n"},
   274  		{NewCode("<test>", "test.py", nil, 0, func(f *Frame, _ []*Object) (*Object, *BaseException) {
   275  			return nil, f.Raise(SystemExitType.ToObject(), NewInt(12).ToObject(), nil)
   276  		}), 12, ""},
   277  	}
   278  	for _, cas := range cases {
   279  		SysModules = NewDict()
   280  		if gotCode, gotOutput, err := runMainAndCaptureStderr(cas.code); err != nil {
   281  			t.Errorf("runMainRedirectStderr() failed: %v", err)
   282  		} else if gotCode != cas.wantCode {
   283  			t.Errorf("RunMain() = %v, want %v", gotCode, cas.wantCode)
   284  		} else if gotOutput != cas.wantOutput {
   285  			t.Errorf("RunMain() output %q, want %q", gotOutput, cas.wantOutput)
   286  		}
   287  	}
   288  }
   289  
   290  func runMainAndCaptureStderr(code *Code) (int, string, error) {
   291  	oldStderr := Stderr
   292  	defer func() {
   293  		Stderr = oldStderr
   294  	}()
   295  	r, w, err := os.Pipe()
   296  	if err != nil {
   297  		return 0, "", err
   298  	}
   299  	Stderr = NewFileFromFD(w.Fd(), nil)
   300  	c := make(chan int)
   301  	go func() {
   302  		defer w.Close()
   303  		c <- RunMain(code)
   304  	}()
   305  	result := <-c
   306  	data, err := ioutil.ReadAll(r)
   307  	if err != nil {
   308  		return 0, "", err
   309  	}
   310  	return result, string(data), nil
   311  }
   312  
   313  var testModuleType *Type
   314  
   315  func init() {
   316  	testModuleType, _ = newClass(NewRootFrame(), TypeType, "testModule", []*Type{ModuleType}, newStringDict(map[string]*Object{
   317  		"__eq__": newBuiltinFunction("__eq__", func(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) {
   318  			if raised := checkMethodArgs(f, "__eq__", args, ModuleType, ObjectType); raised != nil {
   319  				return nil, raised
   320  			}
   321  			if !args[1].isInstance(ModuleType) {
   322  				return NotImplemented, nil
   323  			}
   324  			m1, m2 := toModuleUnsafe(args[0]), toModuleUnsafe(args[1])
   325  			name1, raised := m1.GetName(f)
   326  			if raised != nil {
   327  				return nil, raised
   328  			}
   329  			name2, raised := m2.GetName(f)
   330  			if raised != nil {
   331  				return nil, raised
   332  			}
   333  			if name1.Value() != name2.Value() {
   334  				return False.ToObject(), nil
   335  			}
   336  			file1, raised := m1.GetFilename(f)
   337  			if raised != nil {
   338  				return nil, raised
   339  			}
   340  			file2, raised := m2.GetFilename(f)
   341  			if raised != nil {
   342  				return nil, raised
   343  			}
   344  			return GetBool(file1.Value() == file2.Value()).ToObject(), nil
   345  		}).ToObject(),
   346  		"__ne__": newBuiltinFunction("__ne__", func(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) {
   347  			if raised := checkMethodArgs(f, "__ne__", args, ModuleType, ObjectType); raised != nil {
   348  				return nil, raised
   349  			}
   350  			eq, raised := Eq(f, args[0], args[1])
   351  			if raised != nil {
   352  				return nil, raised
   353  			}
   354  			isEq, raised := IsTrue(f, eq)
   355  			if raised != nil {
   356  				return nil, raised
   357  			}
   358  			return GetBool(!isEq).ToObject(), nil
   359  		}).ToObject(),
   360  	}))
   361  }
   362  
   363  func newTestModule(name, filename string) *Module {
   364  	return &Module{Object: Object{typ: testModuleType, dict: newTestDict("__name__", name, "__file__", filename)}}
   365  }