github.com/google/syzkaller@v0.0.0-20240517125934-c0f1611a36d6/pkg/rpctype/rpc.go (about)

     1  // Copyright 2017 syzkaller project authors. All rights reserved.
     2  // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
     3  
     4  package rpctype
     5  
     6  import (
     7  	"compress/flate"
     8  	"fmt"
     9  	"io"
    10  	"net"
    11  	"net/rpc"
    12  	"os"
    13  	"time"
    14  
    15  	"github.com/google/syzkaller/pkg/log"
    16  )
    17  
    18  type RPCServer struct {
    19  	ln net.Listener
    20  	s  *rpc.Server
    21  }
    22  
    23  func NewRPCServer(addr, name string, receiver interface{}) (*RPCServer, error) {
    24  	ln, err := net.Listen("tcp", addr)
    25  	if err != nil {
    26  		return nil, fmt.Errorf("failed to listen on %v: %w", addr, err)
    27  	}
    28  	s := rpc.NewServer()
    29  	if err := s.RegisterName(name, receiver); err != nil {
    30  		return nil, err
    31  	}
    32  	serv := &RPCServer{
    33  		ln: ln,
    34  		s:  s,
    35  	}
    36  	return serv, nil
    37  }
    38  
    39  func (serv *RPCServer) Serve() {
    40  	for {
    41  		conn, err := serv.ln.Accept()
    42  		if err != nil {
    43  			log.Logf(0, "failed to accept an rpc connection: %v", err)
    44  			continue
    45  		}
    46  		setupKeepAlive(conn, time.Minute)
    47  		go serv.s.ServeConn(newFlateConn(conn))
    48  	}
    49  }
    50  
    51  func (serv *RPCServer) Addr() net.Addr {
    52  	return serv.ln.Addr()
    53  }
    54  
    55  type RPCClient struct {
    56  	conn net.Conn
    57  	c    *rpc.Client
    58  }
    59  
    60  func NewRPCClient(addr string) (*RPCClient, error) {
    61  	var conn net.Conn
    62  	var err error
    63  	if addr == "stdin" {
    64  		// This is used by vm/gvisor which passes us a unix socket connection in stdin.
    65  		// TODO: remove this once we switch to flatrpc for target communication.
    66  		conn, err = net.FileConn(os.Stdin)
    67  		if err != nil {
    68  			return nil, err
    69  		}
    70  	} else {
    71  		conn, err = net.DialTimeout("tcp", addr, 3*time.Minute)
    72  		if err != nil {
    73  			return nil, err
    74  		}
    75  		setupKeepAlive(conn, time.Minute)
    76  	}
    77  	cli := &RPCClient{
    78  		conn: conn,
    79  		c:    rpc.NewClient(newFlateConn(conn)),
    80  	}
    81  	return cli, nil
    82  }
    83  
    84  func (cli *RPCClient) Call(method string, args, reply interface{}) error {
    85  	// Note: SetDeadline is not implemented on fuchsia, so don't fail on error.
    86  	cli.conn.SetDeadline(time.Now().Add(10 * time.Minute))
    87  	defer cli.conn.SetDeadline(time.Time{})
    88  	return cli.c.Call(method, args, reply)
    89  }
    90  
    91  func (cli *RPCClient) AsyncCall(method string, args interface{}) {
    92  	cli.c.Go(method, args, nil, nil)
    93  }
    94  
    95  func (cli *RPCClient) Close() {
    96  	cli.c.Close()
    97  }
    98  
    99  func setupKeepAlive(conn net.Conn, keepAlive time.Duration) {
   100  	conn.(*net.TCPConn).SetKeepAlive(true)
   101  	conn.(*net.TCPConn).SetKeepAlivePeriod(keepAlive)
   102  }
   103  
   104  // flateConn wraps net.Conn in flate.Reader/Writer for compressed traffic.
   105  type flateConn struct {
   106  	r io.ReadCloser
   107  	w *flate.Writer
   108  	c io.Closer
   109  }
   110  
   111  func newFlateConn(conn io.ReadWriteCloser) io.ReadWriteCloser {
   112  	w, err := flate.NewWriter(conn, 9)
   113  	if err != nil {
   114  		panic(err)
   115  	}
   116  	return &flateConn{
   117  		r: flate.NewReader(conn),
   118  		w: w,
   119  		c: conn,
   120  	}
   121  }
   122  
   123  func (fc *flateConn) Read(data []byte) (int, error) {
   124  	return fc.r.Read(data)
   125  }
   126  
   127  func (fc *flateConn) Write(data []byte) (int, error) {
   128  	n, err := fc.w.Write(data)
   129  	if err != nil {
   130  		return n, err
   131  	}
   132  	if err := fc.w.Flush(); err != nil {
   133  		return n, err
   134  	}
   135  	return n, nil
   136  }
   137  
   138  func (fc *flateConn) Close() error {
   139  	var err0 error
   140  	if err := fc.r.Close(); err != nil {
   141  		err0 = err
   142  	}
   143  	if err := fc.w.Close(); err != nil {
   144  		err0 = err
   145  	}
   146  	if err := fc.c.Close(); err != nil {
   147  		err0 = err
   148  	}
   149  	return err0
   150  }