github.com/google/syzkaller@v0.0.0-20251211124644-a066d2bc4b02/pkg/flatrpc/conn.go (about) 1 // Copyright 2024 syzkaller project authors. All rights reserved. 2 // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. 3 4 package flatrpc 5 6 import ( 7 "context" 8 "errors" 9 "fmt" 10 "io" 11 "net" 12 "reflect" 13 "slices" 14 "sync" 15 "unsafe" 16 17 flatbuffers "github.com/google/flatbuffers/go" 18 "github.com/google/syzkaller/pkg/log" 19 "github.com/google/syzkaller/pkg/stat" 20 "golang.org/x/sync/errgroup" 21 ) 22 23 var ( 24 statSent = stat.New("rpc sent", "Outbound RPC traffic", 25 stat.Graph("traffic"), stat.Rate{}, stat.FormatMB) 26 statRecv = stat.New("rpc recv", "Inbound RPC traffic", 27 stat.Graph("traffic"), stat.Rate{}, stat.FormatMB) 28 ) 29 30 type Serv struct { 31 Addr *net.TCPAddr 32 ln net.Listener 33 } 34 35 func Listen(addr string) (*Serv, error) { 36 ln, err := net.Listen("tcp", addr) 37 if err != nil { 38 return nil, err 39 } 40 return &Serv{ 41 Addr: ln.Addr().(*net.TCPAddr), 42 ln: ln, 43 }, nil 44 } 45 46 // Serve accepts incoming connections and calls handler for each of them. 47 // An error returned from the handler stops the server and aborts the whole processing. 48 func (s *Serv) Serve(baseCtx context.Context, handler func(context.Context, *Conn) error) error { 49 eg, ctx := errgroup.WithContext(baseCtx) 50 go func() { 51 // If the context is cancelled, stop the server. 52 <-ctx.Done() 53 s.Close() 54 }() 55 for { 56 conn, err := s.ln.Accept() 57 if err != nil && errors.Is(err, net.ErrClosed) { 58 break 59 } 60 if err != nil { 61 var netErr *net.OpError 62 if errors.As(err, &netErr) && !netErr.Temporary() { 63 return fmt.Errorf("flatrpc: failed to accept: %w", err) 64 } 65 log.Logf(0, "flatrpc: failed to accept: %v", err) 66 continue 67 } 68 eg.Go(func() error { 69 connCtx, cancel := context.WithCancel(ctx) 70 defer cancel() 71 72 c := NewConn(conn) 73 // Closing the server does not automatically close all the connections. 74 go func() { 75 <-connCtx.Done() 76 c.Close() 77 }() 78 return handler(connCtx, c) 79 }) 80 } 81 return eg.Wait() 82 } 83 84 func (s *Serv) Close() error { 85 return s.ln.Close() 86 } 87 88 type Conn struct { 89 conn net.Conn 90 91 sendMu sync.Mutex 92 builder *flatbuffers.Builder 93 94 data []byte 95 hasData int 96 lastMsg int 97 } 98 99 func NewConn(conn net.Conn) *Conn { 100 return &Conn{ 101 conn: conn, 102 builder: flatbuffers.NewBuilder(0), 103 } 104 } 105 106 func (c *Conn) Close() error { 107 return c.conn.Close() 108 } 109 110 type sendMsg interface { 111 Pack(*flatbuffers.Builder) flatbuffers.UOffsetT 112 } 113 114 // Send sends an RPC message. 115 // The type T is supposed to be an "object API" type ending with T (e.g. ConnectRequestT). 116 // Sending can be done from multiple goroutines concurrently. 117 func Send[T sendMsg](c *Conn, msg T) error { 118 c.sendMu.Lock() 119 defer c.sendMu.Unlock() 120 off := msg.Pack(c.builder) 121 c.builder.FinishSizePrefixed(off) 122 data := c.builder.FinishedBytes() 123 _, err := c.conn.Write(data) 124 c.builder.Reset() 125 statSent.Add(len(data)) 126 if err != nil { 127 return fmt.Errorf("failed to send %T: %w", msg, err) 128 } 129 return nil 130 } 131 132 type RecvType[T any] interface { 133 UnPack() *T 134 flatbuffers.FlatBuffer 135 } 136 137 // Recv receives an RPC message. 138 // The type T is supposed to be a pointer to a normal flatbuffers type (not ending with T, e.g. *ConnectRequestRaw). 139 // Receiving should be done from a single goroutine, the received message is valid 140 // only until the next Recv call (messages share the same underlying receive buffer). 141 func Recv[Raw RecvType[T], T any](c *Conn) (res *T, err0 error) { 142 // First, discard the previous message. 143 // For simplicity we copy any data from the next message to the beginning of the buffer. 144 // Theoretically we could something more efficient, e.g. don't copy if we already 145 // have a full next message. 146 if c.hasData > c.lastMsg { 147 copy(c.data, c.data[c.lastMsg:c.hasData]) 148 } 149 c.hasData -= c.lastMsg 150 c.lastMsg = 0 151 const ( 152 sizePrefixSize = flatbuffers.SizeUint32 153 maxMessageSize = 64 << 20 154 ) 155 // Then, receive at least the size prefix (4 bytes). 156 // And then the full message, if we have not got it yet. 157 if err := c.recv(sizePrefixSize); err != nil { 158 return nil, fmt.Errorf("failed to recv %T: %w", (*T)(nil), err) 159 } 160 size := int(flatbuffers.GetSizePrefix(c.data, 0)) 161 if size > maxMessageSize { 162 return nil, fmt.Errorf("message %T has too large size %v", (*T)(nil), size) 163 } 164 c.lastMsg = sizePrefixSize + size 165 if err := c.recv(c.lastMsg); err != nil { 166 return nil, fmt.Errorf("failed to recv %T: %w", (*T)(nil), err) 167 } 168 return Parse[Raw](c.data[sizePrefixSize:c.lastMsg]) 169 } 170 171 // recv ensures that we have at least 'size' bytes received in c.data. 172 func (c *Conn) recv(size int) error { 173 need := size - c.hasData 174 if need <= 0 { 175 return nil 176 } 177 if grow := size - len(c.data) + c.hasData; grow > 0 { 178 c.data = slices.Grow(c.data, grow)[:len(c.data)+grow] 179 } 180 n, err := io.ReadAtLeast(c.conn, c.data[c.hasData:], need) 181 if err != nil { 182 return err 183 } 184 c.hasData += n 185 return nil 186 } 187 188 func Parse[Raw RecvType[T], T any](data []byte) (res *T, err0 error) { 189 defer func() { 190 if err := recover(); err != nil { 191 err0 = fmt.Errorf("%v", err) 192 } 193 }() 194 statRecv.Add(len(data)) 195 // This probably can be expressed w/o reflect as "new U" where U is *T, 196 // but I failed to express that as generic constraints. 197 var msg Raw 198 msg = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(Raw) 199 msg.Init(data, flatbuffers.GetUOffsetT(data)) 200 if err := verify(msg, len(data)); err != nil { 201 return nil, err 202 } 203 return msg.UnPack(), nil 204 } 205 206 func verify(raw any, rawSize int) error { 207 switch msg := raw.(type) { 208 case *ExecutorMessageRaw: 209 return verifyExecutorMessage(msg, rawSize) 210 } 211 return nil 212 } 213 214 func verifyExecutorMessage(raw *ExecutorMessageRaw, rawSize int) error { 215 // We receive the message into raw (non object API) type and carefully verify 216 // because the message from the test machine can be corrupted in all possible ways. 217 // Recovering from panics handles most corruptions (since flatbuffers does not use unsafe 218 // and panics on any OOB references). But it's still possible that UnPack may try to allocate 219 // unbounded amount of memory and crash with OOM. To prevent that we check that arrays have 220 // reasonable size. We don't need to check []byte/string b/c for them flatbuffers use 221 // Table.ByteVector which directly references the underlying byte slice and also panics 222 // if size is OOB. But we need to check all other arrays b/c for them flatbuffers will 223 // first do make([]T, size), filling that array later will panic, but it's already too late 224 // since the make will kill the process with OOM. 225 switch typ := raw.MsgType(); typ { 226 case ExecutorMessagesRawExecResult, 227 ExecutorMessagesRawExecuting, 228 ExecutorMessagesRawState: 229 default: 230 return fmt.Errorf("bad executor message type %v", typ) 231 } 232 var tab flatbuffers.Table 233 if !raw.Msg(&tab) { 234 return errors.New("received no message") 235 } 236 // Only ExecResult has arrays. 237 if raw.MsgType() == ExecutorMessagesRawExecResult { 238 var res ExecResultRaw 239 res.Init(tab.Bytes, tab.Pos) 240 return verifyExecResult(&res, rawSize) 241 } 242 return nil 243 } 244 245 func verifyExecResult(res *ExecResultRaw, rawSize int) error { 246 info := res.Info(nil) 247 if info == nil { 248 return nil 249 } 250 var tmp ComparisonRaw 251 // It's hard to impose good limit on each individual signal/cover/comps array, 252 // so instead we count total memory size for all calls and check that it's not 253 // larger than the total message size. 254 callSize := func(call *CallInfoRaw) int { 255 // Cap array size at 1G to prevent overflows during multiplication by size and addition. 256 const maxSize = 1 << 30 257 size := 0 258 if call.SignalLength() != 0 { 259 size += min(maxSize, call.SignalLength()) * int(unsafe.Sizeof(call.Signal(0))) 260 } 261 if call.CoverLength() != 0 { 262 size += min(maxSize, call.CoverLength()) * int(unsafe.Sizeof(call.Cover(0))) 263 } 264 if call.CompsLength() != 0 { 265 size += min(maxSize, call.CompsLength()) * int(unsafe.Sizeof(call.Comps(&tmp, 0))) 266 } 267 return size 268 } 269 size := 0 270 var call CallInfoRaw 271 for i := 0; i < info.CallsLength(); i++ { 272 if info.Calls(&call, i) { 273 size += callSize(&call) 274 } 275 } 276 for i := 0; i < info.ExtraRawLength(); i++ { 277 if info.ExtraRaw(&call, i) { 278 size += callSize(&call) 279 } 280 } 281 if info.Extra(&call) != nil { 282 size += callSize(&call) 283 } 284 if size > rawSize { 285 return fmt.Errorf("corrupted message: total size %v, size of elements %v", 286 rawSize, size) 287 } 288 return nil 289 }