github.com/pdmccormick/importable-docker-buildx@v0.0.0-20240426161518-e47091289030/controller/remote/io.go (about)

     1  package remote
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"syscall"
     7  	"time"
     8  
     9  	"github.com/docker/buildx/controller/pb"
    10  	"github.com/moby/sys/signal"
    11  	"github.com/pkg/errors"
    12  	"github.com/sirupsen/logrus"
    13  	"golang.org/x/sync/errgroup"
    14  )
    15  
    16  type msgStream interface {
    17  	Send(*pb.Message) error
    18  	Recv() (*pb.Message, error)
    19  }
    20  
    21  type ioServerConfig struct {
    22  	stdin          io.WriteCloser
    23  	stdout, stderr io.ReadCloser
    24  
    25  	// signalFn is a callback function called when a signal is reached to the client.
    26  	signalFn func(context.Context, syscall.Signal) error
    27  
    28  	// resizeFn is a callback function called when a resize event is reached to the client.
    29  	resizeFn func(context.Context, winSize) error
    30  }
    31  
    32  func serveIO(attachCtx context.Context, srv msgStream, initFn func(*pb.InitMessage) error, ioConfig *ioServerConfig) (err error) {
    33  	stdin, stdout, stderr := ioConfig.stdin, ioConfig.stdout, ioConfig.stderr
    34  	stream := &debugStream{srv, "server=" + time.Now().String()}
    35  	eg, ctx := errgroup.WithContext(attachCtx)
    36  	done := make(chan struct{})
    37  
    38  	msg, err := receive(ctx, stream)
    39  	if err != nil {
    40  		return err
    41  	}
    42  	init := msg.GetInit()
    43  	if init == nil {
    44  		return errors.Errorf("unexpected message: %T; wanted init", msg.GetInput())
    45  	}
    46  	ref := init.Ref
    47  	if ref == "" {
    48  		return errors.New("no ref is provided")
    49  	}
    50  	if err := initFn(init); err != nil {
    51  		return errors.Wrap(err, "failed to initialize IO server")
    52  	}
    53  
    54  	if stdout != nil {
    55  		stdoutReader, stdoutWriter := io.Pipe()
    56  		eg.Go(func() error {
    57  			<-done
    58  			return stdoutWriter.Close()
    59  		})
    60  
    61  		go func() {
    62  			// do not wait for read completion but return here and let the caller send EOF
    63  			// this allows us to return on ctx.Done() without being blocked by this reader.
    64  			io.Copy(stdoutWriter, stdout)
    65  			stdoutWriter.Close()
    66  		}()
    67  
    68  		eg.Go(func() error {
    69  			defer stdoutReader.Close()
    70  			return copyToStream(1, stream, stdoutReader)
    71  		})
    72  	}
    73  
    74  	if stderr != nil {
    75  		stderrReader, stderrWriter := io.Pipe()
    76  		eg.Go(func() error {
    77  			<-done
    78  			return stderrWriter.Close()
    79  		})
    80  
    81  		go func() {
    82  			// do not wait for read completion but return here and let the caller send EOF
    83  			// this allows us to return on ctx.Done() without being blocked by this reader.
    84  			io.Copy(stderrWriter, stderr)
    85  			stderrWriter.Close()
    86  		}()
    87  
    88  		eg.Go(func() error {
    89  			defer stderrReader.Close()
    90  			return copyToStream(2, stream, stderrReader)
    91  		})
    92  	}
    93  
    94  	msgCh := make(chan *pb.Message)
    95  	eg.Go(func() error {
    96  		defer close(msgCh)
    97  		for {
    98  			msg, err := receive(ctx, stream)
    99  			if err != nil {
   100  				return err
   101  			}
   102  			select {
   103  			case msgCh <- msg:
   104  			case <-done:
   105  				return nil
   106  			case <-ctx.Done():
   107  				return nil
   108  			}
   109  		}
   110  	})
   111  
   112  	eg.Go(func() error {
   113  		defer close(done)
   114  		for {
   115  			var msg *pb.Message
   116  			select {
   117  			case msg = <-msgCh:
   118  			case <-ctx.Done():
   119  				return nil
   120  			}
   121  			if msg == nil {
   122  				return nil
   123  			}
   124  			if file := msg.GetFile(); file != nil {
   125  				if file.Fd != 0 {
   126  					return errors.Errorf("unexpected fd: %v", file.Fd)
   127  				}
   128  				if stdin == nil {
   129  					continue // no stdin destination is specified so ignore the data
   130  				}
   131  				if len(file.Data) > 0 {
   132  					_, err := stdin.Write(file.Data)
   133  					if err != nil {
   134  						return err
   135  					}
   136  				}
   137  				if file.EOF {
   138  					stdin.Close()
   139  				}
   140  			} else if resize := msg.GetResize(); resize != nil {
   141  				if ioConfig.resizeFn != nil {
   142  					ioConfig.resizeFn(ctx, winSize{
   143  						cols: resize.Cols,
   144  						rows: resize.Rows,
   145  					})
   146  				}
   147  			} else if sig := msg.GetSignal(); sig != nil {
   148  				if ioConfig.signalFn != nil {
   149  					syscallSignal, ok := signal.SignalMap[sig.Name]
   150  					if !ok {
   151  						continue
   152  					}
   153  					ioConfig.signalFn(ctx, syscallSignal)
   154  				}
   155  			} else {
   156  				return errors.Errorf("unexpected message: %T", msg.GetInput())
   157  			}
   158  		}
   159  	})
   160  
   161  	return eg.Wait()
   162  }
   163  
   164  type ioAttachConfig struct {
   165  	stdin          io.ReadCloser
   166  	stdout, stderr io.WriteCloser
   167  	signal         <-chan syscall.Signal
   168  	resize         <-chan winSize
   169  }
   170  
   171  type winSize struct {
   172  	rows uint32
   173  	cols uint32
   174  }
   175  
   176  func attachIO(ctx context.Context, stream msgStream, initMessage *pb.InitMessage, cfg ioAttachConfig) (retErr error) {
   177  	eg, ctx := errgroup.WithContext(ctx)
   178  	done := make(chan struct{})
   179  
   180  	if err := stream.Send(&pb.Message{
   181  		Input: &pb.Message_Init{
   182  			Init: initMessage,
   183  		},
   184  	}); err != nil {
   185  		return errors.Wrap(err, "failed to init")
   186  	}
   187  
   188  	if cfg.stdin != nil {
   189  		stdinReader, stdinWriter := io.Pipe()
   190  		eg.Go(func() error {
   191  			<-done
   192  			return stdinWriter.Close()
   193  		})
   194  
   195  		go func() {
   196  			// do not wait for read completion but return here and let the caller send EOF
   197  			// this allows us to return on ctx.Done() without being blocked by this reader.
   198  			io.Copy(stdinWriter, cfg.stdin)
   199  			stdinWriter.Close()
   200  		}()
   201  
   202  		eg.Go(func() error {
   203  			defer stdinReader.Close()
   204  			return copyToStream(0, stream, stdinReader)
   205  		})
   206  	}
   207  
   208  	if cfg.signal != nil {
   209  		eg.Go(func() error {
   210  			for {
   211  				var sig syscall.Signal
   212  				select {
   213  				case sig = <-cfg.signal:
   214  				case <-done:
   215  					return nil
   216  				case <-ctx.Done():
   217  					return nil
   218  				}
   219  				name := sigToName[sig]
   220  				if name == "" {
   221  					continue
   222  				}
   223  				if err := stream.Send(&pb.Message{
   224  					Input: &pb.Message_Signal{
   225  						Signal: &pb.SignalMessage{
   226  							Name: name,
   227  						},
   228  					},
   229  				}); err != nil {
   230  					return errors.Wrap(err, "failed to send signal")
   231  				}
   232  			}
   233  		})
   234  	}
   235  
   236  	if cfg.resize != nil {
   237  		eg.Go(func() error {
   238  			for {
   239  				var win winSize
   240  				select {
   241  				case win = <-cfg.resize:
   242  				case <-done:
   243  					return nil
   244  				case <-ctx.Done():
   245  					return nil
   246  				}
   247  				if err := stream.Send(&pb.Message{
   248  					Input: &pb.Message_Resize{
   249  						Resize: &pb.ResizeMessage{
   250  							Rows: win.rows,
   251  							Cols: win.cols,
   252  						},
   253  					},
   254  				}); err != nil {
   255  					return errors.Wrap(err, "failed to send resize")
   256  				}
   257  			}
   258  		})
   259  	}
   260  
   261  	msgCh := make(chan *pb.Message)
   262  	eg.Go(func() error {
   263  		defer close(msgCh)
   264  		for {
   265  			msg, err := receive(ctx, stream)
   266  			if err != nil {
   267  				return err
   268  			}
   269  			select {
   270  			case msgCh <- msg:
   271  			case <-done:
   272  				return nil
   273  			case <-ctx.Done():
   274  				return nil
   275  			}
   276  		}
   277  	})
   278  
   279  	eg.Go(func() error {
   280  		eofs := make(map[uint32]struct{})
   281  		defer close(done)
   282  		for {
   283  			var msg *pb.Message
   284  			select {
   285  			case msg = <-msgCh:
   286  			case <-ctx.Done():
   287  				return nil
   288  			}
   289  			if msg == nil {
   290  				return nil
   291  			}
   292  			if file := msg.GetFile(); file != nil {
   293  				if _, ok := eofs[file.Fd]; ok {
   294  					continue
   295  				}
   296  				var out io.WriteCloser
   297  				switch file.Fd {
   298  				case 1:
   299  					out = cfg.stdout
   300  				case 2:
   301  					out = cfg.stderr
   302  				default:
   303  					return errors.Errorf("unsupported fd %d", file.Fd)
   304  
   305  				}
   306  				if out == nil {
   307  					logrus.Warnf("attachIO: no writer for fd %d", file.Fd)
   308  					continue
   309  				}
   310  				if len(file.Data) > 0 {
   311  					if _, err := out.Write(file.Data); err != nil {
   312  						return err
   313  					}
   314  				}
   315  				if file.EOF {
   316  					eofs[file.Fd] = struct{}{}
   317  				}
   318  			} else {
   319  				return errors.Errorf("unexpected message: %T", msg.GetInput())
   320  			}
   321  		}
   322  	})
   323  
   324  	return eg.Wait()
   325  }
   326  
   327  func receive(ctx context.Context, stream msgStream) (*pb.Message, error) {
   328  	msgCh := make(chan *pb.Message)
   329  	errCh := make(chan error)
   330  	go func() {
   331  		msg, err := stream.Recv()
   332  		if err != nil {
   333  			if errors.Is(err, io.EOF) {
   334  				return
   335  			}
   336  			errCh <- err
   337  			return
   338  		}
   339  		msgCh <- msg
   340  	}()
   341  	select {
   342  	case msg := <-msgCh:
   343  		return msg, nil
   344  	case err := <-errCh:
   345  		return nil, err
   346  	case <-ctx.Done():
   347  		return nil, ctx.Err()
   348  	}
   349  }
   350  
   351  func copyToStream(fd uint32, snd msgStream, r io.Reader) error {
   352  	for {
   353  		buf := make([]byte, 32*1024)
   354  		n, err := r.Read(buf)
   355  		if err != nil {
   356  			if err == io.EOF {
   357  				break // break loop and send EOF
   358  			}
   359  			return err
   360  		} else if n > 0 {
   361  			if err := snd.Send(&pb.Message{
   362  				Input: &pb.Message_File{
   363  					File: &pb.FdMessage{
   364  						Fd:   fd,
   365  						Data: buf[:n],
   366  					},
   367  				},
   368  			}); err != nil {
   369  				return err
   370  			}
   371  		}
   372  	}
   373  	return snd.Send(&pb.Message{
   374  		Input: &pb.Message_File{
   375  			File: &pb.FdMessage{
   376  				Fd:  fd,
   377  				EOF: true,
   378  			},
   379  		},
   380  	})
   381  }
   382  
   383  var sigToName = map[syscall.Signal]string{}
   384  
   385  func init() {
   386  	for name, value := range signal.SignalMap {
   387  		sigToName[value] = name
   388  	}
   389  }
   390  
   391  type debugStream struct {
   392  	msgStream
   393  	prefix string
   394  }
   395  
   396  func (s *debugStream) Send(msg *pb.Message) error {
   397  	switch m := msg.GetInput().(type) {
   398  	case *pb.Message_File:
   399  		if m.File.EOF {
   400  			logrus.Debugf("|---> File Message (sender:%v) fd=%d, EOF", s.prefix, m.File.Fd)
   401  		} else {
   402  			logrus.Debugf("|---> File Message (sender:%v) fd=%d, %d bytes", s.prefix, m.File.Fd, len(m.File.Data))
   403  		}
   404  	case *pb.Message_Resize:
   405  		logrus.Debugf("|---> Resize Message (sender:%v): %+v", s.prefix, m.Resize)
   406  	case *pb.Message_Signal:
   407  		logrus.Debugf("|---> Signal Message (sender:%v): %s", s.prefix, m.Signal.Name)
   408  	}
   409  	return s.msgStream.Send(msg)
   410  }
   411  
   412  func (s *debugStream) Recv() (*pb.Message, error) {
   413  	msg, err := s.msgStream.Recv()
   414  	if err != nil {
   415  		return nil, err
   416  	}
   417  	switch m := msg.GetInput().(type) {
   418  	case *pb.Message_File:
   419  		if m.File.EOF {
   420  			logrus.Debugf("|<--- File Message (receiver:%v) fd=%d, EOF", s.prefix, m.File.Fd)
   421  		} else {
   422  			logrus.Debugf("|<--- File Message (receiver:%v) fd=%d, %d bytes", s.prefix, m.File.Fd, len(m.File.Data))
   423  		}
   424  	case *pb.Message_Resize:
   425  		logrus.Debugf("|<--- Resize Message (receiver:%v): %+v", s.prefix, m.Resize)
   426  	case *pb.Message_Signal:
   427  		logrus.Debugf("|<--- Signal Message (receiver:%v): %s", s.prefix, m.Signal.Name)
   428  	}
   429  	return msg, nil
   430  }