github.com/NeowayLabs/nash@v0.2.2-0.20200127205349-a227041ffd50/internal/sh/rfork_linux.go (about)

     1  // +build linux
     2  
     3  // nash provides the execution engine
     4  package sh
     5  
     6  import (
     7  	"fmt"
     8  	"io"
     9  	"net"
    10  	"os"
    11  	"os/exec"
    12  	"strconv"
    13  	"syscall"
    14  	"time"
    15  
    16  	"github.com/madlambda/nash/ast"
    17  )
    18  
    19  func getProcAttrs(flags uintptr) *syscall.SysProcAttr {
    20  	uid := os.Getuid()
    21  	gid := os.Getgid()
    22  
    23  	sysproc := &syscall.SysProcAttr{
    24  		Cloneflags: flags,
    25  	}
    26  
    27  	if (flags & syscall.CLONE_NEWUSER) == syscall.CLONE_NEWUSER {
    28  		sysproc.UidMappings = []syscall.SysProcIDMap{
    29  			{
    30  				ContainerID: 0,
    31  				HostID:      uid,
    32  				Size:        1,
    33  			},
    34  		}
    35  
    36  		sysproc.GidMappings = []syscall.SysProcIDMap{
    37  			{
    38  				ContainerID: 0,
    39  				HostID:      gid,
    40  				Size:        1,
    41  			},
    42  		}
    43  	}
    44  
    45  	return sysproc
    46  }
    47  
    48  func dialRc(sockpath string) (net.Conn, error) {
    49  	retries := 0
    50  
    51  retryRforkDial:
    52  	client, err := net.Dial("unix", sockpath)
    53  
    54  	if err != nil {
    55  		if retries < 3 {
    56  			retries++
    57  			time.Sleep(time.Duration(retries) * time.Second)
    58  			goto retryRforkDial
    59  		}
    60  	}
    61  
    62  	return client, err
    63  }
    64  
    65  // executeRfork executes the calling program again but passing
    66  // a new name for the process on os.Args[0] and passing an unix
    67  // socket file to communicate to.
    68  func (sh *Shell) executeRfork(rfork *ast.RforkNode) error {
    69  	var (
    70  		tr               *ast.Tree
    71  		i                int
    72  		nashClient       net.Conn
    73  		copyOut, copyErr bool
    74  	)
    75  
    76  	if sh.stdout != os.Stdout {
    77  		copyOut = true
    78  	}
    79  
    80  	if sh.stderr != os.Stderr {
    81  		copyErr = true
    82  	}
    83  
    84  	if sh.nashdPath == "" {
    85  		return fmt.Errorf("Nashd not set")
    86  	}
    87  
    88  	unixfile := "/tmp/nash." + randRunes(4) + ".sock"
    89  
    90  	cmd := exec.Cmd{
    91  		Path: sh.nashdPath,
    92  		Args: append([]string{"-nashd-"}, "-noinit", "-addr", unixfile),
    93  		Env:  buildenv(sh.Environ()),
    94  	}
    95  
    96  	arg := rfork.Arg()
    97  
    98  	forkFlags, err := getflags(arg.Value())
    99  
   100  	if err != nil {
   101  		return err
   102  	}
   103  
   104  	cmd.SysProcAttr = getProcAttrs(forkFlags)
   105  
   106  	stdoutDone := make(chan bool)
   107  	stderrDone := make(chan bool)
   108  
   109  	var (
   110  		stdout, stderr io.ReadCloser
   111  	)
   112  
   113  	if copyOut {
   114  		stdout, err = cmd.StdoutPipe()
   115  
   116  		if err != nil {
   117  			return err
   118  		}
   119  	} else {
   120  		cmd.Stdout = sh.stdout
   121  		close(stdoutDone)
   122  	}
   123  
   124  	if copyErr {
   125  		stderr, err = cmd.StderrPipe()
   126  
   127  		if err != nil {
   128  			return err
   129  		}
   130  	} else {
   131  		cmd.Stderr = sh.stderr
   132  		close(stderrDone)
   133  	}
   134  
   135  	cmd.Stdin = sh.stdin
   136  
   137  	err = cmd.Start()
   138  
   139  	if err != nil {
   140  		return err
   141  	}
   142  
   143  	if copyOut {
   144  		go func() {
   145  			defer close(stdoutDone)
   146  
   147  			io.Copy(sh.stdout, stdout)
   148  		}()
   149  	}
   150  
   151  	if copyErr {
   152  		go func() {
   153  			defer close(stderrDone)
   154  
   155  			io.Copy(sh.stderr, stderr)
   156  		}()
   157  	}
   158  
   159  	nashClient, err = dialRc(unixfile)
   160  
   161  	defer nashClient.Close()
   162  
   163  	tr = rfork.Tree()
   164  
   165  	if tr == nil || tr.Root == nil {
   166  		return fmt.Errorf("Rfork with no sub block")
   167  	}
   168  
   169  	for i = 0; i < len(tr.Root.Nodes); i++ {
   170  		var (
   171  			n, status int
   172  		)
   173  
   174  		node := tr.Root.Nodes[i]
   175  		data := []byte(node.String() + "\n")
   176  
   177  		n, err = nashClient.Write(data)
   178  
   179  		if err != nil || n != len(data) {
   180  			return fmt.Errorf("RPC call failed: Err: %v, bytes written: %d", err, n)
   181  		}
   182  
   183  		// read response
   184  
   185  		var response [1024]byte
   186  		n, err = nashClient.Read(response[:])
   187  
   188  		if err != nil {
   189  			break
   190  		}
   191  
   192  		status, err = strconv.Atoi(string(response[0:n]))
   193  
   194  		if err != nil {
   195  			err = fmt.Errorf("Invalid status: %s", string(response[0:n]))
   196  			break
   197  		}
   198  
   199  		if status != 0 {
   200  			err = fmt.Errorf("nash: Exited with status %d", status)
   201  			break
   202  		}
   203  	}
   204  
   205  	// we're done with rfork daemon
   206  	nashClient.Write([]byte("quit"))
   207  
   208  	<-stdoutDone
   209  	<-stderrDone
   210  
   211  	err2 := cmd.Wait()
   212  
   213  	if err != nil {
   214  		return err
   215  	}
   216  
   217  	if err2 != nil {
   218  		return err2
   219  	}
   220  
   221  	return nil
   222  }
   223  
   224  func getflags(flags string) (uintptr, error) {
   225  	var (
   226  		lflags uintptr
   227  	)
   228  
   229  	for i := 0; i < len(flags); i++ {
   230  		switch flags[i] {
   231  		case 'c':
   232  			lflags |= (syscall.CLONE_NEWUSER |
   233  				syscall.CLONE_NEWPID |
   234  				syscall.CLONE_NEWNET |
   235  				syscall.CLONE_NEWNS |
   236  				syscall.CLONE_NEWUTS |
   237  				syscall.CLONE_NEWIPC)
   238  		case 'u':
   239  			lflags |= syscall.CLONE_NEWUSER
   240  		case 'p':
   241  			lflags |= syscall.CLONE_NEWPID
   242  		case 'n':
   243  			lflags |= syscall.CLONE_NEWNET
   244  		case 'm':
   245  			lflags |= syscall.CLONE_NEWNS
   246  		case 's':
   247  			lflags |= syscall.CLONE_NEWUTS
   248  		case 'i':
   249  			lflags |= syscall.CLONE_NEWIPC
   250  		default:
   251  			return 0, fmt.Errorf("Wrong rfork flag: %c", flags[i])
   252  		}
   253  	}
   254  
   255  	if lflags == 0 {
   256  		return 0, fmt.Errorf("Rfork requires some flag")
   257  	}
   258  
   259  	return lflags, nil
   260  }