github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/observability/tracing/ssh/client_test.go (about)

     1  // Copyright 2022 Gravitational, Inc
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package ssh
    16  
    17  import (
    18  	"context"
    19  	"encoding/json"
    20  	"fmt"
    21  	"testing"
    22  	"time"
    23  
    24  	"github.com/gravitational/trace"
    25  	"github.com/stretchr/testify/require"
    26  	"golang.org/x/crypto/ssh"
    27  )
    28  
    29  func TestIsTracingSupported(t *testing.T) {
    30  	cases := []struct {
    31  		name               string
    32  		srvVersion         string
    33  		expectedCapability tracingCapability
    34  	}{
    35  		{
    36  			name:               "supported",
    37  			expectedCapability: tracingSupported,
    38  			srvVersion:         "SSH-2.0-Teleport",
    39  		},
    40  		{
    41  			name:               "unsupported",
    42  			expectedCapability: tracingUnsupported,
    43  			srvVersion:         "SSH-2.0-OpenSSH_7.4", // Only Teleport supports tracing
    44  		},
    45  	}
    46  
    47  	for _, tt := range cases {
    48  		t.Run(tt.name, func(t *testing.T) {
    49  			ctx, cancel := context.WithCancel(context.Background())
    50  			t.Cleanup(cancel)
    51  			errChan := make(chan error, 5)
    52  
    53  			srv := newServer(t, tt.expectedCapability, func(conn *ssh.ServerConn, channels <-chan ssh.NewChannel, requests <-chan *ssh.Request) {
    54  				go ssh.DiscardRequests(requests)
    55  
    56  				for {
    57  					select {
    58  					case <-ctx.Done():
    59  						return
    60  
    61  					case ch := <-channels:
    62  						if ch == nil {
    63  							return
    64  						}
    65  
    66  						if err := ch.Reject(ssh.Prohibited, "no channels allowed"); err != nil {
    67  							errChan <- trace.Wrap(err, "rejecting channel")
    68  							return
    69  						}
    70  					}
    71  				}
    72  			})
    73  
    74  			if tt.srvVersion != "" {
    75  				srv.config.ServerVersion = tt.srvVersion
    76  			}
    77  
    78  			go srv.Run(errChan)
    79  
    80  			conn, chans, reqs := srv.GetClient(t)
    81  			client := NewClient(conn, chans, reqs)
    82  
    83  			require.Equal(t, tt.expectedCapability, client.capability)
    84  
    85  			select {
    86  			case err := <-errChan:
    87  				require.NoError(t, err)
    88  			default:
    89  			}
    90  		})
    91  	}
    92  }
    93  
    94  // envReqParams are parameters for env request
    95  type envReqParams struct {
    96  	Name  string
    97  	Value string
    98  }
    99  
   100  // TestSetEnvs verifies that client uses EnvsRequest to
   101  // send multiple envs and falls back to sending individual "env"
   102  // requests if the server does not support EnvsRequests.
   103  func TestSetEnvs(t *testing.T) {
   104  	t.Parallel()
   105  	ctx, cancel := context.WithCancel(context.Background())
   106  	t.Cleanup(cancel)
   107  	errChan := make(chan error, 5)
   108  
   109  	expected := map[string]string{"a": "1", "b": "2", "c": "3"}
   110  
   111  	// used to collect individual envs requests
   112  	envReqC := make(chan envReqParams, 3)
   113  
   114  	srv := newServer(t, tracingSupported, func(conn *ssh.ServerConn, channels <-chan ssh.NewChannel, requests <-chan *ssh.Request) {
   115  		for {
   116  			select {
   117  			case <-ctx.Done():
   118  				return
   119  			case ch := <-channels:
   120  				switch {
   121  				case ch == nil:
   122  					return
   123  				case ch.ChannelType() == "session":
   124  					ch, reqs, err := ch.Accept()
   125  					if err != nil {
   126  						errChan <- trace.Wrap(err, "failed to accept session channel")
   127  						return
   128  					}
   129  
   130  					go func() {
   131  						defer ch.Close()
   132  						for i := 0; ; i++ {
   133  							select {
   134  							case <-ctx.Done():
   135  								return
   136  							case req := <-reqs:
   137  								if req == nil {
   138  									return
   139  								}
   140  
   141  								switch {
   142  								case i == 0 && req.Type == EnvsRequest: // accept 1st EnvsRequest
   143  									var envReq EnvsReq
   144  									if err := ssh.Unmarshal(req.Payload, &envReq); err != nil {
   145  										_ = req.Reply(false, []byte(err.Error()))
   146  										return
   147  									}
   148  
   149  									var envs map[string]string
   150  									if err := json.Unmarshal(envReq.EnvsJSON, &envs); err != nil {
   151  										_ = req.Reply(false, []byte(err.Error()))
   152  										return
   153  									}
   154  
   155  									for k, v := range expected {
   156  										actual, ok := envs[k]
   157  										if !ok {
   158  											_ = req.Reply(false, []byte(fmt.Sprintf("expected env %s not present", k)))
   159  											return
   160  										}
   161  
   162  										if actual != v {
   163  											_ = req.Reply(false, []byte(fmt.Sprintf("expected value %s for env %s, got %s", v, k, actual)))
   164  											return
   165  										}
   166  									}
   167  
   168  									_ = req.Reply(true, nil)
   169  								case i == 1 && req.Type == EnvsRequest: // reject additional EnvsRequest so we test fallbacks
   170  									_ = req.Reply(false, nil)
   171  								case i >= 2 && i <= len(expected)+2 && req.Type == "env": // accept individual "env" fallbacks.
   172  									var e envReqParams
   173  									if err := ssh.Unmarshal(req.Payload, &e); err != nil {
   174  										_ = req.Reply(false, []byte(err.Error()))
   175  										return
   176  									}
   177  									envReqC <- e
   178  									_ = req.Reply(true, nil)
   179  								default: // out of order or unexpected message
   180  									_ = req.Reply(false, []byte(fmt.Sprintf("unexpected ssh request %s on iteration %d", req.Type, i)))
   181  									errChan <- err
   182  									return
   183  								}
   184  							}
   185  						}
   186  					}()
   187  				default:
   188  					if err := ch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unexpected channel %s", ch.ChannelType())); err != nil {
   189  						errChan <- err
   190  						return
   191  					}
   192  				}
   193  			}
   194  		}
   195  	})
   196  
   197  	go srv.Run(errChan)
   198  
   199  	// create a client and open a session
   200  	conn, chans, reqs := srv.GetClient(t)
   201  	client := NewClient(conn, chans, reqs)
   202  	session, err := client.NewSession(ctx)
   203  	require.NoError(t, err)
   204  
   205  	// the first request shouldn't fall back
   206  	t.Run("envs set via envs@goteleport.com", func(t *testing.T) {
   207  		require.NoError(t, session.SetEnvs(ctx, expected))
   208  
   209  		select {
   210  		case <-envReqC:
   211  			t.Fatal("env request received instead of an envs@goteleport.com request")
   212  		default:
   213  		}
   214  	})
   215  
   216  	// subsequent requests should fall back to standard "env" requests
   217  	t.Run("envs set individually", func(t *testing.T) {
   218  		require.NoError(t, session.SetEnvs(ctx, expected))
   219  
   220  		envs := map[string]string{}
   221  		envsTimeout := time.NewTimer(3 * time.Second)
   222  		defer envsTimeout.Stop()
   223  		for i := 0; i < len(expected); i++ {
   224  			select {
   225  			case env := <-envReqC:
   226  				envs[env.Name] = env.Value
   227  			case <-envsTimeout.C:
   228  				t.Fatalf("Time out waiting for env request %d to be processed", i)
   229  			}
   230  		}
   231  
   232  		for k, v := range expected {
   233  			actual, ok := envs[k]
   234  			require.True(t, ok, "expected env %s to be set", k)
   235  			require.Equal(t, v, actual, "expected value %s for env %s, got %s", v, k, actual)
   236  		}
   237  	})
   238  
   239  	select {
   240  	case err := <-errChan:
   241  		require.NoError(t, err)
   242  	default:
   243  	}
   244  }