github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/observability/tracing/ssh/ssh.go (about)

     1  // Copyright 2022 Gravitational, Inc
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package ssh
    16  
    17  import (
    18  	"context"
    19  	"encoding/json"
    20  	"net"
    21  	"time"
    22  
    23  	"github.com/gravitational/trace"
    24  	"go.opentelemetry.io/otel/attribute"
    25  	"go.opentelemetry.io/otel/propagation"
    26  	semconv "go.opentelemetry.io/otel/semconv/v1.10.0"
    27  	oteltrace "go.opentelemetry.io/otel/trace"
    28  	"golang.org/x/crypto/ssh"
    29  
    30  	"github.com/gravitational/teleport/api/observability/tracing"
    31  )
    32  
    33  const (
    34  	// EnvsRequest sets multiple environment variables that will be applied to any
    35  	// command executed by Shell or Run.
    36  	// See [EnvsReq] for the corresponding payload.
    37  	EnvsRequest = "envs@goteleport.com"
    38  
    39  	// instrumentationName is the name of this instrumentation package.
    40  	instrumentationName = "otelssh"
    41  )
    42  
    43  // EnvsReq contains json marshaled key:value pairs sent as the
    44  // payload for an [EnvsRequest].
    45  type EnvsReq struct {
    46  	// EnvsJSON is a json marshaled map[string]string containing
    47  	// environment variables.
    48  	EnvsJSON []byte `json:"envs"`
    49  }
    50  
    51  // FileTransferReq contains parameters used to create a file transfer
    52  // request to be stored in the SSH server
    53  type FileTransferReq struct {
    54  	// Download is true if the file transfer requests a download, false if upload
    55  	Download bool
    56  	// Location is the location of the file to be downloaded, or directory to upload a file
    57  	Location string
    58  	// Filename is the name of the file to be uploaded
    59  	Filename string
    60  }
    61  
    62  // FileTransferDecisionReq contains parameters used to approve or deny an active
    63  // file transfer request on the SSH server
    64  type FileTransferDecisionReq struct {
    65  	// RequestID is the ID of the file transfer request being responded to
    66  	RequestID string
    67  	// Approved is true if approved, false if denied.
    68  	Approved bool
    69  }
    70  
    71  // ContextFromRequest extracts any tracing data provided via an Envelope
    72  // in the ssh.Request payload. If the payload contains an Envelope, then
    73  // the context returned will have tracing data populated from the remote
    74  // tracing context and the ssh.Request payload will be replaced with the
    75  // original payload from the client.
    76  func ContextFromRequest(req *ssh.Request, opts ...tracing.Option) context.Context {
    77  	ctx := context.Background()
    78  
    79  	var envelope Envelope
    80  	if err := json.Unmarshal(req.Payload, &envelope); err != nil {
    81  		return ctx
    82  	}
    83  
    84  	ctx = tracing.WithPropagationContext(ctx, envelope.PropagationContext, opts...)
    85  	req.Payload = envelope.Payload
    86  
    87  	return ctx
    88  }
    89  
    90  // ContextFromNewChannel extracts any tracing data provided via an Envelope
    91  // in the ssh.NewChannel ExtraData. If the ExtraData contains an Envelope, then
    92  // the context returned will have tracing data populated from the remote
    93  // tracing context and the ssh.NewChannel wrapped in a TraceCh so that the
    94  // original ExtraData from the client is exposed instead of the Envelope
    95  // payload.
    96  func ContextFromNewChannel(nch ssh.NewChannel, opts ...tracing.Option) (context.Context, ssh.NewChannel) {
    97  	ch := NewTraceNewChannel(nch)
    98  	ctx := tracing.WithPropagationContext(context.Background(), ch.Envelope.PropagationContext, opts...)
    99  
   100  	return ctx, ch
   101  }
   102  
   103  // Dial starts a client connection to the given SSH server. It is a
   104  // convenience function that connects to the given network address,
   105  // initiates the SSH handshake, and then sets up a Client.  For access
   106  // to incoming channels and requests, use net.Dial with NewClientConn
   107  // instead.
   108  func Dial(ctx context.Context, network, addr string, config *ssh.ClientConfig, opts ...tracing.Option) (*Client, error) {
   109  	tracer := tracing.NewConfig(opts).TracerProvider.Tracer(instrumentationName)
   110  	ctx, span := tracer.Start(
   111  		ctx,
   112  		"ssh/Dial",
   113  		oteltrace.WithSpanKind(oteltrace.SpanKindClient),
   114  		oteltrace.WithAttributes(
   115  			attribute.String("network", network),
   116  			attribute.String("address", addr),
   117  			semconv.RPCServiceKey.String("ssh"),
   118  			semconv.RPCMethodKey.String("Dial"),
   119  			semconv.RPCSystemKey.String("ssh"),
   120  		),
   121  	)
   122  	defer span.End()
   123  
   124  	dialer := net.Dialer{Timeout: config.Timeout}
   125  	conn, err := dialer.DialContext(ctx, network, addr)
   126  	if err != nil {
   127  		return nil, err
   128  	}
   129  	c, chans, reqs, err := NewClientConn(ctx, conn, addr, config, opts...)
   130  	if err != nil {
   131  		return nil, err
   132  	}
   133  	return NewClient(c, chans, reqs), nil
   134  }
   135  
   136  // NewClientConn creates a new SSH client connection that is passed tracing context so that spans may be correlated
   137  // properly over the ssh connection.
   138  func NewClientConn(ctx context.Context, conn net.Conn, addr string, config *ssh.ClientConfig, opts ...tracing.Option) (ssh.Conn, <-chan ssh.NewChannel, <-chan *ssh.Request, error) {
   139  	tracer := tracing.NewConfig(opts).TracerProvider.Tracer(instrumentationName)
   140  	ctx, span := tracer.Start( //nolint:staticcheck,ineffassign // keeping shadowed ctx to avoid accidental missing in the future
   141  		ctx,
   142  		"ssh/NewClientConn",
   143  		oteltrace.WithSpanKind(oteltrace.SpanKindClient),
   144  		oteltrace.WithAttributes(
   145  			append(
   146  				peerAttr(conn.RemoteAddr()),
   147  				attribute.String("address", addr),
   148  				semconv.RPCServiceKey.String("ssh"),
   149  				semconv.RPCMethodKey.String("NewClientConn"),
   150  				semconv.RPCSystemKey.String("ssh"),
   151  			)...,
   152  		),
   153  	)
   154  	defer span.End()
   155  
   156  	c, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
   157  	if err != nil {
   158  		return nil, nil, nil, trace.Wrap(err)
   159  	}
   160  
   161  	return c, chans, reqs, nil
   162  }
   163  
   164  // NewClientConnWithDeadline establishes new client connection with specified deadline
   165  func NewClientConnWithDeadline(ctx context.Context, conn net.Conn, addr string, config *ssh.ClientConfig, opts ...tracing.Option) (*Client, error) {
   166  	if config.Timeout > 0 {
   167  		if err := conn.SetReadDeadline(time.Now().Add(config.Timeout)); err != nil {
   168  			return nil, trace.Wrap(err)
   169  		}
   170  	}
   171  	c, chans, reqs, err := NewClientConn(ctx, conn, addr, config, opts...)
   172  	if err != nil {
   173  		return nil, err
   174  	}
   175  	if config.Timeout > 0 {
   176  		if err := conn.SetReadDeadline(time.Time{}); err != nil {
   177  			return nil, trace.Wrap(err)
   178  		}
   179  	}
   180  	return NewClient(c, chans, reqs, opts...), nil
   181  }
   182  
   183  // peerAttr returns attributes about the peer address.
   184  func peerAttr(addr net.Addr) []attribute.KeyValue {
   185  	host, port, err := net.SplitHostPort(addr.String())
   186  	if err != nil {
   187  		return nil
   188  	}
   189  
   190  	if host == "" {
   191  		host = "127.0.0.1"
   192  	}
   193  
   194  	return []attribute.KeyValue{
   195  		semconv.NetPeerIPKey.String(host),
   196  		semconv.NetPeerPortKey.String(port),
   197  	}
   198  }
   199  
   200  // Envelope wraps the payload of all ssh messages with
   201  // tracing context. Any servers that support tracing propagation
   202  // will attempt to parse the Envelope for all received requests and
   203  // ensure that the original payload is provided to the handlers.
   204  type Envelope struct {
   205  	PropagationContext tracing.PropagationContext
   206  	Payload            []byte
   207  }
   208  
   209  // createEnvelope wraps the provided payload with a tracing envelope
   210  // that is used to propagate trace context .
   211  func createEnvelope(ctx context.Context, propagator propagation.TextMapPropagator, payload []byte) Envelope {
   212  	envelope := Envelope{
   213  		Payload: payload,
   214  	}
   215  
   216  	span := oteltrace.SpanFromContext(ctx)
   217  	if !span.IsRecording() {
   218  		return envelope
   219  	}
   220  
   221  	traceCtx := tracing.PropagationContextFromContext(ctx, tracing.WithTextMapPropagator(propagator))
   222  	if len(traceCtx) == 0 {
   223  		return envelope
   224  	}
   225  
   226  	envelope.PropagationContext = traceCtx
   227  
   228  	return envelope
   229  }
   230  
   231  // wrapPayload wraps the provided payload within an envelope if tracing is
   232  // enabled and there is any tracing information to propagate. Otherwise, the
   233  // original payload is returned
   234  func wrapPayload(ctx context.Context, supported tracingCapability, propagator propagation.TextMapPropagator, payload []byte) []byte {
   235  	if supported != tracingSupported {
   236  		return payload
   237  	}
   238  
   239  	envelope := createEnvelope(ctx, propagator, payload)
   240  	if len(envelope.PropagationContext) == 0 {
   241  		return payload
   242  	}
   243  
   244  	wrappedPayload, err := json.Marshal(envelope)
   245  	if err == nil {
   246  		return wrappedPayload
   247  	}
   248  
   249  	return payload
   250  }