github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/observability/tracing/ssh/ssh_test.go (about)

     1  // Copyright 2022 Gravitational, Inc
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package ssh
    16  
    17  import (
    18  	"context"
    19  	"crypto/rand"
    20  	"crypto/rsa"
    21  	"crypto/subtle"
    22  	"crypto/x509"
    23  	"encoding/json"
    24  	"encoding/pem"
    25  	"errors"
    26  	"net"
    27  	"testing"
    28  
    29  	"github.com/gravitational/trace"
    30  	"github.com/stretchr/testify/require"
    31  	"go.opentelemetry.io/otel"
    32  	"go.opentelemetry.io/otel/propagation"
    33  	sdktrace "go.opentelemetry.io/otel/sdk/trace"
    34  	"golang.org/x/crypto/ssh"
    35  
    36  	"github.com/gravitational/teleport/api/observability/tracing"
    37  )
    38  
    39  const testPayload = "test"
    40  
    41  type server struct {
    42  	listener net.Listener
    43  	config   *ssh.ServerConfig
    44  	handler  func(*ssh.ServerConn, <-chan ssh.NewChannel, <-chan *ssh.Request)
    45  
    46  	cSigner ssh.Signer
    47  	hSigner ssh.Signer
    48  }
    49  
    50  func (s *server) Run(errC chan error) {
    51  	for {
    52  		conn, err := s.listener.Accept()
    53  		if err != nil {
    54  			if !errors.Is(err, net.ErrClosed) {
    55  				errC <- err
    56  			}
    57  			return
    58  		}
    59  
    60  		go func() {
    61  			sconn, chans, reqs, err := ssh.NewServerConn(conn, s.config)
    62  			if err != nil {
    63  				errC <- err
    64  				return
    65  			}
    66  			s.handler(sconn, chans, reqs)
    67  		}()
    68  	}
    69  }
    70  
    71  func (s *server) Stop() error {
    72  	return s.listener.Close()
    73  }
    74  
    75  func generateSigner(t *testing.T) ssh.Signer {
    76  	private, err := rsa.GenerateKey(rand.Reader, 2048)
    77  	require.NoError(t, err)
    78  
    79  	block := &pem.Block{
    80  		Type:  "RSA PRIVATE KEY",
    81  		Bytes: x509.MarshalPKCS1PrivateKey(private),
    82  	}
    83  
    84  	privatePEM := pem.EncodeToMemory(block)
    85  	signer, err := ssh.ParsePrivateKey(privatePEM)
    86  	require.NoError(t, err)
    87  
    88  	return signer
    89  }
    90  
    91  func (s *server) GetClient(t *testing.T) (ssh.Conn, <-chan ssh.NewChannel, <-chan *ssh.Request) {
    92  	conn, err := net.Dial("tcp", s.listener.Addr().String())
    93  	require.NoError(t, err)
    94  
    95  	sconn, nc, r, err := ssh.NewClientConn(conn, "", &ssh.ClientConfig{
    96  		Auth:            []ssh.AuthMethod{ssh.PublicKeys(s.cSigner)},
    97  		HostKeyCallback: ssh.FixedHostKey(s.hSigner.PublicKey()),
    98  	})
    99  	require.NoError(t, err)
   100  
   101  	return sconn, nc, r
   102  }
   103  
   104  func newServer(t *testing.T, tracingCap tracingCapability, handler func(*ssh.ServerConn, <-chan ssh.NewChannel, <-chan *ssh.Request)) *server {
   105  	listener, err := net.Listen("tcp", "localhost:0")
   106  	require.NoError(t, err)
   107  
   108  	cSigner := generateSigner(t)
   109  	hSigner := generateSigner(t)
   110  
   111  	version := "SSH-2.0-Teleport"
   112  	if tracingCap != tracingSupported {
   113  		version = "SSH-2.0"
   114  	}
   115  
   116  	config := &ssh.ServerConfig{
   117  		NoClientAuth:  true,
   118  		ServerVersion: version,
   119  	}
   120  	config.AddHostKey(hSigner)
   121  
   122  	srv := &server{
   123  		listener: listener,
   124  		config:   config,
   125  		handler:  handler,
   126  		cSigner:  cSigner,
   127  		hSigner:  hSigner,
   128  	}
   129  
   130  	t.Cleanup(func() { require.NoError(t, srv.Stop()) })
   131  
   132  	return srv
   133  }
   134  
   135  type handler struct {
   136  	tracingSupported tracingCapability
   137  	errChan          chan error
   138  	ctx              context.Context
   139  }
   140  
   141  func (h handler) handle(sconn *ssh.ServerConn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) {
   142  	for {
   143  		select {
   144  		case <-h.ctx.Done():
   145  			return
   146  		case req := <-reqs:
   147  			if req == nil {
   148  				return
   149  			}
   150  
   151  			h.requestHandler(req)
   152  
   153  		case ch := <-chans:
   154  			if ch == nil {
   155  				return
   156  			}
   157  
   158  			h.channelHandler(ch)
   159  		}
   160  	}
   161  }
   162  
   163  func (h handler) requestHandler(req *ssh.Request) {
   164  	switch {
   165  	case req.Type == "test":
   166  		defer func() {
   167  			if req.WantReply {
   168  				if err := req.Reply(true, nil); err != nil {
   169  					h.errChan <- err
   170  				}
   171  			}
   172  		}()
   173  
   174  	default:
   175  		if err := req.Reply(false, nil); err != nil {
   176  			h.errChan <- err
   177  		}
   178  	}
   179  }
   180  
   181  func (h handler) channelHandler(ch ssh.NewChannel) {
   182  	switch ch.ChannelType() {
   183  	case "session":
   184  		switch h.tracingSupported {
   185  		case tracingUnsupported:
   186  			if subtle.ConstantTimeCompare(ch.ExtraData(), []byte(testPayload)) == 1 {
   187  				h.errChan <- errors.New("payload mismatch")
   188  			}
   189  		case tracingSupported:
   190  			var envelope Envelope
   191  			if err := json.Unmarshal(ch.ExtraData(), &envelope); err != nil {
   192  				h.errChan <- trace.Wrap(err, "failed to unmarshal envelope")
   193  				ch.Accept()
   194  				return
   195  			}
   196  			if len(envelope.PropagationContext) <= 0 {
   197  				h.errChan <- errors.New("empty propagation context")
   198  				ch.Accept()
   199  				return
   200  			}
   201  			if len(envelope.Payload) > 0 {
   202  				h.errChan <- errors.New("payload mismatch")
   203  				ch.Accept()
   204  				return
   205  			}
   206  		}
   207  
   208  		_, chReqs, err := ch.Accept()
   209  		if err != nil {
   210  			h.errChan <- trace.Wrap(err, "failed to accept channel")
   211  			return
   212  		}
   213  
   214  		go func() {
   215  			for {
   216  				select {
   217  				case <-h.ctx.Done():
   218  					return
   219  				case req := <-chReqs:
   220  					switch req.Type {
   221  					case "subsystem":
   222  						h.subsystemHandler(req)
   223  					}
   224  				}
   225  			}
   226  		}()
   227  	default:
   228  		if err := ch.Reject(ssh.UnknownChannelType, "unknown channel type"); err != nil {
   229  			h.errChan <- trace.Wrap(err, "failed to reject channel")
   230  		}
   231  	}
   232  }
   233  
   234  type subsystemRequestMsg struct {
   235  	Subsystem string
   236  }
   237  
   238  func (h handler) subsystemHandler(req *ssh.Request) {
   239  	defer func() {
   240  		if req.WantReply {
   241  			if err := req.Reply(true, nil); err != nil {
   242  				h.errChan <- err
   243  			}
   244  		}
   245  	}()
   246  
   247  	switch h.tracingSupported {
   248  	case tracingUnsupported:
   249  		var msg subsystemRequestMsg
   250  		if err := ssh.Unmarshal(req.Payload, &msg); err != nil {
   251  			h.errChan <- trace.Wrap(err, "failed to unmarshal payload")
   252  			return
   253  		}
   254  
   255  		if msg.Subsystem != "test" {
   256  			h.errChan <- errors.New("received wrong subsystem")
   257  		}
   258  	case tracingSupported:
   259  		var envelope Envelope
   260  		if err := json.Unmarshal(req.Payload, &envelope); err != nil {
   261  			h.errChan <- trace.Wrap(err, "failed to unmarshal envelope")
   262  			return
   263  		}
   264  		if len(envelope.PropagationContext) <= 0 {
   265  			h.errChan <- errors.New("empty propagation context")
   266  			return
   267  		}
   268  
   269  		var msg subsystemRequestMsg
   270  		if err := ssh.Unmarshal(envelope.Payload, &msg); err != nil {
   271  			h.errChan <- trace.Wrap(err, "failed to unmarshal payload")
   272  			return
   273  		}
   274  		if msg.Subsystem != "test" {
   275  			h.errChan <- errors.New("received wrong subsystem")
   276  			return
   277  		}
   278  	default:
   279  		if err := req.Reply(false, nil); err != nil {
   280  			h.errChan <- err
   281  		}
   282  	}
   283  }
   284  
   285  func TestClient(t *testing.T) {
   286  	cases := []struct {
   287  		name             string
   288  		tracingSupported tracingCapability
   289  	}{
   290  		{
   291  			name:             "server supports tracing",
   292  			tracingSupported: tracingSupported,
   293  		},
   294  		{
   295  			name:             "server does not support tracing",
   296  			tracingSupported: tracingSupported,
   297  		},
   298  	}
   299  
   300  	for _, tt := range cases {
   301  		t.Run(tt.name, func(t *testing.T) {
   302  			ctx, cancel := context.WithCancel(context.Background())
   303  			t.Cleanup(cancel)
   304  
   305  			errChan := make(chan error, 5)
   306  
   307  			handler := handler{
   308  				tracingSupported: tt.tracingSupported,
   309  				errChan:          errChan,
   310  				ctx:              ctx,
   311  			}
   312  
   313  			srv := newServer(t, tt.tracingSupported, handler.handle)
   314  			go srv.Run(errChan)
   315  
   316  			tp := sdktrace.NewTracerProvider()
   317  			conn, chans, reqs := srv.GetClient(t)
   318  			client := NewClient(
   319  				conn,
   320  				chans,
   321  				reqs,
   322  				tracing.WithTracerProvider(tp),
   323  				tracing.WithTextMapPropagator(propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{})),
   324  			)
   325  			require.Equal(t, tt.tracingSupported, client.capability)
   326  
   327  			ctx, span := tp.Tracer("test").Start(context.Background(), "test")
   328  			ok, resp, err := client.SendRequest(ctx, "test", true, []byte("test"))
   329  			span.End()
   330  			require.True(t, ok)
   331  			require.Empty(t, resp)
   332  			require.NoError(t, err)
   333  
   334  			select {
   335  			case err := <-errChan:
   336  				require.NoError(t, err)
   337  			default:
   338  			}
   339  
   340  			session, err := client.NewSession(ctx)
   341  			require.NoError(t, err)
   342  			require.NotNil(t, session)
   343  
   344  			select {
   345  			case err := <-errChan:
   346  				require.NoError(t, err)
   347  			default:
   348  			}
   349  
   350  			require.NoError(t, session.RequestSubsystem(ctx, "test"))
   351  
   352  			select {
   353  			case err := <-errChan:
   354  				require.NoError(t, err)
   355  			default:
   356  			}
   357  		})
   358  	}
   359  }
   360  
   361  func TestWrapPayload(t *testing.T) {
   362  	testPayload := []byte("test")
   363  
   364  	nonRecordingCtx, nonRecordingSpan := otel.GetTracerProvider().Tracer("non-recording").Start(context.Background(), "test")
   365  	nonRecordingSpan.End()
   366  
   367  	emptyCtx, emptySpan := sdktrace.NewTracerProvider().Tracer("empty-trace-context").Start(context.Background(), "test")
   368  	t.Cleanup(func() { emptySpan.End() })
   369  
   370  	recordingCtx, recordingSpan := sdktrace.NewTracerProvider().Tracer("recording").Start(context.Background(), "test")
   371  	t.Cleanup(func() { recordingSpan.End() })
   372  	cases := []struct {
   373  		name             string
   374  		ctx              context.Context
   375  		supported        tracingCapability
   376  		propagator       propagation.TextMapPropagator
   377  		payloadAssertion require.ComparisonAssertionFunc
   378  	}{
   379  		{
   380  			name:             "unsupported returns provided payload",
   381  			ctx:              recordingCtx,
   382  			supported:        tracingUnsupported,
   383  			payloadAssertion: require.Equal,
   384  		},
   385  		{
   386  
   387  			name:             "non-recording spans aren't propagated",
   388  			supported:        tracingSupported,
   389  			ctx:              nonRecordingCtx,
   390  			payloadAssertion: require.Equal,
   391  		},
   392  		{
   393  			name:             "empty trace context is not propagated",
   394  			supported:        tracingSupported,
   395  			ctx:              emptyCtx,
   396  			payloadAssertion: require.Equal,
   397  		},
   398  		{
   399  			name:       "recording spans are propagated",
   400  			supported:  tracingSupported,
   401  			ctx:        recordingCtx,
   402  			propagator: propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}),
   403  			payloadAssertion: func(t require.TestingT, i interface{}, i2 interface{}, i3 ...interface{}) {
   404  				payload, ok := i2.([]byte)
   405  				require.True(t, ok)
   406  
   407  				require.NotEqual(t, testPayload, payload)
   408  
   409  				var envelope Envelope
   410  				require.NoError(t, json.Unmarshal(payload, &envelope))
   411  				require.Equal(t, testPayload, envelope.Payload)
   412  				require.NotEmpty(t, envelope.PropagationContext)
   413  			},
   414  		},
   415  	}
   416  
   417  	for _, tt := range cases {
   418  		t.Run(tt.name, func(t *testing.T) {
   419  			if tt.propagator == nil {
   420  				tt.propagator = otel.GetTextMapPropagator()
   421  			}
   422  			payload := wrapPayload(tt.ctx, tt.supported, tt.propagator, testPayload)
   423  			tt.payloadAssertion(t, testPayload, payload)
   424  		})
   425  	}
   426  }