github.com/kaiya/goutils@v1.0.1-0.20230226104005-4ae4a4dc3688/hotpatch/hotpatch.go (about)

     1  package hotpatch
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"runtime"
     7  	"syscall"
     8  	"unsafe"
     9  )
    10  
    11  type Patches struct {
    12  	originals    map[uintptr][]byte
    13  	values       map[reflect.Value]reflect.Value
    14  	valueHolders map[reflect.Value]reflect.Value
    15  }
    16  
    17  func create() *Patches {
    18  	return &Patches{originals: make(map[uintptr][]byte), values: make(map[reflect.Value]reflect.Value), valueHolders: make(map[reflect.Value]reflect.Value)}
    19  }
    20  
    21  // Create New Patches
    22  func NewPatches() *Patches {
    23  	return create()
    24  }
    25  
    26  func ApplyFunc(target, double interface{}) *Patches {
    27  	return create().ApplyFunc(target, double)
    28  }
    29  
    30  func (p *Patches) ApplyFunc(target, double interface{}) *Patches {
    31  	t := reflect.ValueOf(target)
    32  	d := reflect.ValueOf(double)
    33  	return p.ApplyCore(t, d)
    34  }
    35  
    36  // Apply Core function
    37  func (p *Patches) ApplyCore(target, double reflect.Value) *Patches {
    38  	check(target, double)
    39  	assTarget := *(*uintptr)(getPointer(target))
    40  	if _, ok := p.originals[assTarget]; ok {
    41  		panic("patch has been existed")
    42  	}
    43  
    44  	p.valueHolders[double] = double
    45  	original := replace(assTarget, uintptr(getPointer(double)))
    46  	p.originals[assTarget] = original
    47  	return p
    48  }
    49  
    50  // Reset Patches
    51  func (p *Patches) Reset() {
    52  	for target, bytes := range p.originals {
    53  		modifyBinary(target, bytes)
    54  		delete(p.originals, target)
    55  	}
    56  
    57  	for target, variable := range p.values {
    58  		target.Elem().Set(variable)
    59  	}
    60  }
    61  
    62  func replace(target, double uintptr) []byte {
    63  	code := buildJmpDirective(double)
    64  	bytes := entryAddress(target, len(code))
    65  	original := make([]byte, len(bytes))
    66  	copy(original, bytes)
    67  	modifyBinary(target, code)
    68  	return original
    69  }
    70  
    71  func buildJmpDirective(double uintptr) []byte {
    72  	d0 := byte(double)
    73  	d1 := byte(double >> 8)
    74  	d2 := byte(double >> 16)
    75  	d3 := byte(double >> 24)
    76  	d4 := byte(double >> 32)
    77  	d5 := byte(double >> 40)
    78  	d6 := byte(double >> 48)
    79  	d7 := byte(double >> 56)
    80  
    81  	return []byte{
    82  		0x48, 0xBA, d0, d1, d2, d3, d4, d5, d6, d7, // MOV rdx, double
    83  		0xFF, 0x22, // JMP [rdx]
    84  	}
    85  }
    86  
    87  // 关键函数:重写目标函数
    88  func modifyBinary(target uintptr, bytes []byte) {
    89  	function := entryAddress(target, len(bytes))
    90  
    91  	page := entryAddress(pageStart(target), syscall.Getpagesize())
    92  	var err error
    93  	var port int
    94  	if runtime.GOOS == "darwin" {
    95  		port = syscall.PROT_READ | syscall.PROT_WRITE
    96  	} else {
    97  		port = syscall.PROT_READ | syscall.PROT_WRITE | syscall.PROT_EXEC
    98  	}
    99  	err = syscall.Mprotect(page, port)
   100  	if err != nil {
   101  		panic(err)
   102  	}
   103  	copy(function, bytes)
   104  
   105  	if runtime.GOOS == "darwin" {
   106  		port = syscall.PROT_READ
   107  	} else {
   108  		port = syscall.PROT_READ | syscall.PROT_EXEC
   109  	}
   110  	err = syscall.Mprotect(page, port)
   111  	if err != nil {
   112  		panic(err)
   113  	}
   114  }
   115  
   116  func entryAddress(p uintptr, l int) []byte {
   117  	return *(*[]byte)(unsafe.Pointer(&reflect.SliceHeader{Data: p, Len: l, Cap: l}))
   118  }
   119  
   120  func pageStart(ptr uintptr) uintptr {
   121  	return ptr & ^(uintptr(syscall.Getpagesize() - 1))
   122  }
   123  
   124  func getPointer(v reflect.Value) unsafe.Pointer {
   125  	return (*funcValue)(unsafe.Pointer(&v)).p
   126  }
   127  
   128  type funcValue struct {
   129  	_ uintptr
   130  	p unsafe.Pointer
   131  }
   132  
   133  func check(target, double reflect.Value) {
   134  	if target.Kind() != reflect.Func {
   135  		panic("target is not a func")
   136  	}
   137  
   138  	if double.Kind() != reflect.Func {
   139  		panic("double is not a func")
   140  	}
   141  
   142  	if target.Type() != double.Type() {
   143  		panic(fmt.Sprintf("target type(%s) and double type(%s) are different", target.Type(), double.Type()))
   144  	}
   145  }