github.com/bytedance/mockey@v1.2.10/mock_test.go (about)

     1  /*
     2   * Copyright 2022 ByteDance Inc.
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package mockey
    18  
    19  import (
    20  	"errors"
    21  	"fmt"
    22  	"testing"
    23  	"time"
    24  
    25  	"github.com/bytedance/mockey/internal/tool"
    26  	. "github.com/smartystreets/goconvey/convey"
    27  )
    28  
    29  func Fun(a string) string {
    30  	fmt.Println(a)
    31  	return a
    32  }
    33  
    34  type Class struct{}
    35  
    36  func (*Class) FunA(a string) string {
    37  	fmt.Println(a)
    38  	return a
    39  }
    40  
    41  func (*Class) VariantParam(a string, b ...string) string {
    42  	fmt.Println("VariantParam")
    43  	return a
    44  }
    45  
    46  func MultiReturn() (int, int) {
    47  	return 0, 0
    48  }
    49  
    50  func MultiReturnErr() (int, int, error) {
    51  	return 0, 0, errors.New("old")
    52  }
    53  
    54  func VariantParam(a int, b ...int) (int, int) {
    55  	return a, b[0]
    56  }
    57  
    58  func ShortFun() {}
    59  
    60  func TestNoConvey(t *testing.T) {
    61  	origin := Fun
    62  	mock := func(p string) string {
    63  		fmt.Println("b")
    64  		origin(p)
    65  		return "b"
    66  	}
    67  	mock2 := Mock(Fun).When(func(p string) bool { return p == "a" }).To(mock).Origin(&origin).Build()
    68  	defer mock2.UnPatch()
    69  	r := Fun("a")
    70  	if r != "b" {
    71  		t.Errorf("result = %s, expected = b", r)
    72  	}
    73  }
    74  
    75  func TestMock(t *testing.T) {
    76  	PatchConvey("test mock", t, func() {
    77  		PatchConvey("test to", func() {
    78  			origin := Fun
    79  			mock := func(p string) string {
    80  				fmt.Println("b")
    81  				origin(p)
    82  				return "b"
    83  			}
    84  			mock2 := Mock(Fun).When(func(p string) bool { return p == "a" }).To(mock).Origin(&origin).Build()
    85  			r := Fun("a")
    86  			So(r, ShouldEqual, "b")
    87  			So(mock2.Times(), ShouldEqual, 1)
    88  		})
    89  		r := Fun("a")
    90  		So(r, ShouldEqual, "a")
    91  		PatchConvey("test return", func() {
    92  			mock3 := Mock(Fun).When(func(p string) bool { return p == "a" }).Return("c").Build()
    93  			r := Fun("a")
    94  			So(r, ShouldEqual, "c")
    95  			So(mock3.Times(), ShouldEqual, 1)
    96  		})
    97  
    98  		PatchConvey("test multi_return", func() {
    99  			mock3 := Mock(MultiReturn).Return(1, 1).Build()
   100  			a, b := MultiReturn()
   101  			So(a, ShouldEqual, 1)
   102  			So(b, ShouldEqual, 1)
   103  			So(mock3.Times(), ShouldEqual, 1)
   104  		})
   105  
   106  		PatchConvey("test multi_return_err", func() {
   107  			newErr := errors.New("new")
   108  			mock3 := Mock(MultiReturnErr).Return(1, 1, newErr).Build()
   109  			a, b, e := MultiReturnErr()
   110  			So(a, ShouldEqual, 1)
   111  			So(b, ShouldEqual, 1)
   112  			So(e, ShouldBeError, newErr)
   113  			So(mock3.Times(), ShouldEqual, 1)
   114  		})
   115  
   116  		PatchConvey("test variant param", func() {
   117  			when := func(a int, bs ...int) bool {
   118  				return bs[0] == 1
   119  			}
   120  			to := func(a int, bs ...int) (int, int) {
   121  				return a + 1, bs[1]
   122  			}
   123  			mock4 := Mock(VariantParam).When(when).To(to).Build()
   124  			a, b := VariantParam(0, 1, 2, 3)
   125  			So(a, ShouldEqual, 1)
   126  			So(b, ShouldEqual, 2)
   127  			So(mock4.Times(), ShouldEqual, 1)
   128  			So(mock4.MockTimes(), ShouldEqual, 1)
   129  		})
   130  	})
   131  }
   132  
   133  func TestParam(t *testing.T) {
   134  	fmt.Printf("gid:%+v\n", tool.GetGoroutineID())
   135  	PatchConvey("test variant param", t, func() {
   136  		when := func(a int, bs ...int) bool {
   137  			return bs[0] == 1
   138  		}
   139  		to := func(a int, bs ...int) (int, int) {
   140  			return a + 1, bs[1]
   141  		}
   142  		mock4 := Mock(VariantParam).When(when).To(to).Build()
   143  		PatchConvey("test when", func() {
   144  			a, b := VariantParam(0, 1, 2, 3)
   145  			So(a, ShouldEqual, 1)
   146  			So(b, ShouldEqual, 2)
   147  			So(mock4.Times(), ShouldEqual, 1)
   148  			So(mock4.MockTimes(), ShouldEqual, 1)
   149  		})
   150  		PatchConvey("test no when", func() {
   151  			a1, b1 := VariantParam(0, 2, 2, 3)
   152  			So(a1, ShouldEqual, 0)
   153  			So(b1, ShouldEqual, 2)
   154  			So(mock4.Times(), ShouldEqual, 1)
   155  			So(mock4.MockTimes(), ShouldEqual, 0)
   156  		})
   157  	})
   158  }
   159  
   160  func TestClass(t *testing.T) {
   161  	PatchConvey("test class", t, func() {
   162  		PatchConvey("test mock", func() {
   163  			mock := func(self *Class, p string) string {
   164  				fmt.Print("b")
   165  				return "b"
   166  			}
   167  			m := Mock((*Class).FunA).When(func(self *Class, p string) bool { return p == "a" }).To(mock).Build()
   168  			c := Class{}
   169  			str := c.FunA("a")
   170  			So(m.MockTimes(), ShouldEqual, 1)
   171  			So(str, ShouldEqual, "b")
   172  		})
   173  		PatchConvey("test class variant param mock", func() {
   174  			mock := func(self *Class, a string, b ...string) string { return b[0] }
   175  			m := Mock((*Class).VariantParam).When(func(self *Class, a string, b ...string) bool { return a == "a" }).To(mock).Build()
   176  			c := Class{}
   177  			str := c.VariantParam("a", "b")
   178  			So(m.MockTimes(), ShouldEqual, 1)
   179  			So(str, ShouldEqual, "b")
   180  		})
   181  
   182  		PatchConvey("test  missing receiver mock", func() {
   183  			mock := func(p string) string {
   184  				fmt.Print("b")
   185  				return "b"
   186  			}
   187  			m := Mock((*Class).FunA).When(func(p string) bool { return p == "a" }).To(mock).Build()
   188  			c := Class{}
   189  			str := c.FunA("a")
   190  			So(m.MockTimes(), ShouldEqual, 1)
   191  			So(str, ShouldEqual, "b")
   192  		})
   193  		PatchConvey("test missing receiver and more args", func() {
   194  			mock := func() string {
   195  				fmt.Print("b")
   196  				return "b"
   197  			}
   198  
   199  			So(func() { Mock((*Class).FunA).When(func(p string) bool { return p == "a" }).To(mock).Build() }, ShouldPanic)
   200  		})
   201  	})
   202  }
   203  
   204  type TestImpl struct {
   205  	a string
   206  }
   207  
   208  func (i *TestImpl) A() string {
   209  	fmt.Println(i.a)
   210  	return i.a
   211  }
   212  
   213  type TestI interface {
   214  	A() string
   215  }
   216  
   217  func ReturnImpl() TestI {
   218  	return &TestImpl{a: "a"}
   219  }
   220  
   221  func TestInterface(t *testing.T) {
   222  	PatchConvey("TestInterface", t, func() {
   223  		PatchConvey("test mock", func() {
   224  			m := Mock(ReturnImpl).Return(&TestImpl{a: "b"}).Build()
   225  			str := ReturnImpl().A()
   226  			So(m.MockTimes(), ShouldEqual, 1)
   227  			So(str, ShouldEqual, "b")
   228  		})
   229  	})
   230  }
   231  
   232  func TestFilterGoRoutine(t *testing.T) {
   233  	PatchConvey("filter go routine", t, func() {
   234  		mock := Mock(Fun).ExcludeCurrentGoRoutine().Return("b").Build()
   235  		r := Fun("a")
   236  		So(r, ShouldEqual, "a")
   237  		So(mock.Times(), ShouldEqual, 1)
   238  		So(mock.MockTimes(), ShouldEqual, 0)
   239  
   240  		mock.IncludeCurrentGoRoutine()
   241  		r = Fun("a")
   242  		So(r, ShouldEqual, "b")
   243  		So(mock.Times(), ShouldEqual, 1)
   244  		So(mock.MockTimes(), ShouldEqual, 1)
   245  
   246  		mock.IncludeCurrentGoRoutine()
   247  		go Fun("a")
   248  		time.Sleep(1 * time.Second)
   249  		So(mock.Times(), ShouldEqual, 1)
   250  		So(mock.MockTimes(), ShouldEqual, 0)
   251  	})
   252  }
   253  
   254  func TestResetPatch(t *testing.T) {
   255  	PatchConvey("test mock", t, func() {
   256  		PatchConvey("test to", func() {
   257  			origin := Fun
   258  			mock := func(p string) string {
   259  				fmt.Println("b")
   260  				origin(p)
   261  				return "b"
   262  			}
   263  			mock2 := Mock(Fun).When(func(p string) bool { return p == "a" }).To(mock).Origin(&origin).Build()
   264  			r := Fun("a")
   265  			So(r, ShouldEqual, "b")
   266  			So(mock2.Times(), ShouldEqual, 1)
   267  
   268  			PatchConvey("test reset when", func() {
   269  				mock2.When(func(p string) bool { return p == "b" })
   270  				r := Fun("a")
   271  				So(r, ShouldEqual, "a")
   272  				So(mock2.MockTimes(), ShouldEqual, 0)
   273  			})
   274  
   275  			PatchConvey("test reset return", func() {
   276  				mock2.Return("c")
   277  				r := Fun("a")
   278  				So(r, ShouldEqual, "c")
   279  				So(mock2.MockTimes(), ShouldEqual, 1)
   280  			})
   281  
   282  			PatchConvey("test reset to and origin", func() {
   283  				origin2 := Fun
   284  				mock := func(p string) string {
   285  					fmt.Println("d")
   286  					return origin2("d") + p
   287  				}
   288  				mock2.To(mock).Origin(&origin2)
   289  				r := Fun("a")
   290  				So(r, ShouldEqual, "da")
   291  				So(mock2.MockTimes(), ShouldEqual, 1)
   292  			})
   293  		})
   294  	})
   295  }
   296  
   297  func TestRePatch(t *testing.T) {
   298  	Convey("TestRePatch", t, func() {
   299  		origin := Fun
   300  		mock := func(p string) string {
   301  			fmt.Println("b")
   302  			origin(p)
   303  			return "b"
   304  		}
   305  		mock2 := Mock(Fun).When(func(p string) bool { return p == "a" }).To(mock).Origin(&origin).Build().Patch().Patch()
   306  		defer mock2.UnPatch()
   307  		r := Fun("a")
   308  		So(r, ShouldEqual, "b")
   309  		mock2.UnPatch()
   310  		mock2.UnPatch()
   311  		mock2.UnPatch()
   312  		fmt.Printf("re unpatch can be run")
   313  	})
   314  }
   315  
   316  func TestMockUnsafe(t *testing.T) {
   317  	Convey("TestMockUnsafe", t, func() {
   318  		mock := MockUnsafe(ShortFun).To(func() { panic("in hook") }).Build()
   319  		defer mock.UnPatch()
   320  		So(func() { ShortFun() }, ShouldPanicWith, "in hook")
   321  	})
   322  }
   323  
   324  type foo struct{ i int }
   325  
   326  func (f *foo) Name(i int) string { return fmt.Sprintf("Fn-%v-%v", f.i, i) }
   327  
   328  func (f *foo) Foo() int { return f.i }
   329  
   330  func TestMockOrigin(t *testing.T) {
   331  	PatchConvey("struct-origin", t, func() {
   332  		PatchConvey("with receiver", func() {
   333  			var ori1 func(*foo, int) string
   334  			var ori2 func(*foo, int) string
   335  			mocker := Mock((*foo).Name).To(func(f *foo, i int) string {
   336  				if i == 1 {
   337  					return ori1(f, i)
   338  				}
   339  				return ori2(f, i)
   340  			}).Origin(&ori1).Build()
   341  
   342  			ori2 = func(f *foo, i int) string { return fmt.Sprintf("Fn-mock2-%v", i) }
   343  			So((&foo{100}).Name(1), ShouldEqual, "Fn-100-1")
   344  			So((&foo{200}).Name(1), ShouldEqual, "Fn-200-1")
   345  			So((&foo{100}).Name(2), ShouldEqual, "Fn-mock2-2")
   346  			So((&foo{200}).Name(2), ShouldEqual, "Fn-mock2-2")
   347  
   348  			ori1 = func(f *foo, i int) string { return fmt.Sprintf("Fn-mock1-%v", i) }
   349  			mocker.Origin(&ori2)
   350  			So((&foo{100}).Name(1), ShouldEqual, "Fn-mock1-1")
   351  			So((&foo{200}).Name(1), ShouldEqual, "Fn-mock1-1")
   352  			So((&foo{100}).Name(2), ShouldEqual, "Fn-100-2")
   353  			So((&foo{200}).Name(2), ShouldEqual, "Fn-200-2")
   354  		})
   355  		PatchConvey("without receiver", func() {
   356  			var ori1 func(int) string
   357  			var ori2 func(int) string
   358  			mocker := Mock((*foo).Name).To(func(i int) string {
   359  				if i == 1 {
   360  					return ori1(i)
   361  				}
   362  				return ori2(i)
   363  			}).Origin(&ori1).Build()
   364  
   365  			ori2 = func(i int) string { return fmt.Sprintf("Fn-mock2-%v", i) }
   366  			So((&foo{100}).Name(1), ShouldEqual, "Fn-100-1")
   367  			So((&foo{200}).Name(1), ShouldEqual, "Fn-200-1")
   368  			So((&foo{100}).Name(2), ShouldEqual, "Fn-mock2-2")
   369  			So((&foo{200}).Name(2), ShouldEqual, "Fn-mock2-2")
   370  
   371  			ori1 = func(i int) string { return fmt.Sprintf("Fn-mock1-%v", i) }
   372  			mocker.Origin(&ori2)
   373  			So((&foo{100}).Name(1), ShouldEqual, "Fn-mock1-1")
   374  			So((&foo{200}).Name(1), ShouldEqual, "Fn-mock1-1")
   375  			So((&foo{100}).Name(2), ShouldEqual, "Fn-100-2")
   376  			So((&foo{200}).Name(2), ShouldEqual, "Fn-200-2")
   377  		})
   378  	})
   379  	PatchConvey("issue https://github.com/bytedance/mockey/issues/15", t, func() {
   380  		var origin func() int
   381  		f := &foo{}
   382  		Mock(GetMethod(f, "Foo")).To(func() int { return origin() + 1 }).Origin(&origin).Build()
   383  		So((&foo{1}).Foo(), ShouldEqual, 2)
   384  		So((&foo{2}).Foo(), ShouldEqual, 3)
   385  		So((&foo{3}).Foo(), ShouldEqual, 4)
   386  	})
   387  }
   388  
   389  func TestMultiArgs(t *testing.T) {
   390  	PatchConvey("multi-arg-result", t, func() {
   391  		PatchConvey("multi-arg", func() {
   392  			// Go supports passing function arguments from go 1.17
   393  			//
   394  			// Mockey used to use X10 register to make BR instruction in
   395  			// arm64, which will cause arguments and results get a wrong value
   396  			//
   397  			// _0~_15 use x0~x15 register
   398  			fn := func(_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20 int64) {
   399  				fmt.Println(_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20)
   400  			}
   401  			ori := fn
   402  			Mock(fn).To(func(_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20 int64) {
   403  				for _, _x := range []int64{_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20} {
   404  					So(_x, ShouldEqual, 0)
   405  				}
   406  			}).Origin(&ori).Build()
   407  			fn(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
   408  		})
   409  		PatchConvey("multi-result", func() {
   410  			fn := func() (_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20 int64) {
   411  				return 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
   412  			}
   413  			ori := fn
   414  			Mock(fn).To(func() (_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20 int64) {
   415  				return ori()
   416  			}).Origin(&ori).Build()
   417  			_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20 := fn()
   418  			for _, _x := range []int64{_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20} {
   419  				So(_x, ShouldEqual, 0)
   420  			}
   421  		})
   422  	})
   423  }