github.com/ungtb10d/cli/v2@v2.0.0-20221110210412-98537dd9d6a1/pkg/liveshare/test/server.go (about)

     1  package livesharetest
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"strings"
    11  	"sync"
    12  
    13  	"github.com/gorilla/websocket"
    14  	"github.com/sourcegraph/jsonrpc2"
    15  	"golang.org/x/crypto/ssh"
    16  )
    17  
    18  const sshPrivateKey = `-----BEGIN RSA PRIVATE KEY-----
    19  MIICXgIBAAKBgQC6VU6XsMaTot9ogsGcJ+juvJOmDvvCZmgJRTRwKkW0u2BLz4yV
    20  rCzQcxaY4kaIuR80Y+1f0BLnZgh4pTREDR0T+p8hUsDSHim1ttKI8rK0hRtJ2qhY
    21  lR4qt7P51rPA4KFA9z9gDjTwQLbDq21QMC4+n4d8CL3xRVGtlUAMM3Kl3wIDAQAB
    22  AoGBAI8UemkYoSM06gBCh5D1RHQt8eKNltzL7g9QSNfoXeZOC7+q+/TiZPcbqLp0
    23  5lyOalu8b8Ym7J0rSE377Ypj13LyHMXS63e4wMiXv3qOl3GDhMLpypnJ8PwqR2b8
    24  IijL2jrpQfLu6IYqlteA+7e9aEexJa1RRwxYIyq6pG1IYpbhAkEA9nKgtj3Z6ZDC
    25  46IdqYzuUM9ZQdcw4AFr407+lub7tbWe5pYmaq3cT725IwLw081OAmnWJYFDMa/n
    26  IPl9YcZSPQJBAMGOMbPs/YPkQAsgNdIUlFtK3o41OrrwJuTRTvv0DsbqDV0LKOiC
    27  t8oAQQvjisH6Ew5OOhFyIFXtvZfzQMJppksCQQDWFd+cUICTUEise/Duj9maY3Uz
    28  J99ySGnTbZTlu8PfJuXhg3/d3ihrMPG6A1z3cPqaSBxaOj8H07mhQHn1zNU1AkEA
    29  hkl+SGPrO793g4CUdq2ahIA8SpO5rIsDoQtq7jlUq0MlhGFCv5Y5pydn+bSjx5MV
    30  933kocf5kUSBntPBIWElYwJAZTm5ghu0JtSE6t3km0iuj7NGAQSdb6mD8+O7C3CP
    31  FU3vi+4HlBysaT6IZ/HG+/dBsr4gYp4LGuS7DbaLuYw/uw==
    32  -----END RSA PRIVATE KEY-----`
    33  
    34  const SSHPublicKey = `AAAAB3NzaC1yc2EAAAADAQABAAAAgQC6VU6XsMaTot9ogsGcJ+juvJOmDvvCZmgJRTRwKkW0u2BLz4yVrCzQcxaY4kaIuR80Y+1f0BLnZgh4pTREDR0T+p8hUsDSHim1ttKI8rK0hRtJ2qhYlR4qt7P51rPA4KFA9z9gDjTwQLbDq21QMC4+n4d8CL3xRVGtlUAMM3Kl3w==`
    35  
    36  // Server represents a LiveShare relay host server.
    37  type Server struct {
    38  	password       string
    39  	services       map[string]RPCHandleFunc
    40  	relaySAS       string
    41  	streams        map[string]io.ReadWriter
    42  	sshConfig      *ssh.ServerConfig
    43  	httptestServer *httptest.Server
    44  	errCh          chan error
    45  	nonSecure      bool
    46  }
    47  
    48  // NewServer creates a new Server. ServerOptions can be passed to configure
    49  // the SSH password, backing service, secrets and more.
    50  func NewServer(opts ...ServerOption) (*Server, error) {
    51  	server := new(Server)
    52  
    53  	for _, o := range opts {
    54  		if err := o(server); err != nil {
    55  			return nil, err
    56  		}
    57  	}
    58  
    59  	server.sshConfig = &ssh.ServerConfig{
    60  		PasswordCallback: sshPasswordCallback(server.password),
    61  	}
    62  	privateKey, err := ssh.ParsePrivateKey([]byte(sshPrivateKey))
    63  	if err != nil {
    64  		return nil, fmt.Errorf("error parsing key: %w", err)
    65  	}
    66  	server.sshConfig.AddHostKey(privateKey)
    67  
    68  	server.errCh = make(chan error, 1)
    69  
    70  	if server.nonSecure {
    71  		server.httptestServer = httptest.NewServer(http.HandlerFunc(makeConnection(server)))
    72  	} else {
    73  		server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(makeConnection(server)))
    74  	}
    75  	return server, nil
    76  }
    77  
    78  // ServerOption is used to configure the Server.
    79  type ServerOption func(*Server) error
    80  
    81  // WithPassword configures the Server password for SSH.
    82  func WithPassword(password string) ServerOption {
    83  	return func(s *Server) error {
    84  		s.password = password
    85  		return nil
    86  	}
    87  }
    88  
    89  // WithNonSecure configures the Server as non-secure.
    90  func WithNonSecure() ServerOption {
    91  	return func(s *Server) error {
    92  		s.nonSecure = true
    93  		return nil
    94  	}
    95  }
    96  
    97  // WithService accepts a mock RPC service for the Server to invoke.
    98  func WithService(serviceName string, handler RPCHandleFunc) ServerOption {
    99  	return func(s *Server) error {
   100  		if s.services == nil {
   101  			s.services = make(map[string]RPCHandleFunc)
   102  		}
   103  
   104  		s.services[serviceName] = handler
   105  		return nil
   106  	}
   107  }
   108  
   109  // WithRelaySAS configures the relay SAS configuration key.
   110  func WithRelaySAS(sas string) ServerOption {
   111  	return func(s *Server) error {
   112  		s.relaySAS = sas
   113  		return nil
   114  	}
   115  }
   116  
   117  // WithStream allows you to specify a mock data stream for the server.
   118  func WithStream(name string, stream io.ReadWriter) ServerOption {
   119  	return func(s *Server) error {
   120  		if s.streams == nil {
   121  			s.streams = make(map[string]io.ReadWriter)
   122  		}
   123  		s.streams[name] = stream
   124  		return nil
   125  	}
   126  }
   127  
   128  func sshPasswordCallback(serverPassword string) func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) {
   129  	return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
   130  		if string(password) == serverPassword {
   131  			return nil, nil
   132  		}
   133  		return nil, errors.New("password rejected")
   134  	}
   135  }
   136  
   137  // Close closes the underlying httptest Server.
   138  func (s *Server) Close() {
   139  	s.httptestServer.Close()
   140  }
   141  
   142  // URL returns the httptest Server url.
   143  func (s *Server) URL() string {
   144  	return s.httptestServer.URL
   145  }
   146  
   147  func (s *Server) Err() <-chan error {
   148  	return s.errCh
   149  }
   150  
   151  var upgrader = websocket.Upgrader{}
   152  
   153  func makeConnection(server *Server) http.HandlerFunc {
   154  	return func(w http.ResponseWriter, req *http.Request) {
   155  		ctx, cancel := context.WithCancel(context.Background())
   156  		defer cancel()
   157  
   158  		if server.relaySAS != "" {
   159  			// validate the sas key
   160  			sasParam := req.URL.Query().Get("sb-hc-token")
   161  			if sasParam != server.relaySAS {
   162  				sendError(server.errCh, errors.New("error validating sas"))
   163  				return
   164  			}
   165  		}
   166  		c, err := upgrader.Upgrade(w, req, nil)
   167  		if err != nil {
   168  			sendError(server.errCh, fmt.Errorf("error upgrading connection: %w", err))
   169  			return
   170  		}
   171  		defer func() {
   172  			if err := c.Close(); err != nil {
   173  				sendError(server.errCh, err)
   174  			}
   175  		}()
   176  
   177  		socketConn := newSocketConn(c)
   178  		_, chans, reqs, err := ssh.NewServerConn(socketConn, server.sshConfig)
   179  		if err != nil {
   180  			sendError(server.errCh, fmt.Errorf("error creating new ssh conn: %w", err))
   181  			return
   182  		}
   183  		go ssh.DiscardRequests(reqs)
   184  
   185  		if err := handleChannels(ctx, server, chans); err != nil {
   186  			sendError(server.errCh, err)
   187  		}
   188  	}
   189  }
   190  
   191  // sendError does a non-blocking send of the error to the err channel.
   192  func sendError(errc chan<- error, err error) {
   193  	select {
   194  	case errc <- err:
   195  	default:
   196  		// channel is blocked with a previous error, so we ignore
   197  		// this current error
   198  	}
   199  }
   200  
   201  // awaitError waits for the context to finish and returns its error (if any).
   202  // It also waits for an err to come through the err channel.
   203  func awaitError(ctx context.Context, errc <-chan error) error {
   204  	select {
   205  	case <-ctx.Done():
   206  		return ctx.Err()
   207  	case err := <-errc:
   208  		return err
   209  	}
   210  }
   211  
   212  // handleChannels services the sshChannels channel. For each SSH channel received
   213  // it creates a go routine to service the channel's requests. It returns on the first
   214  // error encountered.
   215  func handleChannels(ctx context.Context, server *Server, sshChannels <-chan ssh.NewChannel) error {
   216  	errc := make(chan error, 1)
   217  	go func() {
   218  		for sshCh := range sshChannels {
   219  			ch, reqs, err := sshCh.Accept()
   220  			if err != nil {
   221  				sendError(errc, fmt.Errorf("failed to accept channel: %w", err))
   222  				return
   223  			}
   224  
   225  			go func() {
   226  				if err := handleRequests(ctx, server, ch, reqs); err != nil {
   227  					sendError(errc, fmt.Errorf("failed to handle requests: %w", err))
   228  				}
   229  			}()
   230  
   231  			handleChannel(server, ch)
   232  		}
   233  	}()
   234  	return awaitError(ctx, errc)
   235  }
   236  
   237  // handleRequests services the SSH channel requests channel. It replies to requests and
   238  // when stream transport requests are encountered, creates a go routine to create a
   239  // bi-directional data stream between the channel and server stream. It returns on the first error
   240  // encountered.
   241  func handleRequests(ctx context.Context, server *Server, channel ssh.Channel, reqs <-chan *ssh.Request) error {
   242  	errc := make(chan error, 1)
   243  	go func() {
   244  		for req := range reqs {
   245  			if req.WantReply {
   246  				if err := req.Reply(true, nil); err != nil {
   247  					sendError(errc, fmt.Errorf("error replying to channel request: %w", err))
   248  					return
   249  				}
   250  			}
   251  
   252  			if strings.HasPrefix(req.Type, "stream-transport") {
   253  				go func() {
   254  					if err := forwardStream(ctx, server, req.Type, channel); err != nil {
   255  						sendError(errc, fmt.Errorf("failed to forward stream: %w", err))
   256  					}
   257  				}()
   258  			}
   259  		}
   260  	}()
   261  
   262  	return awaitError(ctx, errc)
   263  }
   264  
   265  // concurrentStream is a concurrency safe io.ReadWriter.
   266  type concurrentStream struct {
   267  	sync.RWMutex
   268  	stream io.ReadWriter
   269  }
   270  
   271  func newConcurrentStream(rw io.ReadWriter) *concurrentStream {
   272  	return &concurrentStream{stream: rw}
   273  }
   274  
   275  func (cs *concurrentStream) Read(b []byte) (int, error) {
   276  	cs.RLock()
   277  	defer cs.RUnlock()
   278  	return cs.stream.Read(b)
   279  }
   280  
   281  func (cs *concurrentStream) Write(b []byte) (int, error) {
   282  	cs.Lock()
   283  	defer cs.Unlock()
   284  	return cs.stream.Write(b)
   285  }
   286  
   287  // forwardStream does a bi-directional copy of the stream <-> with the SSH channel. The io.Copy
   288  // runs until an error is encountered.
   289  func forwardStream(ctx context.Context, server *Server, streamName string, channel ssh.Channel) (err error) {
   290  	simpleStreamName := strings.TrimPrefix(streamName, "stream-transport-")
   291  	stream, found := server.streams[simpleStreamName]
   292  	if !found {
   293  		return fmt.Errorf("stream '%s' not found", simpleStreamName)
   294  	}
   295  	defer func() {
   296  		if closeErr := channel.Close(); err == nil && closeErr != io.EOF {
   297  			err = closeErr
   298  		}
   299  	}()
   300  
   301  	errc := make(chan error, 2)
   302  	copy := func(dst io.Writer, src io.Reader) {
   303  		if _, err := io.Copy(dst, src); err != nil {
   304  			errc <- err
   305  		}
   306  	}
   307  
   308  	csStream := newConcurrentStream(stream)
   309  	go copy(csStream, channel)
   310  	go copy(channel, csStream)
   311  
   312  	return awaitError(ctx, errc)
   313  }
   314  
   315  func handleChannel(server *Server, channel ssh.Channel) {
   316  	stream := jsonrpc2.NewBufferedStream(channel, jsonrpc2.VSCodeObjectCodec{})
   317  	jsonrpc2.NewConn(context.Background(), stream, newRPCHandler(server))
   318  }
   319  
   320  type RPCHandleFunc func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error)
   321  
   322  type rpcHandler struct {
   323  	server *Server
   324  }
   325  
   326  func newRPCHandler(server *Server) *rpcHandler {
   327  	return &rpcHandler{server}
   328  }
   329  
   330  // Handle satisfies the jsonrpc2 pkg handler interface. It tries to find a mocked
   331  // RPC service method and if found, it invokes the handler and replies to the request.
   332  func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) {
   333  	handler, found := r.server.services[req.Method]
   334  	if !found {
   335  		sendError(r.server.errCh, fmt.Errorf("RPC Method: '%s' not serviced", req.Method))
   336  		return
   337  	}
   338  
   339  	result, err := handler(conn, req)
   340  	if err != nil {
   341  		sendError(r.server.errCh, fmt.Errorf("error handling: '%s': %w", req.Method, err))
   342  		return
   343  	}
   344  
   345  	if err := conn.Reply(ctx, req.ID, result); err != nil {
   346  		sendError(r.server.errCh, fmt.Errorf("error replying: %w", err))
   347  	}
   348  }