github.com/wfusion/gofusion@v1.1.14/common/utils/gomonkey/patch.go (about)

     1  // Fork from github.com/agiledragon/gomonkey@v2.10.1
     2  
     3  package gomonkey
     4  
     5  import (
     6  	"fmt"
     7  	"reflect"
     8  	"syscall"
     9  	"unsafe"
    10  
    11  	"github.com/wfusion/gofusion/common/utils/gomonkey/creflect"
    12  )
    13  
    14  type Patches struct {
    15  	originals    map[uintptr][]byte
    16  	values       map[reflect.Value]reflect.Value
    17  	valueHolders map[reflect.Value]reflect.Value
    18  }
    19  
    20  type Params []interface{}
    21  type OutputCell struct {
    22  	Values Params
    23  	Times  int
    24  }
    25  
    26  func ApplyFunc(target, double interface{}) *Patches {
    27  	return create().ApplyFunc(target, double)
    28  }
    29  
    30  func ApplyMethod(target interface{}, methodName string, double interface{}) *Patches {
    31  	return create().ApplyMethod(target, methodName, double)
    32  }
    33  
    34  func ApplyMethodFunc(target interface{}, methodName string, doubleFunc interface{}) *Patches {
    35  	return create().ApplyMethodFunc(target, methodName, doubleFunc)
    36  }
    37  
    38  func ApplyPrivateMethod(target interface{}, methodName string, double interface{}) *Patches {
    39  	return create().ApplyPrivateMethod(target, methodName, double)
    40  }
    41  
    42  func ApplyGlobalVar(target, double interface{}) *Patches {
    43  	return create().ApplyGlobalVar(target, double)
    44  }
    45  
    46  func ApplyFuncVar(target, double interface{}) *Patches {
    47  	return create().ApplyFuncVar(target, double)
    48  }
    49  
    50  func ApplyFuncSeq(target interface{}, outputs []OutputCell) *Patches {
    51  	return create().ApplyFuncSeq(target, outputs)
    52  }
    53  
    54  func ApplyMethodSeq(target interface{}, methodName string, outputs []OutputCell) *Patches {
    55  	return create().ApplyMethodSeq(target, methodName, outputs)
    56  }
    57  
    58  func ApplyFuncVarSeq(target interface{}, outputs []OutputCell) *Patches {
    59  	return create().ApplyFuncVarSeq(target, outputs)
    60  }
    61  
    62  func ApplyFuncReturn(target interface{}, output ...interface{}) *Patches {
    63  	return create().ApplyFuncReturn(target, output...)
    64  }
    65  
    66  func ApplyMethodReturn(target interface{}, methodName string, output ...interface{}) *Patches {
    67  	return create().ApplyMethodReturn(target, methodName, output...)
    68  }
    69  
    70  func ApplyFuncVarReturn(target interface{}, output ...interface{}) *Patches {
    71  	return create().ApplyFuncVarReturn(target, output...)
    72  }
    73  
    74  func create() *Patches {
    75  	return &Patches{originals: make(map[uintptr][]byte), values: make(map[reflect.Value]reflect.Value), valueHolders: make(map[reflect.Value]reflect.Value)}
    76  }
    77  
    78  func NewPatches() *Patches {
    79  	return create()
    80  }
    81  
    82  func (this *Patches) ApplyFunc(target, double interface{}) *Patches {
    83  	t := reflect.ValueOf(target)
    84  	d := reflect.ValueOf(double)
    85  	return this.ApplyCore(t, d)
    86  }
    87  
    88  func (this *Patches) ApplyMethod(target interface{}, methodName string, double interface{}) *Patches {
    89  	m, ok := castRType(target).MethodByName(methodName)
    90  	if !ok {
    91  		panic("retrieve method by name failed")
    92  	}
    93  	d := reflect.ValueOf(double)
    94  	return this.ApplyCore(m.Func, d)
    95  }
    96  
    97  func (this *Patches) ApplyMethodFunc(target interface{}, methodName string, doubleFunc interface{}) *Patches {
    98  	m, ok := castRType(target).MethodByName(methodName)
    99  	if !ok {
   100  		panic("retrieve method by name failed")
   101  	}
   102  	d := funcToMethod(m.Type, doubleFunc)
   103  	return this.ApplyCore(m.Func, d)
   104  }
   105  
   106  func (this *Patches) ApplyPrivateMethod(target interface{}, methodName string, double interface{}) *Patches {
   107  	m, ok := creflect.MethodByName(castRType(target), methodName)
   108  	if !ok {
   109  		panic("retrieve method by name failed")
   110  	}
   111  	d := reflect.ValueOf(double)
   112  	return this.ApplyCoreOnlyForPrivateMethod(m, d)
   113  }
   114  
   115  func (this *Patches) ApplyGlobalVar(target, double interface{}) *Patches {
   116  	t := reflect.ValueOf(target)
   117  	if t.Type().Kind() != reflect.Ptr {
   118  		panic("target is not a pointer")
   119  	}
   120  
   121  	this.values[t] = reflect.ValueOf(t.Elem().Interface())
   122  	d := reflect.ValueOf(double)
   123  	t.Elem().Set(d)
   124  	return this
   125  }
   126  
   127  func (this *Patches) ApplyFuncVar(target, double interface{}) *Patches {
   128  	t := reflect.ValueOf(target)
   129  	d := reflect.ValueOf(double)
   130  	if t.Type().Kind() != reflect.Ptr {
   131  		panic("target is not a pointer")
   132  	}
   133  	this.check(t.Elem(), d)
   134  	return this.ApplyGlobalVar(target, double)
   135  }
   136  
   137  func (this *Patches) ApplyFuncSeq(target interface{}, outputs []OutputCell) *Patches {
   138  	funcType := reflect.TypeOf(target)
   139  	t := reflect.ValueOf(target)
   140  	d := getDoubleFunc(funcType, outputs)
   141  	return this.ApplyCore(t, d)
   142  }
   143  
   144  func (this *Patches) ApplyMethodSeq(target interface{}, methodName string, outputs []OutputCell) *Patches {
   145  	m, ok := castRType(target).MethodByName(methodName)
   146  	if !ok {
   147  		panic("retrieve method by name failed")
   148  	}
   149  	d := getDoubleFunc(m.Type, outputs)
   150  	return this.ApplyCore(m.Func, d)
   151  }
   152  
   153  func (this *Patches) ApplyFuncVarSeq(target interface{}, outputs []OutputCell) *Patches {
   154  	t := reflect.ValueOf(target)
   155  	if t.Type().Kind() != reflect.Ptr {
   156  		panic("target is not a pointer")
   157  	}
   158  	if t.Elem().Kind() != reflect.Func {
   159  		panic("target is not a func")
   160  	}
   161  
   162  	funcType := reflect.TypeOf(target).Elem()
   163  	double := getDoubleFunc(funcType, outputs).Interface()
   164  	return this.ApplyGlobalVar(target, double)
   165  }
   166  
   167  func (this *Patches) ApplyFuncReturn(target interface{}, returns ...interface{}) *Patches {
   168  	funcType := reflect.TypeOf(target)
   169  	t := reflect.ValueOf(target)
   170  	outputs := []OutputCell{{Values: returns, Times: -1}}
   171  	d := getDoubleFunc(funcType, outputs)
   172  	return this.ApplyCore(t, d)
   173  }
   174  
   175  func (this *Patches) ApplyMethodReturn(target interface{}, methodName string, returns ...interface{}) *Patches {
   176  	m, ok := reflect.TypeOf(target).MethodByName(methodName)
   177  	if !ok {
   178  		panic("retrieve method by name failed")
   179  	}
   180  
   181  	outputs := []OutputCell{{Values: returns, Times: -1}}
   182  	d := getDoubleFunc(m.Type, outputs)
   183  	return this.ApplyCore(m.Func, d)
   184  }
   185  
   186  func (this *Patches) ApplyFuncVarReturn(target interface{}, returns ...interface{}) *Patches {
   187  	t := reflect.ValueOf(target)
   188  	if t.Type().Kind() != reflect.Ptr {
   189  		panic("target is not a pointer")
   190  	}
   191  	if t.Elem().Kind() != reflect.Func {
   192  		panic("target is not a func")
   193  	}
   194  
   195  	funcType := reflect.TypeOf(target).Elem()
   196  	outputs := []OutputCell{{Values: returns, Times: -1}}
   197  	double := getDoubleFunc(funcType, outputs).Interface()
   198  	return this.ApplyGlobalVar(target, double)
   199  }
   200  
   201  func (this *Patches) Reset() {
   202  	for target, bytes := range this.originals {
   203  		modifyBinary(target, bytes)
   204  		delete(this.originals, target)
   205  	}
   206  
   207  	for target, variable := range this.values {
   208  		target.Elem().Set(variable)
   209  	}
   210  }
   211  
   212  func (this *Patches) ApplyCore(target, double reflect.Value) *Patches {
   213  	this.check(target, double)
   214  	assTarget := *(*uintptr)(getPointer(target))
   215  	original := replace(assTarget, uintptr(getPointer(double)))
   216  	if _, ok := this.originals[assTarget]; !ok {
   217  		this.originals[assTarget] = original
   218  	}
   219  	this.valueHolders[double] = double
   220  	return this
   221  }
   222  
   223  func (this *Patches) ApplyCoreOnlyForPrivateMethod(target unsafe.Pointer, double reflect.Value) *Patches {
   224  	if double.Kind() != reflect.Func {
   225  		panic("double is not a func")
   226  	}
   227  	assTarget := *(*uintptr)(target)
   228  	original := replace(assTarget, uintptr(getPointer(double)))
   229  	if _, ok := this.originals[assTarget]; !ok {
   230  		this.originals[assTarget] = original
   231  	}
   232  	this.valueHolders[double] = double
   233  	return this
   234  }
   235  
   236  func (this *Patches) check(target, double reflect.Value) {
   237  	if target.Kind() != reflect.Func {
   238  		panic("target is not a func")
   239  	}
   240  
   241  	if double.Kind() != reflect.Func {
   242  		panic("double is not a func")
   243  	}
   244  
   245  	targetType := target.Type()
   246  	doubleType := double.Type()
   247  
   248  	if targetType.NumIn() < doubleType.NumIn() ||
   249  		targetType.NumOut() != doubleType.NumOut() ||
   250  		(targetType.NumIn() == doubleType.NumIn() && targetType.IsVariadic() != doubleType.IsVariadic()) {
   251  		panic(fmt.Sprintf("target type(%s) and double type(%s) are different", target.Type(), double.Type()))
   252  	}
   253  
   254  	for i, size := 0, doubleType.NumIn(); i < size; i++ {
   255  		targetIn := targetType.In(i)
   256  		doubleIn := doubleType.In(i)
   257  
   258  		if targetIn.AssignableTo(doubleIn) {
   259  			continue
   260  		}
   261  
   262  		panic(fmt.Sprintf("target type(%s) and double type(%s) are different", target.Type(), double.Type()))
   263  	}
   264  }
   265  
   266  func replace(target, double uintptr) []byte {
   267  	code := buildJmpDirective(double)
   268  	bytes := entryAddress(target, len(code))
   269  	original := make([]byte, len(bytes))
   270  	copy(original, bytes)
   271  	modifyBinary(target, code)
   272  	return original
   273  }
   274  
   275  func getDoubleFunc(funcType reflect.Type, outputs []OutputCell) reflect.Value {
   276  	if funcType.NumOut() != len(outputs[0].Values) {
   277  		panic(fmt.Sprintf("func type has %v return values, but only %v values provided as double",
   278  			funcType.NumOut(), len(outputs[0].Values)))
   279  	}
   280  
   281  	needReturn := false
   282  	slice := make([]Params, 0)
   283  	for _, output := range outputs {
   284  		if output.Times == -1 {
   285  			needReturn = true
   286  			slice = []Params{output.Values}
   287  			break
   288  		}
   289  		t := 0
   290  		if output.Times <= 1 {
   291  			t = 1
   292  		} else {
   293  			t = output.Times
   294  		}
   295  		for j := 0; j < t; j++ {
   296  			slice = append(slice, output.Values)
   297  		}
   298  	}
   299  
   300  	i := 0
   301  	lenOutputs := len(slice)
   302  	return reflect.MakeFunc(funcType, func(_ []reflect.Value) []reflect.Value {
   303  		if needReturn {
   304  			return GetResultValues(funcType, slice[0]...)
   305  		}
   306  		if i < lenOutputs {
   307  			i++
   308  			return GetResultValues(funcType, slice[i-1]...)
   309  		}
   310  		panic("double seq is less than call seq")
   311  	})
   312  }
   313  
   314  func GetResultValues(funcType reflect.Type, results ...interface{}) []reflect.Value {
   315  	var resultValues []reflect.Value
   316  	for i, r := range results {
   317  		var resultValue reflect.Value
   318  		if r == nil {
   319  			resultValue = reflect.Zero(funcType.Out(i))
   320  		} else {
   321  			v := reflect.New(funcType.Out(i))
   322  			v.Elem().Set(reflect.ValueOf(r))
   323  			resultValue = v.Elem()
   324  		}
   325  		resultValues = append(resultValues, resultValue)
   326  	}
   327  	return resultValues
   328  }
   329  
   330  type funcValue struct {
   331  	_ uintptr
   332  	p unsafe.Pointer
   333  }
   334  
   335  func getPointer(v reflect.Value) unsafe.Pointer {
   336  	return (*funcValue)(unsafe.Pointer(&v)).p
   337  }
   338  
   339  func entryAddress(p uintptr, l int) []byte {
   340  	return *(*[]byte)(unsafe.Pointer(&reflect.SliceHeader{Data: p, Len: l, Cap: l}))
   341  }
   342  
   343  func pageStart(ptr uintptr) uintptr {
   344  	return ptr & ^(uintptr(syscall.Getpagesize() - 1))
   345  }
   346  
   347  func funcToMethod(funcType reflect.Type, doubleFunc interface{}) reflect.Value {
   348  	rf := reflect.TypeOf(doubleFunc)
   349  	if rf.Kind() != reflect.Func {
   350  		panic("doubleFunc is not a func")
   351  	}
   352  	vf := reflect.ValueOf(doubleFunc)
   353  	return reflect.MakeFunc(funcType, func(in []reflect.Value) []reflect.Value {
   354  		if funcType.IsVariadic() {
   355  			return vf.CallSlice(in[1:])
   356  		} else {
   357  			return vf.Call(in[1:])
   358  		}
   359  	})
   360  }
   361  
   362  func castRType(val interface{}) reflect.Type {
   363  	if rTypeVal, ok := val.(reflect.Type); ok {
   364  		return rTypeVal
   365  	}
   366  	return reflect.TypeOf(val)
   367  }