github.com/slackhq/nebula@v1.9.0/sshd/server.go (about) 1 package sshd 2 3 import ( 4 "bytes" 5 "errors" 6 "fmt" 7 "net" 8 "sync" 9 10 "github.com/armon/go-radix" 11 "github.com/sirupsen/logrus" 12 "golang.org/x/crypto/ssh" 13 ) 14 15 type SSHServer struct { 16 config *ssh.ServerConfig 17 l *logrus.Entry 18 19 certChecker *ssh.CertChecker 20 21 // Map of user -> authorized keys 22 trustedKeys map[string]map[string]bool 23 trustedCAs []ssh.PublicKey 24 25 // List of available commands 26 helpCommand *Command 27 commands *radix.Tree 28 listener net.Listener 29 30 // Locks the conns/counter to avoid concurrent map access 31 connsLock sync.Mutex 32 conns map[int]*session 33 counter int 34 } 35 36 // NewSSHServer creates a new ssh server rigged with default commands and prepares to listen 37 func NewSSHServer(l *logrus.Entry) (*SSHServer, error) { 38 39 s := &SSHServer{ 40 trustedKeys: make(map[string]map[string]bool), 41 l: l, 42 commands: radix.New(), 43 conns: make(map[int]*session), 44 } 45 46 cc := ssh.CertChecker{ 47 IsUserAuthority: func(auth ssh.PublicKey) bool { 48 for _, ca := range s.trustedCAs { 49 if bytes.Equal(ca.Marshal(), auth.Marshal()) { 50 return true 51 } 52 } 53 54 return false 55 }, 56 UserKeyFallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { 57 pk := string(pubKey.Marshal()) 58 fp := ssh.FingerprintSHA256(pubKey) 59 60 tk, ok := s.trustedKeys[c.User()] 61 if !ok { 62 return nil, fmt.Errorf("unknown user %s", c.User()) 63 } 64 65 _, ok = tk[pk] 66 if !ok { 67 return nil, fmt.Errorf("unknown public key for %s (%s)", c.User(), fp) 68 } 69 70 return &ssh.Permissions{ 71 // Record the public key used for authentication. 72 Extensions: map[string]string{ 73 "fp": fp, 74 "user": c.User(), 75 }, 76 }, nil 77 78 }, 79 } 80 81 s.config = &ssh.ServerConfig{ 82 PublicKeyCallback: cc.Authenticate, 83 //TODO: AuthLogCallback: s.authAttempt, 84 //TODO: version string 85 ServerVersion: fmt.Sprintf("SSH-2.0-Nebula???"), 86 } 87 88 s.RegisterCommand(&Command{ 89 Name: "help", 90 ShortDescription: "prints available commands or help <command> for specific usage info", 91 Callback: func(a interface{}, args []string, w StringWriter) error { 92 return helpCallback(s.commands, args, w) 93 }, 94 }) 95 96 return s, nil 97 } 98 99 func (s *SSHServer) SetHostKey(hostPrivateKey []byte) error { 100 private, err := ssh.ParsePrivateKey(hostPrivateKey) 101 if err != nil { 102 return fmt.Errorf("failed to parse private key: %s", err) 103 } 104 105 s.config.AddHostKey(private) 106 return nil 107 } 108 109 func (s *SSHServer) ClearTrustedCAs() { 110 s.trustedCAs = []ssh.PublicKey{} 111 } 112 113 func (s *SSHServer) ClearAuthorizedKeys() { 114 s.trustedKeys = make(map[string]map[string]bool) 115 } 116 117 // AddTrustedCA adds a trusted CA for user certificates 118 func (s *SSHServer) AddTrustedCA(pubKey string) error { 119 pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pubKey)) 120 if err != nil { 121 return err 122 } 123 124 s.trustedCAs = append(s.trustedCAs, pk) 125 s.l.WithField("sshKey", pubKey).Info("Trusted CA key") 126 return nil 127 } 128 129 // AddAuthorizedKey adds an ssh public key for a user 130 func (s *SSHServer) AddAuthorizedKey(user, pubKey string) error { 131 pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pubKey)) 132 if err != nil { 133 return err 134 } 135 136 tk, ok := s.trustedKeys[user] 137 if !ok { 138 tk = make(map[string]bool) 139 s.trustedKeys[user] = tk 140 } 141 142 tk[string(pk.Marshal())] = true 143 s.l.WithField("sshKey", pubKey).WithField("sshUser", user).Info("Authorized ssh key") 144 return nil 145 } 146 147 // RegisterCommand adds a command that can be run by a user, by default only `help` is available 148 func (s *SSHServer) RegisterCommand(c *Command) { 149 s.commands.Insert(c.Name, c) 150 } 151 152 // Run begins listening and accepting connections 153 func (s *SSHServer) Run(addr string) error { 154 var err error 155 s.listener, err = net.Listen("tcp", addr) 156 if err != nil { 157 return err 158 } 159 160 s.l.WithField("sshListener", addr).Info("SSH server is listening") 161 162 // Run loops until there is an error 163 s.run() 164 s.closeSessions() 165 166 s.l.Info("SSH server stopped listening") 167 // We don't return an error because run logs for us 168 return nil 169 } 170 171 func (s *SSHServer) run() { 172 for { 173 c, err := s.listener.Accept() 174 if err != nil { 175 if !errors.Is(err, net.ErrClosed) { 176 s.l.WithError(err).Warn("Error in listener, shutting down") 177 } 178 return 179 } 180 181 conn, chans, reqs, err := ssh.NewServerConn(c, s.config) 182 fp := "" 183 if conn != nil { 184 fp = conn.Permissions.Extensions["fp"] 185 } 186 187 if err != nil { 188 l := s.l.WithError(err).WithField("remoteAddress", c.RemoteAddr()) 189 if conn != nil { 190 l = l.WithField("sshUser", conn.User()) 191 conn.Close() 192 } 193 if fp != "" { 194 l = l.WithField("sshFingerprint", fp) 195 } 196 l.Warn("failed to handshake") 197 continue 198 } 199 200 l := s.l.WithField("sshUser", conn.User()) 201 l.WithField("remoteAddress", c.RemoteAddr()).WithField("sshFingerprint", fp).Info("ssh user logged in") 202 203 session := NewSession(s.commands, conn, chans, l.WithField("subsystem", "sshd.session")) 204 s.connsLock.Lock() 205 s.counter++ 206 counter := s.counter 207 s.conns[counter] = session 208 s.connsLock.Unlock() 209 210 go ssh.DiscardRequests(reqs) 211 go func() { 212 <-session.exitChan 213 s.l.WithField("id", counter).Debug("closing conn") 214 s.connsLock.Lock() 215 delete(s.conns, counter) 216 s.connsLock.Unlock() 217 }() 218 } 219 } 220 221 func (s *SSHServer) Stop() { 222 // Close the listener, this will cause all session to terminate as well, see SSHServer.Run 223 if s.listener != nil { 224 if err := s.listener.Close(); err != nil { 225 s.l.WithError(err).Warn("Failed to close the sshd listener") 226 } 227 } 228 } 229 230 func (s *SSHServer) closeSessions() { 231 s.connsLock.Lock() 232 for _, c := range s.conns { 233 c.Close() 234 } 235 s.connsLock.Unlock() 236 }