github.com/criyle/go-sandbox@v0.10.3/ptracer/tracer_track_linux.go (about)

     1  package ptracer
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"runtime"
     7  	"time"
     8  
     9  	unix "golang.org/x/sys/unix"
    10  
    11  	"github.com/criyle/go-sandbox/runner"
    12  )
    13  
    14  // Trace start and traces all child process by runner in the calling goroutine
    15  // parameter done used to cancel work, start is used notify child starts
    16  func (t *Tracer) Trace(c context.Context) (result runner.Result) {
    17  	// ptrace is thread based (kernel proc)
    18  	runtime.LockOSThread()
    19  	defer runtime.UnlockOSThread()
    20  
    21  	// Start the runner
    22  	pgid, err := t.Runner.Start()
    23  	t.Handler.Debug("tracer started: ", pgid, err)
    24  	if err != nil {
    25  		t.Handler.Debug("start tracee failed: ", err)
    26  		result.Status = runner.StatusRunnerError
    27  		result.Error = err.Error()
    28  		return
    29  	}
    30  	return t.trace(c, pgid)
    31  }
    32  
    33  func (t *Tracer) trace(c context.Context, pgid int) (result runner.Result) {
    34  	cc, cancel := context.WithCancel(c)
    35  	defer cancel()
    36  
    37  	// handle cancellation
    38  	go func() {
    39  		<-cc.Done()
    40  		killAll(pgid)
    41  	}()
    42  
    43  	sTime := time.Now()
    44  	ph := newPtraceHandle(t, pgid)
    45  
    46  	// handler potential panic and tle
    47  	// also ensure processes was well terminated
    48  	defer func() {
    49  		if err := recover(); err != nil {
    50  			t.Handler.Debug("panic: ", err)
    51  			result.Status = runner.StatusRunnerError
    52  			result.Error = fmt.Sprintf("%v", err)
    53  		}
    54  		// kill all tracee upon return
    55  		killAll(pgid)
    56  		collectZombie(pgid)
    57  		if !ph.fTime.IsZero() {
    58  			result.SetUpTime = ph.fTime.Sub(sTime)
    59  			result.RunningTime = time.Since(ph.fTime)
    60  		}
    61  	}()
    62  
    63  	// ptrace pool loop
    64  	for {
    65  		var (
    66  			wstatus unix.WaitStatus // wait4 wait status
    67  			rusage  unix.Rusage     // wait4 rusage
    68  			pid     int             // store pid of wait4 result
    69  			err     error
    70  		)
    71  		if ph.execved {
    72  			// Wait for all child in the process group
    73  			pid, err = unix.Wait4(-pgid, &wstatus, unix.WALL, &rusage)
    74  		} else {
    75  			// Ensure the process have called setpgid
    76  			pid, err = unix.Wait4(pgid, &wstatus, unix.WALL, &rusage)
    77  		}
    78  		if err == unix.EINTR {
    79  			t.Handler.Debug("wait4 EINTR")
    80  			continue
    81  		}
    82  		if err != nil {
    83  			t.Handler.Debug("wait4 failed: ", err)
    84  			result.Status = runner.StatusRunnerError
    85  			result.Error = err.Error()
    86  			return
    87  		}
    88  		t.Handler.Debug("------ ", pid, " ------")
    89  
    90  		// update rusage
    91  		if pid == pgid {
    92  			userTime, userMem, curStatus := t.checkUsage(rusage)
    93  			result.Status = curStatus
    94  			result.Time = userTime
    95  			result.Memory = userMem
    96  			if curStatus != runner.StatusNormal {
    97  				return
    98  			}
    99  		}
   100  
   101  		status, exitStatus, errStr, finished := ph.handle(pid, wstatus)
   102  		if finished || status != runner.StatusNormal {
   103  			result.Status = status
   104  			result.ExitStatus = exitStatus
   105  			result.Error = errStr
   106  			return
   107  		}
   108  	}
   109  }
   110  
   111  func (t *Tracer) checkUsage(rusage unix.Rusage) (time.Duration, runner.Size, runner.Status) {
   112  	status := runner.StatusNormal
   113  	// update resource usage and check against limits
   114  	userTime := time.Duration(rusage.Utime.Nano()) // ns
   115  	userMem := runner.Size(rusage.Maxrss << 10)    // bytes
   116  
   117  	// check tle / mle
   118  	if userTime > t.Limit.TimeLimit {
   119  		status = runner.StatusTimeLimitExceeded
   120  	}
   121  	if userMem > t.Limit.MemoryLimit {
   122  		status = runner.StatusMemoryLimitExceeded
   123  	}
   124  	return userTime, userMem, status
   125  }
   126  
   127  type ptraceHandle struct {
   128  	*Tracer
   129  	pgid    int
   130  	traced  map[int]bool
   131  	execved bool
   132  	fTime   time.Time
   133  }
   134  
   135  func newPtraceHandle(t *Tracer, pgid int) *ptraceHandle {
   136  	return &ptraceHandle{t, pgid, make(map[int]bool), false, time.Time{}}
   137  }
   138  
   139  func (ph *ptraceHandle) handle(pid int, wstatus unix.WaitStatus) (status runner.Status, exitStatus int, errStr string, finished bool) {
   140  	status = runner.StatusNormal
   141  	// check process status
   142  	switch {
   143  	case wstatus.Exited():
   144  		delete(ph.traced, pid)
   145  		ph.Handler.Debug("process exited: ", pid, wstatus.ExitStatus())
   146  		if pid == ph.pgid {
   147  			finished = true
   148  			if ph.execved {
   149  				exitStatus = wstatus.ExitStatus()
   150  				if exitStatus == 0 {
   151  					status = runner.StatusNormal
   152  				} else {
   153  					status = runner.StatusNonzeroExitStatus
   154  				}
   155  				return
   156  			}
   157  			status = runner.StatusRunnerError
   158  			errStr = "child process exit before execve"
   159  			return
   160  		}
   161  
   162  	case wstatus.Signaled():
   163  		sig := wstatus.Signal()
   164  		ph.Handler.Debug("ptrace signaled: ", sig)
   165  		if pid == ph.pgid {
   166  			delete(ph.traced, pid)
   167  			switch sig {
   168  			case unix.SIGXCPU, unix.SIGKILL:
   169  				status = runner.StatusTimeLimitExceeded
   170  			case unix.SIGXFSZ:
   171  				status = runner.StatusOutputLimitExceeded
   172  			case unix.SIGSYS:
   173  				status = runner.StatusDisallowedSyscall
   174  			default:
   175  				status = runner.StatusSignalled
   176  			}
   177  			exitStatus = int(sig)
   178  			return
   179  		}
   180  		unix.PtraceCont(pid, int(sig))
   181  
   182  	case wstatus.Stopped():
   183  		// Set option if the process is newly forked
   184  		if !ph.traced[pid] {
   185  			ph.Handler.Debug("set ptrace option for", pid)
   186  			ph.traced[pid] = true
   187  			// Ptrace set option valid if the tracee is stopped
   188  			if err := setPtraceOption(pid); err != nil {
   189  				status = runner.StatusRunnerError
   190  				errStr = err.Error()
   191  				return
   192  			}
   193  		}
   194  
   195  		stopSig := wstatus.StopSignal()
   196  		// Check stop signal, if trap then check seccomp
   197  		switch stopSig {
   198  		case unix.SIGTRAP:
   199  			switch trapCause := wstatus.TrapCause(); trapCause {
   200  			case unix.PTRACE_EVENT_SECCOMP:
   201  				if ph.execved {
   202  					// give the customized handle for syscall
   203  					err := ph.handleTrap(pid)
   204  					if err != nil {
   205  						status = runner.StatusDisallowedSyscall
   206  						errStr = err.Error()
   207  						return
   208  					}
   209  				} else {
   210  					ph.Handler.Debug("ptrace seccomp before execve (should be the execve syscall)")
   211  				}
   212  
   213  			case unix.PTRACE_EVENT_CLONE:
   214  				ph.Handler.Debug("ptrace stop clone")
   215  			case unix.PTRACE_EVENT_VFORK:
   216  				ph.Handler.Debug("ptrace stop vfork")
   217  			case unix.PTRACE_EVENT_FORK:
   218  				ph.Handler.Debug("ptrace stop fork")
   219  			case unix.PTRACE_EVENT_EXEC:
   220  				// forked tracee have successfully called execve
   221  				if !ph.execved {
   222  					ph.fTime = time.Now()
   223  					ph.execved = true
   224  				}
   225  				ph.Handler.Debug("ptrace stop exec")
   226  
   227  			default:
   228  				ph.Handler.Debug("ptrace unexpected trap cause: ", trapCause)
   229  			}
   230  			unix.PtraceCont(pid, 0)
   231  			return
   232  
   233  		// check if cpu rlimit hit
   234  		case unix.SIGXCPU:
   235  			status = runner.StatusTimeLimitExceeded
   236  		case unix.SIGXFSZ:
   237  			status = runner.StatusOutputLimitExceeded
   238  		}
   239  		if status != runner.StatusNormal {
   240  			return
   241  		}
   242  		// Likely encountered SIGSEGV (segment violation)
   243  		// Or compiler child exited
   244  		if stopSig != unix.SIGSTOP {
   245  			ph.Handler.Debug("ptrace unexpected stop signal: ", stopSig)
   246  		}
   247  		ph.Handler.Debug("ptrace stopped")
   248  		unix.PtraceCont(pid, int(stopSig))
   249  	}
   250  	return
   251  }
   252  
   253  // handleTrap handles the seccomp trap including the custom handle
   254  func (ph *ptraceHandle) handleTrap(pid int) error {
   255  	ph.Handler.Debug("seccomp traced")
   256  	// msg, err := unix.PtraceGetEventMsg(pid)
   257  	// if err != nil {
   258  	// 	t.Handler.Debug("PtraceGetEventMsg failed:", err)
   259  	// 	return err
   260  	// }
   261  	if ph.Handler != nil {
   262  		ctx, err := getTrapContext(pid)
   263  		if err != nil {
   264  			return err
   265  		}
   266  		act := ph.Handler.Handle(ctx)
   267  
   268  		switch act {
   269  		case TraceBan:
   270  			// Set the syscallno to -1 and return value into register to skip syscall.
   271  			// https://www.kernel.org/doc/Documentation/prctl/pkg/seccomp_filter.txt
   272  			return ctx.skipSyscall()
   273  
   274  		case TraceKill:
   275  			return runner.StatusDisallowedSyscall
   276  		}
   277  	}
   278  	return nil
   279  }
   280  
   281  // set Ptrace option that set up seccomp, exit kill and all mult-process actions
   282  func setPtraceOption(pid int) error {
   283  	const ptraceFlags = unix.PTRACE_O_TRACESECCOMP | unix.PTRACE_O_EXITKILL | unix.PTRACE_O_TRACEFORK |
   284  		unix.PTRACE_O_TRACECLONE | unix.PTRACE_O_TRACEEXEC | unix.PTRACE_O_TRACEVFORK
   285  	return unix.PtraceSetOptions(pid, ptraceFlags)
   286  }
   287  
   288  // kill all tracee according to pids
   289  func killAll(pgid int) {
   290  	unix.Kill(-pgid, unix.SIGKILL)
   291  }
   292  
   293  // collect died child processes
   294  func collectZombie(pgid int) {
   295  	var wstatus unix.WaitStatus
   296  	// collect zombies
   297  	for {
   298  		if _, err := unix.Wait4(-pgid, &wstatus, unix.WALL|unix.WNOHANG, nil); err != unix.EINTR && err != nil {
   299  			break
   300  		}
   301  	}
   302  }