github.com/google/syzkaller@v0.0.0-20240517125934-c0f1611a36d6/vm/vmimpl/merger.go (about)

     1  // Copyright 2016 syzkaller project authors. All rights reserved.
     2  // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
     3  
     4  package vmimpl
     5  
     6  import (
     7  	"bytes"
     8  	"fmt"
     9  	"io"
    10  	"sync"
    11  )
    12  
    13  type OutputMerger struct {
    14  	Output chan []byte
    15  	Err    chan error
    16  	teeMu  sync.Mutex
    17  	tee    io.Writer
    18  	wg     sync.WaitGroup
    19  }
    20  
    21  type MergerError struct {
    22  	Name string
    23  	R    io.ReadCloser
    24  	Err  error
    25  }
    26  
    27  func (err MergerError) Error() string {
    28  	return fmt.Sprintf("failed to read from %v: %v", err.Name, err.Err)
    29  }
    30  
    31  func NewOutputMerger(tee io.Writer) *OutputMerger {
    32  	return &OutputMerger{
    33  		Output: make(chan []byte, 1000),
    34  		Err:    make(chan error, 1),
    35  		tee:    tee,
    36  	}
    37  }
    38  
    39  func (merger *OutputMerger) Wait() {
    40  	merger.wg.Wait()
    41  	close(merger.Output)
    42  }
    43  
    44  func (merger *OutputMerger) Add(name string, r io.ReadCloser) {
    45  	merger.AddDecoder(name, r, nil)
    46  }
    47  
    48  func (merger *OutputMerger) AddDecoder(name string, r io.ReadCloser,
    49  	decoder func(data []byte) (start, size int, decoded []byte)) {
    50  	merger.wg.Add(1)
    51  	go func() {
    52  		var pending []byte
    53  		var proto []byte
    54  		var buf [4 << 10]byte
    55  		for {
    56  			n, err := r.Read(buf[:])
    57  			if n != 0 {
    58  				if decoder != nil {
    59  					proto = append(proto, buf[:n]...)
    60  					start, size, decoded := decoder(proto)
    61  					proto = proto[start+size:]
    62  					if len(decoded) != 0 {
    63  						merger.Output <- decoded // note: this can block
    64  					}
    65  				}
    66  				// Remove all carriage returns.
    67  				buf := buf[:n]
    68  				if bytes.IndexByte(buf, '\r') != -1 {
    69  					buf = bytes.ReplaceAll(buf, []byte("\r"), nil)
    70  				}
    71  				pending = append(pending, buf...)
    72  				if pos := bytes.LastIndexByte(pending, '\n'); pos != -1 {
    73  					out := pending[:pos+1]
    74  					if merger.tee != nil {
    75  						merger.teeMu.Lock()
    76  						merger.tee.Write(out)
    77  						merger.teeMu.Unlock()
    78  					}
    79  					select {
    80  					case merger.Output <- append([]byte{}, out...):
    81  						r := copy(pending, pending[pos+1:])
    82  						pending = pending[:r]
    83  					default:
    84  					}
    85  				}
    86  			}
    87  			if err != nil {
    88  				if len(pending) != 0 {
    89  					pending = append(pending, '\n')
    90  					if merger.tee != nil {
    91  						merger.teeMu.Lock()
    92  						merger.tee.Write(pending)
    93  						merger.teeMu.Unlock()
    94  					}
    95  					select {
    96  					case merger.Output <- pending:
    97  					default:
    98  					}
    99  				}
   100  				r.Close()
   101  				select {
   102  				case merger.Err <- MergerError{name, r, err}:
   103  				default:
   104  				}
   105  				merger.wg.Done()
   106  				return
   107  			}
   108  		}
   109  	}()
   110  }