github.com/madlambda/nash@v0.2.2-0.20230113003044-f2284521680b/internal/sh/rfork_linux.go (about) 1 // +build linux 2 3 // nash provides the execution engine 4 package sh 5 6 import ( 7 "fmt" 8 "io" 9 "net" 10 "os" 11 "os/exec" 12 "strconv" 13 "syscall" 14 "time" 15 16 "github.com/madlambda/nash/ast" 17 ) 18 19 func getProcAttrs(flags uintptr) *syscall.SysProcAttr { 20 uid := os.Getuid() 21 gid := os.Getgid() 22 23 sysproc := &syscall.SysProcAttr{ 24 Cloneflags: flags, 25 } 26 27 if (flags & syscall.CLONE_NEWUSER) == syscall.CLONE_NEWUSER { 28 sysproc.UidMappings = []syscall.SysProcIDMap{ 29 { 30 ContainerID: 0, 31 HostID: uid, 32 Size: 1, 33 }, 34 } 35 36 sysproc.GidMappings = []syscall.SysProcIDMap{ 37 { 38 ContainerID: 0, 39 HostID: gid, 40 Size: 1, 41 }, 42 } 43 } 44 45 return sysproc 46 } 47 48 func dialRc(sockpath string) (net.Conn, error) { 49 retries := 0 50 51 retryRforkDial: 52 client, err := net.Dial("unix", sockpath) 53 54 if err != nil { 55 if retries < 3 { 56 retries++ 57 time.Sleep(time.Duration(retries) * time.Second) 58 goto retryRforkDial 59 } 60 } 61 62 return client, err 63 } 64 65 // executeRfork executes the calling program again but passing 66 // a new name for the process on os.Args[0] and passing an unix 67 // socket file to communicate to. 68 func (sh *Shell) executeRfork(rfork *ast.RforkNode) error { 69 var ( 70 tr *ast.Tree 71 i int 72 nashClient net.Conn 73 copyOut, copyErr bool 74 ) 75 76 if sh.stdout != os.Stdout { 77 copyOut = true 78 } 79 80 if sh.stderr != os.Stderr { 81 copyErr = true 82 } 83 84 if sh.nashdPath == "" { 85 return fmt.Errorf("Nashd not set") 86 } 87 88 unixfile := "/tmp/nash." + randRunes(4) + ".sock" 89 90 cmd := exec.Cmd{ 91 Path: sh.nashdPath, 92 Args: append([]string{"-nashd-"}, "-noinit", "-addr", unixfile), 93 Env: buildenv(sh.Environ()), 94 } 95 96 arg := rfork.Arg() 97 98 forkFlags, err := getflags(arg.Value()) 99 100 if err != nil { 101 return err 102 } 103 104 cmd.SysProcAttr = getProcAttrs(forkFlags) 105 106 stdoutDone := make(chan bool) 107 stderrDone := make(chan bool) 108 109 var ( 110 stdout, stderr io.ReadCloser 111 ) 112 113 if copyOut { 114 stdout, err = cmd.StdoutPipe() 115 116 if err != nil { 117 return err 118 } 119 } else { 120 cmd.Stdout = sh.stdout 121 close(stdoutDone) 122 } 123 124 if copyErr { 125 stderr, err = cmd.StderrPipe() 126 127 if err != nil { 128 return err 129 } 130 } else { 131 cmd.Stderr = sh.stderr 132 close(stderrDone) 133 } 134 135 cmd.Stdin = sh.stdin 136 137 err = cmd.Start() 138 139 if err != nil { 140 return err 141 } 142 143 if copyOut { 144 go func() { 145 defer close(stdoutDone) 146 147 io.Copy(sh.stdout, stdout) 148 }() 149 } 150 151 if copyErr { 152 go func() { 153 defer close(stderrDone) 154 155 io.Copy(sh.stderr, stderr) 156 }() 157 } 158 159 nashClient, err = dialRc(unixfile) 160 161 defer nashClient.Close() 162 163 tr = rfork.Tree() 164 165 if tr == nil || tr.Root == nil { 166 return fmt.Errorf("Rfork with no sub block") 167 } 168 169 for i = 0; i < len(tr.Root.Nodes); i++ { 170 var ( 171 n, status int 172 ) 173 174 node := tr.Root.Nodes[i] 175 data := []byte(node.String() + "\n") 176 177 n, err = nashClient.Write(data) 178 179 if err != nil || n != len(data) { 180 return fmt.Errorf("RPC call failed: Err: %v, bytes written: %d", err, n) 181 } 182 183 // read response 184 185 var response [1024]byte 186 n, err = nashClient.Read(response[:]) 187 188 if err != nil { 189 break 190 } 191 192 status, err = strconv.Atoi(string(response[0:n])) 193 194 if err != nil { 195 err = fmt.Errorf("Invalid status: %s", string(response[0:n])) 196 break 197 } 198 199 if status != 0 { 200 err = fmt.Errorf("nash: Exited with status %d", status) 201 break 202 } 203 } 204 205 // we're done with rfork daemon 206 nashClient.Write([]byte("quit")) 207 208 <-stdoutDone 209 <-stderrDone 210 211 err2 := cmd.Wait() 212 213 if err != nil { 214 return err 215 } 216 217 if err2 != nil { 218 return err2 219 } 220 221 return nil 222 } 223 224 func getflags(flags string) (uintptr, error) { 225 var ( 226 lflags uintptr 227 ) 228 229 for i := 0; i < len(flags); i++ { 230 switch flags[i] { 231 case 'c': 232 lflags |= (syscall.CLONE_NEWUSER | 233 syscall.CLONE_NEWPID | 234 syscall.CLONE_NEWNET | 235 syscall.CLONE_NEWNS | 236 syscall.CLONE_NEWUTS | 237 syscall.CLONE_NEWIPC) 238 case 'u': 239 lflags |= syscall.CLONE_NEWUSER 240 case 'p': 241 lflags |= syscall.CLONE_NEWPID 242 case 'n': 243 lflags |= syscall.CLONE_NEWNET 244 case 'm': 245 lflags |= syscall.CLONE_NEWNS 246 case 's': 247 lflags |= syscall.CLONE_NEWUTS 248 case 'i': 249 lflags |= syscall.CLONE_NEWIPC 250 default: 251 return 0, fmt.Errorf("Wrong rfork flag: %c", flags[i]) 252 } 253 } 254 255 if lflags == 0 { 256 return 0, fmt.Errorf("Rfork requires some flag") 257 } 258 259 return lflags, nil 260 }