github.com/supabase/cli@v1.168.1/internal/utils/tea.go (about)

     1  package utils
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"os"
     8  	"strings"
     9  
    10  	"github.com/charmbracelet/bubbles/progress"
    11  	"github.com/charmbracelet/bubbles/spinner"
    12  	tea "github.com/charmbracelet/bubbletea"
    13  	"github.com/charmbracelet/lipgloss"
    14  	"github.com/muesli/reflow/wrap"
    15  	"golang.org/x/term"
    16  )
    17  
    18  func NewProgram(model tea.Model, opts ...tea.ProgramOption) Program {
    19  	var p Program
    20  	if term.IsTerminal(int(os.Stdin.Fd())) {
    21  		p = tea.NewProgram(model, opts...)
    22  	} else {
    23  		p = newFakeProgram(model)
    24  	}
    25  	return p
    26  }
    27  
    28  // An interface describing the parts of BubbleTea's Program that we actually use.
    29  type Program interface {
    30  	Start() error
    31  	Send(msg tea.Msg)
    32  	Quit()
    33  }
    34  
    35  func newFakeProgram(model tea.Model) *fakeProgram {
    36  	p := &fakeProgram{
    37  		model: model,
    38  	}
    39  	return p
    40  }
    41  
    42  // A dumb text implementation of BubbleTea's Program that allows
    43  // for output to be piped to another program.
    44  type fakeProgram struct {
    45  	model tea.Model
    46  }
    47  
    48  func (p *fakeProgram) Start() error {
    49  	initCmd := p.model.Init()
    50  	message := initCmd()
    51  	if message != nil {
    52  		p.model.Update(message)
    53  	}
    54  	return nil
    55  }
    56  
    57  func (p *fakeProgram) Send(msg tea.Msg) {
    58  	switch msg := msg.(type) {
    59  	case StatusMsg:
    60  		fmt.Println(msg)
    61  	case PsqlMsg:
    62  		if msg != nil {
    63  			fmt.Println(*msg)
    64  		}
    65  	}
    66  
    67  	_, cmd := p.model.Update(msg)
    68  	if cmd != nil {
    69  		cmd()
    70  	}
    71  }
    72  
    73  func (p *fakeProgram) Quit() {
    74  	p.Send(tea.Quit())
    75  }
    76  
    77  type (
    78  	StatusMsg   string
    79  	ProgressMsg *float64
    80  	PsqlMsg     *string
    81  )
    82  
    83  type StatusWriter struct {
    84  	Program
    85  }
    86  
    87  func (t StatusWriter) Write(p []byte) (int, error) {
    88  	trimmed := bytes.TrimRight(p, "\n")
    89  	t.Send(StatusMsg(trimmed))
    90  	return len(p), nil
    91  }
    92  
    93  func RunProgram(ctx context.Context, f func(p Program, ctx context.Context) error) error {
    94  	ctx, cancel := context.WithCancel(ctx)
    95  	p := NewProgram(logModel{
    96  		cancel: cancel,
    97  		spinner: spinner.New(
    98  			spinner.WithSpinner(spinner.Dot),
    99  			spinner.WithStyle(lipgloss.NewStyle().Foreground(lipgloss.Color("205"))),
   100  		),
   101  	})
   102  
   103  	errCh := make(chan error, 1)
   104  	go func() {
   105  		errCh <- f(p, ctx)
   106  		p.Quit()
   107  	}()
   108  
   109  	if err := p.Start(); err != nil {
   110  		return err
   111  	}
   112  	return <-errCh
   113  }
   114  
   115  type logModel struct {
   116  	cancel context.CancelFunc
   117  
   118  	spinner     spinner.Model
   119  	status      string
   120  	progress    *progress.Model
   121  	psqlOutputs []string
   122  
   123  	width int
   124  }
   125  
   126  func (m logModel) Init() tea.Cmd {
   127  	return m.spinner.Tick
   128  }
   129  
   130  func (m logModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
   131  	switch msg := msg.(type) {
   132  	case tea.KeyMsg:
   133  		switch msg.Type {
   134  		case tea.KeyCtrlC:
   135  			if m.cancel != nil {
   136  				m.cancel()
   137  			}
   138  			return m, tea.Quit
   139  		default:
   140  			return m, nil
   141  		}
   142  	case tea.WindowSizeMsg:
   143  		m.width = msg.Width
   144  		return m, nil
   145  	case spinner.TickMsg:
   146  		spinnerModel, cmd := m.spinner.Update(msg)
   147  		m.spinner = spinnerModel
   148  		return m, cmd
   149  	case progress.FrameMsg:
   150  		if m.progress == nil {
   151  			return m, nil
   152  		}
   153  
   154  		tmp, cmd := m.progress.Update(msg)
   155  		progressModel := tmp.(progress.Model)
   156  		m.progress = &progressModel
   157  		return m, cmd
   158  	case StatusMsg:
   159  		m.status = string(msg)
   160  		return m, nil
   161  	case ProgressMsg:
   162  		if msg == nil {
   163  			m.progress = nil
   164  			return m, nil
   165  		}
   166  
   167  		if m.progress == nil {
   168  			progressModel := progress.New(progress.WithGradient("#1c1c1c", "#34b27b"))
   169  			m.progress = &progressModel
   170  		}
   171  
   172  		return m, m.progress.SetPercent(*msg)
   173  	case PsqlMsg:
   174  		if msg == nil {
   175  			m.psqlOutputs = []string{}
   176  			return m, nil
   177  		}
   178  
   179  		m.psqlOutputs = append(m.psqlOutputs, *msg)
   180  		if len(m.psqlOutputs) > 5 {
   181  			m.psqlOutputs = m.psqlOutputs[1:]
   182  		}
   183  		return m, nil
   184  	default:
   185  		return m, nil
   186  	}
   187  }
   188  
   189  func (m logModel) View() string {
   190  	var progress string
   191  	if m.progress != nil {
   192  		progress = "\n\n" + m.progress.View()
   193  	}
   194  
   195  	var psqlOutputs string
   196  	if len(m.psqlOutputs) > 0 {
   197  		psqlOutputs = "\n\n" + strings.Join(m.psqlOutputs, "\n")
   198  	}
   199  
   200  	return wrap.String(m.spinner.View()+m.status+progress+psqlOutputs, m.width)
   201  }