github.com/tencent/goom@v1.0.1/internal/patch/patch.go (about) 1 // Package patch 对不同类型的函数、方法、未导出函数、进行hook 2 package patch 3 4 import ( 5 "errors" 6 "reflect" 7 "runtime" 8 "sync" 9 10 "github.com/tencent/goom/internal/bytecode" 11 "github.com/tencent/goom/internal/logger" 12 ) 13 14 var ( 15 // patches 缓存 16 patches = make(map[uintptr]*patch) 17 // lock patches 缓存的读写锁定 18 patchesLock = sync.Mutex{} 19 ) 20 21 // lock 锁定 patches map 和内存指令读写 22 func lock() { 23 patchesLock.Lock() 24 } 25 26 // unlock 解锁 27 func unlock() { 28 patchesLock.Unlock() 29 } 30 31 // patch 一个可以 Apply 的 patch 32 type patch struct { 33 origin interface{} // 原始函数,即要mock的目标函数, 相对于代理函数来说叫原始函数 34 replacement interface{} // 代理函数 35 trampoline interface{} // 跳板函数 36 37 originValue reflect.Value 38 replacementValue reflect.Value 39 40 // 指针管理 41 originPtr uintptr 42 replacementPtr uintptr 43 trampolinePtr uintptr 44 fixOriginPtr uintptr 45 46 originBytes []byte 47 jumpBytes []byte 48 49 guard *Guard 50 } 51 52 // patchValue 对 value 进行应用代理 53 func (p *patch) patch() error { 54 p.originValue = reflect.ValueOf(p.origin) 55 p.replacementValue = reflect.ValueOf(p.replacement) 56 return p.patchValue() 57 } 58 59 // patchValue 对 value 进行应用代理 60 func (p *patch) patchValue() error { 61 SignatureEquals(p.originValue.Type(), p.replacementValue.Type()) 62 return p.unsafePatchValue() 63 } 64 65 // unsafePatchValue 不做类型检查 66 func (p *patch) unsafePatchValue() error { 67 if p.originValue.Kind() != reflect.Func { 68 return errors.New("target has to be a ExportFunc") 69 } 70 if p.replacementValue.Kind() != reflect.Func { 71 return errors.New("replacementValue has to be a ExportFunc") 72 } 73 originPointer := p.originValue.Pointer() 74 p.originPtr = originPointer 75 76 // fix for generics variants 77 funcName := runtime.FuncForPC(originPointer).Name() 78 if IsGenericsFunc(funcName) { 79 innerPointer, err := bytecode.GetInnerFunc(64, originPointer) 80 if err == nil && innerPointer != 0 { 81 p.originPtr = innerPointer 82 } 83 } 84 return p.unsafePatchPtr() 85 } 86 87 // unsafePatchPtr 不做类型检查 88 func (p *patch) unsafePatchPtr() error { 89 replacementPointer := p.replacementValue.Pointer() 90 p.replacementPtr = replacementPointer 91 if p.trampoline != nil { 92 trampolinePtr, err := bytecode.GetTrampolinePtr(p.trampoline) 93 if err != nil { 94 return err 95 } 96 p.trampolinePtr = trampolinePtr 97 } 98 return p.replaceFunc() 99 } 100 101 // replaceFunc 替换函数 102 func (p *patch) replaceFunc() error { 103 lock() 104 defer unlock() 105 106 if _, ok := patches[p.originPtr]; ok { 107 unpatchValue(p.originPtr) 108 } 109 patches[p.originPtr] = p 110 111 replacementInAddr := (uintptr)(bytecode.GetPtr(p.replacementValue)) 112 jumpData, err := genJumpData(p.originPtr, replacementInAddr, p.replacementPtr) 113 if err != nil { 114 if errors.Unwrap(err) == errAlreadyPatch { 115 if pc, ok := patches[p.originPtr]; ok { 116 bytecode.PrintInstf("origin bytes", pc.originPtr, pc.originBytes, logger.WarningLevel) 117 } 118 } 119 return err 120 } 121 p.jumpBytes = jumpData 122 123 originBytes, err := checkAndReadOriginBytes(p.originPtr, len(jumpData)) 124 if err != nil { 125 return err 126 } 127 p.originBytes = originBytes 128 129 // 是否修复指令 130 if p.trampolinePtr > 0 { 131 fixOriginPtr, err := fixOrigin(p.originPtr, p.trampolinePtr, len(jumpData)) 132 if err != nil { 133 return err 134 } 135 p.fixOriginPtr = fixOriginPtr 136 } 137 138 return nil 139 } 140 141 // unpatch do unpatch by uint ptr 142 func (p *patch) unpatch() { 143 p.Guard().Unpatch() 144 } 145 146 // Guard 获取 PatchGuard 147 func (p *patch) Guard() *Guard { 148 if p.guard != nil { 149 return p.guard 150 } 151 p.guard = &Guard{ 152 origin: p.originPtr, 153 originBytes: p.originBytes, 154 jumpBytes: p.jumpBytes, 155 fixOriginPtr: p.fixOriginPtr, 156 applied: false, 157 } 158 return p.guard 159 }