github.com/google/syzkaller@v0.0.0-20251211124644-a066d2bc4b02/pkg/flatrpc/conn_test.go (about)

     1  // Copyright 2024 syzkaller project authors. All rights reserved.
     2  // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
     3  
     4  package flatrpc
     5  
     6  import (
     7  	"context"
     8  	"fmt"
     9  	"net"
    10  	"os"
    11  	"reflect"
    12  	"runtime/debug"
    13  	"sync"
    14  	"syscall"
    15  	"testing"
    16  	"time"
    17  
    18  	flatbuffers "github.com/google/flatbuffers/go"
    19  	"github.com/stretchr/testify/assert"
    20  )
    21  
    22  func TestConn(t *testing.T) {
    23  	connectHello := &ConnectHello{
    24  		Cookie: 1,
    25  	}
    26  	connectReq := &ConnectRequest{
    27  		Cookie:      73856093,
    28  		Id:          1,
    29  		Arch:        "arch",
    30  		GitRevision: "rev1",
    31  		SyzRevision: "rev2",
    32  	}
    33  	connectReply := &ConnectReply{
    34  		LeakFrames: []string{"foo", "bar"},
    35  		RaceFrames: []string{"bar", "baz"},
    36  		Features:   FeatureCoverage | FeatureLeak,
    37  		Files:      []string{"file1"},
    38  	}
    39  	executorMsg := &ExecutorMessage{
    40  		Msg: &ExecutorMessages{
    41  			Type: ExecutorMessagesRawExecuting,
    42  			Value: &ExecutingMessage{
    43  				Id:     1,
    44  				ProcId: 2,
    45  				Try:    3,
    46  			},
    47  		},
    48  	}
    49  
    50  	serv, err := Listen(":0")
    51  	if err != nil {
    52  		t.Fatal(err)
    53  	}
    54  
    55  	done := make(chan error)
    56  	go func() {
    57  		done <- serv.Serve(context.Background(),
    58  			func(_ context.Context, c *Conn) error {
    59  				if err := Send(c, connectHello); err != nil {
    60  					return err
    61  				}
    62  				connectReqGot, err := Recv[*ConnectRequestRaw](c)
    63  				if err != nil {
    64  					return err
    65  				}
    66  				if !reflect.DeepEqual(connectReq, connectReqGot) {
    67  					return fmt.Errorf("connectReq != connectReqGot")
    68  				}
    69  
    70  				if err := Send(c, connectReply); err != nil {
    71  					return err
    72  				}
    73  
    74  				for i := 0; i < 10; i++ {
    75  					got, err := Recv[*ExecutorMessageRaw](c)
    76  					if err != nil {
    77  						return nil
    78  					}
    79  					if !reflect.DeepEqual(executorMsg, got) {
    80  						return fmt.Errorf("executorMsg !=got")
    81  					}
    82  				}
    83  				return nil
    84  			})
    85  	}()
    86  	c := dial(t, serv.Addr.String())
    87  	defer c.Close()
    88  
    89  	connectHelloGot, err := Recv[*ConnectHelloRaw](c)
    90  	if err != nil {
    91  		t.Fatal(err)
    92  	}
    93  	assert.Equal(t, connectHello, connectHelloGot)
    94  
    95  	if err := Send(c, connectReq); err != nil {
    96  		t.Fatal(err)
    97  	}
    98  
    99  	connectReplyGot, err := Recv[*ConnectReplyRaw](c)
   100  	if err != nil {
   101  		t.Fatal(err)
   102  	}
   103  	assert.Equal(t, connectReply, connectReplyGot)
   104  
   105  	for i := 0; i < 10; i++ {
   106  		if err := Send(c, executorMsg); err != nil {
   107  			t.Fatal(err)
   108  		}
   109  	}
   110  
   111  	serv.Close()
   112  	if err := <-done; err != nil {
   113  		t.Fatal(err)
   114  	}
   115  }
   116  
   117  func BenchmarkConn(b *testing.B) {
   118  	connectHello := &ConnectHello{
   119  		Cookie: 1,
   120  	}
   121  	connectReq := &ConnectRequest{
   122  		Cookie:      73856093,
   123  		Id:          1,
   124  		Arch:        "arch",
   125  		GitRevision: "rev1",
   126  		SyzRevision: "rev2",
   127  	}
   128  	connectReply := &ConnectReply{
   129  		LeakFrames: []string{"foo", "bar"},
   130  		RaceFrames: []string{"bar", "baz"},
   131  		Features:   FeatureCoverage | FeatureLeak,
   132  		Files:      []string{"file1"},
   133  	}
   134  
   135  	serv, err := Listen(":0")
   136  	if err != nil {
   137  		b.Fatal(err)
   138  	}
   139  	done := make(chan error)
   140  
   141  	go func() {
   142  		done <- serv.Serve(context.Background(),
   143  			func(_ context.Context, c *Conn) error {
   144  				for i := 0; i < b.N; i++ {
   145  					if err := Send(c, connectHello); err != nil {
   146  						return err
   147  					}
   148  
   149  					_, err = Recv[*ConnectRequestRaw](c)
   150  					if err != nil {
   151  						return err
   152  					}
   153  					if err := Send(c, connectReply); err != nil {
   154  						return err
   155  					}
   156  				}
   157  				return nil
   158  			})
   159  	}()
   160  
   161  	c := dial(b, serv.Addr.String())
   162  	defer c.Close()
   163  
   164  	b.ReportAllocs()
   165  	b.ResetTimer()
   166  	for i := 0; i < b.N; i++ {
   167  		_, err := Recv[*ConnectHelloRaw](c)
   168  		if err != nil {
   169  			b.Fatal(err)
   170  		}
   171  		if err := Send(c, connectReq); err != nil {
   172  			b.Fatal(err)
   173  		}
   174  		_, err = Recv[*ConnectReplyRaw](c)
   175  		if err != nil {
   176  			b.Fatal(err)
   177  		}
   178  	}
   179  
   180  	serv.Close()
   181  	if err := <-done; err != nil {
   182  		b.Fatal(err)
   183  	}
   184  }
   185  
   186  func dial(t testing.TB, addr string) *Conn {
   187  	conn, err := net.DialTimeout("tcp", addr, time.Minute)
   188  	if err != nil {
   189  		t.Fatal(err)
   190  	}
   191  	return NewConn(conn)
   192  }
   193  
   194  var memoryLimitOnce sync.Once
   195  
   196  func FuzzRecv(f *testing.F) {
   197  	msg := &ExecutorMessage{
   198  		Msg: &ExecutorMessages{
   199  			Type: ExecutorMessagesRawExecResult,
   200  			Value: &ExecResult{
   201  				Id:     1,
   202  				Output: []byte("aaa"),
   203  				Error:  "bbb",
   204  				Info: &ProgInfo{
   205  					ExtraRaw: []*CallInfo{
   206  						{
   207  							Signal: []uint64{1, 2},
   208  						},
   209  					},
   210  				},
   211  			},
   212  		},
   213  	}
   214  	builder := flatbuffers.NewBuilder(0)
   215  	builder.FinishSizePrefixed(msg.Pack(builder))
   216  	f.Add(builder.FinishedBytes())
   217  	f.Fuzz(func(t *testing.T, data []byte) {
   218  		memoryLimitOnce.Do(func() {
   219  			debug.SetMemoryLimit(64 << 20)
   220  		})
   221  		if len(data) > 1<<10 {
   222  			t.Skip()
   223  		}
   224  		fds, err := syscall.Socketpair(syscall.AF_LOCAL, syscall.SOCK_STREAM, 0)
   225  		if err != nil {
   226  			t.Fatal(err)
   227  		}
   228  		w := os.NewFile(uintptr(fds[0]), "")
   229  		r := os.NewFile(uintptr(fds[1]), "")
   230  		defer w.Close()
   231  		defer r.Close()
   232  		if _, err := w.Write(data); err != nil {
   233  			t.Fatal(err)
   234  		}
   235  		w.Close()
   236  		n, err := net.FileConn(r)
   237  		if err != nil {
   238  			t.Fatal(err)
   239  		}
   240  		c := NewConn(n)
   241  		for {
   242  			_, err := Recv[*ExecutorMessageRaw](c)
   243  			if err != nil {
   244  				break
   245  			}
   246  		}
   247  	})
   248  }