github.com/jlowellwofford/u-root@v1.0.0/cmds/sshd/sshd.go (about)

     1  // Copyright 2018 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"flag"
     9  	"fmt"
    10  	"io"
    11  	"io/ioutil"
    12  	"log"
    13  	"net"
    14  	"os"
    15  	"os/exec"
    16  	"syscall"
    17  
    18  	"github.com/u-root/u-root/pkg/pty"
    19  	"golang.org/x/crypto/ssh"
    20  )
    21  
    22  // The ssh package does not define these things so we will
    23  type (
    24  	ptyReq struct {
    25  		TERM   string //TERM environment variable value (e.g., vt100)
    26  		Col    uint32
    27  		Row    uint32
    28  		Xpixel uint32
    29  		Ypixel uint32
    30  		Modes  string //encoded terminal modes
    31  	}
    32  	execReq struct {
    33  		Command string
    34  	}
    35  	exitStatusReq struct {
    36  		ExitStatus uint32
    37  	}
    38  )
    39  
    40  var (
    41  	shells  = [...]string{"bash", "zsh", "rush"}
    42  	shell   = "/bin/sh"
    43  	debug   = flag.Bool("d", false, "Enable debug prints")
    44  	keys    = flag.String("keys", "authorized_keys", "Path to the authorized_keys file")
    45  	privkey = flag.String("privatekey", "id_rsa", "Path of private key")
    46  	ip      = flag.String("ip", "0.0.0.0", "ip address to listen on")
    47  	port    = flag.String("port", "2022", "port to listen on")
    48  	dprintf = func(string, ...interface{}) {}
    49  )
    50  
    51  // start a command
    52  // TODO: use /etc/passwd, but the Go support for that is incomplete
    53  func runCommand(c ssh.Channel, p *pty.Pty, cmd string, args ...string) error {
    54  	var ps *os.ProcessState
    55  	defer c.Close()
    56  
    57  	if p != nil {
    58  		log.Printf("Executing PTY command %s %v", cmd, args)
    59  		p.Command(cmd, args...)
    60  		if err := p.C.Start(); err != nil {
    61  			dprintf("Failed to execute: %v", err)
    62  			return err
    63  		}
    64  		defer p.C.Wait()
    65  		go io.Copy(p.Ptm, c)
    66  		go io.Copy(c, p.Ptm)
    67  		ps, _ = p.C.Process.Wait()
    68  	} else {
    69  		e := exec.Command(cmd, args...)
    70  		e.Stdin, e.Stdout, e.Stderr = c, c, c
    71  		log.Printf("Executing non-PTY command %s %v", cmd, args)
    72  		if err := e.Start(); err != nil {
    73  			dprintf("Failed to execute: %v", err)
    74  			return err
    75  		}
    76  		ps, _ = e.Process.Wait()
    77  	}
    78  
    79  	var ws syscall.WaitStatus
    80  	ws = ps.Sys().(syscall.WaitStatus)
    81  	if ws.Signaled() {
    82  		// TOOD(bluecmd): If somebody wants we can send exit-signal to return
    83  		// information about signal termination, but leave it until somebody needs
    84  		// it.
    85  	}
    86  	if ws.Exited() {
    87  		code := uint32(ws.ExitStatus())
    88  		dprintf("Exit status %v", code)
    89  		c.SendRequest("exit-status", false, ssh.Marshal(exitStatusReq{code}))
    90  	}
    91  	return nil
    92  }
    93  
    94  func newPTY(b []byte) (*pty.Pty, error) {
    95  	ptyReq := &ptyReq{}
    96  	err := ssh.Unmarshal(b, ptyReq)
    97  	dprintf("newPTY: %q", ptyReq)
    98  	if err != nil {
    99  		return nil, err
   100  	}
   101  	p, err := pty.New()
   102  	if err != nil {
   103  		return nil, err
   104  	}
   105  	ws, err := p.TTY.GetWinSize()
   106  	if err != nil {
   107  		return nil, err
   108  	}
   109  	ws.Row = uint16(ptyReq.Row)
   110  	ws.Ypixel = uint16(ptyReq.Ypixel)
   111  	ws.Col = uint16(ptyReq.Col)
   112  	ws.Xpixel = uint16(ptyReq.Xpixel)
   113  	dprintf("newPTY: Set winsizes to %v", ws)
   114  	if err := p.TTY.SetWinSize(ws); err != nil {
   115  		return nil, err
   116  	}
   117  	dprintf("newPTY: set TERM to %q", ptyReq.TERM)
   118  	if err := os.Setenv("TERM", ptyReq.TERM); err != nil {
   119  		return nil, err
   120  	}
   121  	return p, nil
   122  }
   123  
   124  func init() {
   125  	for _, s := range shells {
   126  		if _, err := exec.LookPath(s); err == nil {
   127  			shell = s
   128  		}
   129  	}
   130  }
   131  
   132  func session(chans <-chan ssh.NewChannel) {
   133  	var p *pty.Pty
   134  	// Service the incoming Channel channel.
   135  	for newChannel := range chans {
   136  		// Channels have a type, depending on the application level
   137  		// protocol intended. In the case of a shell, the type is
   138  		// "session" and ServerShell may be used to present a simple
   139  		// terminal interface.
   140  		if newChannel.ChannelType() != "session" {
   141  			newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
   142  			continue
   143  		}
   144  		channel, requests, err := newChannel.Accept()
   145  		if err != nil {
   146  			log.Printf("Could not accept channel: %v", err)
   147  			continue
   148  		}
   149  
   150  		// Sessions have out-of-band requests such as "shell",
   151  		// "pty-req" and "env".  Here we handle only the
   152  		// "shell" request.
   153  		go func(in <-chan *ssh.Request) {
   154  			for req := range in {
   155  				dprintf("Request %v", req.Type)
   156  				switch req.Type {
   157  				case "shell":
   158  					err := runCommand(channel, p, shell)
   159  					req.Reply(true, []byte(fmt.Sprintf("%v", err)))
   160  				case "exec":
   161  					e := &execReq{}
   162  					if err := ssh.Unmarshal(req.Payload, e); err != nil {
   163  						log.Printf("sshd: %v", err)
   164  						break
   165  					}
   166  					// Execute command using user's shell. This is what OpenSSH does
   167  					// so it's the least surprising to the user.
   168  					err := runCommand(channel, p, shell, "-c", e.Command)
   169  					req.Reply(true, []byte(fmt.Sprintf("%v", err)))
   170  				case "pty-req":
   171  					p, err = newPTY(req.Payload)
   172  					req.Reply(err == nil, nil)
   173  				default:
   174  					log.Printf("Not handling req %v %q", req, string(req.Payload))
   175  					req.Reply(false, nil)
   176  				}
   177  			}
   178  		}(requests)
   179  
   180  	}
   181  }
   182  
   183  func main() {
   184  	flag.Parse()
   185  	if *debug {
   186  		dprintf = log.Printf
   187  	}
   188  	// Public key authentication is done by comparing
   189  	// the public key of a received connection
   190  	// with the entries in the authorized_keys file.
   191  	authorizedKeysBytes, err := ioutil.ReadFile(*keys)
   192  	if err != nil {
   193  		log.Fatal(err)
   194  	}
   195  
   196  	authorizedKeysMap := map[string]bool{}
   197  	for len(authorizedKeysBytes) > 0 {
   198  		pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes)
   199  		if err != nil {
   200  			log.Fatal(err)
   201  		}
   202  
   203  		authorizedKeysMap[string(pubKey.Marshal())] = true
   204  		authorizedKeysBytes = rest
   205  	}
   206  
   207  	// An SSH server is represented by a ServerConfig, which holds
   208  	// certificate details and handles authentication of ServerConns.
   209  	config := &ssh.ServerConfig{
   210  		// Remove to disable public key auth.
   211  		PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
   212  			if authorizedKeysMap[string(pubKey.Marshal())] {
   213  				return &ssh.Permissions{
   214  					// Record the public key used for authentication.
   215  					Extensions: map[string]string{
   216  						"pubkey-fp": ssh.FingerprintSHA256(pubKey),
   217  					},
   218  				}, nil
   219  			}
   220  			return nil, fmt.Errorf("unknown public key for %q", c.User())
   221  		},
   222  	}
   223  
   224  	privateBytes, err := ioutil.ReadFile(*privkey)
   225  	if err != nil {
   226  		log.Fatal(err)
   227  	}
   228  
   229  	private, err := ssh.ParsePrivateKey(privateBytes)
   230  	if err != nil {
   231  		log.Fatal(err)
   232  	}
   233  
   234  	config.AddHostKey(private)
   235  
   236  	// Once a ServerConfig has been configured, connections can be
   237  	// accepted.
   238  	listener, err := net.Listen("tcp", *ip+":"+*port)
   239  	if err != nil {
   240  		log.Fatal(err)
   241  	}
   242  	for {
   243  		nConn, err := listener.Accept()
   244  		if err != nil {
   245  			log.Printf("failed to accept incoming connection: %s", err)
   246  			continue
   247  		}
   248  
   249  		// Before use, a handshake must be performed on the incoming
   250  		// net.Conn.
   251  		conn, chans, reqs, err := ssh.NewServerConn(nConn, config)
   252  		if err != nil {
   253  			log.Printf("failed to handshake: %v", err)
   254  			continue
   255  		}
   256  		log.Printf("%v logged in with key %s", conn.RemoteAddr(), conn.Permissions.Extensions["pubkey-fp"])
   257  
   258  		// The incoming Request channel must be serviced.
   259  		go ssh.DiscardRequests(reqs)
   260  
   261  		go session(chans)
   262  	}
   263  }