github.com/google/syzkaller@v0.0.0-20240517125934-c0f1611a36d6/pkg/flatrpc/conn.go (about)

     1  // Copyright 2024 syzkaller project authors. All rights reserved.
     2  // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
     3  
     4  package flatrpc
     5  
     6  import (
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"net"
    11  	"os"
    12  	"slices"
    13  	"sync"
    14  	"time"
    15  
    16  	flatbuffers "github.com/google/flatbuffers/go"
    17  	"github.com/google/syzkaller/pkg/log"
    18  	"github.com/google/syzkaller/pkg/stats"
    19  )
    20  
    21  var (
    22  	statSent = stats.Create("rpc sent", "Outbound RPC traffic",
    23  		stats.Graph("traffic"), stats.Rate{}, stats.FormatMB)
    24  	statRecv = stats.Create("rpc recv", "Inbound RPC traffic",
    25  		stats.Graph("traffic"), stats.Rate{}, stats.FormatMB)
    26  )
    27  
    28  type Serv struct {
    29  	Addr *net.TCPAddr
    30  	ln   net.Listener
    31  }
    32  
    33  func ListenAndServe(addr string, handler func(*Conn)) (*Serv, error) {
    34  	ln, err := net.Listen("tcp", addr)
    35  	if err != nil {
    36  		return nil, err
    37  	}
    38  	go func() {
    39  		for {
    40  			conn, err := ln.Accept()
    41  			if err != nil {
    42  				if errors.Is(err, net.ErrClosed) {
    43  					break
    44  				}
    45  				var netErr *net.OpError
    46  				if errors.As(err, &netErr) && !netErr.Temporary() {
    47  					log.Fatalf("flatrpc: failed to accept: %v", err)
    48  				}
    49  				log.Logf(0, "flatrpc: failed to accept: %v", err)
    50  				continue
    51  			}
    52  			go func() {
    53  				c := newConn(conn)
    54  				defer c.Close()
    55  				handler(c)
    56  			}()
    57  		}
    58  	}()
    59  	return &Serv{
    60  		Addr: ln.Addr().(*net.TCPAddr),
    61  		ln:   ln,
    62  	}, nil
    63  }
    64  
    65  func (s *Serv) Close() error {
    66  	return s.ln.Close()
    67  }
    68  
    69  type Conn struct {
    70  	conn net.Conn
    71  
    72  	sendMu  sync.Mutex
    73  	builder *flatbuffers.Builder
    74  
    75  	data    []byte
    76  	hasData int
    77  	lastMsg int
    78  }
    79  
    80  func Dial(addr string, timeScale time.Duration) (*Conn, error) {
    81  	var conn net.Conn
    82  	var err error
    83  	if addr == "stdin" {
    84  		// This is used by vm/gvisor which passes us a unix socket connection in stdin.
    85  		conn, err = net.FileConn(os.Stdin)
    86  	} else {
    87  		conn, err = net.DialTimeout("tcp", addr, time.Minute*timeScale)
    88  	}
    89  	if err != nil {
    90  		return nil, err
    91  	}
    92  	return newConn(conn), nil
    93  }
    94  
    95  func newConn(conn net.Conn) *Conn {
    96  	return &Conn{
    97  		conn:    conn,
    98  		builder: flatbuffers.NewBuilder(0),
    99  	}
   100  }
   101  
   102  func (c *Conn) Close() error {
   103  	return c.conn.Close()
   104  }
   105  
   106  type sendMsg interface {
   107  	Pack(*flatbuffers.Builder) flatbuffers.UOffsetT
   108  }
   109  
   110  // Send sends an RPC message.
   111  // The type T is supposed to be an "object API" type ending with T (e.g. ConnectRequestT).
   112  // Sending can be done from multiple goroutines concurrently.
   113  func Send[T sendMsg](c *Conn, msg T) error {
   114  	c.sendMu.Lock()
   115  	defer c.sendMu.Unlock()
   116  	off := msg.Pack(c.builder)
   117  	c.builder.FinishSizePrefixed(off)
   118  	data := c.builder.FinishedBytes()
   119  	_, err := c.conn.Write(data)
   120  	c.builder.Reset()
   121  	statSent.Add(len(data))
   122  	if err != nil {
   123  		return fmt.Errorf("failed to send %T: %w", msg, err)
   124  	}
   125  	return nil
   126  }
   127  
   128  // Recv received an RPC message.
   129  // The type T is supposed to be a normal flatbuffers type (not ending with T, e.g. ConnectRequest).
   130  // Receiving should be done from a single goroutine, the received message is valid
   131  // only until the next Recv call (messages share the same underlying receive buffer).
   132  func Recv[T any, PT interface {
   133  	*T
   134  	flatbuffers.FlatBuffer
   135  }](c *Conn) (*T, error) {
   136  	// First, discard the previous message.
   137  	// For simplicity we copy any data from the next message to the beginning of the buffer.
   138  	// Theoretically we could something more efficient, e.g. don't copy if we already
   139  	// have a full next message.
   140  	if c.hasData > c.lastMsg {
   141  		copy(c.data, c.data[c.lastMsg:c.hasData])
   142  	}
   143  	c.hasData -= c.lastMsg
   144  	c.lastMsg = 0
   145  	const (
   146  		sizePrefixSize = flatbuffers.SizeUint32
   147  		maxMessageSize = 64 << 20
   148  	)
   149  	msg := PT(new(T))
   150  	// Then, receive at least the size prefix (4 bytes).
   151  	// And then the full message, if we have not got it yet.
   152  	if err := c.recv(sizePrefixSize); err != nil {
   153  		return nil, fmt.Errorf("failed to recv %T: %w", msg, err)
   154  	}
   155  	size := int(flatbuffers.GetSizePrefix(c.data, 0))
   156  	if size > maxMessageSize {
   157  		return nil, fmt.Errorf("message %T has too large size %v", msg, size)
   158  	}
   159  	c.lastMsg = sizePrefixSize + size
   160  	if err := c.recv(c.lastMsg); err != nil {
   161  		return nil, fmt.Errorf("failed to recv %T: %w", msg, err)
   162  	}
   163  	statRecv.Add(c.lastMsg)
   164  	data := c.data[sizePrefixSize:c.lastMsg]
   165  	msg.Init(data, flatbuffers.GetUOffsetT(data))
   166  	return msg, nil
   167  }
   168  
   169  // recv ensures that we have at least 'size' bytes received in c.data.
   170  func (c *Conn) recv(size int) error {
   171  	need := size - c.hasData
   172  	if need <= 0 {
   173  		return nil
   174  	}
   175  	if grow := size - len(c.data) + c.hasData; grow > 0 {
   176  		c.data = slices.Grow(c.data, grow)[:len(c.data)+grow]
   177  	}
   178  	n, err := io.ReadAtLeast(c.conn, c.data[c.hasData:], need)
   179  	if err != nil {
   180  		return err
   181  	}
   182  	c.hasData += n
   183  	return nil
   184  }