github.com/stealthrocket/wzprof@v0.2.1-0.20230830205924-5fa86be5e5b3/python.go (about)

     1  package wzprof
     2  
     3  import (
     4  	"debug/dwarf"
     5  	"encoding/binary"
     6  	"fmt"
     7  	"path/filepath"
     8  	"strings"
     9  	"unsafe"
    10  
    11  	"github.com/tetratelabs/wazero"
    12  	"github.com/tetratelabs/wazero/api"
    13  	"github.com/tetratelabs/wazero/experimental"
    14  )
    15  
    16  const (
    17  	runtimeAddrName = "_PyRuntime"
    18  	versionAddrName = "Py_Version"
    19  )
    20  
    21  func supportedPython(wasmbin []byte) bool {
    22  	p, err := newDwarfParserFromBin(wasmbin)
    23  	if err != nil {
    24  		return false
    25  	}
    26  
    27  	versionAddr := pythonAddress(p, versionAddrName)
    28  	if versionAddr == 0 {
    29  		return false
    30  	}
    31  
    32  	data := wasmdataSection(wasmbin)
    33  	if data == nil {
    34  		return false
    35  	}
    36  
    37  	var versionhex uint32
    38  	d := newDataIterator(data)
    39  	for {
    40  		vaddr, seg := d.Next()
    41  		if seg == nil || vaddr > int64(versionAddr) {
    42  			break
    43  		}
    44  
    45  		end := vaddr + int64(len(seg))
    46  		if int64(versionAddr)+4 >= end {
    47  			continue
    48  		}
    49  
    50  		offset := int64(versionAddr) - vaddr
    51  		versionhex = binary.LittleEndian.Uint32(seg[offset:])
    52  		break
    53  	}
    54  
    55  	// see cpython patchlevel.h
    56  	major := (versionhex >> 24) & 0xFF
    57  	minor := (versionhex >> 16) & 0xFF
    58  	return major == 3 && minor == 11
    59  }
    60  
    61  func preparePython(mod wazero.CompiledModule) (*python, error) {
    62  	p, err := newDwarfparser(mod)
    63  	if err != nil {
    64  		return nil, fmt.Errorf("could not build dwarf parser: %w", err)
    65  	}
    66  	runtimeAddr := pythonAddress(p, runtimeAddrName)
    67  	if runtimeAddr == 0 {
    68  		return nil, fmt.Errorf("could not find python runtime address")
    69  	}
    70  	return &python{
    71  		pyrtaddr: ptr32(runtimeAddr),
    72  	}, nil
    73  }
    74  
    75  func pythonAddress(p dwarfparser, name string) uint32 {
    76  	for {
    77  		ent, err := p.r.Next()
    78  		if err != nil || ent == nil {
    79  			break
    80  		}
    81  		if ent.Tag != dwarf.TagVariable {
    82  			continue
    83  		}
    84  		n, _ := ent.Val(dwarf.AttrName).(string)
    85  		if n != name {
    86  			continue
    87  		}
    88  		return getDwarfLocationAddress(ent)
    89  	}
    90  	return 0
    91  }
    92  
    93  type python struct {
    94  	pyrtaddr ptr32
    95  }
    96  
    97  func getDwarfLocationAddress(ent *dwarf.Entry) uint32 {
    98  	f := ent.AttrField(dwarf.AttrLocation)
    99  	if f == nil {
   100  		return 0
   101  	}
   102  	if f.Class != dwarf.ClassExprLoc {
   103  		panic(fmt.Errorf("invalid location class: %s", f.Class))
   104  	}
   105  	const DW_OP_addr = 0x3
   106  	loc := f.Val.([]byte)
   107  	if len(loc) == 0 || loc[0] != DW_OP_addr {
   108  		panic(fmt.Errorf("unexpected address format: %X", loc))
   109  	}
   110  	return binary.LittleEndian.Uint32(loc[1:])
   111  }
   112  
   113  // Padding of fields in various CPython structs. They are calculated
   114  // by writing a function in any CPython module, and executing it with
   115  // wazero.
   116  //
   117  // TODO: look into using CGO and #import<Python.h> to generate them
   118  // instead.
   119  const (
   120  	// _PyRuntimeState.
   121  	padTstateCurrentInRT = 360
   122  	// PyThreadState.
   123  	padCframeInThreadState = 40
   124  	// _PyCFrame.
   125  	padCurrentFrameInCFrame = 4
   126  	// _PyInterpreterFrame.
   127  	padPreviousInFrame  = 24
   128  	padCodeInFrame      = 16
   129  	padPrevInstrInFrame = 28
   130  	padOwnerInFrame     = 37
   131  	// PyCodeObject.
   132  	padFilenameInCodeObject       = 80
   133  	padNameInCodeObject           = 84
   134  	padCodeAdaptiveInCodeObject   = 116
   135  	padFirstlinenoInCodeObject    = 48
   136  	padLinearrayInCodeObject      = 104
   137  	padLinetableInCodeObject      = 92
   138  	padFirstTraceableInCodeObject = 108
   139  	sizeCodeUnit                  = 2
   140  	// PyASCIIObject.
   141  	padStateInAsciiObject  = 16
   142  	padLengthInAsciiObject = 8
   143  	sizeAsciiObject        = 24
   144  	// PyBytesObject.
   145  	padSvalInBytesObject = 16
   146  	padSizeInBytesObject = 8
   147  	// Enum constants.
   148  	enumCodeLocation1         = 11
   149  	enumCodeLocation2         = 12
   150  	enumCodeLocationNoCol     = 13
   151  	enumCodeLocationLong      = 14
   152  	enumFrameOwnedByGenerator = 1
   153  )
   154  
   155  func (p *python) Locations(fn experimental.InternalFunction, pc experimental.ProgramCounter) (uint64, []location) {
   156  	call := fn.(pyfuncall)
   157  
   158  	loc := location{
   159  		File:       call.file,
   160  		Line:       int64(call.line),
   161  		Column:     0, // TODO
   162  		Inlined:    false,
   163  		HumanName:  call.name,
   164  		StableName: call.file + "." + call.name,
   165  	}
   166  
   167  	return uint64(call.addr), []location{loc}
   168  }
   169  
   170  func (p *python) Stackiter(mod api.Module, def api.FunctionDefinition, wasmsi experimental.StackIterator) experimental.StackIterator {
   171  	m := mod.Memory()
   172  	tsp := deref[ptr32](m, p.pyrtaddr+padTstateCurrentInRT)
   173  	cframep := deref[ptr32](m, tsp+padCframeInThreadState)
   174  	framep := deref[ptr32](m, cframep+padCurrentFrameInCFrame)
   175  
   176  	return &pystackiter{
   177  		namedbg: def.DebugName(),
   178  		mem:     m,
   179  		framep:  framep,
   180  	}
   181  }
   182  
   183  type pystackiter struct {
   184  	namedbg string
   185  	mem     api.Memory
   186  	started bool
   187  	framep  ptr32 // _PyInterpreterFrame*
   188  }
   189  
   190  func (p *pystackiter) Next() bool {
   191  	if !p.started {
   192  		p.started = true
   193  		return p.framep != 0
   194  	}
   195  
   196  	oldframe := p.framep
   197  	p.framep = deref[ptr32](p.mem, p.framep+padPreviousInFrame)
   198  	if oldframe == p.framep {
   199  		p.framep = 0
   200  		return false
   201  	}
   202  	return p.framep != 0
   203  }
   204  
   205  func (p *pystackiter) ProgramCounter() experimental.ProgramCounter {
   206  	return experimental.ProgramCounter(deref[uint32](p.mem, p.framep+padPrevInstrInFrame))
   207  }
   208  
   209  func (p *pystackiter) Function() experimental.InternalFunction {
   210  	codep := deref[ptr32](p.mem, p.framep+padCodeInFrame)
   211  	line, _ := lineForFrame(p.mem, p.framep, codep)
   212  	file := derefPyUnicodeUtf8(p.mem, codep+padFilenameInCodeObject)
   213  	name := derefPyUnicodeUtf8(p.mem, codep+padNameInCodeObject)
   214  	return pyfuncall{
   215  		file: file,
   216  		name: functionName(file, name),
   217  		addr: deref[uint32](p.mem, p.framep+padPrevInstrInFrame),
   218  		line: line,
   219  	}
   220  }
   221  
   222  func functionName(path, function string) string {
   223  	mod := ""
   224  	const frozenPrefix = "<frozen "
   225  	if strings.HasPrefix(path, frozenPrefix) {
   226  		mod = path[len(frozenPrefix) : len(path)-1]
   227  	} else {
   228  		if strings.HasSuffix(path, "__init__.py") {
   229  			path = filepath.Dir(path)
   230  		}
   231  		file := filepath.Base(path)
   232  		mod = file[:len(file)-len(filepath.Ext(file))]
   233  	}
   234  
   235  	if function == "<module>" {
   236  		return mod
   237  	}
   238  	return mod + "." + function
   239  }
   240  
   241  func (p *pystackiter) Parameters() []uint64 {
   242  	panic("TODO parameters()")
   243  }
   244  
   245  // pyfuncall represent a specific place in the python source where a
   246  // function call occurred.
   247  type pyfuncall struct {
   248  	file string
   249  	name string
   250  	line int32
   251  	addr uint32
   252  
   253  	api.FunctionDefinition // required for WazeroOnly
   254  }
   255  
   256  func (f pyfuncall) Definition() api.FunctionDefinition {
   257  	return f
   258  }
   259  
   260  func (f pyfuncall) SourceOffsetForPC(pc experimental.ProgramCounter) uint64 {
   261  	panic("does not make sense")
   262  }
   263  
   264  func (f pyfuncall) ModuleName() string {
   265  	return "<unknown>" // TODO
   266  }
   267  
   268  func (f pyfuncall) Index() uint32 {
   269  	return 42 // TODO
   270  }
   271  
   272  func (f pyfuncall) Import() (string, string, bool) {
   273  	panic("implement me")
   274  }
   275  
   276  func (f pyfuncall) ExportNames() []string {
   277  	panic("implement me")
   278  }
   279  
   280  func (f pyfuncall) Name() string {
   281  	return f.name
   282  }
   283  
   284  func (f pyfuncall) DebugName() string {
   285  	return f.name
   286  }
   287  
   288  func (f pyfuncall) GoFunction() interface{} {
   289  	return nil
   290  }
   291  
   292  func (f pyfuncall) ParamTypes() []api.ValueType {
   293  	panic("implement me")
   294  }
   295  
   296  func (f pyfuncall) ParamNames() []string {
   297  	panic("implement me")
   298  }
   299  
   300  func (f pyfuncall) ResultTypes() []api.ValueType {
   301  	panic("implement me")
   302  }
   303  
   304  func (f pyfuncall) ResultNames() []string {
   305  	panic("implement me")
   306  }
   307  
   308  // Return the utf8 encoding of a PyUnicode object. It is a
   309  // re-implementation of PyUnicode_AsUTF8. The bytes are copied from
   310  // the vmem, so the returned string is safe to use.
   311  func pyUnicodeUTf8(m vmem, p ptr32) string {
   312  	statep := p + padStateInAsciiObject
   313  	state := deref[uint8](m, statep)
   314  	compact := state&(1<<5) > 0
   315  	ascii := state&(1<<6) > 0
   316  	if !compact || !ascii {
   317  		panic("only support ascii-compact utf8 representation")
   318  	}
   319  
   320  	length := deref[int32](m, p+padLengthInAsciiObject)
   321  	bytes := derefArray[byte](m, p+sizeAsciiObject, uint32(length))
   322  	return unsafe.String(unsafe.SliceData(bytes), len(bytes))
   323  }
   324  
   325  func derefPyUnicodeUtf8(m vmem, p ptr32) string {
   326  	x := deref[ptr32](m, p)
   327  	return pyUnicodeUTf8(m, x)
   328  }
   329  
   330  func lineForFrame(m vmem, framep, codep ptr32) (int32, bool) {
   331  	codestart := codep + padCodeAdaptiveInCodeObject
   332  	previnstr := deref[ptr32](m, framep+padPrevInstrInFrame)
   333  	firstlineno := deref[int32](m, codep+padFirstlinenoInCodeObject)
   334  
   335  	if previnstr < codestart {
   336  		return firstlineno, false
   337  	}
   338  
   339  	linearray := deref[ptr32](m, codep+padLinearrayInCodeObject)
   340  	if linearray != 0 {
   341  		panic("can't handle code sections with line arrays")
   342  	}
   343  
   344  	codebytes := deref[ptr32](m, codep+padLinetableInCodeObject)
   345  	if codebytes == 0 {
   346  		panic("code section must have a linetable")
   347  	}
   348  
   349  	length := deref[int32](m, codebytes+padSizeInBytesObject)
   350  	linetable := codebytes + padSvalInBytesObject
   351  	addrq := int32(previnstr - codestart)
   352  
   353  	lo_next := linetable             // pointer to the current byte in the line table
   354  	limit := lo_next + ptr32(length) // pointer to the end of the linetable
   355  	ar_end := int32(0)               // offset into the code section
   356  	computed_line := firstlineno     // current known line number
   357  	ar_line := int32(-1)             // line for the current bytecode
   358  
   359  	for ar_end <= addrq && lo_next < limit {
   360  		lineDelta := int32(0)
   361  		ptr := lo_next
   362  
   363  		entry := deref[uint8](m, ptr)
   364  		code := (entry >> 3) & 15
   365  		switch code {
   366  		case enumCodeLocation1:
   367  			lineDelta = 1
   368  		case enumCodeLocation2:
   369  			lineDelta = 2
   370  		case enumCodeLocationNoCol, enumCodeLocationLong:
   371  			lineDelta = pysvarint(m, ptr+1)
   372  		}
   373  
   374  		computed_line += lineDelta
   375  
   376  		if (entry >> 3) == 0x1F {
   377  			ar_line = -1
   378  		} else {
   379  			ar_line = computed_line
   380  		}
   381  
   382  		ar_end += (int32(entry&7) + 1) * sizeCodeUnit
   383  
   384  		lo_next++
   385  		for lo_next < limit && (deref[uint8](m, lo_next)&128 == 0) {
   386  			lo_next++
   387  		}
   388  	}
   389  
   390  	return ar_line, true
   391  }
   392  
   393  // Python-specific implementation of protobuf signed varints. However
   394  // it only uses 7 bits, as python uses the most significant bit to
   395  // store whether an entry starts on that byte.
   396  func pysvarint(m vmem, p ptr32) int32 {
   397  	read := deref[uint8](m, p)
   398  	val := uint32(read & 63)
   399  	shift := 0
   400  	for read&64 > 0 {
   401  		read = deref[uint8](m, p)
   402  		p++
   403  		shift += 6
   404  		val |= uint32(read&63) << shift
   405  	}
   406  
   407  	x := int32(val >> 1)
   408  	if val&1 > 0 {
   409  		x = -x
   410  	}
   411  	return x
   412  }