github.com/agiledragon/gomonkey/v2@v2.11.1-0.20240427155748-d56c6823ec17/patch.go (about) 1 package gomonkey 2 3 import ( 4 "fmt" 5 "reflect" 6 "syscall" 7 "unsafe" 8 9 "github.com/agiledragon/gomonkey/v2/creflect" 10 ) 11 12 type Patches struct { 13 originals map[uintptr][]byte 14 values map[reflect.Value]reflect.Value 15 valueHolders map[reflect.Value]reflect.Value 16 } 17 18 type Params []interface{} 19 type OutputCell struct { 20 Values Params 21 Times int 22 } 23 24 func ApplyFunc(target, double interface{}) *Patches { 25 return create().ApplyFunc(target, double) 26 } 27 28 func ApplyMethod(target interface{}, methodName string, double interface{}) *Patches { 29 return create().ApplyMethod(target, methodName, double) 30 } 31 32 func ApplyMethodFunc(target interface{}, methodName string, doubleFunc interface{}) *Patches { 33 return create().ApplyMethodFunc(target, methodName, doubleFunc) 34 } 35 36 func ApplyPrivateMethod(target interface{}, methodName string, double interface{}) *Patches { 37 return create().ApplyPrivateMethod(target, methodName, double) 38 } 39 40 func ApplyGlobalVar(target, double interface{}) *Patches { 41 return create().ApplyGlobalVar(target, double) 42 } 43 44 func ApplyFuncVar(target, double interface{}) *Patches { 45 return create().ApplyFuncVar(target, double) 46 } 47 48 func ApplyFuncSeq(target interface{}, outputs []OutputCell) *Patches { 49 return create().ApplyFuncSeq(target, outputs) 50 } 51 52 func ApplyMethodSeq(target interface{}, methodName string, outputs []OutputCell) *Patches { 53 return create().ApplyMethodSeq(target, methodName, outputs) 54 } 55 56 func ApplyFuncVarSeq(target interface{}, outputs []OutputCell) *Patches { 57 return create().ApplyFuncVarSeq(target, outputs) 58 } 59 60 func ApplyFuncReturn(target interface{}, output ...interface{}) *Patches { 61 return create().ApplyFuncReturn(target, output...) 62 } 63 64 func ApplyMethodReturn(target interface{}, methodName string, output ...interface{}) *Patches { 65 return create().ApplyMethodReturn(target, methodName, output...) 66 } 67 68 func ApplyFuncVarReturn(target interface{}, output ...interface{}) *Patches { 69 return create().ApplyFuncVarReturn(target, output...) 70 } 71 72 func create() *Patches { 73 return &Patches{originals: make(map[uintptr][]byte), values: make(map[reflect.Value]reflect.Value), valueHolders: make(map[reflect.Value]reflect.Value)} 74 } 75 76 func NewPatches() *Patches { 77 return create() 78 } 79 80 func (this *Patches) ApplyFunc(target, double interface{}) *Patches { 81 t := reflect.ValueOf(target) 82 d := reflect.ValueOf(double) 83 return this.ApplyCore(t, d) 84 } 85 86 func (this *Patches) ApplyMethod(target interface{}, methodName string, double interface{}) *Patches { 87 m, ok := castRType(target).MethodByName(methodName) 88 if !ok { 89 panic("retrieve method by name failed") 90 } 91 d := reflect.ValueOf(double) 92 return this.ApplyCore(m.Func, d) 93 } 94 95 func (this *Patches) ApplyMethodFunc(target interface{}, methodName string, doubleFunc interface{}) *Patches { 96 m, ok := castRType(target).MethodByName(methodName) 97 if !ok { 98 panic("retrieve method by name failed") 99 } 100 d := funcToMethod(m.Type, doubleFunc) 101 return this.ApplyCore(m.Func, d) 102 } 103 104 func (this *Patches) ApplyPrivateMethod(target interface{}, methodName string, double interface{}) *Patches { 105 m, ok := creflect.MethodByName(castRType(target), methodName) 106 if !ok { 107 panic("retrieve method by name failed") 108 } 109 d := reflect.ValueOf(double) 110 return this.ApplyCoreOnlyForPrivateMethod(m, d) 111 } 112 113 func (this *Patches) ApplyGlobalVar(target, double interface{}) *Patches { 114 t := reflect.ValueOf(target) 115 if t.Type().Kind() != reflect.Ptr { 116 panic("target is not a pointer") 117 } 118 119 this.values[t] = reflect.ValueOf(t.Elem().Interface()) 120 d := reflect.ValueOf(double) 121 t.Elem().Set(d) 122 return this 123 } 124 125 func (this *Patches) ApplyFuncVar(target, double interface{}) *Patches { 126 t := reflect.ValueOf(target) 127 d := reflect.ValueOf(double) 128 if t.Type().Kind() != reflect.Ptr { 129 panic("target is not a pointer") 130 } 131 this.check(t.Elem(), d) 132 return this.ApplyGlobalVar(target, double) 133 } 134 135 func (this *Patches) ApplyFuncSeq(target interface{}, outputs []OutputCell) *Patches { 136 funcType := reflect.TypeOf(target) 137 t := reflect.ValueOf(target) 138 d := getDoubleFunc(funcType, outputs) 139 return this.ApplyCore(t, d) 140 } 141 142 func (this *Patches) ApplyMethodSeq(target interface{}, methodName string, outputs []OutputCell) *Patches { 143 m, ok := castRType(target).MethodByName(methodName) 144 if !ok { 145 panic("retrieve method by name failed") 146 } 147 d := getDoubleFunc(m.Type, outputs) 148 return this.ApplyCore(m.Func, d) 149 } 150 151 func (this *Patches) ApplyFuncVarSeq(target interface{}, outputs []OutputCell) *Patches { 152 t := reflect.ValueOf(target) 153 if t.Type().Kind() != reflect.Ptr { 154 panic("target is not a pointer") 155 } 156 if t.Elem().Kind() != reflect.Func { 157 panic("target is not a func") 158 } 159 160 funcType := reflect.TypeOf(target).Elem() 161 double := getDoubleFunc(funcType, outputs).Interface() 162 return this.ApplyGlobalVar(target, double) 163 } 164 165 func (this *Patches) ApplyFuncReturn(target interface{}, returns ...interface{}) *Patches { 166 funcType := reflect.TypeOf(target) 167 t := reflect.ValueOf(target) 168 outputs := []OutputCell{{Values: returns, Times: -1}} 169 d := getDoubleFunc(funcType, outputs) 170 return this.ApplyCore(t, d) 171 } 172 173 func (this *Patches) ApplyMethodReturn(target interface{}, methodName string, returns ...interface{}) *Patches { 174 m, ok := reflect.TypeOf(target).MethodByName(methodName) 175 if !ok { 176 panic("retrieve method by name failed") 177 } 178 179 outputs := []OutputCell{{Values: returns, Times: -1}} 180 d := getDoubleFunc(m.Type, outputs) 181 return this.ApplyCore(m.Func, d) 182 } 183 184 func (this *Patches) ApplyFuncVarReturn(target interface{}, returns ...interface{}) *Patches { 185 t := reflect.ValueOf(target) 186 if t.Type().Kind() != reflect.Ptr { 187 panic("target is not a pointer") 188 } 189 if t.Elem().Kind() != reflect.Func { 190 panic("target is not a func") 191 } 192 193 funcType := reflect.TypeOf(target).Elem() 194 outputs := []OutputCell{{Values: returns, Times: -1}} 195 double := getDoubleFunc(funcType, outputs).Interface() 196 return this.ApplyGlobalVar(target, double) 197 } 198 199 func (this *Patches) Reset() { 200 for target, bytes := range this.originals { 201 modifyBinary(target, bytes) 202 delete(this.originals, target) 203 } 204 205 for target, variable := range this.values { 206 target.Elem().Set(variable) 207 } 208 } 209 210 func (this *Patches) ApplyCore(target, double reflect.Value) *Patches { 211 this.check(target, double) 212 assTarget := *(*uintptr)(getPointer(target)) 213 original := replace(assTarget, uintptr(getPointer(double))) 214 if _, ok := this.originals[assTarget]; !ok { 215 this.originals[assTarget] = original 216 } 217 this.valueHolders[double] = double 218 return this 219 } 220 221 func (this *Patches) ApplyCoreOnlyForPrivateMethod(target unsafe.Pointer, double reflect.Value) *Patches { 222 if double.Kind() != reflect.Func { 223 panic("double is not a func") 224 } 225 assTarget := *(*uintptr)(target) 226 original := replace(assTarget, uintptr(getPointer(double))) 227 if _, ok := this.originals[assTarget]; !ok { 228 this.originals[assTarget] = original 229 } 230 this.valueHolders[double] = double 231 return this 232 } 233 234 func (this *Patches) check(target, double reflect.Value) { 235 if target.Kind() != reflect.Func { 236 panic("target is not a func") 237 } 238 239 if double.Kind() != reflect.Func { 240 panic("double is not a func") 241 } 242 243 targetType := target.Type() 244 doubleType := double.Type() 245 246 if targetType.NumIn() < doubleType.NumIn() || 247 targetType.NumOut() != doubleType.NumOut() || 248 (targetType.NumIn() == doubleType.NumIn() && targetType.IsVariadic() != doubleType.IsVariadic()) { 249 panic(fmt.Sprintf("target type(%s) and double type(%s) are different", target.Type(), double.Type())) 250 } 251 252 for i, size := 0, doubleType.NumIn(); i < size; i++ { 253 targetIn := targetType.In(i) 254 doubleIn := doubleType.In(i) 255 256 if targetIn.AssignableTo(doubleIn) { 257 continue 258 } 259 260 panic(fmt.Sprintf("target type(%s) and double type(%s) are different", target.Type(), double.Type())) 261 } 262 263 for i, size := 0, doubleType.NumOut(); i < size; i++ { 264 targetOut := targetType.Out(i) 265 doubleOut := doubleType.Out(i) 266 267 if targetOut.AssignableTo(doubleOut) { 268 continue 269 } 270 271 panic(fmt.Sprintf("target type(%s) and double type(%s) are different", target.Type(), double.Type())) 272 } 273 } 274 275 func replace(target, double uintptr) []byte { 276 code := buildJmpDirective(double) 277 bytes := entryAddress(target, len(code)) 278 original := make([]byte, len(bytes)) 279 copy(original, bytes) 280 modifyBinary(target, code) 281 return original 282 } 283 284 func getDoubleFunc(funcType reflect.Type, outputs []OutputCell) reflect.Value { 285 if funcType.NumOut() != len(outputs[0].Values) { 286 panic(fmt.Sprintf("func type has %v return values, but only %v values provided as double", 287 funcType.NumOut(), len(outputs[0].Values))) 288 } 289 290 needReturn := false 291 slice := make([]Params, 0) 292 for _, output := range outputs { 293 if output.Times == -1 { 294 needReturn = true 295 slice = []Params{output.Values} 296 break 297 } 298 t := 0 299 if output.Times <= 1 { 300 t = 1 301 } else { 302 t = output.Times 303 } 304 for j := 0; j < t; j++ { 305 slice = append(slice, output.Values) 306 } 307 } 308 309 i := 0 310 lenOutputs := len(slice) 311 return reflect.MakeFunc(funcType, func(_ []reflect.Value) []reflect.Value { 312 if needReturn { 313 return GetResultValues(funcType, slice[0]...) 314 } 315 if i < lenOutputs { 316 i++ 317 return GetResultValues(funcType, slice[i-1]...) 318 } 319 panic("double seq is less than call seq") 320 }) 321 } 322 323 func GetResultValues(funcType reflect.Type, results ...interface{}) []reflect.Value { 324 var resultValues []reflect.Value 325 for i, r := range results { 326 var resultValue reflect.Value 327 if r == nil { 328 resultValue = reflect.Zero(funcType.Out(i)) 329 } else { 330 v := reflect.New(funcType.Out(i)) 331 v.Elem().Set(reflect.ValueOf(r)) 332 resultValue = v.Elem() 333 } 334 resultValues = append(resultValues, resultValue) 335 } 336 return resultValues 337 } 338 339 type funcValue struct { 340 _ uintptr 341 p unsafe.Pointer 342 } 343 344 func getPointer(v reflect.Value) unsafe.Pointer { 345 return (*funcValue)(unsafe.Pointer(&v)).p 346 } 347 348 func entryAddress(p uintptr, l int) []byte { 349 return *(*[]byte)(unsafe.Pointer(&reflect.SliceHeader{Data: p, Len: l, Cap: l})) 350 } 351 352 func pageStart(ptr uintptr) uintptr { 353 return ptr & ^(uintptr(syscall.Getpagesize() - 1)) 354 } 355 356 func funcToMethod(funcType reflect.Type, doubleFunc interface{}) reflect.Value { 357 rf := reflect.TypeOf(doubleFunc) 358 if rf.Kind() != reflect.Func { 359 panic("doubleFunc is not a func") 360 } 361 vf := reflect.ValueOf(doubleFunc) 362 return reflect.MakeFunc(funcType, func(in []reflect.Value) []reflect.Value { 363 if funcType.IsVariadic() { 364 return vf.CallSlice(in[1:]) 365 } else { 366 return vf.Call(in[1:]) 367 } 368 }) 369 } 370 371 func castRType(val interface{}) reflect.Type { 372 if rTypeVal, ok := val.(reflect.Type); ok { 373 return rTypeVal 374 } 375 return reflect.TypeOf(val) 376 }