github.com/google/syzkaller@v0.0.0-20251211124644-a066d2bc4b02/pkg/flatrpc/conn_test.go (about) 1 // Copyright 2024 syzkaller project authors. All rights reserved. 2 // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. 3 4 package flatrpc 5 6 import ( 7 "context" 8 "fmt" 9 "net" 10 "os" 11 "reflect" 12 "runtime/debug" 13 "sync" 14 "syscall" 15 "testing" 16 "time" 17 18 flatbuffers "github.com/google/flatbuffers/go" 19 "github.com/stretchr/testify/assert" 20 ) 21 22 func TestConn(t *testing.T) { 23 connectHello := &ConnectHello{ 24 Cookie: 1, 25 } 26 connectReq := &ConnectRequest{ 27 Cookie: 73856093, 28 Id: 1, 29 Arch: "arch", 30 GitRevision: "rev1", 31 SyzRevision: "rev2", 32 } 33 connectReply := &ConnectReply{ 34 LeakFrames: []string{"foo", "bar"}, 35 RaceFrames: []string{"bar", "baz"}, 36 Features: FeatureCoverage | FeatureLeak, 37 Files: []string{"file1"}, 38 } 39 executorMsg := &ExecutorMessage{ 40 Msg: &ExecutorMessages{ 41 Type: ExecutorMessagesRawExecuting, 42 Value: &ExecutingMessage{ 43 Id: 1, 44 ProcId: 2, 45 Try: 3, 46 }, 47 }, 48 } 49 50 serv, err := Listen(":0") 51 if err != nil { 52 t.Fatal(err) 53 } 54 55 done := make(chan error) 56 go func() { 57 done <- serv.Serve(context.Background(), 58 func(_ context.Context, c *Conn) error { 59 if err := Send(c, connectHello); err != nil { 60 return err 61 } 62 connectReqGot, err := Recv[*ConnectRequestRaw](c) 63 if err != nil { 64 return err 65 } 66 if !reflect.DeepEqual(connectReq, connectReqGot) { 67 return fmt.Errorf("connectReq != connectReqGot") 68 } 69 70 if err := Send(c, connectReply); err != nil { 71 return err 72 } 73 74 for i := 0; i < 10; i++ { 75 got, err := Recv[*ExecutorMessageRaw](c) 76 if err != nil { 77 return nil 78 } 79 if !reflect.DeepEqual(executorMsg, got) { 80 return fmt.Errorf("executorMsg !=got") 81 } 82 } 83 return nil 84 }) 85 }() 86 c := dial(t, serv.Addr.String()) 87 defer c.Close() 88 89 connectHelloGot, err := Recv[*ConnectHelloRaw](c) 90 if err != nil { 91 t.Fatal(err) 92 } 93 assert.Equal(t, connectHello, connectHelloGot) 94 95 if err := Send(c, connectReq); err != nil { 96 t.Fatal(err) 97 } 98 99 connectReplyGot, err := Recv[*ConnectReplyRaw](c) 100 if err != nil { 101 t.Fatal(err) 102 } 103 assert.Equal(t, connectReply, connectReplyGot) 104 105 for i := 0; i < 10; i++ { 106 if err := Send(c, executorMsg); err != nil { 107 t.Fatal(err) 108 } 109 } 110 111 serv.Close() 112 if err := <-done; err != nil { 113 t.Fatal(err) 114 } 115 } 116 117 func BenchmarkConn(b *testing.B) { 118 connectHello := &ConnectHello{ 119 Cookie: 1, 120 } 121 connectReq := &ConnectRequest{ 122 Cookie: 73856093, 123 Id: 1, 124 Arch: "arch", 125 GitRevision: "rev1", 126 SyzRevision: "rev2", 127 } 128 connectReply := &ConnectReply{ 129 LeakFrames: []string{"foo", "bar"}, 130 RaceFrames: []string{"bar", "baz"}, 131 Features: FeatureCoverage | FeatureLeak, 132 Files: []string{"file1"}, 133 } 134 135 serv, err := Listen(":0") 136 if err != nil { 137 b.Fatal(err) 138 } 139 done := make(chan error) 140 141 go func() { 142 done <- serv.Serve(context.Background(), 143 func(_ context.Context, c *Conn) error { 144 for i := 0; i < b.N; i++ { 145 if err := Send(c, connectHello); err != nil { 146 return err 147 } 148 149 _, err = Recv[*ConnectRequestRaw](c) 150 if err != nil { 151 return err 152 } 153 if err := Send(c, connectReply); err != nil { 154 return err 155 } 156 } 157 return nil 158 }) 159 }() 160 161 c := dial(b, serv.Addr.String()) 162 defer c.Close() 163 164 b.ReportAllocs() 165 b.ResetTimer() 166 for i := 0; i < b.N; i++ { 167 _, err := Recv[*ConnectHelloRaw](c) 168 if err != nil { 169 b.Fatal(err) 170 } 171 if err := Send(c, connectReq); err != nil { 172 b.Fatal(err) 173 } 174 _, err = Recv[*ConnectReplyRaw](c) 175 if err != nil { 176 b.Fatal(err) 177 } 178 } 179 180 serv.Close() 181 if err := <-done; err != nil { 182 b.Fatal(err) 183 } 184 } 185 186 func dial(t testing.TB, addr string) *Conn { 187 conn, err := net.DialTimeout("tcp", addr, time.Minute) 188 if err != nil { 189 t.Fatal(err) 190 } 191 return NewConn(conn) 192 } 193 194 var memoryLimitOnce sync.Once 195 196 func FuzzRecv(f *testing.F) { 197 msg := &ExecutorMessage{ 198 Msg: &ExecutorMessages{ 199 Type: ExecutorMessagesRawExecResult, 200 Value: &ExecResult{ 201 Id: 1, 202 Output: []byte("aaa"), 203 Error: "bbb", 204 Info: &ProgInfo{ 205 ExtraRaw: []*CallInfo{ 206 { 207 Signal: []uint64{1, 2}, 208 }, 209 }, 210 }, 211 }, 212 }, 213 } 214 builder := flatbuffers.NewBuilder(0) 215 builder.FinishSizePrefixed(msg.Pack(builder)) 216 f.Add(builder.FinishedBytes()) 217 f.Fuzz(func(t *testing.T, data []byte) { 218 memoryLimitOnce.Do(func() { 219 debug.SetMemoryLimit(64 << 20) 220 }) 221 if len(data) > 1<<10 { 222 t.Skip() 223 } 224 fds, err := syscall.Socketpair(syscall.AF_LOCAL, syscall.SOCK_STREAM, 0) 225 if err != nil { 226 t.Fatal(err) 227 } 228 w := os.NewFile(uintptr(fds[0]), "") 229 r := os.NewFile(uintptr(fds[1]), "") 230 defer w.Close() 231 defer r.Close() 232 if _, err := w.Write(data); err != nil { 233 t.Fatal(err) 234 } 235 w.Close() 236 n, err := net.FileConn(r) 237 if err != nil { 238 t.Fatal(err) 239 } 240 c := NewConn(n) 241 for { 242 _, err := Recv[*ExecutorMessageRaw](c) 243 if err != nil { 244 break 245 } 246 } 247 }) 248 }