github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/observability/tracing/ssh/client_test.go (about) 1 // Copyright 2022 Gravitational, Inc 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package ssh 16 17 import ( 18 "context" 19 "encoding/json" 20 "fmt" 21 "testing" 22 "time" 23 24 "github.com/gravitational/trace" 25 "github.com/stretchr/testify/require" 26 "golang.org/x/crypto/ssh" 27 ) 28 29 func TestIsTracingSupported(t *testing.T) { 30 cases := []struct { 31 name string 32 srvVersion string 33 expectedCapability tracingCapability 34 }{ 35 { 36 name: "supported", 37 expectedCapability: tracingSupported, 38 srvVersion: "SSH-2.0-Teleport", 39 }, 40 { 41 name: "unsupported", 42 expectedCapability: tracingUnsupported, 43 srvVersion: "SSH-2.0-OpenSSH_7.4", // Only Teleport supports tracing 44 }, 45 } 46 47 for _, tt := range cases { 48 t.Run(tt.name, func(t *testing.T) { 49 ctx, cancel := context.WithCancel(context.Background()) 50 t.Cleanup(cancel) 51 errChan := make(chan error, 5) 52 53 srv := newServer(t, tt.expectedCapability, func(conn *ssh.ServerConn, channels <-chan ssh.NewChannel, requests <-chan *ssh.Request) { 54 go ssh.DiscardRequests(requests) 55 56 for { 57 select { 58 case <-ctx.Done(): 59 return 60 61 case ch := <-channels: 62 if ch == nil { 63 return 64 } 65 66 if err := ch.Reject(ssh.Prohibited, "no channels allowed"); err != nil { 67 errChan <- trace.Wrap(err, "rejecting channel") 68 return 69 } 70 } 71 } 72 }) 73 74 if tt.srvVersion != "" { 75 srv.config.ServerVersion = tt.srvVersion 76 } 77 78 go srv.Run(errChan) 79 80 conn, chans, reqs := srv.GetClient(t) 81 client := NewClient(conn, chans, reqs) 82 83 require.Equal(t, tt.expectedCapability, client.capability) 84 85 select { 86 case err := <-errChan: 87 require.NoError(t, err) 88 default: 89 } 90 }) 91 } 92 } 93 94 // envReqParams are parameters for env request 95 type envReqParams struct { 96 Name string 97 Value string 98 } 99 100 // TestSetEnvs verifies that client uses EnvsRequest to 101 // send multiple envs and falls back to sending individual "env" 102 // requests if the server does not support EnvsRequests. 103 func TestSetEnvs(t *testing.T) { 104 t.Parallel() 105 ctx, cancel := context.WithCancel(context.Background()) 106 t.Cleanup(cancel) 107 errChan := make(chan error, 5) 108 109 expected := map[string]string{"a": "1", "b": "2", "c": "3"} 110 111 // used to collect individual envs requests 112 envReqC := make(chan envReqParams, 3) 113 114 srv := newServer(t, tracingSupported, func(conn *ssh.ServerConn, channels <-chan ssh.NewChannel, requests <-chan *ssh.Request) { 115 for { 116 select { 117 case <-ctx.Done(): 118 return 119 case ch := <-channels: 120 switch { 121 case ch == nil: 122 return 123 case ch.ChannelType() == "session": 124 ch, reqs, err := ch.Accept() 125 if err != nil { 126 errChan <- trace.Wrap(err, "failed to accept session channel") 127 return 128 } 129 130 go func() { 131 defer ch.Close() 132 for i := 0; ; i++ { 133 select { 134 case <-ctx.Done(): 135 return 136 case req := <-reqs: 137 if req == nil { 138 return 139 } 140 141 switch { 142 case i == 0 && req.Type == EnvsRequest: // accept 1st EnvsRequest 143 var envReq EnvsReq 144 if err := ssh.Unmarshal(req.Payload, &envReq); err != nil { 145 _ = req.Reply(false, []byte(err.Error())) 146 return 147 } 148 149 var envs map[string]string 150 if err := json.Unmarshal(envReq.EnvsJSON, &envs); err != nil { 151 _ = req.Reply(false, []byte(err.Error())) 152 return 153 } 154 155 for k, v := range expected { 156 actual, ok := envs[k] 157 if !ok { 158 _ = req.Reply(false, []byte(fmt.Sprintf("expected env %s not present", k))) 159 return 160 } 161 162 if actual != v { 163 _ = req.Reply(false, []byte(fmt.Sprintf("expected value %s for env %s, got %s", v, k, actual))) 164 return 165 } 166 } 167 168 _ = req.Reply(true, nil) 169 case i == 1 && req.Type == EnvsRequest: // reject additional EnvsRequest so we test fallbacks 170 _ = req.Reply(false, nil) 171 case i >= 2 && i <= len(expected)+2 && req.Type == "env": // accept individual "env" fallbacks. 172 var e envReqParams 173 if err := ssh.Unmarshal(req.Payload, &e); err != nil { 174 _ = req.Reply(false, []byte(err.Error())) 175 return 176 } 177 envReqC <- e 178 _ = req.Reply(true, nil) 179 default: // out of order or unexpected message 180 _ = req.Reply(false, []byte(fmt.Sprintf("unexpected ssh request %s on iteration %d", req.Type, i))) 181 errChan <- err 182 return 183 } 184 } 185 } 186 }() 187 default: 188 if err := ch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unexpected channel %s", ch.ChannelType())); err != nil { 189 errChan <- err 190 return 191 } 192 } 193 } 194 } 195 }) 196 197 go srv.Run(errChan) 198 199 // create a client and open a session 200 conn, chans, reqs := srv.GetClient(t) 201 client := NewClient(conn, chans, reqs) 202 session, err := client.NewSession(ctx) 203 require.NoError(t, err) 204 205 // the first request shouldn't fall back 206 t.Run("envs set via envs@goteleport.com", func(t *testing.T) { 207 require.NoError(t, session.SetEnvs(ctx, expected)) 208 209 select { 210 case <-envReqC: 211 t.Fatal("env request received instead of an envs@goteleport.com request") 212 default: 213 } 214 }) 215 216 // subsequent requests should fall back to standard "env" requests 217 t.Run("envs set individually", func(t *testing.T) { 218 require.NoError(t, session.SetEnvs(ctx, expected)) 219 220 envs := map[string]string{} 221 envsTimeout := time.NewTimer(3 * time.Second) 222 defer envsTimeout.Stop() 223 for i := 0; i < len(expected); i++ { 224 select { 225 case env := <-envReqC: 226 envs[env.Name] = env.Value 227 case <-envsTimeout.C: 228 t.Fatalf("Time out waiting for env request %d to be processed", i) 229 } 230 } 231 232 for k, v := range expected { 233 actual, ok := envs[k] 234 require.True(t, ok, "expected env %s to be set", k) 235 require.Equal(t, v, actual, "expected value %s for env %s, got %s", v, k, actual) 236 } 237 }) 238 239 select { 240 case err := <-errChan: 241 require.NoError(t, err) 242 default: 243 } 244 }