github.com/slackhq/nebula@v1.9.0/sshd/session.go (about)

     1  package sshd
     2  
     3  import (
     4  	"fmt"
     5  	"sort"
     6  	"strings"
     7  
     8  	"github.com/anmitsu/go-shlex"
     9  	"github.com/armon/go-radix"
    10  	"github.com/sirupsen/logrus"
    11  	"golang.org/x/crypto/ssh"
    12  	"golang.org/x/crypto/ssh/terminal"
    13  )
    14  
    15  type session struct {
    16  	l        *logrus.Entry
    17  	c        *ssh.ServerConn
    18  	term     *terminal.Terminal
    19  	commands *radix.Tree
    20  	exitChan chan bool
    21  }
    22  
    23  func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.NewChannel, l *logrus.Entry) *session {
    24  	s := &session{
    25  		commands: radix.NewFromMap(commands.ToMap()),
    26  		l:        l,
    27  		c:        conn,
    28  		exitChan: make(chan bool),
    29  	}
    30  
    31  	s.commands.Insert("logout", &Command{
    32  		Name:             "logout",
    33  		ShortDescription: "Ends the current session",
    34  		Callback: func(a interface{}, args []string, w StringWriter) error {
    35  			s.Close()
    36  			return nil
    37  		},
    38  	})
    39  
    40  	go s.handleChannels(chans)
    41  	return s
    42  }
    43  
    44  func (s *session) handleChannels(chans <-chan ssh.NewChannel) {
    45  	for newChannel := range chans {
    46  		if newChannel.ChannelType() != "session" {
    47  			s.l.WithField("sshChannelType", newChannel.ChannelType()).Error("unknown channel type")
    48  			newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
    49  			continue
    50  		}
    51  
    52  		channel, requests, err := newChannel.Accept()
    53  		if err != nil {
    54  			s.l.WithError(err).Warn("could not accept channel")
    55  			continue
    56  		}
    57  
    58  		go s.handleRequests(requests, channel)
    59  	}
    60  }
    61  
    62  func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) {
    63  	for req := range in {
    64  		var err error
    65  		//TODO: maybe support window sizing?
    66  		switch req.Type {
    67  		case "shell":
    68  			if s.term == nil {
    69  				s.term = s.createTerm(channel)
    70  				err = req.Reply(true, nil)
    71  			} else {
    72  				err = req.Reply(false, nil)
    73  			}
    74  
    75  		case "pty-req":
    76  			err = req.Reply(true, nil)
    77  
    78  		case "window-change":
    79  			err = req.Reply(true, nil)
    80  
    81  		case "exec":
    82  			var payload = struct{ Value string }{}
    83  			cErr := ssh.Unmarshal(req.Payload, &payload)
    84  			if cErr != nil {
    85  				req.Reply(false, nil)
    86  				return
    87  			}
    88  
    89  			req.Reply(true, nil)
    90  			s.dispatchCommand(payload.Value, &stringWriter{channel})
    91  
    92  			//TODO: Fix error handling and report the proper status back
    93  			status := struct{ Status uint32 }{uint32(0)}
    94  			//TODO: I think this is how we shut down a shell as well?
    95  			channel.SendRequest("exit-status", false, ssh.Marshal(status))
    96  			channel.Close()
    97  			return
    98  
    99  		default:
   100  			s.l.WithField("sshRequest", req.Type).Debug("Rejected unknown request")
   101  			err = req.Reply(false, nil)
   102  		}
   103  
   104  		if err != nil {
   105  			s.l.WithError(err).Info("Error handling ssh session requests")
   106  			s.Close()
   107  			return
   108  		}
   109  	}
   110  }
   111  
   112  func (s *session) createTerm(channel ssh.Channel) *terminal.Terminal {
   113  	//TODO: PS1 with nebula cert name
   114  	term := terminal.NewTerminal(channel, s.c.User()+"@nebula > ")
   115  	term.AutoCompleteCallback = func(line string, pos int, key rune) (newLine string, newPos int, ok bool) {
   116  		// key 9 is tab
   117  		if key == 9 {
   118  			cmds := matchCommand(s.commands, line)
   119  			if len(cmds) == 1 {
   120  				return cmds[0] + " ", len(cmds[0]) + 1, true
   121  			}
   122  
   123  			sort.Strings(cmds)
   124  			term.Write([]byte(strings.Join(cmds, "\n") + "\n\n"))
   125  		}
   126  
   127  		return "", 0, false
   128  	}
   129  
   130  	go s.handleInput(channel)
   131  	return term
   132  }
   133  
   134  func (s *session) handleInput(channel ssh.Channel) {
   135  	defer s.Close()
   136  	w := &stringWriter{w: s.term}
   137  	for {
   138  		line, err := s.term.ReadLine()
   139  		if err != nil {
   140  			//TODO: log
   141  			break
   142  		}
   143  
   144  		s.dispatchCommand(line, w)
   145  	}
   146  }
   147  
   148  func (s *session) dispatchCommand(line string, w StringWriter) {
   149  	args, err := shlex.Split(line, true)
   150  	if err != nil {
   151  		//todo: LOG IT
   152  		return
   153  	}
   154  
   155  	if len(args) == 0 {
   156  		dumpCommands(s.commands, w)
   157  		return
   158  	}
   159  
   160  	c, err := lookupCommand(s.commands, args[0])
   161  	if err != nil {
   162  		//TODO: handle the error
   163  		return
   164  	}
   165  
   166  	if c == nil {
   167  		err := w.WriteLine(fmt.Sprintf("did not understand: %s", line))
   168  		//TODO: log error
   169  		_ = err
   170  
   171  		dumpCommands(s.commands, w)
   172  		return
   173  	}
   174  
   175  	if checkHelpArgs(args) {
   176  		s.dispatchCommand(fmt.Sprintf("%s %s", "help", c.Name), w)
   177  		return
   178  	}
   179  
   180  	err = execCommand(c, args[1:], w)
   181  	if err != nil {
   182  		//TODO: log the error
   183  	}
   184  	return
   185  }
   186  
   187  func (s *session) Close() {
   188  	s.c.Close()
   189  	s.exitChan <- true
   190  }