github.com/wfusion/gofusion@v1.1.14/common/utils/gomonkey/patch.go (about) 1 // Fork from github.com/agiledragon/gomonkey@v2.10.1 2 3 package gomonkey 4 5 import ( 6 "fmt" 7 "reflect" 8 "syscall" 9 "unsafe" 10 11 "github.com/wfusion/gofusion/common/utils/gomonkey/creflect" 12 ) 13 14 type Patches struct { 15 originals map[uintptr][]byte 16 values map[reflect.Value]reflect.Value 17 valueHolders map[reflect.Value]reflect.Value 18 } 19 20 type Params []interface{} 21 type OutputCell struct { 22 Values Params 23 Times int 24 } 25 26 func ApplyFunc(target, double interface{}) *Patches { 27 return create().ApplyFunc(target, double) 28 } 29 30 func ApplyMethod(target interface{}, methodName string, double interface{}) *Patches { 31 return create().ApplyMethod(target, methodName, double) 32 } 33 34 func ApplyMethodFunc(target interface{}, methodName string, doubleFunc interface{}) *Patches { 35 return create().ApplyMethodFunc(target, methodName, doubleFunc) 36 } 37 38 func ApplyPrivateMethod(target interface{}, methodName string, double interface{}) *Patches { 39 return create().ApplyPrivateMethod(target, methodName, double) 40 } 41 42 func ApplyGlobalVar(target, double interface{}) *Patches { 43 return create().ApplyGlobalVar(target, double) 44 } 45 46 func ApplyFuncVar(target, double interface{}) *Patches { 47 return create().ApplyFuncVar(target, double) 48 } 49 50 func ApplyFuncSeq(target interface{}, outputs []OutputCell) *Patches { 51 return create().ApplyFuncSeq(target, outputs) 52 } 53 54 func ApplyMethodSeq(target interface{}, methodName string, outputs []OutputCell) *Patches { 55 return create().ApplyMethodSeq(target, methodName, outputs) 56 } 57 58 func ApplyFuncVarSeq(target interface{}, outputs []OutputCell) *Patches { 59 return create().ApplyFuncVarSeq(target, outputs) 60 } 61 62 func ApplyFuncReturn(target interface{}, output ...interface{}) *Patches { 63 return create().ApplyFuncReturn(target, output...) 64 } 65 66 func ApplyMethodReturn(target interface{}, methodName string, output ...interface{}) *Patches { 67 return create().ApplyMethodReturn(target, methodName, output...) 68 } 69 70 func ApplyFuncVarReturn(target interface{}, output ...interface{}) *Patches { 71 return create().ApplyFuncVarReturn(target, output...) 72 } 73 74 func create() *Patches { 75 return &Patches{originals: make(map[uintptr][]byte), values: make(map[reflect.Value]reflect.Value), valueHolders: make(map[reflect.Value]reflect.Value)} 76 } 77 78 func NewPatches() *Patches { 79 return create() 80 } 81 82 func (this *Patches) ApplyFunc(target, double interface{}) *Patches { 83 t := reflect.ValueOf(target) 84 d := reflect.ValueOf(double) 85 return this.ApplyCore(t, d) 86 } 87 88 func (this *Patches) ApplyMethod(target interface{}, methodName string, double interface{}) *Patches { 89 m, ok := castRType(target).MethodByName(methodName) 90 if !ok { 91 panic("retrieve method by name failed") 92 } 93 d := reflect.ValueOf(double) 94 return this.ApplyCore(m.Func, d) 95 } 96 97 func (this *Patches) ApplyMethodFunc(target interface{}, methodName string, doubleFunc interface{}) *Patches { 98 m, ok := castRType(target).MethodByName(methodName) 99 if !ok { 100 panic("retrieve method by name failed") 101 } 102 d := funcToMethod(m.Type, doubleFunc) 103 return this.ApplyCore(m.Func, d) 104 } 105 106 func (this *Patches) ApplyPrivateMethod(target interface{}, methodName string, double interface{}) *Patches { 107 m, ok := creflect.MethodByName(castRType(target), methodName) 108 if !ok { 109 panic("retrieve method by name failed") 110 } 111 d := reflect.ValueOf(double) 112 return this.ApplyCoreOnlyForPrivateMethod(m, d) 113 } 114 115 func (this *Patches) ApplyGlobalVar(target, double interface{}) *Patches { 116 t := reflect.ValueOf(target) 117 if t.Type().Kind() != reflect.Ptr { 118 panic("target is not a pointer") 119 } 120 121 this.values[t] = reflect.ValueOf(t.Elem().Interface()) 122 d := reflect.ValueOf(double) 123 t.Elem().Set(d) 124 return this 125 } 126 127 func (this *Patches) ApplyFuncVar(target, double interface{}) *Patches { 128 t := reflect.ValueOf(target) 129 d := reflect.ValueOf(double) 130 if t.Type().Kind() != reflect.Ptr { 131 panic("target is not a pointer") 132 } 133 this.check(t.Elem(), d) 134 return this.ApplyGlobalVar(target, double) 135 } 136 137 func (this *Patches) ApplyFuncSeq(target interface{}, outputs []OutputCell) *Patches { 138 funcType := reflect.TypeOf(target) 139 t := reflect.ValueOf(target) 140 d := getDoubleFunc(funcType, outputs) 141 return this.ApplyCore(t, d) 142 } 143 144 func (this *Patches) ApplyMethodSeq(target interface{}, methodName string, outputs []OutputCell) *Patches { 145 m, ok := castRType(target).MethodByName(methodName) 146 if !ok { 147 panic("retrieve method by name failed") 148 } 149 d := getDoubleFunc(m.Type, outputs) 150 return this.ApplyCore(m.Func, d) 151 } 152 153 func (this *Patches) ApplyFuncVarSeq(target interface{}, outputs []OutputCell) *Patches { 154 t := reflect.ValueOf(target) 155 if t.Type().Kind() != reflect.Ptr { 156 panic("target is not a pointer") 157 } 158 if t.Elem().Kind() != reflect.Func { 159 panic("target is not a func") 160 } 161 162 funcType := reflect.TypeOf(target).Elem() 163 double := getDoubleFunc(funcType, outputs).Interface() 164 return this.ApplyGlobalVar(target, double) 165 } 166 167 func (this *Patches) ApplyFuncReturn(target interface{}, returns ...interface{}) *Patches { 168 funcType := reflect.TypeOf(target) 169 t := reflect.ValueOf(target) 170 outputs := []OutputCell{{Values: returns, Times: -1}} 171 d := getDoubleFunc(funcType, outputs) 172 return this.ApplyCore(t, d) 173 } 174 175 func (this *Patches) ApplyMethodReturn(target interface{}, methodName string, returns ...interface{}) *Patches { 176 m, ok := reflect.TypeOf(target).MethodByName(methodName) 177 if !ok { 178 panic("retrieve method by name failed") 179 } 180 181 outputs := []OutputCell{{Values: returns, Times: -1}} 182 d := getDoubleFunc(m.Type, outputs) 183 return this.ApplyCore(m.Func, d) 184 } 185 186 func (this *Patches) ApplyFuncVarReturn(target interface{}, returns ...interface{}) *Patches { 187 t := reflect.ValueOf(target) 188 if t.Type().Kind() != reflect.Ptr { 189 panic("target is not a pointer") 190 } 191 if t.Elem().Kind() != reflect.Func { 192 panic("target is not a func") 193 } 194 195 funcType := reflect.TypeOf(target).Elem() 196 outputs := []OutputCell{{Values: returns, Times: -1}} 197 double := getDoubleFunc(funcType, outputs).Interface() 198 return this.ApplyGlobalVar(target, double) 199 } 200 201 func (this *Patches) Reset() { 202 for target, bytes := range this.originals { 203 modifyBinary(target, bytes) 204 delete(this.originals, target) 205 } 206 207 for target, variable := range this.values { 208 target.Elem().Set(variable) 209 } 210 } 211 212 func (this *Patches) ApplyCore(target, double reflect.Value) *Patches { 213 this.check(target, double) 214 assTarget := *(*uintptr)(getPointer(target)) 215 original := replace(assTarget, uintptr(getPointer(double))) 216 if _, ok := this.originals[assTarget]; !ok { 217 this.originals[assTarget] = original 218 } 219 this.valueHolders[double] = double 220 return this 221 } 222 223 func (this *Patches) ApplyCoreOnlyForPrivateMethod(target unsafe.Pointer, double reflect.Value) *Patches { 224 if double.Kind() != reflect.Func { 225 panic("double is not a func") 226 } 227 assTarget := *(*uintptr)(target) 228 original := replace(assTarget, uintptr(getPointer(double))) 229 if _, ok := this.originals[assTarget]; !ok { 230 this.originals[assTarget] = original 231 } 232 this.valueHolders[double] = double 233 return this 234 } 235 236 func (this *Patches) check(target, double reflect.Value) { 237 if target.Kind() != reflect.Func { 238 panic("target is not a func") 239 } 240 241 if double.Kind() != reflect.Func { 242 panic("double is not a func") 243 } 244 245 targetType := target.Type() 246 doubleType := double.Type() 247 248 if targetType.NumIn() < doubleType.NumIn() || 249 targetType.NumOut() != doubleType.NumOut() || 250 (targetType.NumIn() == doubleType.NumIn() && targetType.IsVariadic() != doubleType.IsVariadic()) { 251 panic(fmt.Sprintf("target type(%s) and double type(%s) are different", target.Type(), double.Type())) 252 } 253 254 for i, size := 0, doubleType.NumIn(); i < size; i++ { 255 targetIn := targetType.In(i) 256 doubleIn := doubleType.In(i) 257 258 if targetIn.AssignableTo(doubleIn) { 259 continue 260 } 261 262 panic(fmt.Sprintf("target type(%s) and double type(%s) are different", target.Type(), double.Type())) 263 } 264 } 265 266 func replace(target, double uintptr) []byte { 267 code := buildJmpDirective(double) 268 bytes := entryAddress(target, len(code)) 269 original := make([]byte, len(bytes)) 270 copy(original, bytes) 271 modifyBinary(target, code) 272 return original 273 } 274 275 func getDoubleFunc(funcType reflect.Type, outputs []OutputCell) reflect.Value { 276 if funcType.NumOut() != len(outputs[0].Values) { 277 panic(fmt.Sprintf("func type has %v return values, but only %v values provided as double", 278 funcType.NumOut(), len(outputs[0].Values))) 279 } 280 281 needReturn := false 282 slice := make([]Params, 0) 283 for _, output := range outputs { 284 if output.Times == -1 { 285 needReturn = true 286 slice = []Params{output.Values} 287 break 288 } 289 t := 0 290 if output.Times <= 1 { 291 t = 1 292 } else { 293 t = output.Times 294 } 295 for j := 0; j < t; j++ { 296 slice = append(slice, output.Values) 297 } 298 } 299 300 i := 0 301 lenOutputs := len(slice) 302 return reflect.MakeFunc(funcType, func(_ []reflect.Value) []reflect.Value { 303 if needReturn { 304 return GetResultValues(funcType, slice[0]...) 305 } 306 if i < lenOutputs { 307 i++ 308 return GetResultValues(funcType, slice[i-1]...) 309 } 310 panic("double seq is less than call seq") 311 }) 312 } 313 314 func GetResultValues(funcType reflect.Type, results ...interface{}) []reflect.Value { 315 var resultValues []reflect.Value 316 for i, r := range results { 317 var resultValue reflect.Value 318 if r == nil { 319 resultValue = reflect.Zero(funcType.Out(i)) 320 } else { 321 v := reflect.New(funcType.Out(i)) 322 v.Elem().Set(reflect.ValueOf(r)) 323 resultValue = v.Elem() 324 } 325 resultValues = append(resultValues, resultValue) 326 } 327 return resultValues 328 } 329 330 type funcValue struct { 331 _ uintptr 332 p unsafe.Pointer 333 } 334 335 func getPointer(v reflect.Value) unsafe.Pointer { 336 return (*funcValue)(unsafe.Pointer(&v)).p 337 } 338 339 func entryAddress(p uintptr, l int) []byte { 340 return *(*[]byte)(unsafe.Pointer(&reflect.SliceHeader{Data: p, Len: l, Cap: l})) 341 } 342 343 func pageStart(ptr uintptr) uintptr { 344 return ptr & ^(uintptr(syscall.Getpagesize() - 1)) 345 } 346 347 func funcToMethod(funcType reflect.Type, doubleFunc interface{}) reflect.Value { 348 rf := reflect.TypeOf(doubleFunc) 349 if rf.Kind() != reflect.Func { 350 panic("doubleFunc is not a func") 351 } 352 vf := reflect.ValueOf(doubleFunc) 353 return reflect.MakeFunc(funcType, func(in []reflect.Value) []reflect.Value { 354 if funcType.IsVariadic() { 355 return vf.CallSlice(in[1:]) 356 } else { 357 return vf.Call(in[1:]) 358 } 359 }) 360 } 361 362 func castRType(val interface{}) reflect.Type { 363 if rTypeVal, ok := val.(reflect.Type); ok { 364 return rTypeVal 365 } 366 return reflect.TypeOf(val) 367 }