github.com/ungtb10d/cli/v2@v2.0.0-20221110210412-98537dd9d6a1/pkg/liveshare/session.go (about) 1 package liveshare 2 3 import ( 4 "context" 5 "fmt" 6 "os" 7 "strconv" 8 "strings" 9 "time" 10 11 "github.com/opentracing/opentracing-go" 12 "golang.org/x/crypto/ssh" 13 "golang.org/x/sync/errgroup" 14 ) 15 16 // A ChannelID is an identifier for an exposed port on a remote 17 // container that may be used to open an SSH channel to it. 18 type ChannelID struct { 19 name, condition string 20 } 21 22 // A Session represents the session between a connected Live Share client and server. 23 type Session struct { 24 ssh *sshSession 25 rpc *rpcClient 26 27 clientName string 28 keepAliveReason chan string 29 logger logger 30 } 31 32 type StartSSHServerOptions struct { 33 UserPublicKeyFile string 34 } 35 36 // Close should be called by users to clean up RPC and SSH resources whenever the session 37 // is no longer active. 38 func (s *Session) Close() error { 39 // Closing the RPC conn closes the underlying stream (SSH) 40 // So we only need to close once 41 if err := s.rpc.Close(); err != nil { 42 s.ssh.Close() // close SSH and ignore error 43 return fmt.Errorf("error while closing Live Share session: %w", err) 44 } 45 46 return nil 47 } 48 49 // registerRequestHandler registers a handler for the given request type with the RPC 50 // server and returns a callback function to deregister the handler 51 func (s *Session) registerRequestHandler(requestType string, h handler) func() { 52 return s.rpc.register(requestType, h) 53 } 54 55 // StartSSHServer starts an SSH server in the container, installing sshd if necessary, applies specified 56 // options, and returns the port on which it listens and the user name clients should provide. 57 func (s *Session) StartSSHServer(ctx context.Context) (int, string, error) { 58 return s.StartSSHServerWithOptions(ctx, StartSSHServerOptions{}) 59 } 60 61 // StartSSHServerWithOptions starts an SSH server in the container, installing sshd if necessary, applies specified 62 // options, and returns the port on which it listens and the user name clients should provide. 63 func (s *Session) StartSSHServerWithOptions(ctx context.Context, options StartSSHServerOptions) (int, string, error) { 64 var params struct { 65 UserPublicKey string `json:"userPublicKey"` 66 } 67 68 var response struct { 69 Result bool `json:"result"` 70 ServerPort string `json:"serverPort"` 71 User string `json:"user"` 72 Message string `json:"message"` 73 } 74 75 if options.UserPublicKeyFile != "" { 76 publicKeyBytes, err := os.ReadFile(options.UserPublicKeyFile) 77 if err != nil { 78 return 0, "", fmt.Errorf("failed to read public key file: %w", err) 79 } 80 81 params.UserPublicKey = strings.TrimSpace(string(publicKeyBytes)) 82 } 83 84 if err := s.rpc.do(ctx, "ISshServerHostService.startRemoteServerWithOptions", params, &response); err != nil { 85 return 0, "", err 86 } 87 88 if !response.Result { 89 return 0, "", fmt.Errorf("failed to start server: %s", response.Message) 90 } 91 92 port, err := strconv.Atoi(response.ServerPort) 93 if err != nil { 94 return 0, "", fmt.Errorf("failed to parse port: %w", err) 95 } 96 97 return port, response.User, nil 98 } 99 100 // StartJupyterServer starts a Juypyter server in the container and returns 101 // the port on which it listens and the server URL. 102 func (s *Session) StartJupyterServer(ctx context.Context) (int, string, error) { 103 var response struct { 104 Result bool `json:"result"` 105 Message string `json:"message"` 106 Port string `json:"port"` 107 ServerUrl string `json:"serverUrl"` 108 } 109 110 if err := s.rpc.do(ctx, "IJupyterServerHostService.getRunningServer", []string{}, &response); err != nil { 111 return 0, "", fmt.Errorf("failed to invoke JupyterLab RPC: %w", err) 112 } 113 114 if !response.Result { 115 return 0, "", fmt.Errorf("failed to start JupyterLab: %s", response.Message) 116 } 117 118 port, err := strconv.Atoi(response.Port) 119 if err != nil { 120 return 0, "", fmt.Errorf("failed to parse JupyterLab port: %w", err) 121 } 122 123 return port, response.ServerUrl, nil 124 } 125 126 func (s *Session) RebuildContainer(ctx context.Context, full bool) error { 127 rpcMethod := "IEnvironmentConfigurationService.incrementalRebuildContainer" 128 if full { 129 rpcMethod = "IEnvironmentConfigurationService.rebuildContainer" 130 } 131 132 var rebuildSuccess bool 133 err := s.rpc.do(ctx, rpcMethod, nil, &rebuildSuccess) 134 if err != nil { 135 return fmt.Errorf("invoking rebuild RPC: %w", err) 136 } 137 138 if !rebuildSuccess { 139 return fmt.Errorf("couldn't rebuild codespace") 140 } 141 142 return nil 143 } 144 145 // heartbeat runs until context cancellation, periodically checking whether there is a 146 // reason to keep the connection alive, and if so, notifying the Live Share host to do so. 147 // Heartbeat ensures it does not send more than one request every "interval" to ratelimit 148 // how many KeepAlives we send at a time. 149 func (s *Session) heartbeat(ctx context.Context, interval time.Duration) { 150 ticker := time.NewTicker(interval) 151 defer ticker.Stop() 152 153 for { 154 select { 155 case <-ctx.Done(): 156 return 157 case <-ticker.C: 158 s.logger.Println("Heartbeat tick") 159 reason := <-s.keepAliveReason 160 s.logger.Println("Keep alive reason: " + reason) 161 if err := s.notifyHostOfActivity(ctx, reason); err != nil { 162 s.logger.Printf("Failed to notify host of activity: %s\n", err) 163 } 164 } 165 } 166 } 167 168 // notifyHostOfActivity notifies the Live Share host of client activity. 169 func (s *Session) notifyHostOfActivity(ctx context.Context, activity string) error { 170 activities := []string{activity} 171 params := []interface{}{s.clientName, activities} 172 return s.rpc.do(ctx, "ICodespaceHostService.notifyCodespaceOfClientActivity", params, nil) 173 } 174 175 // KeepAlive accepts a reason that is retained if there is no active reason 176 // to send to the server. 177 func (s *Session) KeepAlive(reason string) { 178 select { 179 case s.keepAliveReason <- reason: 180 default: 181 // there is already an active keep alive reason 182 // so we can ignore this one 183 } 184 } 185 186 // StartSharing tells the Live Share host to start sharing the specified port from the container. 187 // The sessionName describes the purpose of the remote port or service. 188 // It returns an identifier that can be used to open an SSH channel to the remote port. 189 func (s *Session) StartSharing(ctx context.Context, sessionName string, port int) (ChannelID, error) { 190 args := []interface{}{port, sessionName, fmt.Sprintf("http://localhost:%d", port)} 191 g, ctx := errgroup.WithContext(ctx) 192 193 g.Go(func() error { 194 startNotification, err := s.WaitForPortNotification(ctx, port, PortChangeKindStart) 195 if err != nil { 196 return fmt.Errorf("error while waiting for port notification: %w", err) 197 198 } 199 if !startNotification.Success { 200 return fmt.Errorf("error while starting port sharing: %s", startNotification.ErrorDetail) 201 } 202 return nil // success 203 }) 204 205 var response Port 206 g.Go(func() error { 207 return s.rpc.do(ctx, "serverSharing.startSharing", args, &response) 208 }) 209 210 if err := g.Wait(); err != nil { 211 return ChannelID{}, err 212 } 213 214 return ChannelID{response.StreamName, response.StreamCondition}, nil 215 } 216 217 func (s *Session) OpenStreamingChannel(ctx context.Context, id ChannelID) (ssh.Channel, error) { 218 type getStreamArgs struct { 219 StreamName string `json:"streamName"` 220 Condition string `json:"condition"` 221 } 222 args := getStreamArgs{ 223 StreamName: id.name, 224 Condition: id.condition, 225 } 226 var streamID string 227 if err := s.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil { 228 return nil, fmt.Errorf("error getting stream id: %w", err) 229 } 230 231 span, ctx := opentracing.StartSpanFromContext(ctx, "Session.OpenChannel+SendRequest") 232 defer span.Finish() 233 _ = ctx // ctx is not currently used 234 235 channel, reqs, err := s.ssh.conn.OpenChannel("session", nil) 236 if err != nil { 237 return nil, fmt.Errorf("error opening ssh channel for transport: %w", err) 238 } 239 go ssh.DiscardRequests(reqs) 240 241 requestType := fmt.Sprintf("stream-transport-%s", streamID) 242 if _, err = channel.SendRequest(requestType, true, nil); err != nil { 243 return nil, fmt.Errorf("error sending channel request: %w", err) 244 } 245 246 return channel, nil 247 }