github.com/ungtb10d/cli/v2@v2.0.0-20221110210412-98537dd9d6a1/pkg/liveshare/session.go (about)

     1  package liveshare
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"os"
     7  	"strconv"
     8  	"strings"
     9  	"time"
    10  
    11  	"github.com/opentracing/opentracing-go"
    12  	"golang.org/x/crypto/ssh"
    13  	"golang.org/x/sync/errgroup"
    14  )
    15  
    16  // A ChannelID is an identifier for an exposed port on a remote
    17  // container that may be used to open an SSH channel to it.
    18  type ChannelID struct {
    19  	name, condition string
    20  }
    21  
    22  // A Session represents the session between a connected Live Share client and server.
    23  type Session struct {
    24  	ssh *sshSession
    25  	rpc *rpcClient
    26  
    27  	clientName      string
    28  	keepAliveReason chan string
    29  	logger          logger
    30  }
    31  
    32  type StartSSHServerOptions struct {
    33  	UserPublicKeyFile string
    34  }
    35  
    36  // Close should be called by users to clean up RPC and SSH resources whenever the session
    37  // is no longer active.
    38  func (s *Session) Close() error {
    39  	// Closing the RPC conn closes the underlying stream (SSH)
    40  	// So we only need to close once
    41  	if err := s.rpc.Close(); err != nil {
    42  		s.ssh.Close() // close SSH and ignore error
    43  		return fmt.Errorf("error while closing Live Share session: %w", err)
    44  	}
    45  
    46  	return nil
    47  }
    48  
    49  // registerRequestHandler registers a handler for the given request type with the RPC
    50  // server and returns a callback function to deregister the handler
    51  func (s *Session) registerRequestHandler(requestType string, h handler) func() {
    52  	return s.rpc.register(requestType, h)
    53  }
    54  
    55  // StartSSHServer starts an SSH server in the container, installing sshd if necessary, applies specified
    56  // options, and returns the port on which it listens and the user name clients should provide.
    57  func (s *Session) StartSSHServer(ctx context.Context) (int, string, error) {
    58  	return s.StartSSHServerWithOptions(ctx, StartSSHServerOptions{})
    59  }
    60  
    61  // StartSSHServerWithOptions starts an SSH server in the container, installing sshd if necessary, applies specified
    62  // options, and returns the port on which it listens and the user name clients should provide.
    63  func (s *Session) StartSSHServerWithOptions(ctx context.Context, options StartSSHServerOptions) (int, string, error) {
    64  	var params struct {
    65  		UserPublicKey string `json:"userPublicKey"`
    66  	}
    67  
    68  	var response struct {
    69  		Result     bool   `json:"result"`
    70  		ServerPort string `json:"serverPort"`
    71  		User       string `json:"user"`
    72  		Message    string `json:"message"`
    73  	}
    74  
    75  	if options.UserPublicKeyFile != "" {
    76  		publicKeyBytes, err := os.ReadFile(options.UserPublicKeyFile)
    77  		if err != nil {
    78  			return 0, "", fmt.Errorf("failed to read public key file: %w", err)
    79  		}
    80  
    81  		params.UserPublicKey = strings.TrimSpace(string(publicKeyBytes))
    82  	}
    83  
    84  	if err := s.rpc.do(ctx, "ISshServerHostService.startRemoteServerWithOptions", params, &response); err != nil {
    85  		return 0, "", err
    86  	}
    87  
    88  	if !response.Result {
    89  		return 0, "", fmt.Errorf("failed to start server: %s", response.Message)
    90  	}
    91  
    92  	port, err := strconv.Atoi(response.ServerPort)
    93  	if err != nil {
    94  		return 0, "", fmt.Errorf("failed to parse port: %w", err)
    95  	}
    96  
    97  	return port, response.User, nil
    98  }
    99  
   100  // StartJupyterServer starts a Juypyter server in the container and returns
   101  // the port on which it listens and the server URL.
   102  func (s *Session) StartJupyterServer(ctx context.Context) (int, string, error) {
   103  	var response struct {
   104  		Result    bool   `json:"result"`
   105  		Message   string `json:"message"`
   106  		Port      string `json:"port"`
   107  		ServerUrl string `json:"serverUrl"`
   108  	}
   109  
   110  	if err := s.rpc.do(ctx, "IJupyterServerHostService.getRunningServer", []string{}, &response); err != nil {
   111  		return 0, "", fmt.Errorf("failed to invoke JupyterLab RPC: %w", err)
   112  	}
   113  
   114  	if !response.Result {
   115  		return 0, "", fmt.Errorf("failed to start JupyterLab: %s", response.Message)
   116  	}
   117  
   118  	port, err := strconv.Atoi(response.Port)
   119  	if err != nil {
   120  		return 0, "", fmt.Errorf("failed to parse JupyterLab port: %w", err)
   121  	}
   122  
   123  	return port, response.ServerUrl, nil
   124  }
   125  
   126  func (s *Session) RebuildContainer(ctx context.Context, full bool) error {
   127  	rpcMethod := "IEnvironmentConfigurationService.incrementalRebuildContainer"
   128  	if full {
   129  		rpcMethod = "IEnvironmentConfigurationService.rebuildContainer"
   130  	}
   131  
   132  	var rebuildSuccess bool
   133  	err := s.rpc.do(ctx, rpcMethod, nil, &rebuildSuccess)
   134  	if err != nil {
   135  		return fmt.Errorf("invoking rebuild RPC: %w", err)
   136  	}
   137  
   138  	if !rebuildSuccess {
   139  		return fmt.Errorf("couldn't rebuild codespace")
   140  	}
   141  
   142  	return nil
   143  }
   144  
   145  // heartbeat runs until context cancellation, periodically checking whether there is a
   146  // reason to keep the connection alive, and if so, notifying the Live Share host to do so.
   147  // Heartbeat ensures it does not send more than one request every "interval" to ratelimit
   148  // how many KeepAlives we send at a time.
   149  func (s *Session) heartbeat(ctx context.Context, interval time.Duration) {
   150  	ticker := time.NewTicker(interval)
   151  	defer ticker.Stop()
   152  
   153  	for {
   154  		select {
   155  		case <-ctx.Done():
   156  			return
   157  		case <-ticker.C:
   158  			s.logger.Println("Heartbeat tick")
   159  			reason := <-s.keepAliveReason
   160  			s.logger.Println("Keep alive reason: " + reason)
   161  			if err := s.notifyHostOfActivity(ctx, reason); err != nil {
   162  				s.logger.Printf("Failed to notify host of activity: %s\n", err)
   163  			}
   164  		}
   165  	}
   166  }
   167  
   168  // notifyHostOfActivity notifies the Live Share host of client activity.
   169  func (s *Session) notifyHostOfActivity(ctx context.Context, activity string) error {
   170  	activities := []string{activity}
   171  	params := []interface{}{s.clientName, activities}
   172  	return s.rpc.do(ctx, "ICodespaceHostService.notifyCodespaceOfClientActivity", params, nil)
   173  }
   174  
   175  // KeepAlive accepts a reason that is retained if there is no active reason
   176  // to send to the server.
   177  func (s *Session) KeepAlive(reason string) {
   178  	select {
   179  	case s.keepAliveReason <- reason:
   180  	default:
   181  		// there is already an active keep alive reason
   182  		// so we can ignore this one
   183  	}
   184  }
   185  
   186  // StartSharing tells the Live Share host to start sharing the specified port from the container.
   187  // The sessionName describes the purpose of the remote port or service.
   188  // It returns an identifier that can be used to open an SSH channel to the remote port.
   189  func (s *Session) StartSharing(ctx context.Context, sessionName string, port int) (ChannelID, error) {
   190  	args := []interface{}{port, sessionName, fmt.Sprintf("http://localhost:%d", port)}
   191  	g, ctx := errgroup.WithContext(ctx)
   192  
   193  	g.Go(func() error {
   194  		startNotification, err := s.WaitForPortNotification(ctx, port, PortChangeKindStart)
   195  		if err != nil {
   196  			return fmt.Errorf("error while waiting for port notification: %w", err)
   197  
   198  		}
   199  		if !startNotification.Success {
   200  			return fmt.Errorf("error while starting port sharing: %s", startNotification.ErrorDetail)
   201  		}
   202  		return nil // success
   203  	})
   204  
   205  	var response Port
   206  	g.Go(func() error {
   207  		return s.rpc.do(ctx, "serverSharing.startSharing", args, &response)
   208  	})
   209  
   210  	if err := g.Wait(); err != nil {
   211  		return ChannelID{}, err
   212  	}
   213  
   214  	return ChannelID{response.StreamName, response.StreamCondition}, nil
   215  }
   216  
   217  func (s *Session) OpenStreamingChannel(ctx context.Context, id ChannelID) (ssh.Channel, error) {
   218  	type getStreamArgs struct {
   219  		StreamName string `json:"streamName"`
   220  		Condition  string `json:"condition"`
   221  	}
   222  	args := getStreamArgs{
   223  		StreamName: id.name,
   224  		Condition:  id.condition,
   225  	}
   226  	var streamID string
   227  	if err := s.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil {
   228  		return nil, fmt.Errorf("error getting stream id: %w", err)
   229  	}
   230  
   231  	span, ctx := opentracing.StartSpanFromContext(ctx, "Session.OpenChannel+SendRequest")
   232  	defer span.Finish()
   233  	_ = ctx // ctx is not currently used
   234  
   235  	channel, reqs, err := s.ssh.conn.OpenChannel("session", nil)
   236  	if err != nil {
   237  		return nil, fmt.Errorf("error opening ssh channel for transport: %w", err)
   238  	}
   239  	go ssh.DiscardRequests(reqs)
   240  
   241  	requestType := fmt.Sprintf("stream-transport-%s", streamID)
   242  	if _, err = channel.SendRequest(requestType, true, nil); err != nil {
   243  		return nil, fmt.Errorf("error sending channel request: %w", err)
   244  	}
   245  
   246  	return channel, nil
   247  }