github.com/slackhq/nebula@v1.9.0/sshd/server.go (about)

     1  package sshd
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  	"net"
     8  	"sync"
     9  
    10  	"github.com/armon/go-radix"
    11  	"github.com/sirupsen/logrus"
    12  	"golang.org/x/crypto/ssh"
    13  )
    14  
    15  type SSHServer struct {
    16  	config *ssh.ServerConfig
    17  	l      *logrus.Entry
    18  
    19  	certChecker *ssh.CertChecker
    20  
    21  	// Map of user -> authorized keys
    22  	trustedKeys map[string]map[string]bool
    23  	trustedCAs  []ssh.PublicKey
    24  
    25  	// List of available commands
    26  	helpCommand *Command
    27  	commands    *radix.Tree
    28  	listener    net.Listener
    29  
    30  	// Locks the conns/counter to avoid concurrent map access
    31  	connsLock sync.Mutex
    32  	conns     map[int]*session
    33  	counter   int
    34  }
    35  
    36  // NewSSHServer creates a new ssh server rigged with default commands and prepares to listen
    37  func NewSSHServer(l *logrus.Entry) (*SSHServer, error) {
    38  
    39  	s := &SSHServer{
    40  		trustedKeys: make(map[string]map[string]bool),
    41  		l:           l,
    42  		commands:    radix.New(),
    43  		conns:       make(map[int]*session),
    44  	}
    45  
    46  	cc := ssh.CertChecker{
    47  		IsUserAuthority: func(auth ssh.PublicKey) bool {
    48  			for _, ca := range s.trustedCAs {
    49  				if bytes.Equal(ca.Marshal(), auth.Marshal()) {
    50  					return true
    51  				}
    52  			}
    53  
    54  			return false
    55  		},
    56  		UserKeyFallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
    57  			pk := string(pubKey.Marshal())
    58  			fp := ssh.FingerprintSHA256(pubKey)
    59  
    60  			tk, ok := s.trustedKeys[c.User()]
    61  			if !ok {
    62  				return nil, fmt.Errorf("unknown user %s", c.User())
    63  			}
    64  
    65  			_, ok = tk[pk]
    66  			if !ok {
    67  				return nil, fmt.Errorf("unknown public key for %s (%s)", c.User(), fp)
    68  			}
    69  
    70  			return &ssh.Permissions{
    71  				// Record the public key used for authentication.
    72  				Extensions: map[string]string{
    73  					"fp":   fp,
    74  					"user": c.User(),
    75  				},
    76  			}, nil
    77  
    78  		},
    79  	}
    80  
    81  	s.config = &ssh.ServerConfig{
    82  		PublicKeyCallback: cc.Authenticate,
    83  		//TODO: AuthLogCallback: s.authAttempt,
    84  		//TODO: version string
    85  		ServerVersion: fmt.Sprintf("SSH-2.0-Nebula???"),
    86  	}
    87  
    88  	s.RegisterCommand(&Command{
    89  		Name:             "help",
    90  		ShortDescription: "prints available commands or help <command> for specific usage info",
    91  		Callback: func(a interface{}, args []string, w StringWriter) error {
    92  			return helpCallback(s.commands, args, w)
    93  		},
    94  	})
    95  
    96  	return s, nil
    97  }
    98  
    99  func (s *SSHServer) SetHostKey(hostPrivateKey []byte) error {
   100  	private, err := ssh.ParsePrivateKey(hostPrivateKey)
   101  	if err != nil {
   102  		return fmt.Errorf("failed to parse private key: %s", err)
   103  	}
   104  
   105  	s.config.AddHostKey(private)
   106  	return nil
   107  }
   108  
   109  func (s *SSHServer) ClearTrustedCAs() {
   110  	s.trustedCAs = []ssh.PublicKey{}
   111  }
   112  
   113  func (s *SSHServer) ClearAuthorizedKeys() {
   114  	s.trustedKeys = make(map[string]map[string]bool)
   115  }
   116  
   117  // AddTrustedCA adds a trusted CA for user certificates
   118  func (s *SSHServer) AddTrustedCA(pubKey string) error {
   119  	pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pubKey))
   120  	if err != nil {
   121  		return err
   122  	}
   123  
   124  	s.trustedCAs = append(s.trustedCAs, pk)
   125  	s.l.WithField("sshKey", pubKey).Info("Trusted CA key")
   126  	return nil
   127  }
   128  
   129  // AddAuthorizedKey adds an ssh public key for a user
   130  func (s *SSHServer) AddAuthorizedKey(user, pubKey string) error {
   131  	pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pubKey))
   132  	if err != nil {
   133  		return err
   134  	}
   135  
   136  	tk, ok := s.trustedKeys[user]
   137  	if !ok {
   138  		tk = make(map[string]bool)
   139  		s.trustedKeys[user] = tk
   140  	}
   141  
   142  	tk[string(pk.Marshal())] = true
   143  	s.l.WithField("sshKey", pubKey).WithField("sshUser", user).Info("Authorized ssh key")
   144  	return nil
   145  }
   146  
   147  // RegisterCommand adds a command that can be run by a user, by default only `help` is available
   148  func (s *SSHServer) RegisterCommand(c *Command) {
   149  	s.commands.Insert(c.Name, c)
   150  }
   151  
   152  // Run begins listening and accepting connections
   153  func (s *SSHServer) Run(addr string) error {
   154  	var err error
   155  	s.listener, err = net.Listen("tcp", addr)
   156  	if err != nil {
   157  		return err
   158  	}
   159  
   160  	s.l.WithField("sshListener", addr).Info("SSH server is listening")
   161  
   162  	// Run loops until there is an error
   163  	s.run()
   164  	s.closeSessions()
   165  
   166  	s.l.Info("SSH server stopped listening")
   167  	// We don't return an error because run logs for us
   168  	return nil
   169  }
   170  
   171  func (s *SSHServer) run() {
   172  	for {
   173  		c, err := s.listener.Accept()
   174  		if err != nil {
   175  			if !errors.Is(err, net.ErrClosed) {
   176  				s.l.WithError(err).Warn("Error in listener, shutting down")
   177  			}
   178  			return
   179  		}
   180  
   181  		conn, chans, reqs, err := ssh.NewServerConn(c, s.config)
   182  		fp := ""
   183  		if conn != nil {
   184  			fp = conn.Permissions.Extensions["fp"]
   185  		}
   186  
   187  		if err != nil {
   188  			l := s.l.WithError(err).WithField("remoteAddress", c.RemoteAddr())
   189  			if conn != nil {
   190  				l = l.WithField("sshUser", conn.User())
   191  				conn.Close()
   192  			}
   193  			if fp != "" {
   194  				l = l.WithField("sshFingerprint", fp)
   195  			}
   196  			l.Warn("failed to handshake")
   197  			continue
   198  		}
   199  
   200  		l := s.l.WithField("sshUser", conn.User())
   201  		l.WithField("remoteAddress", c.RemoteAddr()).WithField("sshFingerprint", fp).Info("ssh user logged in")
   202  
   203  		session := NewSession(s.commands, conn, chans, l.WithField("subsystem", "sshd.session"))
   204  		s.connsLock.Lock()
   205  		s.counter++
   206  		counter := s.counter
   207  		s.conns[counter] = session
   208  		s.connsLock.Unlock()
   209  
   210  		go ssh.DiscardRequests(reqs)
   211  		go func() {
   212  			<-session.exitChan
   213  			s.l.WithField("id", counter).Debug("closing conn")
   214  			s.connsLock.Lock()
   215  			delete(s.conns, counter)
   216  			s.connsLock.Unlock()
   217  		}()
   218  	}
   219  }
   220  
   221  func (s *SSHServer) Stop() {
   222  	// Close the listener, this will cause all session to terminate as well, see SSHServer.Run
   223  	if s.listener != nil {
   224  		if err := s.listener.Close(); err != nil {
   225  			s.l.WithError(err).Warn("Failed to close the sshd listener")
   226  		}
   227  	}
   228  }
   229  
   230  func (s *SSHServer) closeSessions() {
   231  	s.connsLock.Lock()
   232  	for _, c := range s.conns {
   233  		c.Close()
   234  	}
   235  	s.connsLock.Unlock()
   236  }