github.com/pdmccormick/importable-docker-buildx@v0.0.0-20240426161518-e47091289030/util/ioset/mux.go (about)

     1  package ioset
     2  
     3  import (
     4  	"bufio"
     5  	"fmt"
     6  	"io"
     7  	"sync"
     8  
     9  	"github.com/pkg/errors"
    10  	"github.com/sirupsen/logrus"
    11  )
    12  
    13  type MuxOut struct {
    14  	Out
    15  	EnableHook  func()
    16  	DisableHook func()
    17  }
    18  
    19  // NewMuxIO forwards IO stream to/from "in" and "outs".
    20  // It toggles IO when it detects "C-a-c" key.
    21  // "outs" are closed automatically when "in" reaches EOF.
    22  // "in" doesn't closed automatically so the caller needs to explicitly close it.
    23  func NewMuxIO(in In, outs []MuxOut, initIdx int, toggleMessage func(prev int, res int) string) *MuxIO {
    24  	m := &MuxIO{
    25  		enabled:       make(map[int]struct{}),
    26  		in:            in,
    27  		outs:          outs,
    28  		closedCh:      make(chan struct{}),
    29  		toggleMessage: toggleMessage,
    30  	}
    31  	for i := range outs {
    32  		m.enabled[i] = struct{}{}
    33  	}
    34  	m.maxCur = len(outs)
    35  	m.cur = initIdx
    36  	var wg sync.WaitGroup
    37  	var mu sync.Mutex
    38  	for i, o := range outs {
    39  		i, o := i, o
    40  		wg.Add(1)
    41  		go func() {
    42  			defer wg.Done()
    43  			if err := copyToFunc(o.Stdout, func() (io.Writer, error) {
    44  				if m.cur == i {
    45  					return in.Stdout, nil
    46  				}
    47  				return nil, nil
    48  			}); err != nil {
    49  				logrus.WithField("output index", i).WithError(err).Warnf("failed to write stdout")
    50  			}
    51  			if err := o.Stdout.Close(); err != nil {
    52  				logrus.WithField("output index", i).WithError(err).Warnf("failed to close stdout")
    53  			}
    54  		}()
    55  		wg.Add(1)
    56  		go func() {
    57  			defer wg.Done()
    58  			if err := copyToFunc(o.Stderr, func() (io.Writer, error) {
    59  				if m.cur == i {
    60  					return in.Stderr, nil
    61  				}
    62  				return nil, nil
    63  			}); err != nil {
    64  				logrus.WithField("output index", i).WithError(err).Warnf("failed to write stderr")
    65  			}
    66  			if err := o.Stderr.Close(); err != nil {
    67  				logrus.WithField("output index", i).WithError(err).Warnf("failed to close stderr")
    68  			}
    69  		}()
    70  	}
    71  	go func() {
    72  		errToggle := errors.Errorf("toggle IO")
    73  		for {
    74  			prevIsControlSequence := false
    75  			if err := copyToFunc(traceReader(in.Stdin, func(r rune) (bool, error) {
    76  				// Toggle IO when it detects C-a-c
    77  				// TODO: make it configurable if needed
    78  				if int(r) == 1 {
    79  					prevIsControlSequence = true
    80  					return false, nil
    81  				}
    82  				defer func() { prevIsControlSequence = false }()
    83  				if prevIsControlSequence {
    84  					if string(r) == "c" {
    85  						return false, errToggle
    86  					}
    87  				}
    88  				return true, nil
    89  			}), func() (io.Writer, error) {
    90  				mu.Lock()
    91  				o := outs[m.cur]
    92  				mu.Unlock()
    93  				return o.Stdin, nil
    94  			}); !errors.Is(err, errToggle) {
    95  				if err != nil {
    96  					logrus.WithError(err).Warnf("failed to read stdin")
    97  				}
    98  				break
    99  			}
   100  			m.toggleIO()
   101  		}
   102  
   103  		// propagate stdin EOF
   104  		for i, o := range outs {
   105  			if err := o.Stdin.Close(); err != nil {
   106  				logrus.WithError(err).Warnf("failed to close stdin of %d", i)
   107  			}
   108  		}
   109  		wg.Wait()
   110  		close(m.closedCh)
   111  	}()
   112  	return m
   113  }
   114  
   115  type MuxIO struct {
   116  	cur           int
   117  	maxCur        int
   118  	enabled       map[int]struct{}
   119  	mu            sync.Mutex
   120  	in            In
   121  	outs          []MuxOut
   122  	closedCh      chan struct{}
   123  	toggleMessage func(prev int, res int) string
   124  }
   125  
   126  func (m *MuxIO) waitClosed() {
   127  	<-m.closedCh
   128  }
   129  
   130  func (m *MuxIO) Enable(i int) {
   131  	m.mu.Lock()
   132  	defer m.mu.Unlock()
   133  	m.enabled[i] = struct{}{}
   134  }
   135  
   136  func (m *MuxIO) SwitchTo(i int) error {
   137  	m.mu.Lock()
   138  	defer m.mu.Unlock()
   139  	if m.cur == i {
   140  		return nil
   141  	}
   142  	if _, ok := m.enabled[i]; !ok {
   143  		return errors.Errorf("IO index %d isn't active", i)
   144  	}
   145  	if m.outs[m.cur].DisableHook != nil {
   146  		m.outs[m.cur].DisableHook()
   147  	}
   148  	prev := m.cur
   149  	m.cur = i
   150  	if m.outs[m.cur].EnableHook != nil {
   151  		m.outs[m.cur].EnableHook()
   152  	}
   153  	fmt.Fprint(m.in.Stdout, m.toggleMessage(prev, i))
   154  	return nil
   155  }
   156  
   157  func (m *MuxIO) Disable(i int) error {
   158  	m.mu.Lock()
   159  	defer m.mu.Unlock()
   160  	if i == 0 {
   161  		return errors.Errorf("disabling 0th io is prohibited")
   162  	}
   163  	delete(m.enabled, i)
   164  	if m.cur == i {
   165  		m.toggleIO()
   166  	}
   167  	return nil
   168  }
   169  
   170  func (m *MuxIO) toggleIO() {
   171  	if m.outs[m.cur].DisableHook != nil {
   172  		m.outs[m.cur].DisableHook()
   173  	}
   174  	prev := m.cur
   175  	for {
   176  		if m.cur+1 >= m.maxCur {
   177  			m.cur = 0
   178  		} else {
   179  			m.cur++
   180  		}
   181  		if _, ok := m.enabled[m.cur]; !ok {
   182  			continue
   183  		}
   184  		break
   185  	}
   186  	res := m.cur
   187  	if m.outs[m.cur].EnableHook != nil {
   188  		m.outs[m.cur].EnableHook()
   189  	}
   190  	fmt.Fprint(m.in.Stdout, m.toggleMessage(prev, res))
   191  }
   192  
   193  func traceReader(r io.ReadCloser, f func(rune) (bool, error)) io.ReadCloser {
   194  	pr, pw := io.Pipe()
   195  	go func() {
   196  		br := bufio.NewReader(r)
   197  		for {
   198  			rn, _, err := br.ReadRune()
   199  			if err != nil {
   200  				if err == io.EOF {
   201  					pw.Close()
   202  					return
   203  				}
   204  				pw.CloseWithError(err)
   205  				return
   206  			}
   207  			if isWrite, err := f(rn); err != nil {
   208  				pw.CloseWithError(err)
   209  				return
   210  			} else if !isWrite {
   211  				continue
   212  			}
   213  			if _, err := pw.Write([]byte(string(rn))); err != nil {
   214  				pw.CloseWithError(err)
   215  				return
   216  			}
   217  		}
   218  	}()
   219  	return &readerWithClose{
   220  		Reader: pr,
   221  		closeFunc: func() error {
   222  			pr.Close()
   223  			return r.Close()
   224  		},
   225  	}
   226  }
   227  
   228  func copyToFunc(r io.Reader, wFunc func() (io.Writer, error)) error {
   229  	buf := make([]byte, 4096)
   230  	for {
   231  		n, readErr := r.Read(buf)
   232  		if readErr != nil && readErr != io.EOF {
   233  			return readErr
   234  		}
   235  		w, err := wFunc()
   236  		if err != nil {
   237  			return err
   238  		}
   239  		if w != nil {
   240  			if _, err := w.Write(buf[:n]); err != nil {
   241  				logrus.WithError(err).Debugf("failed to copy")
   242  			}
   243  		}
   244  		if readErr == io.EOF {
   245  			return nil
   246  		}
   247  	}
   248  }
   249  
   250  type readerWithClose struct {
   251  	io.Reader
   252  	closeFunc func() error
   253  }
   254  
   255  func (r *readerWithClose) Close() error {
   256  	return r.closeFunc()
   257  }