github.com/cloudwego/frugal@v0.1.15/internal/atm/pgen/pgen_amd64_test.go (about)

     1  /*
     2   * Copyright 2022 ByteDance Inc.
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package pgen
    18  
    19  import (
    20      `fmt`
    21      `runtime`
    22      `strings`
    23      `testing`
    24      `unsafe`
    25  
    26      `github.com/cloudwego/iasm/x86_64`
    27      `github.com/cloudwego/frugal/internal/atm/hir`
    28      `github.com/cloudwego/frugal/internal/atm/rtx`
    29      `github.com/cloudwego/frugal/internal/loader`
    30      `github.com/cloudwego/frugal/internal/rt`
    31      `github.com/davecgh/go-spew/spew`
    32      `github.com/stretchr/testify/require`
    33      `golang.org/x/arch/x86/x86asm`
    34  )
    35  
    36  const (
    37      _MaxByte = 10
    38  )
    39  
    40  func symlookup(addr uint64) (string, uint64) {
    41      fp := runtime.FuncForPC(uintptr(addr))
    42      if fp != nil {
    43          ent := uint64(fp.Entry())
    44          if addr == ent {
    45              return fmt.Sprintf("%#x{%s}", addr, fp.Name()), ent
    46          }
    47          return fmt.Sprintf("%#x{%s+%#x}", addr, fp.Name(), addr - ent), ent
    48      }
    49      if addr == uint64(uintptr(rtx.V_pWriteBarrier)) {
    50          return fmt.Sprintf("%#x{runtime.writeBarrier}", addr), addr
    51      }
    52      return "", 0
    53  }
    54  
    55  func disasm(orig uintptr, c []byte) {
    56      var pc int
    57      for pc < len(c) {
    58          i, err := x86asm.Decode(c[pc:], 64)
    59          if err != nil {
    60              panic(err)
    61          }
    62          dis := x86asm.GNUSyntax(i, uint64(pc) + uint64(orig), symlookup)
    63          fmt.Printf("0x%08x : ", pc + int(orig))
    64          for x := 0; x < i.Len; x++ {
    65              if x != 0 && x % _MaxByte == 0 {
    66                  fmt.Printf("\n           : ")
    67              }
    68              fmt.Printf(" %02x", c[pc + x])
    69              if x == _MaxByte - 1 {
    70                  fmt.Printf("    %s", dis)
    71              }
    72          }
    73          if i.Len < _MaxByte {
    74              fmt.Printf("%s    %s", strings.Repeat(" ", (_MaxByte - i.Len) * 3), dis)
    75          }
    76          fmt.Printf("\n")
    77          pc += i.Len
    78      }
    79  }
    80  
    81  type TestIface interface {
    82      Bar(x int, y int) int
    83      Foo(x int, y int) int
    84  }
    85  
    86  var (
    87      hfunc *hir.CallHandle
    88      hmeth *hir.CallHandle
    89      cfunc uintptr
    90  )
    91  
    92  var (
    93      testfn = hir.RegisterGCall(testemu_pfunc, func(ctx hir.CallContext) {
    94          var v0 struct {P unsafe.Pointer; L uint64}
    95          var v1 struct {P unsafe.Pointer; L uint64}
    96          var v2 struct {P unsafe.Pointer; L uint64}
    97          if !ctx.Verify("*i*i*i", "*i*i") {
    98              panic("invalid testemu_pfunc call")
    99          }
   100          v0.P = ctx.Ap(0)
   101          v0.L = ctx.Au(1)
   102          v1.P = ctx.Ap(2)
   103          v1.L = ctx.Au(3)
   104          v2.P = ctx.Ap(4)
   105          v2.L = ctx.Au(5)
   106          r0, r1 := testemu_pfunc(
   107              *(*string)(unsafe.Pointer(&v0)),
   108              *(*string)(unsafe.Pointer(&v1)),
   109              *(*string)(unsafe.Pointer(&v2)),
   110          )
   111          ctx.Ru(1, uint64(len(r0)))
   112          ctx.Ru(3, uint64(len(r1)))
   113          ctx.Rp(0, *(*unsafe.Pointer)(unsafe.Pointer(&r0)))
   114          ctx.Rp(2, *(*unsafe.Pointer)(unsafe.Pointer(&r1)))
   115      })
   116  )
   117  
   118  func init() {
   119      hfunc = hir.RegisterCCall(unsafe.Pointer(&cfunc), nil)
   120      hmeth = hir.RegisterICall(rt.GetMethod((*TestIface)(nil), "Foo"), nil)
   121  }
   122  
   123  func testemu_pfunc(a string, b string, c string) (d string, e string) {
   124      d = a + b
   125      e = b + c
   126      return
   127  }
   128  
   129  func TestPGen_Generate(t *testing.T) {
   130      p := hir.CreateBuilder()
   131      p.IQ(0, hir.R0)
   132      p.IQ(1, hir.R1)
   133      p.IQ(2, hir.R2)
   134      p.MOVP(hir.Pn, hir.P0)
   135      p.MOVP(hir.Pn, hir.P1)
   136      p.MOVP(hir.Pn, hir.P2)
   137      p.BREAK()
   138      p.BREAK()
   139      p.BREAK()
   140      p.BREAK()
   141      p.BREAK()
   142      p.CCALL(hfunc).A0(hir.R0).A1(hir.R1).A2(hir.R2).R0(hir.R0)
   143      p.BREAK()
   144      p.BREAK()
   145      p.BREAK()
   146      p.BREAK()
   147      p.BREAK()
   148      p.GCALL(testfn).A0(hir.P0).A1(hir.R0).A2(hir.P1).A3(hir.R1).A4(hir.P2).A5(hir.R2).R0(hir.P0).R1(hir.R0).R2(hir.P1).R3(hir.R1)
   149      p.BREAK()
   150      p.BREAK()
   151      p.BREAK()
   152      p.BREAK()
   153      p.BREAK()
   154      p.ICALL(hir.P0, hir.P1, hmeth).A0(hir.R0).A1(hir.R1).R0(hir.R2)
   155      p.BREAK()
   156      p.BREAK()
   157      p.BREAK()
   158      p.BREAK()
   159      p.BREAK()
   160      p.BCOPY(hir.P1, hir.R1, hir.P0)
   161      p.BREAK()
   162      p.BREAK()
   163      p.BREAK()
   164      p.BREAK()
   165      p.BREAK()
   166      p.RET()
   167      g := CreateCodeGen(func(){})
   168      c := g.Generate(p.Build(), 0)
   169      disasm(0, c.Code)
   170  }
   171  
   172  type ifacetest interface {
   173      Foo(int) int
   174  }
   175  
   176  type ifacetesttype int
   177  func (self ifacetesttype) Foo(v int) int {
   178      runtime.GC()
   179      println("iface Foo(), self is", self, ", v is", v)
   180      return int(self) + v
   181  }
   182  
   183  func gcalltestfn(a int) (int, int, int) {
   184      runtime.GC()
   185      println("a is", a)
   186      return a + 100, a + 200, a + 300
   187  }
   188  
   189  func mkccalltestfn() unsafe.Pointer {
   190      var asm x86_64.Assembler
   191      err := asm.Assemble(`
   192          movq    %rdi, %rax
   193          addq    $10087327, %rax
   194          ret
   195      `)
   196      if err != nil {
   197          panic(err)
   198      }
   199      p := loader.Loader(asm.Code()).Load("_ccalltestfn", rt.Frame{})
   200      return *(*unsafe.Pointer)(p)
   201  }
   202  
   203  func TestPGen_FunctionCall(t *testing.T) {
   204      var s ifacetest
   205      var i ifacetesttype = 123456
   206      s = i
   207      m := hir.RegisterICall(rt.GetMethod((*ifacetest)(nil), "Foo"), nil)
   208      c := hir.RegisterCCall(mkccalltestfn(), nil)
   209      h := hir.RegisterGCall(gcalltestfn, nil)
   210      p := hir.CreateBuilder()
   211      e := *(*rt.GoIface)(unsafe.Pointer(&s))
   212      p.IP(e.Itab, hir.P0)
   213      p.IP(e.Value, hir.P1)
   214      p.LDAQ(0, hir.R0)
   215      p.GCALL(h).A0(hir.R0).R0(hir.R1).R1(hir.R2).R2(hir.R3)
   216      p.ADD(hir.R1, hir.R2, hir.R1)
   217      p.ADD(hir.R2, hir.R3, hir.R2)
   218      p.ICALL(hir.P0, hir.P1, m).A0(hir.R3).R0(hir.R4)
   219      p.ADDI(hir.R4, 10000000, hir.R3)
   220      p.CCALL(c).A0(hir.R3).R0(hir.R4)
   221      p.RET().R0(hir.R1).R1(hir.R2).R2(hir.R4)
   222      g := CreateCodeGen((func(int) (int, int, int))(nil))
   223      r := g.Generate(p.Build(), 0)
   224      spew.Dump(r.Frame)
   225      v := loader.Loader(r.Code).Load("_test_gcall", r.Frame)
   226      disasm(*(*uintptr)(v), r.Code)
   227      f := *(*func(int) (int, int, int))(unsafe.Pointer(&v))
   228      x, y, z := f(123)
   229      println("f(123) is", x, y, z)
   230      require.Equal(t, 546, x)
   231      require.Equal(t, 746, y)
   232      require.Equal(t, 20211206, z)
   233  }