github.com/google/syzkaller@v0.0.0-20251211124644-a066d2bc4b02/pkg/flatrpc/conn.go (about)

     1  // Copyright 2024 syzkaller project authors. All rights reserved.
     2  // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
     3  
     4  package flatrpc
     5  
     6  import (
     7  	"context"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"net"
    12  	"reflect"
    13  	"slices"
    14  	"sync"
    15  	"unsafe"
    16  
    17  	flatbuffers "github.com/google/flatbuffers/go"
    18  	"github.com/google/syzkaller/pkg/log"
    19  	"github.com/google/syzkaller/pkg/stat"
    20  	"golang.org/x/sync/errgroup"
    21  )
    22  
    23  var (
    24  	statSent = stat.New("rpc sent", "Outbound RPC traffic",
    25  		stat.Graph("traffic"), stat.Rate{}, stat.FormatMB)
    26  	statRecv = stat.New("rpc recv", "Inbound RPC traffic",
    27  		stat.Graph("traffic"), stat.Rate{}, stat.FormatMB)
    28  )
    29  
    30  type Serv struct {
    31  	Addr *net.TCPAddr
    32  	ln   net.Listener
    33  }
    34  
    35  func Listen(addr string) (*Serv, error) {
    36  	ln, err := net.Listen("tcp", addr)
    37  	if err != nil {
    38  		return nil, err
    39  	}
    40  	return &Serv{
    41  		Addr: ln.Addr().(*net.TCPAddr),
    42  		ln:   ln,
    43  	}, nil
    44  }
    45  
    46  // Serve accepts incoming connections and calls handler for each of them.
    47  // An error returned from the handler stops the server and aborts the whole processing.
    48  func (s *Serv) Serve(baseCtx context.Context, handler func(context.Context, *Conn) error) error {
    49  	eg, ctx := errgroup.WithContext(baseCtx)
    50  	go func() {
    51  		// If the context is cancelled, stop the server.
    52  		<-ctx.Done()
    53  		s.Close()
    54  	}()
    55  	for {
    56  		conn, err := s.ln.Accept()
    57  		if err != nil && errors.Is(err, net.ErrClosed) {
    58  			break
    59  		}
    60  		if err != nil {
    61  			var netErr *net.OpError
    62  			if errors.As(err, &netErr) && !netErr.Temporary() {
    63  				return fmt.Errorf("flatrpc: failed to accept: %w", err)
    64  			}
    65  			log.Logf(0, "flatrpc: failed to accept: %v", err)
    66  			continue
    67  		}
    68  		eg.Go(func() error {
    69  			connCtx, cancel := context.WithCancel(ctx)
    70  			defer cancel()
    71  
    72  			c := NewConn(conn)
    73  			// Closing the server does not automatically close all the connections.
    74  			go func() {
    75  				<-connCtx.Done()
    76  				c.Close()
    77  			}()
    78  			return handler(connCtx, c)
    79  		})
    80  	}
    81  	return eg.Wait()
    82  }
    83  
    84  func (s *Serv) Close() error {
    85  	return s.ln.Close()
    86  }
    87  
    88  type Conn struct {
    89  	conn net.Conn
    90  
    91  	sendMu  sync.Mutex
    92  	builder *flatbuffers.Builder
    93  
    94  	data    []byte
    95  	hasData int
    96  	lastMsg int
    97  }
    98  
    99  func NewConn(conn net.Conn) *Conn {
   100  	return &Conn{
   101  		conn:    conn,
   102  		builder: flatbuffers.NewBuilder(0),
   103  	}
   104  }
   105  
   106  func (c *Conn) Close() error {
   107  	return c.conn.Close()
   108  }
   109  
   110  type sendMsg interface {
   111  	Pack(*flatbuffers.Builder) flatbuffers.UOffsetT
   112  }
   113  
   114  // Send sends an RPC message.
   115  // The type T is supposed to be an "object API" type ending with T (e.g. ConnectRequestT).
   116  // Sending can be done from multiple goroutines concurrently.
   117  func Send[T sendMsg](c *Conn, msg T) error {
   118  	c.sendMu.Lock()
   119  	defer c.sendMu.Unlock()
   120  	off := msg.Pack(c.builder)
   121  	c.builder.FinishSizePrefixed(off)
   122  	data := c.builder.FinishedBytes()
   123  	_, err := c.conn.Write(data)
   124  	c.builder.Reset()
   125  	statSent.Add(len(data))
   126  	if err != nil {
   127  		return fmt.Errorf("failed to send %T: %w", msg, err)
   128  	}
   129  	return nil
   130  }
   131  
   132  type RecvType[T any] interface {
   133  	UnPack() *T
   134  	flatbuffers.FlatBuffer
   135  }
   136  
   137  // Recv receives an RPC message.
   138  // The type T is supposed to be a pointer to a normal flatbuffers type (not ending with T, e.g. *ConnectRequestRaw).
   139  // Receiving should be done from a single goroutine, the received message is valid
   140  // only until the next Recv call (messages share the same underlying receive buffer).
   141  func Recv[Raw RecvType[T], T any](c *Conn) (res *T, err0 error) {
   142  	// First, discard the previous message.
   143  	// For simplicity we copy any data from the next message to the beginning of the buffer.
   144  	// Theoretically we could something more efficient, e.g. don't copy if we already
   145  	// have a full next message.
   146  	if c.hasData > c.lastMsg {
   147  		copy(c.data, c.data[c.lastMsg:c.hasData])
   148  	}
   149  	c.hasData -= c.lastMsg
   150  	c.lastMsg = 0
   151  	const (
   152  		sizePrefixSize = flatbuffers.SizeUint32
   153  		maxMessageSize = 64 << 20
   154  	)
   155  	// Then, receive at least the size prefix (4 bytes).
   156  	// And then the full message, if we have not got it yet.
   157  	if err := c.recv(sizePrefixSize); err != nil {
   158  		return nil, fmt.Errorf("failed to recv %T: %w", (*T)(nil), err)
   159  	}
   160  	size := int(flatbuffers.GetSizePrefix(c.data, 0))
   161  	if size > maxMessageSize {
   162  		return nil, fmt.Errorf("message %T has too large size %v", (*T)(nil), size)
   163  	}
   164  	c.lastMsg = sizePrefixSize + size
   165  	if err := c.recv(c.lastMsg); err != nil {
   166  		return nil, fmt.Errorf("failed to recv %T: %w", (*T)(nil), err)
   167  	}
   168  	return Parse[Raw](c.data[sizePrefixSize:c.lastMsg])
   169  }
   170  
   171  // recv ensures that we have at least 'size' bytes received in c.data.
   172  func (c *Conn) recv(size int) error {
   173  	need := size - c.hasData
   174  	if need <= 0 {
   175  		return nil
   176  	}
   177  	if grow := size - len(c.data) + c.hasData; grow > 0 {
   178  		c.data = slices.Grow(c.data, grow)[:len(c.data)+grow]
   179  	}
   180  	n, err := io.ReadAtLeast(c.conn, c.data[c.hasData:], need)
   181  	if err != nil {
   182  		return err
   183  	}
   184  	c.hasData += n
   185  	return nil
   186  }
   187  
   188  func Parse[Raw RecvType[T], T any](data []byte) (res *T, err0 error) {
   189  	defer func() {
   190  		if err := recover(); err != nil {
   191  			err0 = fmt.Errorf("%v", err)
   192  		}
   193  	}()
   194  	statRecv.Add(len(data))
   195  	// This probably can be expressed w/o reflect as "new U" where U is *T,
   196  	// but I failed to express that as generic constraints.
   197  	var msg Raw
   198  	msg = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(Raw)
   199  	msg.Init(data, flatbuffers.GetUOffsetT(data))
   200  	if err := verify(msg, len(data)); err != nil {
   201  		return nil, err
   202  	}
   203  	return msg.UnPack(), nil
   204  }
   205  
   206  func verify(raw any, rawSize int) error {
   207  	switch msg := raw.(type) {
   208  	case *ExecutorMessageRaw:
   209  		return verifyExecutorMessage(msg, rawSize)
   210  	}
   211  	return nil
   212  }
   213  
   214  func verifyExecutorMessage(raw *ExecutorMessageRaw, rawSize int) error {
   215  	// We receive the message into raw (non object API) type and carefully verify
   216  	// because the message from the test machine can be corrupted in all possible ways.
   217  	// Recovering from panics handles most corruptions (since flatbuffers does not use unsafe
   218  	// and panics on any OOB references). But it's still possible that UnPack may try to allocate
   219  	// unbounded amount of memory and crash with OOM. To prevent that we check that arrays have
   220  	// reasonable size. We don't need to check []byte/string b/c for them flatbuffers use
   221  	// Table.ByteVector which directly references the underlying byte slice and also panics
   222  	// if size is OOB. But we need to check all other arrays b/c for them flatbuffers will
   223  	// first do make([]T, size), filling that array later will panic, but it's already too late
   224  	// since the make will kill the process with OOM.
   225  	switch typ := raw.MsgType(); typ {
   226  	case ExecutorMessagesRawExecResult,
   227  		ExecutorMessagesRawExecuting,
   228  		ExecutorMessagesRawState:
   229  	default:
   230  		return fmt.Errorf("bad executor message type %v", typ)
   231  	}
   232  	var tab flatbuffers.Table
   233  	if !raw.Msg(&tab) {
   234  		return errors.New("received no message")
   235  	}
   236  	// Only ExecResult has arrays.
   237  	if raw.MsgType() == ExecutorMessagesRawExecResult {
   238  		var res ExecResultRaw
   239  		res.Init(tab.Bytes, tab.Pos)
   240  		return verifyExecResult(&res, rawSize)
   241  	}
   242  	return nil
   243  }
   244  
   245  func verifyExecResult(res *ExecResultRaw, rawSize int) error {
   246  	info := res.Info(nil)
   247  	if info == nil {
   248  		return nil
   249  	}
   250  	var tmp ComparisonRaw
   251  	// It's hard to impose good limit on each individual signal/cover/comps array,
   252  	// so instead we count total memory size for all calls and check that it's not
   253  	// larger than the total message size.
   254  	callSize := func(call *CallInfoRaw) int {
   255  		// Cap array size at 1G to prevent overflows during multiplication by size and addition.
   256  		const maxSize = 1 << 30
   257  		size := 0
   258  		if call.SignalLength() != 0 {
   259  			size += min(maxSize, call.SignalLength()) * int(unsafe.Sizeof(call.Signal(0)))
   260  		}
   261  		if call.CoverLength() != 0 {
   262  			size += min(maxSize, call.CoverLength()) * int(unsafe.Sizeof(call.Cover(0)))
   263  		}
   264  		if call.CompsLength() != 0 {
   265  			size += min(maxSize, call.CompsLength()) * int(unsafe.Sizeof(call.Comps(&tmp, 0)))
   266  		}
   267  		return size
   268  	}
   269  	size := 0
   270  	var call CallInfoRaw
   271  	for i := 0; i < info.CallsLength(); i++ {
   272  		if info.Calls(&call, i) {
   273  			size += callSize(&call)
   274  		}
   275  	}
   276  	for i := 0; i < info.ExtraRawLength(); i++ {
   277  		if info.ExtraRaw(&call, i) {
   278  			size += callSize(&call)
   279  		}
   280  	}
   281  	if info.Extra(&call) != nil {
   282  		size += callSize(&call)
   283  	}
   284  	if size > rawSize {
   285  		return fmt.Errorf("corrupted message: total size %v, size of elements %v",
   286  			rawSize, size)
   287  	}
   288  	return nil
   289  }