github.com/tencent/goom@v1.0.1/internal/patch/patch.go (about)

     1  // Package patch 对不同类型的函数、方法、未导出函数、进行hook
     2  package patch
     3  
     4  import (
     5  	"errors"
     6  	"reflect"
     7  	"runtime"
     8  	"sync"
     9  
    10  	"github.com/tencent/goom/internal/bytecode"
    11  	"github.com/tencent/goom/internal/logger"
    12  )
    13  
    14  var (
    15  	// patches 缓存
    16  	patches = make(map[uintptr]*patch)
    17  	// lock patches 缓存的读写锁定
    18  	patchesLock = sync.Mutex{}
    19  )
    20  
    21  // lock 锁定 patches map 和内存指令读写
    22  func lock() {
    23  	patchesLock.Lock()
    24  }
    25  
    26  // unlock 解锁
    27  func unlock() {
    28  	patchesLock.Unlock()
    29  }
    30  
    31  // patch 一个可以 Apply 的 patch
    32  type patch struct {
    33  	origin      interface{} // 原始函数,即要mock的目标函数, 相对于代理函数来说叫原始函数
    34  	replacement interface{} // 代理函数
    35  	trampoline  interface{} // 跳板函数
    36  
    37  	originValue      reflect.Value
    38  	replacementValue reflect.Value
    39  
    40  	// 指针管理
    41  	originPtr      uintptr
    42  	replacementPtr uintptr
    43  	trampolinePtr  uintptr
    44  	fixOriginPtr   uintptr
    45  
    46  	originBytes []byte
    47  	jumpBytes   []byte
    48  
    49  	guard *Guard
    50  }
    51  
    52  // patchValue 对 value 进行应用代理
    53  func (p *patch) patch() error {
    54  	p.originValue = reflect.ValueOf(p.origin)
    55  	p.replacementValue = reflect.ValueOf(p.replacement)
    56  	return p.patchValue()
    57  }
    58  
    59  // patchValue 对 value 进行应用代理
    60  func (p *patch) patchValue() error {
    61  	SignatureEquals(p.originValue.Type(), p.replacementValue.Type())
    62  	return p.unsafePatchValue()
    63  }
    64  
    65  // unsafePatchValue 不做类型检查
    66  func (p *patch) unsafePatchValue() error {
    67  	if p.originValue.Kind() != reflect.Func {
    68  		return errors.New("target has to be a ExportFunc")
    69  	}
    70  	if p.replacementValue.Kind() != reflect.Func {
    71  		return errors.New("replacementValue has to be a ExportFunc")
    72  	}
    73  	originPointer := p.originValue.Pointer()
    74  	p.originPtr = originPointer
    75  
    76  	// fix for generics variants
    77  	funcName := runtime.FuncForPC(originPointer).Name()
    78  	if IsGenericsFunc(funcName) {
    79  		innerPointer, err := bytecode.GetInnerFunc(64, originPointer)
    80  		if err == nil && innerPointer != 0 {
    81  			p.originPtr = innerPointer
    82  		}
    83  	}
    84  	return p.unsafePatchPtr()
    85  }
    86  
    87  // unsafePatchPtr 不做类型检查
    88  func (p *patch) unsafePatchPtr() error {
    89  	replacementPointer := p.replacementValue.Pointer()
    90  	p.replacementPtr = replacementPointer
    91  	if p.trampoline != nil {
    92  		trampolinePtr, err := bytecode.GetTrampolinePtr(p.trampoline)
    93  		if err != nil {
    94  			return err
    95  		}
    96  		p.trampolinePtr = trampolinePtr
    97  	}
    98  	return p.replaceFunc()
    99  }
   100  
   101  // replaceFunc 替换函数
   102  func (p *patch) replaceFunc() error {
   103  	lock()
   104  	defer unlock()
   105  
   106  	if _, ok := patches[p.originPtr]; ok {
   107  		unpatchValue(p.originPtr)
   108  	}
   109  	patches[p.originPtr] = p
   110  
   111  	replacementInAddr := (uintptr)(bytecode.GetPtr(p.replacementValue))
   112  	jumpData, err := genJumpData(p.originPtr, replacementInAddr, p.replacementPtr)
   113  	if err != nil {
   114  		if errors.Unwrap(err) == errAlreadyPatch {
   115  			if pc, ok := patches[p.originPtr]; ok {
   116  				bytecode.PrintInstf("origin bytes", pc.originPtr, pc.originBytes, logger.WarningLevel)
   117  			}
   118  		}
   119  		return err
   120  	}
   121  	p.jumpBytes = jumpData
   122  
   123  	originBytes, err := checkAndReadOriginBytes(p.originPtr, len(jumpData))
   124  	if err != nil {
   125  		return err
   126  	}
   127  	p.originBytes = originBytes
   128  
   129  	// 是否修复指令
   130  	if p.trampolinePtr > 0 {
   131  		fixOriginPtr, err := fixOrigin(p.originPtr, p.trampolinePtr, len(jumpData))
   132  		if err != nil {
   133  			return err
   134  		}
   135  		p.fixOriginPtr = fixOriginPtr
   136  	}
   137  
   138  	return nil
   139  }
   140  
   141  // unpatch do unpatch by uint ptr
   142  func (p *patch) unpatch() {
   143  	p.Guard().Unpatch()
   144  }
   145  
   146  // Guard 获取 PatchGuard
   147  func (p *patch) Guard() *Guard {
   148  	if p.guard != nil {
   149  		return p.guard
   150  	}
   151  	p.guard = &Guard{
   152  		origin:       p.originPtr,
   153  		originBytes:  p.originBytes,
   154  		jumpBytes:    p.jumpBytes,
   155  		fixOriginPtr: p.fixOriginPtr,
   156  		applied:      false,
   157  	}
   158  	return p.guard
   159  }