github.com/undefinedlabs/go-mpatch@v1.0.8-0.20230904093002-fbac8a0d7853/patcher.go (about) 1 package mpatch // import "github.com/undefinedlabs/go-mpatch" 2 3 import ( 4 "errors" 5 "fmt" 6 "reflect" 7 "sync" 8 "syscall" 9 "unsafe" 10 ) 11 12 type ( 13 Patch struct { 14 targetBytes []byte 15 target *reflect.Value 16 redirection *reflect.Value 17 } 18 sliceHeader struct { 19 Data unsafe.Pointer 20 Len int 21 Cap int 22 } 23 ) 24 25 //go:linkname getInternalPtrFromValue reflect.(*Value).pointer 26 func getInternalPtrFromValue(v *reflect.Value) unsafe.Pointer 27 28 var ( 29 patchLock = sync.Mutex{} 30 patches = make(map[unsafe.Pointer]*Patch) 31 pageSize = syscall.Getpagesize() 32 ) 33 34 // Patches a target func to redirect calls to "redirection" func. Both function must have same arguments and return types. 35 func PatchMethod(target, redirection interface{}) (*Patch, error) { 36 tValue := getValueFrom(target) 37 rValue := getValueFrom(redirection) 38 if err := isPatchable(&tValue, &rValue); err != nil { 39 return nil, err 40 } 41 patch := &Patch{target: &tValue, redirection: &rValue} 42 if err := applyPatch(patch); err != nil { 43 return nil, err 44 } 45 return patch, nil 46 } 47 48 // Patches an instance func by using two parameters, the target struct type and the method name inside that type, 49 // this func will be redirected to the "redirection" func. Note: The first parameter of the redirection func must be the object instance. 50 func PatchInstanceMethodByName(target reflect.Type, methodName string, redirection interface{}) (*Patch, error) { 51 method, ok := target.MethodByName(methodName) 52 if !ok && target.Kind() == reflect.Struct { 53 target = reflect.PtrTo(target) 54 method, ok = target.MethodByName(methodName) 55 } 56 if !ok { 57 return nil, errors.New(fmt.Sprintf("Method '%v' not found", methodName)) 58 } 59 return PatchMethodByReflect(method, redirection) 60 } 61 62 // Patches a target func by passing the reflect.Method of the func. The target func will be redirected to the "redirection" func. 63 // Both function must have same arguments and return types. 64 func PatchMethodByReflect(target reflect.Method, redirection interface{}) (*Patch, error) { 65 return PatchMethodByReflectValue(target.Func, redirection) 66 } 67 68 // Patches a target func with a "redirection" function created at runtime by using "reflect.MakeFunc". 69 func PatchMethodWithMakeFunc(target reflect.Method, fn func(args []reflect.Value) (results []reflect.Value)) (*Patch, error) { 70 return PatchMethodByReflect(target, reflect.MakeFunc(target.Type, fn)) 71 } 72 73 // Patches a target func by passing the reflect.ValueOf of the func. The target func will be redirected to the "redirection" func. 74 // Both function must have same arguments and return types. 75 func PatchMethodByReflectValue(target reflect.Value, redirection interface{}) (*Patch, error) { 76 tValue := &target 77 rValue := getValueFrom(redirection) 78 if err := isPatchable(tValue, &rValue); err != nil { 79 return nil, err 80 } 81 patch := &Patch{target: tValue, redirection: &rValue} 82 if err := applyPatch(patch); err != nil { 83 return nil, err 84 } 85 return patch, nil 86 } 87 88 // Patches a target func with a "redirection" function created at runtime by using "reflect.MakeFunc". 89 func PatchMethodWithMakeFuncValue(target reflect.Value, fn func(args []reflect.Value) (results []reflect.Value)) (*Patch, error) { 90 return PatchMethodByReflectValue(target, reflect.MakeFunc(target.Type(), fn)) 91 } 92 93 // Patch the target func with the redirection func. 94 func (p *Patch) Patch() error { 95 if p == nil { 96 return errors.New("patch is nil") 97 } 98 if err := isPatchable(p.target, p.redirection); err != nil { 99 return err 100 } 101 if err := applyPatch(p); err != nil { 102 return err 103 } 104 return nil 105 } 106 107 // Unpatch the target func and recover the original func. 108 func (p *Patch) Unpatch() error { 109 if p == nil { 110 return errors.New("patch is nil") 111 } 112 return applyUnpatch(p) 113 } 114 115 func isPatchable(target, redirection *reflect.Value) error { 116 patchLock.Lock() 117 defer patchLock.Unlock() 118 if target.Kind() != reflect.Func || redirection.Kind() != reflect.Func { 119 return errors.New("the target and/or redirection is not a Func") 120 } 121 if target.Type() != redirection.Type() { 122 return errors.New(fmt.Sprintf("the target and/or redirection doesn't have the same type: %s != %s", target.Type(), redirection.Type())) 123 } 124 if _, ok := patches[getCodePointer(target)]; ok { 125 return errors.New("the target is already patched") 126 } 127 return nil 128 } 129 130 func applyPatch(patch *Patch) error { 131 patchLock.Lock() 132 defer patchLock.Unlock() 133 tPointer := getCodePointer(patch.target) 134 rPointer := getInternalPtrFromValue(patch.redirection) 135 rPointerJumpBytes, err := getJumpFuncBytes(rPointer) 136 if err != nil { 137 return err 138 } 139 tPointerBytes := getMemorySliceFromPointer(tPointer, len(rPointerJumpBytes)) 140 targetBytes := make([]byte, len(tPointerBytes)) 141 copy(targetBytes, tPointerBytes) 142 if err := writeDataToPointer(tPointer, rPointerJumpBytes); err != nil { 143 return err 144 } 145 patch.targetBytes = targetBytes 146 patches[tPointer] = patch 147 return nil 148 } 149 150 func applyUnpatch(patch *Patch) error { 151 patchLock.Lock() 152 defer patchLock.Unlock() 153 if patch.targetBytes == nil || len(patch.targetBytes) == 0 { 154 return errors.New("the target is not patched") 155 } 156 tPointer := getCodePointer(patch.target) 157 if _, ok := patches[tPointer]; !ok { 158 return errors.New("the target is not patched") 159 } 160 delete(patches, tPointer) 161 err := writeDataToPointer(tPointer, patch.targetBytes) 162 if err != nil { 163 return err 164 } 165 return nil 166 } 167 168 func getValueFrom(data interface{}) reflect.Value { 169 if cValue, ok := data.(reflect.Value); ok { 170 return cValue 171 } else { 172 return reflect.ValueOf(data) 173 } 174 } 175 176 // Extracts a memory slice from a pointer 177 func getMemorySliceFromPointer(p unsafe.Pointer, length int) []byte { 178 return *(*[]byte)(unsafe.Pointer(&sliceHeader{ 179 Data: p, 180 Len: length, 181 Cap: length, 182 })) 183 } 184 185 // Gets the code pointer of a func 186 func getCodePointer(value *reflect.Value) unsafe.Pointer { 187 p := getInternalPtrFromValue(value) 188 if p != nil { 189 p = *(*unsafe.Pointer)(p) 190 } 191 return p 192 }