github.com/tencent/goom@v1.0.1/internal/proxy/func.go (about)

     1  // Package proxy 封装了给各种类型的代理(或较 patch)中间层
     2  // 负责比如外部传如类型校验、私有函数名转换成 uintptr、trampoline 初始化、并发 proxy 等
     3  package proxy
     4  
     5  import (
     6  	"errors"
     7  	"fmt"
     8  	"reflect"
     9  
    10  	"github.com/tencent/goom/internal/bytecode"
    11  	"github.com/tencent/goom/internal/logger"
    12  	"github.com/tencent/goom/internal/patch"
    13  	"github.com/tencent/goom/internal/unexports"
    14  )
    15  
    16  // Func 通过函数生成代理函数
    17  // @param funcDef 原始函数定义
    18  // @param proxyFunc 代理函数实现
    19  // @param originFunc 跳板函数即代理后的原始函数定义(值为 nil 时,使用公共的跳板函数, 不为 nil 时使用指定的跳板函数)
    20  func Func(funcDef interface{}, proxyFunc, trampolineFunc interface{}) (*patch.Guard, error) {
    21  	if e := checkTrampolineFunc(trampolineFunc); e != nil {
    22  		return nil, e
    23  	}
    24  
    25  	logger.Info("start func proxy funcDef=", funcDef)
    26  	// 添加函数 hook
    27  	patchGuard, err := patch.Trampoline(
    28  		reflect.Indirect(reflect.ValueOf(funcDef)).Interface(), proxyFunc, trampolineFunc)
    29  	if err != nil {
    30  		logger.Error("func proxy fail funcDef=", funcDef, ":", err)
    31  		return nil, err
    32  	}
    33  
    34  	// 构造原先方法实例值
    35  	logger.Debug("origin ptr is:", fmt.Sprintf("0x%x", patchGuard.FixOriginFunc()))
    36  	if bytecode.IsValidPtr(trampolineFunc) {
    37  		_, err = unexports.CreateFuncForCodePtr(trampolineFunc, patchGuard.FixOriginFunc())
    38  		if err != nil {
    39  			logger.Error("func proxy fail funcDef=", funcDef, ":", err)
    40  			patchGuard.Unpatch()
    41  			return nil, err
    42  		}
    43  	}
    44  
    45  	logger.Debug("func proxy ok funcDef=", funcDef)
    46  	return patchGuard, nil
    47  }
    48  
    49  // FuncName 通过函数名生成代理函数
    50  // @param genCallableMethod 函数名称
    51  // @param proxyFunc 代理函数实现
    52  // @param trampolineFunc 跳板函数,即代理后的原始函数定义;跳板函数的签名必须和原函数一致,值不能为空
    53  func FuncName(funcName string, proxyFunc interface{}, trampolineFunc interface{}) (*patch.Guard, error) {
    54  	if e := checkTrampolineFunc(trampolineFunc); e != nil {
    55  		return nil, e
    56  	}
    57  	originFuncPtr, err := unexports.FindFuncByName(funcName)
    58  	if err != nil {
    59  		return nil, err
    60  	}
    61  
    62  	logger.Info("start funcName proxy genCallableMethod=", funcName)
    63  	// 添加函数 hook
    64  	patchGuard, err := patch.PtrTrampoline(originFuncPtr, proxyFunc, trampolineFunc)
    65  	if err != nil {
    66  		logger.Error("funcName proxy fail genCallableMethod=", funcName, ":", err)
    67  		return nil, err
    68  	}
    69  
    70  	// 构造原先方法实例值
    71  	logger.Debug("origin ptr is:", fmt.Sprintf("0x%x", patchGuard.FixOriginFunc()))
    72  	logger.Info("funcName proxy[trampoline] ok, genCallableMethod=", funcName)
    73  	return patchGuard, nil
    74  }
    75  
    76  // Method 通过方法生成代理方法
    77  // @param target 类型
    78  // @param methodName 方法名
    79  // @param proxyFunc 代理函数实现
    80  // @param trampolineFunc 跳板函数即代理后的原始方法定义(值为 nil 时,使用公共的跳板函数, 不为 nil 时使用指定的跳板函数)
    81  func Method(target reflect.Type, methodName string, proxyFunc,
    82  	trampolineFunc interface{}) (*patch.Guard, error) {
    83  	if e := checkTrampolineFunc(trampolineFunc); e != nil {
    84  		return nil, e
    85  	}
    86  
    87  	logger.Info("start method proxy genCallableMethod=", target, ".", methodName)
    88  	// 添加函数 hook
    89  	patchGuard, err := patch.InstanceMethodTrampoline(target, methodName, proxyFunc, trampolineFunc)
    90  	if err != nil {
    91  		logger.Error("method proxy fail type=", target, "methodName=", methodName, ":", err)
    92  		return nil, err
    93  	}
    94  
    95  	// 构造原先方法实例值
    96  	logger.Debug("origin ptr is:", fmt.Sprintf("0x%x", patchGuard.FixOriginFunc()))
    97  	if bytecode.IsValidPtr(trampolineFunc) {
    98  		_, err = unexports.CreateFuncForCodePtr(trampolineFunc, patchGuard.FixOriginFunc())
    99  		if err != nil {
   100  			logger.Error("method proxy fail method=", target, ".", methodName, ":", err)
   101  			patchGuard.Unpatch()
   102  			return nil, err
   103  		}
   104  	}
   105  
   106  	logger.Debug("method proxy ok genCallableMethod=", target, ".", methodName)
   107  	return patchGuard, nil
   108  }
   109  
   110  // checkTrampolineFunc 检测 TrampolineFunc 类型
   111  func checkTrampolineFunc(trampolineFunc interface{}) error {
   112  	if trampolineFunc != nil {
   113  		if reflect.ValueOf(trampolineFunc).Kind() != reflect.Func &&
   114  			reflect.ValueOf(trampolineFunc).Elem().Kind() != reflect.Func {
   115  			return errors.New("trampoline func must be a exported func")
   116  		}
   117  	}
   118  	return nil
   119  }