github.com/Rookout/GoSDK@v0.1.48/pkg/services/instrumentation/callback/callback.go (about)

     1  package callback
     2  
     3  import (
     4  	"fmt"
     5  	"runtime"
     6  	"runtime/debug"
     7  	"sync"
     8  	"unsafe"
     9  
    10  	"github.com/Rookout/GoSDK/pkg/augs"
    11  	"github.com/Rookout/GoSDK/pkg/locations_set"
    12  	"github.com/Rookout/GoSDK/pkg/logger"
    13  	"github.com/Rookout/GoSDK/pkg/rookoutErrors"
    14  	"github.com/Rookout/GoSDK/pkg/services/collection"
    15  	"github.com/Rookout/GoSDK/pkg/services/collection/go_id"
    16  	"github.com/Rookout/GoSDK/pkg/services/collection/registers"
    17  	"github.com/Rookout/GoSDK/pkg/services/go_runtime"
    18  	"github.com/Rookout/GoSDK/pkg/services/instrumentation/binary_info"
    19  	"github.com/Rookout/GoSDK/pkg/utils"
    20  )
    21  
    22  
    23  func getContext() uintptr
    24  
    25  var BinaryInfo *binary_info.BinaryInfo
    26  var locationsSet *locations_set.LocationsSet
    27  var triggerChan chan bool
    28  
    29  func SetBinaryInfo(binaryInfoIn *binary_info.BinaryInfo) {
    30  	BinaryInfo = binaryInfoIn
    31  }
    32  
    33  func SetLocationsSet(locationsSetIn *locations_set.LocationsSet) {
    34  	locationsSet = locationsSetIn
    35  }
    36  
    37  func SetTriggerChan(triggerChanIn chan bool) {
    38  	triggerChan = triggerChanIn
    39  }
    40  
    41  type BreakpointInfo struct {
    42  	Stacktrace []collection.Stackframe
    43  	regs       registers.Registers
    44  }
    45  
    46  func collectStacktrace(regs *registers.OnStackRegisters, g go_runtime.GPtr, maxStacktrace int) []collection.Stackframe {
    47  	if maxStacktrace == 0 {
    48  		maxStacktrace = 1 
    49  	}
    50  
    51  	pcs := make([]uintptr, maxStacktrace)
    52  	
    53  	frameCount := go_runtime.Callers(uintptr(regs.PC()), uintptr(regs.SP()), g, pcs)
    54  	pcs = pcs[:frameCount]
    55  	stacktrace := make([]collection.Stackframe, 0, frameCount)
    56  	frames := runtime.CallersFrames(pcs)
    57  
    58  	more := true
    59  	frame := runtime.Frame{}
    60  	for i := 0; i < frameCount; i++ {
    61  		if !more {
    62  			
    63  			logger.Logger().Warningf("Expected more frames but more is false: %d/%d\n", i, frameCount)
    64  			break
    65  		}
    66  
    67  		frame, more = frames.Next()
    68  		
    69  		
    70  		
    71  		
    72  		
    73  		stacktrace = append(stacktrace,
    74  			collection.Stackframe{
    75  				File:     frame.File,
    76  				Line:     frame.Line,
    77  				Function: frame.Func.Name(),
    78  			})
    79  	}
    80  
    81  	return stacktrace
    82  }
    83  
    84  func printBytesAt(sp, count uint64, prefixes []string) uint64 {
    85  	for i := uint64(0); i < count; i++ {
    86  		//goland:noinspection GoVetUnsafePointer
    87  		stackValue := *(*uint64)(unsafe.Pointer(uintptr(sp)))
    88  		fmt.Printf("0x%016x:\t0x%016x (%s) \n", sp, stackValue, prefixes[i])
    89  		sp = sp - 8
    90  	}
    91  	return sp
    92  }
    93  
    94  func printStack(stackRegs registers.OnStackRegisters) {
    95  	stackPtr := stackRegs.SP()
    96  
    97  	fmt.Printf("BP: 0x%016x\n", stackRegs.BP())
    98  	fmt.Printf("SP: 0x%016x\n", stackRegs.SP())
    99  
   100  	fmt.Println("Native:")
   101  	stackPtr = printBytesAt(stackPtr, 4, []string{"idk", "flags", "rdi", "rdx"})
   102  	stackPtr -= 240
   103  	stackPtr = printBytesAt(stackPtr, 11, []string{"rbx", "rax", "rcx", "rsi", "r8", "r9", "r10", "r11", "rbp", "rdi", "retval"})
   104  	fmt.Println("Golang:")
   105  	if runtime.Version() == "go1.16" {
   106  		stackPtr = printBytesAt(stackPtr, 1, []string{"idk", "idk"})
   107  	} else {
   108  		stackPtr = printBytesAt(stackPtr, 2, []string{"idk", "idk"})
   109  	}
   110  	stackPtr = printBytesAt(stackPtr, 12, []string{"rsp", "rbp", "tls", "rip", "r11", "r10", "r9", "r8", "rsi", "rcx", "rax", "rbx"})
   111  	stackPtr = printBytesAt(stackPtr, 2, []string{"rdx", "rdi"})
   112  }
   113  
   114  //go:linkname systemstack runtime.systemstack
   115  func systemstack(func())
   116  
   117  //go:nosplit
   118  func Callback() {
   119  	context := getContext()
   120  	g := go_runtime.Getg()
   121  
   122  	goCollect(context, g)
   123  }
   124  
   125  //go:nosplit
   126  func goCollect(context uintptr, g go_runtime.GPtr) {
   127  	var waitChan chan struct{}
   128  
   129  	triggerChan <- true
   130  
   131  	systemstack(func() {
   132  		waitChan = make(chan struct{})
   133  
   134  		
   135  		go func() {
   136  			defer func() {
   137  				waitChan <- struct{}{}
   138  				triggerChan <- false
   139  				if v := recover(); v != nil {
   140  					if utils.OnPanicFunc != nil {
   141  						utils.OnPanicFunc(rookoutErrors.NewUnknownError(v))
   142  					}
   143  
   144  					return
   145  				}
   146  			}()
   147  			debug.SetPanicOnFault(true)
   148  
   149  			collectBreakpoint(context, g)
   150  		}()
   151  	})
   152  
   153  	<-waitChan
   154  }
   155  
   156  func collectBreakpoint(context uintptr, g go_runtime.GPtr) {
   157  	regs := registers.NewOnStackRegisters(context)
   158  	bpInstance, ok := locationsSet.FindBreakpointByAddr(regs.PC())
   159  	if !ok {
   160  		file, line, function := BinaryInfo.PCToLine(regs.PC())
   161  		var functionName string
   162  		if function != nil {
   163  			functionName = function.Name
   164  		}
   165  
   166  		logger.Logger().Errorf("Breakpoint triggered on unknown address, 0x%x (%s:%d - %s)", regs.PC(), file, line, functionName)
   167  		return
   168  	}
   169  
   170  	goid := go_id.GetGoID(g)
   171  	stacktrace := collectStacktrace(regs, g, bpInstance.Breakpoint.Stacktrace)
   172  	
   173  	stacktrace[0].Line = bpInstance.Breakpoint.Line
   174  
   175  	bpInfo := &BreakpointInfo{
   176  		Stacktrace: stacktrace,
   177  		regs:       regs,
   178  	}
   179  	reportBreakpoint(bpInstance, bpInfo, goid)
   180  }
   181  
   182  func reportBreakpoint(bpInstance *augs.BreakpointInstance, bpInfo *BreakpointInfo, goid int) {
   183  	bp := bpInstance.Breakpoint
   184  	locations, exists := locationsSet.FindLocationsByBreakpointName(bp.Name)
   185  	if !exists {
   186  		logger.Logger().Errorf("Breakpoint %s (on %s:%d - 0x%x) triggered but the breakpoint doesn't exist.", bp.Name, bp.File, bp.Line, bpInfo.regs.PC())
   187  		return
   188  	}
   189  
   190  	wg := sync.WaitGroup{}
   191  	wg.Add(len(locations))
   192  	for i := range locations {
   193  		utils.CreateGoroutine(func(i int) func() {
   194  			return func() { 
   195  				defer func() {
   196  					if r := recover(); r != nil {
   197  						locations[i].SetError(rookoutErrors.NewUnknownError(r))
   198  					}
   199  				}()
   200  
   201  				defer wg.Done()
   202  
   203  				collectionService, err := collection.NewCollectionService(bpInfo.regs, BinaryInfo.PointerSize, bpInfo.Stacktrace, bpInstance.VariableLocators, goid)
   204  				if err != nil {
   205  					logger.Logger().WithError(err).Errorf("failed to report breakpoint info")
   206  				}
   207  
   208  				locations[i].GetAug().Execute(collectionService)
   209  			}
   210  		}(i))
   211  	}
   212  	wg.Wait()
   213  }