github.com/bytedance/mockey@v1.2.10/internal/monkey/patch.go (about)

     1  /*
     2   * Copyright 2022 ByteDance Inc.
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package monkey
    18  
    19  import (
    20  	"reflect"
    21  
    22  	"github.com/bytedance/mockey/internal/monkey/common"
    23  	"github.com/bytedance/mockey/internal/monkey/fn"
    24  	"github.com/bytedance/mockey/internal/monkey/inst"
    25  	"github.com/bytedance/mockey/internal/monkey/mem"
    26  	"github.com/bytedance/mockey/internal/tool"
    27  )
    28  
    29  // Patch is a context that holds the address and original codes of the patched function.
    30  type Patch struct {
    31  	size int
    32  	code []byte
    33  	base uintptr
    34  }
    35  
    36  // Unpatch restores the patched function to the original function.
    37  func (p *Patch) Unpatch() {
    38  	mem.WriteWithSTW(p.base, p.code[:p.size])
    39  	common.ReleasePage(p.code)
    40  }
    41  
    42  // PatchValue replace the target function with a hook function, and stores the target function in the proxy function
    43  // for future restore. Target and hook are values of function. Proxy is a value of proxy function pointer.
    44  func PatchValue(target, hook, proxy reflect.Value, unsafe, generic bool) *Patch {
    45  	tool.Assert(hook.Kind() == reflect.Func, "'%s' is not a function", hook.Kind())
    46  	tool.Assert(proxy.Kind() == reflect.Ptr, "'%v' is not a function pointer", proxy.Kind())
    47  
    48  	targetAddr := target.Pointer()
    49  	if generic {
    50  		// we assume that generic call/bl op is located in first 200 bytes of codes from targetAddr
    51  		targetAddr = inst.GetGenericJumpAddr(targetAddr, 10000)
    52  	}
    53  	// The first few bytes of the target function code
    54  	const bufSize = 64
    55  	targetCodeBuf := common.BytesOf(targetAddr, bufSize)
    56  	// construct the branch instruction, i.e. jump to the hook function
    57  	hookCode := inst.BranchInto(common.PtrAt(hook))
    58  	// search the cutting point of the target code, i.e. the minimum length of full instructions that is longer than the hookCode
    59  	cuttingIdx := inst.Disassemble(targetCodeBuf, len(hookCode), !unsafe)
    60  
    61  	// construct the proxy code
    62  	proxyCode := common.AllocatePage()
    63  	// save the original code before the cutting point
    64  	copy(proxyCode, targetCodeBuf[:cuttingIdx])
    65  	// construct the branch instruction, i.e. jump to the cutting point
    66  	copy(proxyCode[cuttingIdx:], inst.BranchTo(targetAddr+uintptr(cuttingIdx)))
    67  	// inject the proxy code to the proxy function
    68  	fn.InjectInto(proxy, proxyCode)
    69  
    70  	tool.DebugPrintf("PatchValue: hook code len(%v), cuttingIdx(%v)\n", len(hookCode), cuttingIdx)
    71  
    72  	// replace target function codes before the cutting point
    73  	mem.WriteWithSTW(targetAddr, hookCode)
    74  
    75  	return &Patch{base: targetAddr, code: proxyCode, size: cuttingIdx}
    76  }
    77  
    78  func PatchFunc(fn, hook, proxy interface{}, unsafe bool) *Patch {
    79  	vv := reflect.ValueOf(fn)
    80  	tool.Assert(vv.Kind() == reflect.Func, "'%v' is not a function", fn)
    81  	return PatchValue(vv, reflect.ValueOf(hook), reflect.ValueOf(proxy), unsafe, false)
    82  }