github.com/google/syzkaller@v0.0.0-20240517125934-c0f1611a36d6/pkg/flatrpc/conn.go (about) 1 // Copyright 2024 syzkaller project authors. All rights reserved. 2 // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. 3 4 package flatrpc 5 6 import ( 7 "errors" 8 "fmt" 9 "io" 10 "net" 11 "os" 12 "slices" 13 "sync" 14 "time" 15 16 flatbuffers "github.com/google/flatbuffers/go" 17 "github.com/google/syzkaller/pkg/log" 18 "github.com/google/syzkaller/pkg/stats" 19 ) 20 21 var ( 22 statSent = stats.Create("rpc sent", "Outbound RPC traffic", 23 stats.Graph("traffic"), stats.Rate{}, stats.FormatMB) 24 statRecv = stats.Create("rpc recv", "Inbound RPC traffic", 25 stats.Graph("traffic"), stats.Rate{}, stats.FormatMB) 26 ) 27 28 type Serv struct { 29 Addr *net.TCPAddr 30 ln net.Listener 31 } 32 33 func ListenAndServe(addr string, handler func(*Conn)) (*Serv, error) { 34 ln, err := net.Listen("tcp", addr) 35 if err != nil { 36 return nil, err 37 } 38 go func() { 39 for { 40 conn, err := ln.Accept() 41 if err != nil { 42 if errors.Is(err, net.ErrClosed) { 43 break 44 } 45 var netErr *net.OpError 46 if errors.As(err, &netErr) && !netErr.Temporary() { 47 log.Fatalf("flatrpc: failed to accept: %v", err) 48 } 49 log.Logf(0, "flatrpc: failed to accept: %v", err) 50 continue 51 } 52 go func() { 53 c := newConn(conn) 54 defer c.Close() 55 handler(c) 56 }() 57 } 58 }() 59 return &Serv{ 60 Addr: ln.Addr().(*net.TCPAddr), 61 ln: ln, 62 }, nil 63 } 64 65 func (s *Serv) Close() error { 66 return s.ln.Close() 67 } 68 69 type Conn struct { 70 conn net.Conn 71 72 sendMu sync.Mutex 73 builder *flatbuffers.Builder 74 75 data []byte 76 hasData int 77 lastMsg int 78 } 79 80 func Dial(addr string, timeScale time.Duration) (*Conn, error) { 81 var conn net.Conn 82 var err error 83 if addr == "stdin" { 84 // This is used by vm/gvisor which passes us a unix socket connection in stdin. 85 conn, err = net.FileConn(os.Stdin) 86 } else { 87 conn, err = net.DialTimeout("tcp", addr, time.Minute*timeScale) 88 } 89 if err != nil { 90 return nil, err 91 } 92 return newConn(conn), nil 93 } 94 95 func newConn(conn net.Conn) *Conn { 96 return &Conn{ 97 conn: conn, 98 builder: flatbuffers.NewBuilder(0), 99 } 100 } 101 102 func (c *Conn) Close() error { 103 return c.conn.Close() 104 } 105 106 type sendMsg interface { 107 Pack(*flatbuffers.Builder) flatbuffers.UOffsetT 108 } 109 110 // Send sends an RPC message. 111 // The type T is supposed to be an "object API" type ending with T (e.g. ConnectRequestT). 112 // Sending can be done from multiple goroutines concurrently. 113 func Send[T sendMsg](c *Conn, msg T) error { 114 c.sendMu.Lock() 115 defer c.sendMu.Unlock() 116 off := msg.Pack(c.builder) 117 c.builder.FinishSizePrefixed(off) 118 data := c.builder.FinishedBytes() 119 _, err := c.conn.Write(data) 120 c.builder.Reset() 121 statSent.Add(len(data)) 122 if err != nil { 123 return fmt.Errorf("failed to send %T: %w", msg, err) 124 } 125 return nil 126 } 127 128 // Recv received an RPC message. 129 // The type T is supposed to be a normal flatbuffers type (not ending with T, e.g. ConnectRequest). 130 // Receiving should be done from a single goroutine, the received message is valid 131 // only until the next Recv call (messages share the same underlying receive buffer). 132 func Recv[T any, PT interface { 133 *T 134 flatbuffers.FlatBuffer 135 }](c *Conn) (*T, error) { 136 // First, discard the previous message. 137 // For simplicity we copy any data from the next message to the beginning of the buffer. 138 // Theoretically we could something more efficient, e.g. don't copy if we already 139 // have a full next message. 140 if c.hasData > c.lastMsg { 141 copy(c.data, c.data[c.lastMsg:c.hasData]) 142 } 143 c.hasData -= c.lastMsg 144 c.lastMsg = 0 145 const ( 146 sizePrefixSize = flatbuffers.SizeUint32 147 maxMessageSize = 64 << 20 148 ) 149 msg := PT(new(T)) 150 // Then, receive at least the size prefix (4 bytes). 151 // And then the full message, if we have not got it yet. 152 if err := c.recv(sizePrefixSize); err != nil { 153 return nil, fmt.Errorf("failed to recv %T: %w", msg, err) 154 } 155 size := int(flatbuffers.GetSizePrefix(c.data, 0)) 156 if size > maxMessageSize { 157 return nil, fmt.Errorf("message %T has too large size %v", msg, size) 158 } 159 c.lastMsg = sizePrefixSize + size 160 if err := c.recv(c.lastMsg); err != nil { 161 return nil, fmt.Errorf("failed to recv %T: %w", msg, err) 162 } 163 statRecv.Add(c.lastMsg) 164 data := c.data[sizePrefixSize:c.lastMsg] 165 msg.Init(data, flatbuffers.GetUOffsetT(data)) 166 return msg, nil 167 } 168 169 // recv ensures that we have at least 'size' bytes received in c.data. 170 func (c *Conn) recv(size int) error { 171 need := size - c.hasData 172 if need <= 0 { 173 return nil 174 } 175 if grow := size - len(c.data) + c.hasData; grow > 0 { 176 c.data = slices.Grow(c.data, grow)[:len(c.data)+grow] 177 } 178 n, err := io.ReadAtLeast(c.conn, c.data[c.hasData:], need) 179 if err != nil { 180 return err 181 } 182 c.hasData += n 183 return nil 184 }