github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/observability/tracing/ssh/ssh_test.go (about) 1 // Copyright 2022 Gravitational, Inc 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package ssh 16 17 import ( 18 "context" 19 "crypto/rand" 20 "crypto/rsa" 21 "crypto/subtle" 22 "crypto/x509" 23 "encoding/json" 24 "encoding/pem" 25 "errors" 26 "net" 27 "testing" 28 29 "github.com/gravitational/trace" 30 "github.com/stretchr/testify/require" 31 "go.opentelemetry.io/otel" 32 "go.opentelemetry.io/otel/propagation" 33 sdktrace "go.opentelemetry.io/otel/sdk/trace" 34 "golang.org/x/crypto/ssh" 35 36 "github.com/gravitational/teleport/api/observability/tracing" 37 ) 38 39 const testPayload = "test" 40 41 type server struct { 42 listener net.Listener 43 config *ssh.ServerConfig 44 handler func(*ssh.ServerConn, <-chan ssh.NewChannel, <-chan *ssh.Request) 45 46 cSigner ssh.Signer 47 hSigner ssh.Signer 48 } 49 50 func (s *server) Run(errC chan error) { 51 for { 52 conn, err := s.listener.Accept() 53 if err != nil { 54 if !errors.Is(err, net.ErrClosed) { 55 errC <- err 56 } 57 return 58 } 59 60 go func() { 61 sconn, chans, reqs, err := ssh.NewServerConn(conn, s.config) 62 if err != nil { 63 errC <- err 64 return 65 } 66 s.handler(sconn, chans, reqs) 67 }() 68 } 69 } 70 71 func (s *server) Stop() error { 72 return s.listener.Close() 73 } 74 75 func generateSigner(t *testing.T) ssh.Signer { 76 private, err := rsa.GenerateKey(rand.Reader, 2048) 77 require.NoError(t, err) 78 79 block := &pem.Block{ 80 Type: "RSA PRIVATE KEY", 81 Bytes: x509.MarshalPKCS1PrivateKey(private), 82 } 83 84 privatePEM := pem.EncodeToMemory(block) 85 signer, err := ssh.ParsePrivateKey(privatePEM) 86 require.NoError(t, err) 87 88 return signer 89 } 90 91 func (s *server) GetClient(t *testing.T) (ssh.Conn, <-chan ssh.NewChannel, <-chan *ssh.Request) { 92 conn, err := net.Dial("tcp", s.listener.Addr().String()) 93 require.NoError(t, err) 94 95 sconn, nc, r, err := ssh.NewClientConn(conn, "", &ssh.ClientConfig{ 96 Auth: []ssh.AuthMethod{ssh.PublicKeys(s.cSigner)}, 97 HostKeyCallback: ssh.FixedHostKey(s.hSigner.PublicKey()), 98 }) 99 require.NoError(t, err) 100 101 return sconn, nc, r 102 } 103 104 func newServer(t *testing.T, tracingCap tracingCapability, handler func(*ssh.ServerConn, <-chan ssh.NewChannel, <-chan *ssh.Request)) *server { 105 listener, err := net.Listen("tcp", "localhost:0") 106 require.NoError(t, err) 107 108 cSigner := generateSigner(t) 109 hSigner := generateSigner(t) 110 111 version := "SSH-2.0-Teleport" 112 if tracingCap != tracingSupported { 113 version = "SSH-2.0" 114 } 115 116 config := &ssh.ServerConfig{ 117 NoClientAuth: true, 118 ServerVersion: version, 119 } 120 config.AddHostKey(hSigner) 121 122 srv := &server{ 123 listener: listener, 124 config: config, 125 handler: handler, 126 cSigner: cSigner, 127 hSigner: hSigner, 128 } 129 130 t.Cleanup(func() { require.NoError(t, srv.Stop()) }) 131 132 return srv 133 } 134 135 type handler struct { 136 tracingSupported tracingCapability 137 errChan chan error 138 ctx context.Context 139 } 140 141 func (h handler) handle(sconn *ssh.ServerConn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) { 142 for { 143 select { 144 case <-h.ctx.Done(): 145 return 146 case req := <-reqs: 147 if req == nil { 148 return 149 } 150 151 h.requestHandler(req) 152 153 case ch := <-chans: 154 if ch == nil { 155 return 156 } 157 158 h.channelHandler(ch) 159 } 160 } 161 } 162 163 func (h handler) requestHandler(req *ssh.Request) { 164 switch { 165 case req.Type == "test": 166 defer func() { 167 if req.WantReply { 168 if err := req.Reply(true, nil); err != nil { 169 h.errChan <- err 170 } 171 } 172 }() 173 174 default: 175 if err := req.Reply(false, nil); err != nil { 176 h.errChan <- err 177 } 178 } 179 } 180 181 func (h handler) channelHandler(ch ssh.NewChannel) { 182 switch ch.ChannelType() { 183 case "session": 184 switch h.tracingSupported { 185 case tracingUnsupported: 186 if subtle.ConstantTimeCompare(ch.ExtraData(), []byte(testPayload)) == 1 { 187 h.errChan <- errors.New("payload mismatch") 188 } 189 case tracingSupported: 190 var envelope Envelope 191 if err := json.Unmarshal(ch.ExtraData(), &envelope); err != nil { 192 h.errChan <- trace.Wrap(err, "failed to unmarshal envelope") 193 ch.Accept() 194 return 195 } 196 if len(envelope.PropagationContext) <= 0 { 197 h.errChan <- errors.New("empty propagation context") 198 ch.Accept() 199 return 200 } 201 if len(envelope.Payload) > 0 { 202 h.errChan <- errors.New("payload mismatch") 203 ch.Accept() 204 return 205 } 206 } 207 208 _, chReqs, err := ch.Accept() 209 if err != nil { 210 h.errChan <- trace.Wrap(err, "failed to accept channel") 211 return 212 } 213 214 go func() { 215 for { 216 select { 217 case <-h.ctx.Done(): 218 return 219 case req := <-chReqs: 220 switch req.Type { 221 case "subsystem": 222 h.subsystemHandler(req) 223 } 224 } 225 } 226 }() 227 default: 228 if err := ch.Reject(ssh.UnknownChannelType, "unknown channel type"); err != nil { 229 h.errChan <- trace.Wrap(err, "failed to reject channel") 230 } 231 } 232 } 233 234 type subsystemRequestMsg struct { 235 Subsystem string 236 } 237 238 func (h handler) subsystemHandler(req *ssh.Request) { 239 defer func() { 240 if req.WantReply { 241 if err := req.Reply(true, nil); err != nil { 242 h.errChan <- err 243 } 244 } 245 }() 246 247 switch h.tracingSupported { 248 case tracingUnsupported: 249 var msg subsystemRequestMsg 250 if err := ssh.Unmarshal(req.Payload, &msg); err != nil { 251 h.errChan <- trace.Wrap(err, "failed to unmarshal payload") 252 return 253 } 254 255 if msg.Subsystem != "test" { 256 h.errChan <- errors.New("received wrong subsystem") 257 } 258 case tracingSupported: 259 var envelope Envelope 260 if err := json.Unmarshal(req.Payload, &envelope); err != nil { 261 h.errChan <- trace.Wrap(err, "failed to unmarshal envelope") 262 return 263 } 264 if len(envelope.PropagationContext) <= 0 { 265 h.errChan <- errors.New("empty propagation context") 266 return 267 } 268 269 var msg subsystemRequestMsg 270 if err := ssh.Unmarshal(envelope.Payload, &msg); err != nil { 271 h.errChan <- trace.Wrap(err, "failed to unmarshal payload") 272 return 273 } 274 if msg.Subsystem != "test" { 275 h.errChan <- errors.New("received wrong subsystem") 276 return 277 } 278 default: 279 if err := req.Reply(false, nil); err != nil { 280 h.errChan <- err 281 } 282 } 283 } 284 285 func TestClient(t *testing.T) { 286 cases := []struct { 287 name string 288 tracingSupported tracingCapability 289 }{ 290 { 291 name: "server supports tracing", 292 tracingSupported: tracingSupported, 293 }, 294 { 295 name: "server does not support tracing", 296 tracingSupported: tracingSupported, 297 }, 298 } 299 300 for _, tt := range cases { 301 t.Run(tt.name, func(t *testing.T) { 302 ctx, cancel := context.WithCancel(context.Background()) 303 t.Cleanup(cancel) 304 305 errChan := make(chan error, 5) 306 307 handler := handler{ 308 tracingSupported: tt.tracingSupported, 309 errChan: errChan, 310 ctx: ctx, 311 } 312 313 srv := newServer(t, tt.tracingSupported, handler.handle) 314 go srv.Run(errChan) 315 316 tp := sdktrace.NewTracerProvider() 317 conn, chans, reqs := srv.GetClient(t) 318 client := NewClient( 319 conn, 320 chans, 321 reqs, 322 tracing.WithTracerProvider(tp), 323 tracing.WithTextMapPropagator(propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{})), 324 ) 325 require.Equal(t, tt.tracingSupported, client.capability) 326 327 ctx, span := tp.Tracer("test").Start(context.Background(), "test") 328 ok, resp, err := client.SendRequest(ctx, "test", true, []byte("test")) 329 span.End() 330 require.True(t, ok) 331 require.Empty(t, resp) 332 require.NoError(t, err) 333 334 select { 335 case err := <-errChan: 336 require.NoError(t, err) 337 default: 338 } 339 340 session, err := client.NewSession(ctx) 341 require.NoError(t, err) 342 require.NotNil(t, session) 343 344 select { 345 case err := <-errChan: 346 require.NoError(t, err) 347 default: 348 } 349 350 require.NoError(t, session.RequestSubsystem(ctx, "test")) 351 352 select { 353 case err := <-errChan: 354 require.NoError(t, err) 355 default: 356 } 357 }) 358 } 359 } 360 361 func TestWrapPayload(t *testing.T) { 362 testPayload := []byte("test") 363 364 nonRecordingCtx, nonRecordingSpan := otel.GetTracerProvider().Tracer("non-recording").Start(context.Background(), "test") 365 nonRecordingSpan.End() 366 367 emptyCtx, emptySpan := sdktrace.NewTracerProvider().Tracer("empty-trace-context").Start(context.Background(), "test") 368 t.Cleanup(func() { emptySpan.End() }) 369 370 recordingCtx, recordingSpan := sdktrace.NewTracerProvider().Tracer("recording").Start(context.Background(), "test") 371 t.Cleanup(func() { recordingSpan.End() }) 372 cases := []struct { 373 name string 374 ctx context.Context 375 supported tracingCapability 376 propagator propagation.TextMapPropagator 377 payloadAssertion require.ComparisonAssertionFunc 378 }{ 379 { 380 name: "unsupported returns provided payload", 381 ctx: recordingCtx, 382 supported: tracingUnsupported, 383 payloadAssertion: require.Equal, 384 }, 385 { 386 387 name: "non-recording spans aren't propagated", 388 supported: tracingSupported, 389 ctx: nonRecordingCtx, 390 payloadAssertion: require.Equal, 391 }, 392 { 393 name: "empty trace context is not propagated", 394 supported: tracingSupported, 395 ctx: emptyCtx, 396 payloadAssertion: require.Equal, 397 }, 398 { 399 name: "recording spans are propagated", 400 supported: tracingSupported, 401 ctx: recordingCtx, 402 propagator: propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}), 403 payloadAssertion: func(t require.TestingT, i interface{}, i2 interface{}, i3 ...interface{}) { 404 payload, ok := i2.([]byte) 405 require.True(t, ok) 406 407 require.NotEqual(t, testPayload, payload) 408 409 var envelope Envelope 410 require.NoError(t, json.Unmarshal(payload, &envelope)) 411 require.Equal(t, testPayload, envelope.Payload) 412 require.NotEmpty(t, envelope.PropagationContext) 413 }, 414 }, 415 } 416 417 for _, tt := range cases { 418 t.Run(tt.name, func(t *testing.T) { 419 if tt.propagator == nil { 420 tt.propagator = otel.GetTextMapPropagator() 421 } 422 payload := wrapPayload(tt.ctx, tt.supported, tt.propagator, testPayload) 423 tt.payloadAssertion(t, testPayload, payload) 424 }) 425 } 426 }