github.com/u-root/u-root@v7.0.1-0.20200915234505-ad7babab0a8e+incompatible/cmds/core/sshd/sshd.go (about)

     1  // Copyright 2018-2020 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"flag"
     9  	"fmt"
    10  	"io"
    11  	"io/ioutil"
    12  	"log"
    13  	"net"
    14  	"os"
    15  	"os/exec"
    16  
    17  	"github.com/u-root/u-root/pkg/pty"
    18  	"golang.org/x/crypto/ssh"
    19  )
    20  
    21  // The ssh package does not define these things so we will
    22  type (
    23  	ptyReq struct {
    24  		TERM   string //TERM environment variable value (e.g., vt100)
    25  		Col    uint32
    26  		Row    uint32
    27  		Xpixel uint32
    28  		Ypixel uint32
    29  		Modes  string //encoded terminal modes
    30  	}
    31  	execReq struct {
    32  		Command string
    33  	}
    34  	exitStatusReq struct {
    35  		ExitStatus uint32
    36  	}
    37  )
    38  
    39  var (
    40  	debug   = flag.Bool("d", false, "Enable debug prints")
    41  	keys    = flag.String("keys", "authorized_keys", "Path to the authorized_keys file")
    42  	privkey = flag.String("privatekey", "id_rsa", "Path of private key")
    43  	ip      = flag.String("ip", "0.0.0.0", "ip address to listen on")
    44  	port    = flag.String("port", "2022", "port to listen on")
    45  	dprintf = func(string, ...interface{}) {}
    46  )
    47  
    48  // start a command
    49  // TODO: use /etc/passwd, but the Go support for that is incomplete
    50  func runCommand(c ssh.Channel, p *pty.Pty, cmd string, args ...string) error {
    51  	var ps *os.ProcessState
    52  	defer c.Close()
    53  
    54  	if p != nil {
    55  		log.Printf("Executing PTY command %s %v", cmd, args)
    56  		p.Command(cmd, args...)
    57  		if err := p.C.Start(); err != nil {
    58  			dprintf("Failed to execute: %v", err)
    59  			return err
    60  		}
    61  		defer p.C.Wait()
    62  		go io.Copy(p.Ptm, c)
    63  		go io.Copy(c, p.Ptm)
    64  		ps, _ = p.C.Process.Wait()
    65  	} else {
    66  		e := exec.Command(cmd, args...)
    67  		e.Stdin, e.Stdout, e.Stderr = c, c, c
    68  		log.Printf("Executing non-PTY command %s %v", cmd, args)
    69  		if err := e.Start(); err != nil {
    70  			dprintf("Failed to execute: %v", err)
    71  			return err
    72  		}
    73  		ps, _ = e.Process.Wait()
    74  	}
    75  
    76  	// TODO(bluecmd): If somebody wants we can send exit-signal to return
    77  	// information about signal termination, but leave it until somebody needs
    78  	// it.
    79  	// if ws.Signaled() {
    80  	// }
    81  	if ps.Exited() {
    82  		code := uint32(ps.ExitCode())
    83  		dprintf("Exit status %v", code)
    84  		c.SendRequest("exit-status", false, ssh.Marshal(exitStatusReq{code}))
    85  	}
    86  	return nil
    87  }
    88  
    89  func newPTY(b []byte) (*pty.Pty, error) {
    90  	ptyReq := &ptyReq{}
    91  	err := ssh.Unmarshal(b, ptyReq)
    92  	dprintf("newPTY: %q", ptyReq)
    93  	if err != nil {
    94  		return nil, err
    95  	}
    96  	p, err := pty.New()
    97  	if err != nil {
    98  		return nil, err
    99  	}
   100  	ws, err := p.TTY.GetWinSize()
   101  	if err != nil {
   102  		return nil, err
   103  	}
   104  	ws.Row = uint16(ptyReq.Row)
   105  	ws.Ypixel = uint16(ptyReq.Ypixel)
   106  	ws.Col = uint16(ptyReq.Col)
   107  	ws.Xpixel = uint16(ptyReq.Xpixel)
   108  	dprintf("newPTY: Set winsizes to %v", ws)
   109  	if err := p.TTY.SetWinSize(ws); err != nil {
   110  		return nil, err
   111  	}
   112  	dprintf("newPTY: set TERM to %q", ptyReq.TERM)
   113  	if err := os.Setenv("TERM", ptyReq.TERM); err != nil {
   114  		return nil, err
   115  	}
   116  	return p, nil
   117  }
   118  
   119  func init() {
   120  	for _, s := range shells {
   121  		if _, err := exec.LookPath(s); err == nil {
   122  			shell = s
   123  		}
   124  	}
   125  }
   126  
   127  func session(chans <-chan ssh.NewChannel) {
   128  	var p *pty.Pty
   129  	// Service the incoming Channel channel.
   130  	for newChannel := range chans {
   131  		// Channels have a type, depending on the application level
   132  		// protocol intended. In the case of a shell, the type is
   133  		// "session" and ServerShell may be used to present a simple
   134  		// terminal interface.
   135  		if newChannel.ChannelType() != "session" {
   136  			newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
   137  			continue
   138  		}
   139  		channel, requests, err := newChannel.Accept()
   140  		if err != nil {
   141  			log.Printf("Could not accept channel: %v", err)
   142  			continue
   143  		}
   144  
   145  		// Sessions have out-of-band requests such as "shell",
   146  		// "pty-req" and "env".  Here we handle only the
   147  		// "shell" request.
   148  		go func(in <-chan *ssh.Request) {
   149  			for req := range in {
   150  				dprintf("Request %v", req.Type)
   151  				switch req.Type {
   152  				case "shell":
   153  					err := runCommand(channel, p, shell)
   154  					req.Reply(true, []byte(fmt.Sprintf("%v", err)))
   155  				case "exec":
   156  					e := &execReq{}
   157  					if err := ssh.Unmarshal(req.Payload, e); err != nil {
   158  						log.Printf("sshd: %v", err)
   159  						break
   160  					}
   161  					// Execute command using user's shell. This is what OpenSSH does
   162  					// so it's the least surprising to the user.
   163  					err := runCommand(channel, p, shell, "-c", e.Command)
   164  					req.Reply(true, []byte(fmt.Sprintf("%v", err)))
   165  				case "pty-req":
   166  					p, err = newPTY(req.Payload)
   167  					req.Reply(err == nil, nil)
   168  				default:
   169  					log.Printf("Not handling req %v %q", req, string(req.Payload))
   170  					req.Reply(false, nil)
   171  				}
   172  			}
   173  		}(requests)
   174  
   175  	}
   176  }
   177  
   178  func main() {
   179  	flag.Parse()
   180  	if *debug {
   181  		dprintf = log.Printf
   182  	}
   183  	// Public key authentication is done by comparing
   184  	// the public key of a received connection
   185  	// with the entries in the authorized_keys file.
   186  	authorizedKeysBytes, err := ioutil.ReadFile(*keys)
   187  	if err != nil {
   188  		log.Fatal(err)
   189  	}
   190  
   191  	authorizedKeysMap := map[string]bool{}
   192  	for len(authorizedKeysBytes) > 0 {
   193  		pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes)
   194  		if err != nil {
   195  			log.Fatal(err)
   196  		}
   197  
   198  		authorizedKeysMap[string(pubKey.Marshal())] = true
   199  		authorizedKeysBytes = rest
   200  	}
   201  
   202  	// An SSH server is represented by a ServerConfig, which holds
   203  	// certificate details and handles authentication of ServerConns.
   204  	config := &ssh.ServerConfig{
   205  		// Remove to disable public key auth.
   206  		PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
   207  			if authorizedKeysMap[string(pubKey.Marshal())] {
   208  				return &ssh.Permissions{
   209  					// Record the public key used for authentication.
   210  					Extensions: map[string]string{
   211  						"pubkey-fp": ssh.FingerprintSHA256(pubKey),
   212  					},
   213  				}, nil
   214  			}
   215  			return nil, fmt.Errorf("unknown public key for %q", c.User())
   216  		},
   217  	}
   218  
   219  	privateBytes, err := ioutil.ReadFile(*privkey)
   220  	if err != nil {
   221  		log.Fatal(err)
   222  	}
   223  
   224  	private, err := ssh.ParsePrivateKey(privateBytes)
   225  	if err != nil {
   226  		log.Fatal(err)
   227  	}
   228  
   229  	config.AddHostKey(private)
   230  
   231  	// Once a ServerConfig has been configured, connections can be
   232  	// accepted.
   233  	listener, err := net.Listen("tcp", net.JoinHostPort(*ip, *port))
   234  	if err != nil {
   235  		log.Fatal(err)
   236  	}
   237  	for {
   238  		nConn, err := listener.Accept()
   239  		if err != nil {
   240  			log.Printf("failed to accept incoming connection: %s", err)
   241  			continue
   242  		}
   243  
   244  		// Before use, a handshake must be performed on the incoming
   245  		// net.Conn.
   246  		conn, chans, reqs, err := ssh.NewServerConn(nConn, config)
   247  		if err != nil {
   248  			log.Printf("failed to handshake: %v", err)
   249  			continue
   250  		}
   251  		log.Printf("%v logged in with key %s", conn.RemoteAddr(), conn.Permissions.Extensions["pubkey-fp"])
   252  
   253  		// The incoming Request channel must be serviced.
   254  		go ssh.DiscardRequests(reqs)
   255  
   256  		go session(chans)
   257  	}
   258  }