github.com/gbl08ma/monkey@v1.1.0/monkey.go (about) 1 package monkey // import "github.com/gbl08ma/monkey" 2 3 import ( 4 "fmt" 5 "reflect" 6 "sync" 7 "unsafe" 8 ) 9 10 //go:linkname stopTheWorld runtime.stopTheWorld 11 func stopTheWorld(reason string) 12 13 //go:linkname startTheWorld runtime.startTheWorld 14 func startTheWorld() 15 16 // patch is an applied patch 17 // needed to undo a patch 18 type patch struct { 19 originalBytes []byte 20 replacement *reflect.Value 21 } 22 23 var ( 24 lock = sync.Mutex{} 25 26 patches = make(map[uintptr]patch) 27 ) 28 29 type value struct { 30 _ uintptr 31 ptr unsafe.Pointer 32 } 33 34 func getPtr(v reflect.Value) unsafe.Pointer { 35 return (*value)(unsafe.Pointer(&v)).ptr 36 } 37 38 type PatchGuard struct { 39 lock sync.Mutex 40 target reflect.Value 41 replacement reflect.Value 42 } 43 44 // Unpatch removes the patch 45 func (g *PatchGuard) Unpatch() { 46 unpatchValue(g.target) 47 } 48 49 // Restore restores the replacement function 50 func (g *PatchGuard) Restore() { 51 patchValue(g.target, g.replacement) 52 } 53 54 // UnpatchLock removes the patch in a thread-safe way, expecting the patch to be restored later using RestoreLock 55 func (g *PatchGuard) UnpatchLock() { 56 g.lock.Lock() 57 unpatchValue(g.target) 58 } 59 60 // RestoreLock restores the replacement function in a thread-safe way. Must only be called if UnpatchLock was used previously 61 func (g *PatchGuard) RestoreLock() { 62 patchValue(g.target, g.replacement) 63 g.lock.Unlock() 64 } 65 66 // Patch replaces a function with another 67 func Patch(target, replacement interface{}) *PatchGuard { 68 t := reflect.ValueOf(target) 69 r := reflect.ValueOf(replacement) 70 patchValue(t, r) 71 72 return &PatchGuard{sync.Mutex{}, t, r} 73 } 74 75 // PatchInstanceMethod replaces an instance method methodName for the type target with replacement 76 // Replacement should expect the receiver (of type target) as the first argument 77 func PatchInstanceMethod(target reflect.Type, methodName string, replacement interface{}) *PatchGuard { 78 m, ok := target.MethodByName(methodName) 79 if !ok { 80 panic(fmt.Sprintf("unknown method %s", methodName)) 81 } 82 r := reflect.ValueOf(replacement) 83 patchValue(m.Func, r) 84 85 return &PatchGuard{sync.Mutex{}, m.Func, r} 86 } 87 88 // PatchCount returns the number of currently patched functions and type methods 89 func PatchCount() int { 90 lock.Lock() 91 defer lock.Unlock() 92 return len(patches) 93 } 94 95 func patchValue(target, replacement reflect.Value) { 96 lock.Lock() 97 defer lock.Unlock() 98 99 if target.Kind() != reflect.Func { 100 panic("target has to be a Func") 101 } 102 103 if replacement.Kind() != reflect.Func { 104 panic("replacement has to be a Func") 105 } 106 107 if target.Type() != replacement.Type() { 108 panic(fmt.Sprintf("target and replacement have to have the same type %s != %s", target.Type(), replacement.Type())) 109 } 110 111 // once the world is stopped, we cannot panic, or we will be forced to crash 112 stopTheWorld("monkey patch") 113 defer startTheWorld() 114 115 if patch, ok := patches[target.Pointer()]; ok { 116 unpatch(target.Pointer(), patch) 117 } 118 119 bytes := replaceFunction(target.Pointer(), (uintptr)(getPtr(replacement))) 120 patches[target.Pointer()] = patch{bytes, &replacement} 121 } 122 123 // Unpatch removes any monkey patches on target 124 // returns whether target was patched in the first place 125 func Unpatch(target interface{}) bool { 126 return unpatchValue(reflect.ValueOf(target)) 127 } 128 129 // UnpatchInstanceMethod removes the patch on methodName of the target 130 // returns whether it was patched in the first place 131 func UnpatchInstanceMethod(target reflect.Type, methodName string) bool { 132 m, ok := target.MethodByName(methodName) 133 if !ok { 134 panic(fmt.Sprintf("unknown method %s", methodName)) 135 } 136 return unpatchValue(m.Func) 137 } 138 139 // UnpatchAll removes all applied monkeypatches 140 func UnpatchAll() { 141 lock.Lock() 142 defer lock.Unlock() 143 stopTheWorld("monkey unpatch all") 144 defer startTheWorld() 145 for target, p := range patches { 146 unpatch(target, p) 147 delete(patches, target) 148 } 149 } 150 151 // Unpatch removes a monkeypatch from the specified function 152 // returns whether the function was patched in the first place 153 func unpatchValue(target reflect.Value) bool { 154 lock.Lock() 155 defer lock.Unlock() 156 stopTheWorld("monkey unpatch") 157 defer startTheWorld() 158 patch, ok := patches[target.Pointer()] 159 if !ok { 160 return false 161 } 162 unpatch(target.Pointer(), patch) 163 delete(patches, target.Pointer()) 164 return true 165 } 166 167 func unpatch(target uintptr, p patch) { 168 copyToLocation(target, p.originalBytes) 169 }