github.com/undefinedlabs/go-mpatch@v1.0.8-0.20230904093002-fbac8a0d7853/patcher.go (about)

     1  package mpatch // import "github.com/undefinedlabs/go-mpatch"
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"reflect"
     7  	"sync"
     8  	"syscall"
     9  	"unsafe"
    10  )
    11  
    12  type (
    13  	Patch struct {
    14  		targetBytes []byte
    15  		target      *reflect.Value
    16  		redirection *reflect.Value
    17  	}
    18  	sliceHeader struct {
    19  		Data unsafe.Pointer
    20  		Len  int
    21  		Cap  int
    22  	}
    23  )
    24  
    25  //go:linkname getInternalPtrFromValue reflect.(*Value).pointer
    26  func getInternalPtrFromValue(v *reflect.Value) unsafe.Pointer
    27  
    28  var (
    29  	patchLock = sync.Mutex{}
    30  	patches   = make(map[unsafe.Pointer]*Patch)
    31  	pageSize  = syscall.Getpagesize()
    32  )
    33  
    34  // Patches a target func to redirect calls to "redirection" func. Both function must have same arguments and return types.
    35  func PatchMethod(target, redirection interface{}) (*Patch, error) {
    36  	tValue := getValueFrom(target)
    37  	rValue := getValueFrom(redirection)
    38  	if err := isPatchable(&tValue, &rValue); err != nil {
    39  		return nil, err
    40  	}
    41  	patch := &Patch{target: &tValue, redirection: &rValue}
    42  	if err := applyPatch(patch); err != nil {
    43  		return nil, err
    44  	}
    45  	return patch, nil
    46  }
    47  
    48  // Patches an instance func by using two parameters, the target struct type and the method name inside that type,
    49  // this func will be redirected to the "redirection" func. Note: The first parameter of the redirection func must be the object instance.
    50  func PatchInstanceMethodByName(target reflect.Type, methodName string, redirection interface{}) (*Patch, error) {
    51  	method, ok := target.MethodByName(methodName)
    52  	if !ok && target.Kind() == reflect.Struct {
    53  		target = reflect.PtrTo(target)
    54  		method, ok = target.MethodByName(methodName)
    55  	}
    56  	if !ok {
    57  		return nil, errors.New(fmt.Sprintf("Method '%v' not found", methodName))
    58  	}
    59  	return PatchMethodByReflect(method, redirection)
    60  }
    61  
    62  // Patches a target func by passing the reflect.Method of the func. The target func will be redirected to the "redirection" func.
    63  // Both function must have same arguments and return types.
    64  func PatchMethodByReflect(target reflect.Method, redirection interface{}) (*Patch, error) {
    65  	return PatchMethodByReflectValue(target.Func, redirection)
    66  }
    67  
    68  // Patches a target func with a "redirection" function created at runtime by using "reflect.MakeFunc".
    69  func PatchMethodWithMakeFunc(target reflect.Method, fn func(args []reflect.Value) (results []reflect.Value)) (*Patch, error) {
    70  	return PatchMethodByReflect(target, reflect.MakeFunc(target.Type, fn))
    71  }
    72  
    73  // Patches a target func by passing the reflect.ValueOf of the func. The target func will be redirected to the "redirection" func.
    74  // Both function must have same arguments and return types.
    75  func PatchMethodByReflectValue(target reflect.Value, redirection interface{}) (*Patch, error) {
    76  	tValue := &target
    77  	rValue := getValueFrom(redirection)
    78  	if err := isPatchable(tValue, &rValue); err != nil {
    79  		return nil, err
    80  	}
    81  	patch := &Patch{target: tValue, redirection: &rValue}
    82  	if err := applyPatch(patch); err != nil {
    83  		return nil, err
    84  	}
    85  	return patch, nil
    86  }
    87  
    88  // Patches a target func with a "redirection" function created at runtime by using "reflect.MakeFunc".
    89  func PatchMethodWithMakeFuncValue(target reflect.Value, fn func(args []reflect.Value) (results []reflect.Value)) (*Patch, error) {
    90  	return PatchMethodByReflectValue(target, reflect.MakeFunc(target.Type(), fn))
    91  }
    92  
    93  // Patch the target func with the redirection func.
    94  func (p *Patch) Patch() error {
    95  	if p == nil {
    96  		return errors.New("patch is nil")
    97  	}
    98  	if err := isPatchable(p.target, p.redirection); err != nil {
    99  		return err
   100  	}
   101  	if err := applyPatch(p); err != nil {
   102  		return err
   103  	}
   104  	return nil
   105  }
   106  
   107  // Unpatch the target func and recover the original func.
   108  func (p *Patch) Unpatch() error {
   109  	if p == nil {
   110  		return errors.New("patch is nil")
   111  	}
   112  	return applyUnpatch(p)
   113  }
   114  
   115  func isPatchable(target, redirection *reflect.Value) error {
   116  	patchLock.Lock()
   117  	defer patchLock.Unlock()
   118  	if target.Kind() != reflect.Func || redirection.Kind() != reflect.Func {
   119  		return errors.New("the target and/or redirection is not a Func")
   120  	}
   121  	if target.Type() != redirection.Type() {
   122  		return errors.New(fmt.Sprintf("the target and/or redirection doesn't have the same type: %s != %s", target.Type(), redirection.Type()))
   123  	}
   124  	if _, ok := patches[getCodePointer(target)]; ok {
   125  		return errors.New("the target is already patched")
   126  	}
   127  	return nil
   128  }
   129  
   130  func applyPatch(patch *Patch) error {
   131  	patchLock.Lock()
   132  	defer patchLock.Unlock()
   133  	tPointer := getCodePointer(patch.target)
   134  	rPointer := getInternalPtrFromValue(patch.redirection)
   135  	rPointerJumpBytes, err := getJumpFuncBytes(rPointer)
   136  	if err != nil {
   137  		return err
   138  	}
   139  	tPointerBytes := getMemorySliceFromPointer(tPointer, len(rPointerJumpBytes))
   140  	targetBytes := make([]byte, len(tPointerBytes))
   141  	copy(targetBytes, tPointerBytes)
   142  	if err := writeDataToPointer(tPointer, rPointerJumpBytes); err != nil {
   143  		return err
   144  	}
   145  	patch.targetBytes = targetBytes
   146  	patches[tPointer] = patch
   147  	return nil
   148  }
   149  
   150  func applyUnpatch(patch *Patch) error {
   151  	patchLock.Lock()
   152  	defer patchLock.Unlock()
   153  	if patch.targetBytes == nil || len(patch.targetBytes) == 0 {
   154  		return errors.New("the target is not patched")
   155  	}
   156  	tPointer := getCodePointer(patch.target)
   157  	if _, ok := patches[tPointer]; !ok {
   158  		return errors.New("the target is not patched")
   159  	}
   160  	delete(patches, tPointer)
   161  	err := writeDataToPointer(tPointer, patch.targetBytes)
   162  	if err != nil {
   163  		return err
   164  	}
   165  	return nil
   166  }
   167  
   168  func getValueFrom(data interface{}) reflect.Value {
   169  	if cValue, ok := data.(reflect.Value); ok {
   170  		return cValue
   171  	} else {
   172  		return reflect.ValueOf(data)
   173  	}
   174  }
   175  
   176  // Extracts a memory slice from a pointer
   177  func getMemorySliceFromPointer(p unsafe.Pointer, length int) []byte {
   178  	return *(*[]byte)(unsafe.Pointer(&sliceHeader{
   179  		Data: p,
   180  		Len:  length,
   181  		Cap:  length,
   182  	}))
   183  }
   184  
   185  // Gets the code pointer of a func
   186  func getCodePointer(value *reflect.Value) unsafe.Pointer {
   187  	p := getInternalPtrFromValue(value)
   188  	if p != nil {
   189  		p = *(*unsafe.Pointer)(p)
   190  	}
   191  	return p
   192  }