github.com/criyle/go-sandbox@v0.10.3/runner/unshare/run_linux.go (about)

     1  package unshare
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"os"
     7  	"time"
     8  
     9  	"golang.org/x/sys/unix"
    10  
    11  	"github.com/criyle/go-sandbox/pkg/forkexec"
    12  	"github.com/criyle/go-sandbox/runner"
    13  )
    14  
    15  const (
    16  	// UnshareFlags is flags used to create namespaces except NET and IPC
    17  	UnshareFlags = unix.CLONE_NEWNS | unix.CLONE_NEWPID | unix.CLONE_NEWUSER | unix.CLONE_NEWUTS | unix.CLONE_NEWCGROUP
    18  )
    19  
    20  // Run starts the unshared process
    21  func (r *Runner) Run(c context.Context) (result runner.Result) {
    22  	ch := &forkexec.Runner{
    23  		Args:       r.Args,
    24  		Env:        r.Env,
    25  		ExecFile:   r.ExecFile,
    26  		RLimits:    r.RLimits,
    27  		Files:      r.Files,
    28  		WorkDir:    r.WorkDir,
    29  		Seccomp:    r.Seccomp.SockFprog(),
    30  		NoNewPrivs: true,
    31  		CloneFlags: UnshareFlags,
    32  		Mounts:     r.Mounts,
    33  		HostName:   r.HostName,
    34  		DomainName: r.DomainName,
    35  		PivotRoot:  r.Root,
    36  		DropCaps:   true,
    37  		SyncFunc:   r.SyncFunc,
    38  
    39  		UnshareCgroupAfterSync: true,
    40  	}
    41  
    42  	var (
    43  		wstatus unix.WaitStatus // wait4 wait status
    44  		rusage  unix.Rusage     // wait4 rusage
    45  		status  = runner.StatusNormal
    46  		sTime   = time.Now() // start time
    47  		fTime   time.Time    // finish time for setup
    48  	)
    49  
    50  	// Start the runner
    51  	pgid, err := ch.Start()
    52  	r.println("Starts: ", pgid, err)
    53  	if err != nil {
    54  		result.Status = runner.StatusRunnerError
    55  		result.Error = err.Error()
    56  		return
    57  	}
    58  
    59  	ctx, cancel := context.WithCancel(c)
    60  	defer cancel()
    61  
    62  	// handle cancel
    63  	go func() {
    64  		<-ctx.Done()
    65  		killAll(pgid)
    66  	}()
    67  
    68  	// kill all tracee upon return
    69  	defer func() {
    70  		killAll(pgid)
    71  		collectZombie(pgid)
    72  		result.SetUpTime = fTime.Sub(sTime)
    73  		result.RunningTime = time.Since(fTime)
    74  	}()
    75  
    76  	fTime = time.Now()
    77  	for {
    78  		_, err := unix.Wait4(pgid, &wstatus, 0, &rusage)
    79  		if err == unix.EINTR {
    80  			continue
    81  		}
    82  		r.println("wait4: ", wstatus)
    83  		if err != nil {
    84  			result.Status = runner.StatusRunnerError
    85  			result.Error = err.Error()
    86  			return
    87  		}
    88  
    89  		// update resource usage and check against limits
    90  		userTime := time.Duration(rusage.Utime.Nano()) // ns
    91  		userMem := runner.Size(rusage.Maxrss << 10)    // bytes
    92  
    93  		// check tle / mle
    94  		if userTime > r.Limit.TimeLimit {
    95  			status = runner.StatusTimeLimitExceeded
    96  		}
    97  		if userMem > r.Limit.MemoryLimit {
    98  			status = runner.StatusMemoryLimitExceeded
    99  		}
   100  		result = runner.Result{
   101  			Status: status,
   102  			Time:   userTime,
   103  			Memory: userMem,
   104  		}
   105  		if status != runner.StatusNormal {
   106  			return
   107  		}
   108  
   109  		switch {
   110  		case wstatus.Exited():
   111  			result.Status = runner.StatusNormal
   112  			result.ExitStatus = wstatus.ExitStatus()
   113  			if result.ExitStatus != 0 {
   114  				result.Status = runner.StatusNonzeroExitStatus
   115  			}
   116  			return
   117  
   118  		case wstatus.Signaled():
   119  			sig := wstatus.Signal()
   120  			switch sig {
   121  			case unix.SIGXCPU, unix.SIGKILL:
   122  				status = runner.StatusTimeLimitExceeded
   123  			case unix.SIGXFSZ:
   124  				status = runner.StatusOutputLimitExceeded
   125  			case unix.SIGSYS:
   126  				status = runner.StatusDisallowedSyscall
   127  			default:
   128  				status = runner.StatusSignalled
   129  			}
   130  			result.Status = status
   131  			result.ExitStatus = int(sig)
   132  			return
   133  		}
   134  	}
   135  }
   136  
   137  // kill all tracee according to pids
   138  func killAll(pgid int) {
   139  	unix.Kill(-pgid, unix.SIGKILL)
   140  }
   141  
   142  // collect died child processes
   143  func collectZombie(pgid int) {
   144  	var wstatus unix.WaitStatus
   145  	for {
   146  		if _, err := unix.Wait4(-pgid, &wstatus, unix.WALL|unix.WNOHANG, nil); err != unix.EINTR && err != nil {
   147  			break
   148  		}
   149  	}
   150  }
   151  
   152  func (r *Runner) println(v ...interface{}) {
   153  	if r.ShowDetails {
   154  		fmt.Fprintln(os.Stderr, v...)
   155  	}
   156  }