github.com/undefinedlabs/go-mpatch@v1.0.8-0.20230904093002-fbac8a0d7853/patcher_test.go (about)

     1  package mpatch
     2  
     3  import (
     4  	"math/rand"
     5  	"reflect"
     6  	"testing"
     7  )
     8  
     9  //go:noinline
    10  func methodA() int {
    11  	x := rand.Int() >> 48
    12  	y := rand.Int() >> 48
    13  	return x + y
    14  }
    15  
    16  //go:noinline
    17  func methodB() int {
    18  	x := rand.Int() >> 48
    19  	y := rand.Int() >> 48
    20  	return -(x + y)
    21  }
    22  
    23  type myStruct struct {
    24  }
    25  
    26  //go:noinline
    27  func (s *myStruct) Method() int {
    28  	return 1
    29  }
    30  
    31  //go:noinline
    32  func (s myStruct) ValueMethod() int {
    33  	return 1
    34  }
    35  
    36  func TestPatcher(t *testing.T) {
    37  	patch, err := PatchMethod(methodA, methodB)
    38  	if err != nil {
    39  		t.Fatal(err)
    40  	}
    41  	if methodA() > 0 {
    42  		t.Fatal("The patch did not work")
    43  	}
    44  	err = patch.Unpatch()
    45  	if err != nil {
    46  		t.Fatal(err)
    47  	}
    48  	if methodA() < 0 {
    49  		t.Fatal("The unpatch did not work")
    50  	}
    51  }
    52  
    53  func TestPatcherUsingReflect(t *testing.T) {
    54  	reflectA := reflect.ValueOf(methodA)
    55  	patch, err := PatchMethodByReflectValue(reflectA, methodB)
    56  	if err != nil {
    57  		t.Fatal(err)
    58  	}
    59  	if methodA() > 0 {
    60  		t.Fatal("The patch did not work")
    61  	}
    62  
    63  	err = patch.Unpatch()
    64  	if err != nil {
    65  		t.Fatal(err)
    66  	}
    67  	if methodA() < 0 {
    68  		t.Fatal("The unpatch did not work")
    69  	}
    70  }
    71  
    72  func TestPatcherUsingMakeFunc(t *testing.T) {
    73  	reflectA := reflect.ValueOf(methodA)
    74  	patch, err := PatchMethodWithMakeFuncValue(reflectA,
    75  		func(args []reflect.Value) (results []reflect.Value) {
    76  			return []reflect.Value{reflect.ValueOf(42)}
    77  		})
    78  	if err != nil {
    79  		t.Fatal(err)
    80  	}
    81  	if methodA() != 42 {
    82  		t.Fatal("The patch did not work")
    83  	}
    84  
    85  	err = patch.Unpatch()
    86  	if err != nil {
    87  		t.Fatal(err)
    88  	}
    89  	if methodA() < 0 {
    90  		t.Fatal("The unpatch did not work")
    91  	}
    92  }
    93  
    94  func TestInstancePatcher(t *testing.T) {
    95  	mStruct := myStruct{}
    96  
    97  	var patch *Patch
    98  	var err error
    99  	patch, err = PatchInstanceMethodByName(reflect.TypeOf(mStruct), "Method", func(m *myStruct) int {
   100  		patch.Unpatch()
   101  		defer patch.Patch()
   102  		return 41 + m.Method()
   103  	})
   104  	if err != nil {
   105  		t.Fatal(err)
   106  	}
   107  
   108  	if mStruct.Method() != 42 {
   109  		t.Fatal("The patch did not work")
   110  	}
   111  	err = patch.Unpatch()
   112  	if err != nil {
   113  		t.Fatal(err)
   114  	}
   115  	if mStruct.Method() != 1 {
   116  		t.Fatal("The unpatch did not work")
   117  	}
   118  }
   119  
   120  func TestInstanceValuePatcher(t *testing.T) {
   121  	mStruct := myStruct{}
   122  
   123  	var patch *Patch
   124  	var err error
   125  	patch, err = PatchInstanceMethodByName(reflect.TypeOf(mStruct), "ValueMethod", func(m myStruct) int {
   126  		patch.Unpatch()
   127  		defer patch.Patch()
   128  		return 41 + m.Method()
   129  	})
   130  	if err != nil {
   131  		t.Fatal(err)
   132  	}
   133  
   134  	if mStruct.ValueMethod() != 42 {
   135  		t.Fatal("The patch did not work")
   136  	}
   137  	err = patch.Unpatch()
   138  	if err != nil {
   139  		t.Fatal(err)
   140  	}
   141  	if mStruct.ValueMethod() != 1 {
   142  		t.Fatal("The unpatch did not work")
   143  	}
   144  }