github.com/nicocha30/gvisor-ligolo@v0.0.0-20230726075806-989fa2c0a413/pkg/sentry/platform/systrap/usertrap/usertrap_amd64.go (about)

     1  // Copyright 2020 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  //go:build amd64
    16  // +build amd64
    17  
    18  package usertrap
    19  
    20  import (
    21  	"encoding/binary"
    22  	"fmt"
    23  	"math/rand"
    24  
    25  	"golang.org/x/sys/unix"
    26  	"github.com/nicocha30/gvisor-ligolo/pkg/context"
    27  	"github.com/nicocha30/gvisor-ligolo/pkg/hostarch"
    28  	"github.com/nicocha30/gvisor-ligolo/pkg/marshal/primitive"
    29  	"github.com/nicocha30/gvisor-ligolo/pkg/sentry/arch"
    30  	"github.com/nicocha30/gvisor-ligolo/pkg/sentry/kernel"
    31  	"github.com/nicocha30/gvisor-ligolo/pkg/sentry/memmap"
    32  	"github.com/nicocha30/gvisor-ligolo/pkg/sync"
    33  	"github.com/nicocha30/gvisor-ligolo/pkg/usermem"
    34  )
    35  
    36  // trapNR is the maximum number of traps what can fit in the trap table.
    37  const trapNR = 256
    38  
    39  // trapSize is the size of one trap.
    40  const trapSize = 80
    41  
    42  var (
    43  	// jmpInst is the binary code of "jmp *addr".
    44  	jmpInst          = [7]byte{0xff, 0x24, 0x25, 0, 0, 0, 0}
    45  	jmpInstOpcodeLen = 3
    46  	// faultInst is the single byte invalid instruction.
    47  	faultInst = [1]byte{0x6}
    48  	// faultInstOffset is the offset of the syscall instruction.
    49  	faultInstOffset = uintptr(5)
    50  )
    51  
    52  type memoryManager interface {
    53  	usermem.IO
    54  	MMap(ctx context.Context, opts memmap.MMapOpts) (hostarch.Addr, error)
    55  	FindVMAByName(ar hostarch.AddrRange, hint string) (hostarch.Addr, uint64, error)
    56  }
    57  
    58  // State represents the current state of the trap table.
    59  //
    60  // +stateify savable
    61  type State struct {
    62  	mu        sync.RWMutex `state:"nosave"`
    63  	nextTrap  uint32
    64  	tableAddr hostarch.Addr
    65  }
    66  
    67  // New returns the new state structure.
    68  func New() *State {
    69  	return &State{}
    70  }
    71  
    72  // +marshal
    73  type header struct {
    74  	nextTrap uint32
    75  }
    76  
    77  func (s *State) trapAddr(trap uint32) hostarch.Addr {
    78  	return s.tableAddr + hostarch.Addr(trapSize*trap)
    79  }
    80  
    81  // newTrapLocked allocates a new trap entry.
    82  //
    83  // Preconditions: s.mu must be locked.
    84  func (s *State) newTrapLocked(ctx context.Context, mm memoryManager) (hostarch.Addr, error) {
    85  	var hdr header
    86  	task := kernel.TaskFromContext(ctx)
    87  	if task == nil {
    88  		return 0, fmt.Errorf("no task found")
    89  	}
    90  
    91  	// s.nextTrap is zero if it isn't initialized. Here are three cases
    92  	// when this can happen:
    93  	//	* A usertrap vma has not been mapped yet.
    94  	//	* The address space has been forked.
    95  	//	* The address space has been restored.
    96  	// nextTrap is saved on the usertrap vma to handle the third and second
    97  	// cases.
    98  	if s.nextTrap == 0 {
    99  		addr, off, err := mm.FindVMAByName(trapTableAddrRange, tableHint)
   100  		if off != 0 {
   101  			return 0, fmt.Errorf("the usertrap vma has been overmounted")
   102  		}
   103  		if err != nil {
   104  			// The usertrap table has not been mapped yet.
   105  			addr := hostarch.Addr(rand.Int63n(int64(trapTableAddrRange.Length()-trapTableSize))).RoundDown() + trapTableAddrRange.Start
   106  			ctx.Debugf("Map a usertrap vma at %x", addr)
   107  			if err := loadUsertrap(ctx, mm, addr); err != nil {
   108  				return 0, err
   109  			}
   110  			// The first cell in the table is used to save an index of a
   111  			// next unused trap.
   112  			s.nextTrap = 1
   113  			s.tableAddr = addr
   114  		} else if _, err := hdr.CopyIn(task.OwnCopyContext(usermem.IOOpts{AddressSpaceActive: false}), addr); err != nil {
   115  			return 0, err
   116  		} else {
   117  			// Read an index of a next unused trap.
   118  			s.nextTrap = hdr.nextTrap
   119  			s.tableAddr = addr
   120  		}
   121  	}
   122  	ctx.Debugf("Allocate a new trap: %p %d", s, s.nextTrap)
   123  	if s.nextTrap >= trapNR {
   124  		ctx.Warningf("No space in the trap table")
   125  		return 0, fmt.Errorf("no space in the trap table")
   126  	}
   127  	trap := s.nextTrap
   128  	s.nextTrap++
   129  
   130  	// An entire trap has to be on the same page to avoid memory faults.
   131  	addr := s.trapAddr(trap)
   132  	if addr/hostarch.PageSize != (addr+trapSize)/hostarch.PageSize {
   133  		trap = s.nextTrap
   134  		s.nextTrap++
   135  	}
   136  	hdr = header{
   137  		nextTrap: s.nextTrap,
   138  	}
   139  	if _, err := hdr.CopyOut(task.OwnCopyContext(usermem.IOOpts{IgnorePermissions: true}), s.tableAddr); err != nil {
   140  		return 0, err
   141  	}
   142  	return s.trapAddr(trap), nil
   143  }
   144  
   145  // trapTableAddrRange is the range where a trap table can be placed.
   146  //
   147  // The value has to be below 2GB and the high two bytes has to be an invalid
   148  // instruction.  In case of 0x60000, the high two bytes is 0x6. This is "push
   149  // es" in x86 and the bad instruction on x64.
   150  var trapTableAddrRange = hostarch.AddrRange{Start: 0x60000, End: 0x70000}
   151  
   152  const (
   153  	trapTableSize = hostarch.Addr(trapNR * trapSize)
   154  
   155  	tableHint = "[usertrap]"
   156  )
   157  
   158  // LoadUsertrap maps the usertrap table into the address space.
   159  func loadUsertrap(ctx context.Context, mm memoryManager, addr hostarch.Addr) error {
   160  	size, _ := hostarch.Addr(trapTableSize).RoundUp()
   161  	// Force is true because Addr is below MinUserAddress.
   162  	_, err := mm.MMap(ctx, memmap.MMapOpts{
   163  		Force:     true,
   164  		Unmap:     true,
   165  		Fixed:     true,
   166  		Addr:      addr,
   167  		Length:    uint64(size),
   168  		Private:   true,
   169  		Hint:      tableHint,
   170  		MLockMode: memmap.MLockEager,
   171  		Perms: hostarch.AccessType{
   172  			Write:   false,
   173  			Read:    true,
   174  			Execute: true,
   175  		},
   176  		MaxPerms: hostarch.AccessType{
   177  			Write:   true,
   178  			Read:    true,
   179  			Execute: true,
   180  		},
   181  	})
   182  	if err != nil {
   183  		return err
   184  	}
   185  
   186  	return nil
   187  }
   188  
   189  // PatchSyscall changes the syscall instruction into a function call.
   190  //
   191  // Returns true if the thread has to be restarted.
   192  func (s *State) PatchSyscall(ctx context.Context, ac *arch.Context64, mm memoryManager) (bool, error) {
   193  	task := kernel.TaskFromContext(ctx)
   194  	if task == nil {
   195  		return false, fmt.Errorf("no task found")
   196  	}
   197  
   198  	s.mu.Lock()
   199  	defer s.mu.Unlock()
   200  
   201  	sysno := ac.SyscallNo()
   202  	patchAddr := ac.IP() - uintptr(len(jmpInst))
   203  
   204  	prevCode := make([]uint8, len(jmpInst))
   205  	if _, err := primitive.CopyUint8SliceIn(task.OwnCopyContext(usermem.IOOpts{AddressSpaceActive: false}), hostarch.Addr(patchAddr), prevCode); err != nil {
   206  		return false, err
   207  	}
   208  
   209  	// Check that another thread has not patched this syscall yet.
   210  	// 0xb8 is the first byte of "mov sysno, %eax".
   211  	if prevCode[0] == uint8(0xb8) {
   212  		ctx.Debugf("Found the pattern at ip %x:sysno %d", patchAddr, sysno)
   213  
   214  		trapAddr, err := s.addTrapLocked(ctx, ac, mm, uint32(sysno))
   215  		if trapAddr == 0 || err != nil {
   216  			ctx.Warningf("Failed to add a new trap: %v", err)
   217  			return false, nil
   218  		}
   219  
   220  		// Replace "mov sysno, %eax; syscall" with "jmp trapAddr".
   221  		newCode := make([]uint8, len(jmpInst))
   222  		copy(newCode[:jmpInstOpcodeLen], jmpInst[:jmpInstOpcodeLen])
   223  		binary.LittleEndian.PutUint32(newCode[jmpInstOpcodeLen:], uint32(trapAddr))
   224  
   225  		ctx.Debugf("Apply the binary patch addr %x trap addr %x (%v -> %v)", patchAddr, trapAddr, prevCode, newCode)
   226  
   227  		ignorePermContext := task.OwnCopyContext(usermem.IOOpts{IgnorePermissions: true})
   228  
   229  		// The patch can't be applied atomically, so we need to
   230  		// guarantee that in each moment other threads will read a
   231  		// valid set of instructions, detect any inconsistent states
   232  		// and restart the patched code if so.
   233  		//
   234  		// A subtle aspect is the address at which the user trap table
   235  		// is always mapped which is 0x60000. The first byte of this is
   236  		// 0x06 which is an invalid opcode. That’s why when we
   237  		// overwrite all the bytes but the first 1 in the second step
   238  		// it works fine since the jump address still writes a 0x6 at
   239  		// the location of the first byte of syscall instruction that
   240  		// we are removing and any threads reading the instructions
   241  		// will still fault at the same place.
   242  		//
   243  		// Another subtle aspect is the second step is done using a
   244  		// regular non-atomic write which means a thread decoding the
   245  		// mov instruction could read a garbage value of the immediate
   246  		// operand for the ‘mov sysyno, %eax” instruction. But it
   247  		// doesn’t matter since we don’t change the first byte which is
   248  		// the one that contains the opcode. Also since the thread will
   249  		// fault on the 0x6 right after and will be restarted with the
   250  		// patched code the mov reading a garbage immediate operand
   251  		// doesn’t impact correctness.
   252  
   253  		// The patch is applied in three steps:
   254  		//
   255  		// The first step is to replace the first byte of the syscall
   256  		// instruction by one-byte invalid instruction (0x06), so that
   257  		// other threads which have passed the mov instruction fault on
   258  		// the invalid instruction and restart a patched code.
   259  		faultInstB := primitive.ByteSlice(faultInst[:])
   260  		if _, err := faultInstB.CopyOut(ignorePermContext, hostarch.Addr(patchAddr+faultInstOffset)); err != nil {
   261  			return false, err
   262  		}
   263  		// The second step is to replace all bytes except the first one
   264  		// which is the opcode of the mov instruction, so that the first
   265  		// five bytes remain "mov XXX, %rax".
   266  		if _, err := primitive.CopyUint8SliceOut(ignorePermContext, hostarch.Addr(patchAddr+1), newCode[1:]); err != nil {
   267  			return false, err
   268  		}
   269  		// The final step is to replace the first byte of the patch.
   270  		// After this point, all threads will read the valid jmp
   271  		// instruction.
   272  		if _, err := primitive.CopyUint8SliceOut(ignorePermContext, hostarch.Addr(patchAddr), newCode[0:1]); err != nil {
   273  			return false, err
   274  		}
   275  	}
   276  	ac.RestartSyscall()
   277  	ac.SetIP(patchAddr)
   278  	return true, nil
   279  }
   280  
   281  // HandleFault handles a fault on a patched syscall instruction.
   282  //
   283  // When we replace a system call with a function call, we replace two
   284  // instructions with one instruction. This means that here can be a thread
   285  // which called the first instruction, then another thread applied a binary
   286  // patch and the first thread calls the second instruction.
   287  //
   288  // To handle this case, the function call (jmp) instruction is constructed so
   289  // that the first byte of the syscall instruction is changed with the one-byte
   290  // invalid instruction (0x6).  And in case of the race, the first thread will
   291  // fault on the invalid instruction and HandleFault will restart the function
   292  // call.
   293  func (s *State) HandleFault(ctx context.Context, ac *arch.Context64, mm memoryManager) error {
   294  	task := kernel.TaskFromContext(ctx)
   295  	if task == nil {
   296  		return fmt.Errorf("no task found")
   297  	}
   298  
   299  	s.mu.RLock()
   300  	defer s.mu.RUnlock()
   301  
   302  	code := make([]uint8, len(jmpInst))
   303  	ip := ac.IP() - faultInstOffset
   304  	if _, err := primitive.CopyUint8SliceIn(task.OwnCopyContext(usermem.IOOpts{AddressSpaceActive: false}), hostarch.Addr(ip), code); err != nil {
   305  		return err
   306  	}
   307  
   308  	for i := 0; i < jmpInstOpcodeLen; i++ {
   309  		if code[i] != jmpInst[i] {
   310  			return nil
   311  		}
   312  	}
   313  	for i := 0; i < len(faultInst); i++ {
   314  		if code[i+int(faultInstOffset)] != faultInst[i] {
   315  			return nil
   316  		}
   317  	}
   318  
   319  	regs := &ac.StateData().Regs
   320  	if regs.Rax == uint64(unix.SYS_RESTART_SYSCALL) {
   321  		// restart_syscall is usually set by the Sentry to restart a
   322  		// system call after interruption by a stop signal. The Sentry
   323  		// sets RAX and moves RIP back on the size of the syscall
   324  		// instruction.
   325  		//
   326  		// RAX can't be set to SYS_RESTART_SYSCALL due to a race with
   327  		// injecting a function call, because neither of the two first
   328  		// bytes are equal to proper bytes of jmpInst.
   329  		regs.Orig_rax = regs.Rax
   330  		regs.Rip += arch.SyscallWidth
   331  		return ErrFaultSyscall
   332  	}
   333  
   334  	ac.SetIP(ip)
   335  	return ErrFaultRestart
   336  }
   337  
   338  // PreFork locks the trap table for reading. This call guarantees that the trap
   339  // table will not be changed before the next PostFork call.
   340  // +checklocksacquireread:s.mu
   341  func (s *State) PreFork() {
   342  	s.mu.RLock()
   343  }
   344  
   345  // PostFork unlocks the trap table.
   346  // +checklocksreleaseread:s.mu
   347  func (s *State) PostFork() {
   348  	s.mu.RUnlock()
   349  }