github.com/agiledragon/gomonkey/v2@v2.11.1-0.20240427155748-d56c6823ec17/patch.go (about)

     1  package gomonkey
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"syscall"
     7  	"unsafe"
     8  
     9  	"github.com/agiledragon/gomonkey/v2/creflect"
    10  )
    11  
    12  type Patches struct {
    13  	originals    map[uintptr][]byte
    14  	values       map[reflect.Value]reflect.Value
    15  	valueHolders map[reflect.Value]reflect.Value
    16  }
    17  
    18  type Params []interface{}
    19  type OutputCell struct {
    20  	Values Params
    21  	Times  int
    22  }
    23  
    24  func ApplyFunc(target, double interface{}) *Patches {
    25  	return create().ApplyFunc(target, double)
    26  }
    27  
    28  func ApplyMethod(target interface{}, methodName string, double interface{}) *Patches {
    29  	return create().ApplyMethod(target, methodName, double)
    30  }
    31  
    32  func ApplyMethodFunc(target interface{}, methodName string, doubleFunc interface{}) *Patches {
    33  	return create().ApplyMethodFunc(target, methodName, doubleFunc)
    34  }
    35  
    36  func ApplyPrivateMethod(target interface{}, methodName string, double interface{}) *Patches {
    37  	return create().ApplyPrivateMethod(target, methodName, double)
    38  }
    39  
    40  func ApplyGlobalVar(target, double interface{}) *Patches {
    41  	return create().ApplyGlobalVar(target, double)
    42  }
    43  
    44  func ApplyFuncVar(target, double interface{}) *Patches {
    45  	return create().ApplyFuncVar(target, double)
    46  }
    47  
    48  func ApplyFuncSeq(target interface{}, outputs []OutputCell) *Patches {
    49  	return create().ApplyFuncSeq(target, outputs)
    50  }
    51  
    52  func ApplyMethodSeq(target interface{}, methodName string, outputs []OutputCell) *Patches {
    53  	return create().ApplyMethodSeq(target, methodName, outputs)
    54  }
    55  
    56  func ApplyFuncVarSeq(target interface{}, outputs []OutputCell) *Patches {
    57  	return create().ApplyFuncVarSeq(target, outputs)
    58  }
    59  
    60  func ApplyFuncReturn(target interface{}, output ...interface{}) *Patches {
    61  	return create().ApplyFuncReturn(target, output...)
    62  }
    63  
    64  func ApplyMethodReturn(target interface{}, methodName string, output ...interface{}) *Patches {
    65  	return create().ApplyMethodReturn(target, methodName, output...)
    66  }
    67  
    68  func ApplyFuncVarReturn(target interface{}, output ...interface{}) *Patches {
    69  	return create().ApplyFuncVarReturn(target, output...)
    70  }
    71  
    72  func create() *Patches {
    73  	return &Patches{originals: make(map[uintptr][]byte), values: make(map[reflect.Value]reflect.Value), valueHolders: make(map[reflect.Value]reflect.Value)}
    74  }
    75  
    76  func NewPatches() *Patches {
    77  	return create()
    78  }
    79  
    80  func (this *Patches) ApplyFunc(target, double interface{}) *Patches {
    81  	t := reflect.ValueOf(target)
    82  	d := reflect.ValueOf(double)
    83  	return this.ApplyCore(t, d)
    84  }
    85  
    86  func (this *Patches) ApplyMethod(target interface{}, methodName string, double interface{}) *Patches {
    87  	m, ok := castRType(target).MethodByName(methodName)
    88  	if !ok {
    89  		panic("retrieve method by name failed")
    90  	}
    91  	d := reflect.ValueOf(double)
    92  	return this.ApplyCore(m.Func, d)
    93  }
    94  
    95  func (this *Patches) ApplyMethodFunc(target interface{}, methodName string, doubleFunc interface{}) *Patches {
    96  	m, ok := castRType(target).MethodByName(methodName)
    97  	if !ok {
    98  		panic("retrieve method by name failed")
    99  	}
   100  	d := funcToMethod(m.Type, doubleFunc)
   101  	return this.ApplyCore(m.Func, d)
   102  }
   103  
   104  func (this *Patches) ApplyPrivateMethod(target interface{}, methodName string, double interface{}) *Patches {
   105  	m, ok := creflect.MethodByName(castRType(target), methodName)
   106  	if !ok {
   107  		panic("retrieve method by name failed")
   108  	}
   109  	d := reflect.ValueOf(double)
   110  	return this.ApplyCoreOnlyForPrivateMethod(m, d)
   111  }
   112  
   113  func (this *Patches) ApplyGlobalVar(target, double interface{}) *Patches {
   114  	t := reflect.ValueOf(target)
   115  	if t.Type().Kind() != reflect.Ptr {
   116  		panic("target is not a pointer")
   117  	}
   118  
   119  	this.values[t] = reflect.ValueOf(t.Elem().Interface())
   120  	d := reflect.ValueOf(double)
   121  	t.Elem().Set(d)
   122  	return this
   123  }
   124  
   125  func (this *Patches) ApplyFuncVar(target, double interface{}) *Patches {
   126  	t := reflect.ValueOf(target)
   127  	d := reflect.ValueOf(double)
   128  	if t.Type().Kind() != reflect.Ptr {
   129  		panic("target is not a pointer")
   130  	}
   131  	this.check(t.Elem(), d)
   132  	return this.ApplyGlobalVar(target, double)
   133  }
   134  
   135  func (this *Patches) ApplyFuncSeq(target interface{}, outputs []OutputCell) *Patches {
   136  	funcType := reflect.TypeOf(target)
   137  	t := reflect.ValueOf(target)
   138  	d := getDoubleFunc(funcType, outputs)
   139  	return this.ApplyCore(t, d)
   140  }
   141  
   142  func (this *Patches) ApplyMethodSeq(target interface{}, methodName string, outputs []OutputCell) *Patches {
   143  	m, ok := castRType(target).MethodByName(methodName)
   144  	if !ok {
   145  		panic("retrieve method by name failed")
   146  	}
   147  	d := getDoubleFunc(m.Type, outputs)
   148  	return this.ApplyCore(m.Func, d)
   149  }
   150  
   151  func (this *Patches) ApplyFuncVarSeq(target interface{}, outputs []OutputCell) *Patches {
   152  	t := reflect.ValueOf(target)
   153  	if t.Type().Kind() != reflect.Ptr {
   154  		panic("target is not a pointer")
   155  	}
   156  	if t.Elem().Kind() != reflect.Func {
   157  		panic("target is not a func")
   158  	}
   159  
   160  	funcType := reflect.TypeOf(target).Elem()
   161  	double := getDoubleFunc(funcType, outputs).Interface()
   162  	return this.ApplyGlobalVar(target, double)
   163  }
   164  
   165  func (this *Patches) ApplyFuncReturn(target interface{}, returns ...interface{}) *Patches {
   166  	funcType := reflect.TypeOf(target)
   167  	t := reflect.ValueOf(target)
   168  	outputs := []OutputCell{{Values: returns, Times: -1}}
   169  	d := getDoubleFunc(funcType, outputs)
   170  	return this.ApplyCore(t, d)
   171  }
   172  
   173  func (this *Patches) ApplyMethodReturn(target interface{}, methodName string, returns ...interface{}) *Patches {
   174  	m, ok := reflect.TypeOf(target).MethodByName(methodName)
   175  	if !ok {
   176  		panic("retrieve method by name failed")
   177  	}
   178  
   179  	outputs := []OutputCell{{Values: returns, Times: -1}}
   180  	d := getDoubleFunc(m.Type, outputs)
   181  	return this.ApplyCore(m.Func, d)
   182  }
   183  
   184  func (this *Patches) ApplyFuncVarReturn(target interface{}, returns ...interface{}) *Patches {
   185  	t := reflect.ValueOf(target)
   186  	if t.Type().Kind() != reflect.Ptr {
   187  		panic("target is not a pointer")
   188  	}
   189  	if t.Elem().Kind() != reflect.Func {
   190  		panic("target is not a func")
   191  	}
   192  
   193  	funcType := reflect.TypeOf(target).Elem()
   194  	outputs := []OutputCell{{Values: returns, Times: -1}}
   195  	double := getDoubleFunc(funcType, outputs).Interface()
   196  	return this.ApplyGlobalVar(target, double)
   197  }
   198  
   199  func (this *Patches) Reset() {
   200  	for target, bytes := range this.originals {
   201  		modifyBinary(target, bytes)
   202  		delete(this.originals, target)
   203  	}
   204  
   205  	for target, variable := range this.values {
   206  		target.Elem().Set(variable)
   207  	}
   208  }
   209  
   210  func (this *Patches) ApplyCore(target, double reflect.Value) *Patches {
   211  	this.check(target, double)
   212  	assTarget := *(*uintptr)(getPointer(target))
   213  	original := replace(assTarget, uintptr(getPointer(double)))
   214  	if _, ok := this.originals[assTarget]; !ok {
   215  		this.originals[assTarget] = original
   216  	}
   217  	this.valueHolders[double] = double
   218  	return this
   219  }
   220  
   221  func (this *Patches) ApplyCoreOnlyForPrivateMethod(target unsafe.Pointer, double reflect.Value) *Patches {
   222  	if double.Kind() != reflect.Func {
   223  		panic("double is not a func")
   224  	}
   225  	assTarget := *(*uintptr)(target)
   226  	original := replace(assTarget, uintptr(getPointer(double)))
   227  	if _, ok := this.originals[assTarget]; !ok {
   228  		this.originals[assTarget] = original
   229  	}
   230  	this.valueHolders[double] = double
   231  	return this
   232  }
   233  
   234  func (this *Patches) check(target, double reflect.Value) {
   235  	if target.Kind() != reflect.Func {
   236  		panic("target is not a func")
   237  	}
   238  
   239  	if double.Kind() != reflect.Func {
   240  		panic("double is not a func")
   241  	}
   242  
   243  	targetType := target.Type()
   244  	doubleType := double.Type()
   245  
   246  	if targetType.NumIn() < doubleType.NumIn() ||
   247  		targetType.NumOut() != doubleType.NumOut() ||
   248  		(targetType.NumIn() == doubleType.NumIn() && targetType.IsVariadic() != doubleType.IsVariadic()) {
   249  		panic(fmt.Sprintf("target type(%s) and double type(%s) are different", target.Type(), double.Type()))
   250  	}
   251  
   252  	for i, size := 0, doubleType.NumIn(); i < size; i++ {
   253  		targetIn := targetType.In(i)
   254  		doubleIn := doubleType.In(i)
   255  
   256  		if targetIn.AssignableTo(doubleIn) {
   257  			continue
   258  		}
   259  
   260  		panic(fmt.Sprintf("target type(%s) and double type(%s) are different", target.Type(), double.Type()))
   261  	}
   262  
   263  	for i, size := 0, doubleType.NumOut(); i < size; i++ {
   264  		targetOut := targetType.Out(i)
   265  		doubleOut := doubleType.Out(i)
   266  
   267  		if targetOut.AssignableTo(doubleOut) {
   268  			continue
   269  		}
   270  
   271  		panic(fmt.Sprintf("target type(%s) and double type(%s) are different", target.Type(), double.Type()))
   272  	}
   273  }
   274  
   275  func replace(target, double uintptr) []byte {
   276  	code := buildJmpDirective(double)
   277  	bytes := entryAddress(target, len(code))
   278  	original := make([]byte, len(bytes))
   279  	copy(original, bytes)
   280  	modifyBinary(target, code)
   281  	return original
   282  }
   283  
   284  func getDoubleFunc(funcType reflect.Type, outputs []OutputCell) reflect.Value {
   285  	if funcType.NumOut() != len(outputs[0].Values) {
   286  		panic(fmt.Sprintf("func type has %v return values, but only %v values provided as double",
   287  			funcType.NumOut(), len(outputs[0].Values)))
   288  	}
   289  
   290  	needReturn := false
   291  	slice := make([]Params, 0)
   292  	for _, output := range outputs {
   293  		if output.Times == -1 {
   294  			needReturn = true
   295  			slice = []Params{output.Values}
   296  			break
   297  		}
   298  		t := 0
   299  		if output.Times <= 1 {
   300  			t = 1
   301  		} else {
   302  			t = output.Times
   303  		}
   304  		for j := 0; j < t; j++ {
   305  			slice = append(slice, output.Values)
   306  		}
   307  	}
   308  
   309  	i := 0
   310  	lenOutputs := len(slice)
   311  	return reflect.MakeFunc(funcType, func(_ []reflect.Value) []reflect.Value {
   312  		if needReturn {
   313  			return GetResultValues(funcType, slice[0]...)
   314  		}
   315  		if i < lenOutputs {
   316  			i++
   317  			return GetResultValues(funcType, slice[i-1]...)
   318  		}
   319  		panic("double seq is less than call seq")
   320  	})
   321  }
   322  
   323  func GetResultValues(funcType reflect.Type, results ...interface{}) []reflect.Value {
   324  	var resultValues []reflect.Value
   325  	for i, r := range results {
   326  		var resultValue reflect.Value
   327  		if r == nil {
   328  			resultValue = reflect.Zero(funcType.Out(i))
   329  		} else {
   330  			v := reflect.New(funcType.Out(i))
   331  			v.Elem().Set(reflect.ValueOf(r))
   332  			resultValue = v.Elem()
   333  		}
   334  		resultValues = append(resultValues, resultValue)
   335  	}
   336  	return resultValues
   337  }
   338  
   339  type funcValue struct {
   340  	_ uintptr
   341  	p unsafe.Pointer
   342  }
   343  
   344  func getPointer(v reflect.Value) unsafe.Pointer {
   345  	return (*funcValue)(unsafe.Pointer(&v)).p
   346  }
   347  
   348  func entryAddress(p uintptr, l int) []byte {
   349  	return *(*[]byte)(unsafe.Pointer(&reflect.SliceHeader{Data: p, Len: l, Cap: l}))
   350  }
   351  
   352  func pageStart(ptr uintptr) uintptr {
   353  	return ptr & ^(uintptr(syscall.Getpagesize() - 1))
   354  }
   355  
   356  func funcToMethod(funcType reflect.Type, doubleFunc interface{}) reflect.Value {
   357  	rf := reflect.TypeOf(doubleFunc)
   358  	if rf.Kind() != reflect.Func {
   359  		panic("doubleFunc is not a func")
   360  	}
   361  	vf := reflect.ValueOf(doubleFunc)
   362  	return reflect.MakeFunc(funcType, func(in []reflect.Value) []reflect.Value {
   363  		if funcType.IsVariadic() {
   364  			return vf.CallSlice(in[1:])
   365  		} else {
   366  			return vf.Call(in[1:])
   367  		}
   368  	})
   369  }
   370  
   371  func castRType(val interface{}) reflect.Type {
   372  	if rTypeVal, ok := val.(reflect.Type); ok {
   373  		return rTypeVal
   374  	}
   375  	return reflect.TypeOf(val)
   376  }