github.com/tencent/goom@v1.0.1/internal/patch/monkey.go (about)

     1  package patch
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  
     7  	"github.com/tencent/goom/internal/logger"
     8  )
     9  
    10  // Patch 将函数调用指定代理函数
    11  // origin 原始函数
    12  // replacement 代理函数
    13  func Patch(origin, replacement interface{}) (*Guard, error) {
    14  	return Trampoline(origin, replacement, nil)
    15  }
    16  
    17  // Trampoline 将函数调用指定代理函数
    18  // origin 原始函数
    19  // replacement 代理函数
    20  // trampoline 指定跳板函数(可不指定,传 nil)
    21  func Trampoline(origin, replacement interface{}, trampoline interface{}) (*Guard, error) {
    22  	patch := &patch{
    23  		origin:      origin,
    24  		replacement: replacement,
    25  		trampoline:  trampoline,
    26  	}
    27  
    28  	err := patch.patch()
    29  	if err != nil {
    30  		return nil, err
    31  	}
    32  
    33  	return patch.Guard(), nil
    34  }
    35  
    36  // UnsafePatch 未受类型检查的 patch
    37  // origin 原始函数
    38  // replacement 代理函数
    39  func UnsafePatch(origin, replacement interface{}) (*Guard, error) {
    40  	return UnsafePatchTrampoline(origin, replacement, nil)
    41  }
    42  
    43  // UnsafePatchTrampoline 未受类型检查的 patch
    44  // origin 原始函数
    45  // replacement 代理函数
    46  // trampoline 指定跳板函数(可不指定,传 nil)
    47  func UnsafePatchTrampoline(origin, replacement interface{}, trampoline interface{}) (*Guard, error) {
    48  	patch := &patch{
    49  		origin:           origin,
    50  		replacement:      replacement,
    51  		trampoline:       trampoline,
    52  		originValue:      reflect.ValueOf(origin),
    53  		replacementValue: reflect.ValueOf(replacement),
    54  	}
    55  
    56  	if err := patch.unsafePatchValue(); err != nil {
    57  		return nil, err
    58  	}
    59  	return patch.Guard(), nil
    60  }
    61  
    62  // Ptr 直接将函数跳转的新函数
    63  // 此方式为经过函数签名检查,可能会导致栈帧无法对其导致堆栈调用异常,因此不安全请谨慎使用
    64  // originPtr 原始函数地址
    65  // replacement 代理函数
    66  func Ptr(originPtr uintptr, replacement interface{}) (*Guard, error) {
    67  	return PtrTrampoline(originPtr, replacement, nil)
    68  }
    69  
    70  // PtrTrampoline 直接将函数跳转的新函数(指定跳板函数)
    71  // 此方式为经过函数签名检查,可能会导致栈帧无法对其导致堆栈调用异常,因此不安全请谨慎使用
    72  // originPtr 原始函数地址
    73  // replacement 代理函数
    74  // trampoline 跳板函数地址(可不指定,传 nil)
    75  func PtrTrampoline(originPtr uintptr, replacement, trampoline interface{}) (*Guard, error) {
    76  	patch := &patch{
    77  		replacement: replacement,
    78  		trampoline:  trampoline,
    79  
    80  		replacementValue: reflect.ValueOf(replacement),
    81  
    82  		originPtr: originPtr,
    83  	}
    84  
    85  	err := patch.unsafePatchPtr()
    86  	if err != nil {
    87  		return nil, err
    88  	}
    89  	return patch.Guard(), nil
    90  }
    91  
    92  // InstanceMethod replaces an instance method methodName for the type target with replacementValue
    93  // Replacement should expect the receiver (of type target) as the first argument
    94  func InstanceMethod(originType reflect.Type, methodName string, replacement interface{}) (*Guard, error) {
    95  	return InstanceMethodTrampoline(originType, methodName, replacement, nil)
    96  }
    97  
    98  // InstanceMethodTrampoline replaces an instance method methodName for the type target with replacementValue
    99  // Replacement should expect the receiver (of type target) as the first argument
   100  func InstanceMethodTrampoline(originType reflect.Type, methodName string, replacement interface{},
   101  	trampoline interface{}) (*Guard, error) {
   102  	m, ok := originType.MethodByName(methodName)
   103  	if !ok {
   104  		return nil, fmt.Errorf("unknown method %s", methodName)
   105  	}
   106  
   107  	patch := &patch{
   108  		replacement: replacement,
   109  		trampoline:  trampoline,
   110  
   111  		originValue:      m.Func,
   112  		replacementValue: reflect.ValueOf(replacement),
   113  	}
   114  
   115  	err := patch.patchValue()
   116  	if err != nil {
   117  		return nil, err
   118  	}
   119  	return patch.Guard(), nil
   120  }
   121  
   122  // Unpatch removes any monkey patches on target
   123  // returns whether target was patched in the first place
   124  func Unpatch(origin interface{}) bool {
   125  	return unpatchValue(reflect.ValueOf(origin).Pointer())
   126  }
   127  
   128  // UnpatchInstanceMethod removes the patch on methodName of the target
   129  // returns whether it was patched in the first place
   130  func UnpatchInstanceMethod(originType reflect.Type, methodName string) bool {
   131  	m, ok := originType.MethodByName(methodName)
   132  	if !ok {
   133  		logger.Debugf(fmt.Sprintf("unknown method %s", methodName))
   134  		return false
   135  	}
   136  
   137  	return unpatchValue(m.Func.Pointer())
   138  }
   139  
   140  // UnpatchAll removes all applied monkey patches
   141  func UnpatchAll() {
   142  	for target, p := range patches {
   143  		p.unpatch()
   144  		delete(patches, target)
   145  	}
   146  }
   147  
   148  // unpatchValue removes a monkeypatch from the specified function
   149  // returns whether the function was patched in the first place
   150  func unpatchValue(origin uintptr) bool {
   151  	p, ok := patches[origin]
   152  	if !ok {
   153  		return false
   154  	}
   155  
   156  	p.unpatch()
   157  	delete(patches, origin)
   158  	return true
   159  }