github.com/google/syzkaller@v0.0.0-20240517125934-c0f1611a36d6/pkg/rpctype/rpc.go (about) 1 // Copyright 2017 syzkaller project authors. All rights reserved. 2 // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. 3 4 package rpctype 5 6 import ( 7 "compress/flate" 8 "fmt" 9 "io" 10 "net" 11 "net/rpc" 12 "os" 13 "time" 14 15 "github.com/google/syzkaller/pkg/log" 16 ) 17 18 type RPCServer struct { 19 ln net.Listener 20 s *rpc.Server 21 } 22 23 func NewRPCServer(addr, name string, receiver interface{}) (*RPCServer, error) { 24 ln, err := net.Listen("tcp", addr) 25 if err != nil { 26 return nil, fmt.Errorf("failed to listen on %v: %w", addr, err) 27 } 28 s := rpc.NewServer() 29 if err := s.RegisterName(name, receiver); err != nil { 30 return nil, err 31 } 32 serv := &RPCServer{ 33 ln: ln, 34 s: s, 35 } 36 return serv, nil 37 } 38 39 func (serv *RPCServer) Serve() { 40 for { 41 conn, err := serv.ln.Accept() 42 if err != nil { 43 log.Logf(0, "failed to accept an rpc connection: %v", err) 44 continue 45 } 46 setupKeepAlive(conn, time.Minute) 47 go serv.s.ServeConn(newFlateConn(conn)) 48 } 49 } 50 51 func (serv *RPCServer) Addr() net.Addr { 52 return serv.ln.Addr() 53 } 54 55 type RPCClient struct { 56 conn net.Conn 57 c *rpc.Client 58 } 59 60 func NewRPCClient(addr string) (*RPCClient, error) { 61 var conn net.Conn 62 var err error 63 if addr == "stdin" { 64 // This is used by vm/gvisor which passes us a unix socket connection in stdin. 65 // TODO: remove this once we switch to flatrpc for target communication. 66 conn, err = net.FileConn(os.Stdin) 67 if err != nil { 68 return nil, err 69 } 70 } else { 71 conn, err = net.DialTimeout("tcp", addr, 3*time.Minute) 72 if err != nil { 73 return nil, err 74 } 75 setupKeepAlive(conn, time.Minute) 76 } 77 cli := &RPCClient{ 78 conn: conn, 79 c: rpc.NewClient(newFlateConn(conn)), 80 } 81 return cli, nil 82 } 83 84 func (cli *RPCClient) Call(method string, args, reply interface{}) error { 85 // Note: SetDeadline is not implemented on fuchsia, so don't fail on error. 86 cli.conn.SetDeadline(time.Now().Add(10 * time.Minute)) 87 defer cli.conn.SetDeadline(time.Time{}) 88 return cli.c.Call(method, args, reply) 89 } 90 91 func (cli *RPCClient) AsyncCall(method string, args interface{}) { 92 cli.c.Go(method, args, nil, nil) 93 } 94 95 func (cli *RPCClient) Close() { 96 cli.c.Close() 97 } 98 99 func setupKeepAlive(conn net.Conn, keepAlive time.Duration) { 100 conn.(*net.TCPConn).SetKeepAlive(true) 101 conn.(*net.TCPConn).SetKeepAlivePeriod(keepAlive) 102 } 103 104 // flateConn wraps net.Conn in flate.Reader/Writer for compressed traffic. 105 type flateConn struct { 106 r io.ReadCloser 107 w *flate.Writer 108 c io.Closer 109 } 110 111 func newFlateConn(conn io.ReadWriteCloser) io.ReadWriteCloser { 112 w, err := flate.NewWriter(conn, 9) 113 if err != nil { 114 panic(err) 115 } 116 return &flateConn{ 117 r: flate.NewReader(conn), 118 w: w, 119 c: conn, 120 } 121 } 122 123 func (fc *flateConn) Read(data []byte) (int, error) { 124 return fc.r.Read(data) 125 } 126 127 func (fc *flateConn) Write(data []byte) (int, error) { 128 n, err := fc.w.Write(data) 129 if err != nil { 130 return n, err 131 } 132 if err := fc.w.Flush(); err != nil { 133 return n, err 134 } 135 return n, nil 136 } 137 138 func (fc *flateConn) Close() error { 139 var err0 error 140 if err := fc.r.Close(); err != nil { 141 err0 = err 142 } 143 if err := fc.w.Close(); err != nil { 144 err0 = err 145 } 146 if err := fc.c.Close(); err != nil { 147 err0 = err 148 } 149 return err0 150 }