github.com/ungtb10d/cli/v2@v2.0.0-20221110210412-98537dd9d6a1/internal/codespaces/states.go (about)

     1  package codespaces
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"fmt"
     8  	"io"
     9  	"log"
    10  	"net"
    11  	"time"
    12  
    13  	"github.com/ungtb10d/cli/v2/internal/codespaces/api"
    14  	"github.com/ungtb10d/cli/v2/internal/text"
    15  	"github.com/ungtb10d/cli/v2/pkg/liveshare"
    16  )
    17  
    18  // PostCreateStateStatus is a string value representing the different statuses a state can have.
    19  type PostCreateStateStatus string
    20  
    21  func (p PostCreateStateStatus) String() string {
    22  	return text.Title(string(p))
    23  }
    24  
    25  const (
    26  	PostCreateStateRunning PostCreateStateStatus = "running"
    27  	PostCreateStateSuccess PostCreateStateStatus = "succeeded"
    28  	PostCreateStateFailed  PostCreateStateStatus = "failed"
    29  )
    30  
    31  // PostCreateState is a combination of a state and status value that is captured
    32  // during codespace creation.
    33  type PostCreateState struct {
    34  	Name   string                `json:"name"`
    35  	Status PostCreateStateStatus `json:"status"`
    36  }
    37  
    38  // PollPostCreateStates watches for state changes in a codespace,
    39  // and calls the supplied poller for each batch of state changes.
    40  // It runs until it encounters an error, including cancellation of the context.
    41  func PollPostCreateStates(ctx context.Context, progress progressIndicator, apiClient apiClient, codespace *api.Codespace, poller func([]PostCreateState)) (err error) {
    42  	noopLogger := log.New(io.Discard, "", 0)
    43  
    44  	session, err := ConnectToLiveshare(ctx, progress, noopLogger, apiClient, codespace)
    45  	if err != nil {
    46  		return fmt.Errorf("connect to codespace: %w", err)
    47  	}
    48  	defer func() {
    49  		if closeErr := session.Close(); err == nil {
    50  			err = closeErr
    51  		}
    52  	}()
    53  
    54  	// Ensure local port is listening before client (getPostCreateOutput) connects.
    55  	listen, err := net.Listen("tcp", "127.0.0.1:0") // arbitrary port
    56  	if err != nil {
    57  		return err
    58  	}
    59  	localPort := listen.Addr().(*net.TCPAddr).Port
    60  
    61  	progress.StartProgressIndicatorWithLabel("Fetching SSH Details")
    62  	defer progress.StopProgressIndicator()
    63  	remoteSSHServerPort, sshUser, err := session.StartSSHServer(ctx)
    64  	if err != nil {
    65  		return fmt.Errorf("error getting ssh server details: %w", err)
    66  	}
    67  
    68  	progress.StartProgressIndicatorWithLabel("Fetching status")
    69  	tunnelClosed := make(chan error, 1) // buffered to avoid sender stuckness
    70  	go func() {
    71  		fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort, false)
    72  		tunnelClosed <- fwd.ForwardToListener(ctx, listen) // error is non-nil
    73  	}()
    74  
    75  	t := time.NewTicker(1 * time.Second)
    76  	defer t.Stop()
    77  
    78  	for ticks := 0; ; ticks++ {
    79  		select {
    80  		case <-ctx.Done():
    81  			return ctx.Err()
    82  
    83  		case err := <-tunnelClosed:
    84  			return fmt.Errorf("connection failed: %w", err)
    85  
    86  		case <-t.C:
    87  			states, err := getPostCreateOutput(ctx, localPort, sshUser)
    88  			// There is an active progress indicator before the first tick
    89  			// to show that we are fetching statuses.
    90  			// Once the first tick happens, we stop the indicator and let
    91  			// the subsequent post create states manage their own progress.
    92  			if ticks == 0 {
    93  				progress.StopProgressIndicator()
    94  			}
    95  			if err != nil {
    96  				return fmt.Errorf("get post create output: %w", err)
    97  			}
    98  
    99  			poller(states)
   100  		}
   101  	}
   102  }
   103  
   104  func getPostCreateOutput(ctx context.Context, tunnelPort int, user string) ([]PostCreateState, error) {
   105  	cmd, err := NewRemoteCommand(
   106  		ctx, tunnelPort, fmt.Sprintf("%s@localhost", user),
   107  		"cat /workspaces/.codespaces/shared/postCreateOutput.json",
   108  	)
   109  	if err != nil {
   110  		return nil, fmt.Errorf("remote command: %w", err)
   111  	}
   112  
   113  	stdout := new(bytes.Buffer)
   114  	cmd.Stdout = stdout
   115  	if err := cmd.Run(); err != nil {
   116  		return nil, fmt.Errorf("run command: %w", err)
   117  	}
   118  	var output struct {
   119  		Steps []PostCreateState `json:"steps"`
   120  	}
   121  	if err := json.Unmarshal(stdout.Bytes(), &output); err != nil {
   122  		return nil, fmt.Errorf("unmarshal output: %w", err)
   123  	}
   124  
   125  	return output.Steps, nil
   126  }