github.com/aitjcize/Overlord@v0.0.0-20240314041920-104a804cf5e8/overlord/rpc.go (about)

     1  // Copyright 2015 The Chromium OS Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style license that can be
     3  // found in the LICENSE file.
     4  
     5  package overlord
     6  
     7  import (
     8  	"encoding/json"
     9  	"errors"
    10  	"log"
    11  	"net"
    12  	"strings"
    13  	"time"
    14  
    15  	uuid "github.com/satori/go.uuid"
    16  )
    17  
    18  const (
    19  	debugRPC              = false
    20  	messageSeparator      = "\r\n"
    21  	bufferSize            = 8192
    22  	requestTimeoutSeconds = 60              // Number of seconds before request timeouts
    23  	timeoutCheckInterval  = 3 * time.Second // The time between checking for timeout
    24  )
    25  
    26  // Message is the interface which defines a sendable message.
    27  type Message interface {
    28  	Marshal() ([]byte, error)
    29  }
    30  
    31  // Request Object.
    32  // Implements the Message interface.
    33  // If Timeout < 0, then the response can be omitted.
    34  type Request struct {
    35  	Rid     string          `json:"rid"`
    36  	Timeout int64           `json:"timeout"`
    37  	Name    string          `json:"name"`
    38  	Params  json.RawMessage `json:"params"`
    39  }
    40  
    41  // NewRequest creats a new Request object.
    42  // name is the name of the request.
    43  // params is map between string and any other JSON-serializable data structure.
    44  func NewRequest(name string, params map[string]interface{}) *Request {
    45  	req := &Request{
    46  		Rid:     uuid.NewV4().String(),
    47  		Timeout: requestTimeoutSeconds,
    48  		Name:    name,
    49  	}
    50  	if targs, err := json.Marshal(params); err != nil {
    51  		panic(err)
    52  	} else {
    53  		req.Params = json.RawMessage(targs)
    54  	}
    55  	return req
    56  }
    57  
    58  // SetTimeout sets the timeout of request.
    59  // The default timeout is is defined in requestTimeoutSeconds.
    60  func (r *Request) SetTimeout(timeout int64) {
    61  	r.Timeout = timeout
    62  }
    63  
    64  // Marshal mashels the Request.
    65  func (r *Request) Marshal() ([]byte, error) {
    66  	return json.Marshal(r)
    67  }
    68  
    69  // Response Object.
    70  // Implements the Message interface.
    71  type Response struct {
    72  	Rid      string          `json:"rid"`
    73  	Response string          `json:"response"`
    74  	Params   json.RawMessage `json:"params"`
    75  }
    76  
    77  // NewResponse creates a new Response object.
    78  // rid is the request ID of the request this response is intended for.
    79  // response is the response status text.
    80  // params is map between string and any other JSON-serializable data structure.
    81  func NewResponse(rid, response string, params map[string]interface{}) *Response {
    82  	res := &Response{
    83  		Rid:      rid,
    84  		Response: response,
    85  	}
    86  	if targs, err := json.Marshal(params); err != nil {
    87  		panic(err)
    88  	} else {
    89  		res.Params = json.RawMessage(targs)
    90  	}
    91  	return res
    92  }
    93  
    94  // Marshal marshals the Response.
    95  func (r *Response) Marshal() ([]byte, error) {
    96  	return json.Marshal(r)
    97  }
    98  
    99  // ResponseHandler is the function type of the response handler.
   100  // if res is nil, means that the response timeout.
   101  type ResponseHandler func(res *Response) error
   102  
   103  // Responder is The structure that stores the response handler information.
   104  type Responder struct {
   105  	RequestTime int64           // Time of request
   106  	Timeout     int64           // Timeout in seconds
   107  	Handler     ResponseHandler // The corresponding request handler
   108  }
   109  
   110  // RPCCore is the core implementation of the TCP-based 2-way RPC protocol.
   111  type RPCCore struct {
   112  	Conn       net.Conn             // handle to the TCP connection
   113  	ReadBuffer string               // internal read buffer
   114  	responders map[string]Responder // response handlers
   115  
   116  	readChan    chan []byte
   117  	readErrChan chan error
   118  }
   119  
   120  // NewRPCCore creates the RPCCore object.
   121  func NewRPCCore(conn net.Conn) *RPCCore {
   122  	return &RPCCore{
   123  		Conn:        conn,
   124  		responders:  make(map[string]Responder),
   125  		readChan:    make(chan []byte),
   126  		readErrChan: make(chan error),
   127  	}
   128  }
   129  
   130  // SendMessage sends a message.
   131  func (rpc *RPCCore) SendMessage(msg Message) error {
   132  	if rpc.Conn == nil {
   133  		return errors.New("SendMessage failed, connection not established")
   134  	}
   135  	var err error
   136  	var msgBytes []byte
   137  
   138  	if msgBytes, err = msg.Marshal(); err == nil {
   139  		if debugRPC {
   140  			log.Printf("-----> %s\n", string(msgBytes))
   141  		}
   142  		_, err = rpc.Conn.Write(append(msgBytes, []byte(messageSeparator)...))
   143  	}
   144  	return err
   145  }
   146  
   147  // SendRequest sends a Request.
   148  func (rpc *RPCCore) SendRequest(req *Request, handler ResponseHandler) error {
   149  	err := rpc.SendMessage(req)
   150  	if err == nil && req.Timeout >= 0 {
   151  		res := Responder{time.Now().Unix(), req.Timeout, handler}
   152  		rpc.responders[req.Rid] = res
   153  	}
   154  	return err
   155  }
   156  
   157  // SendResponse sends a Response.
   158  func (rpc *RPCCore) SendResponse(res *Response) error {
   159  	return rpc.SendMessage(res)
   160  }
   161  
   162  func (rpc *RPCCore) handleResponse(res *Response) error {
   163  	defer delete(rpc.responders, res.Rid)
   164  
   165  	if responder, ok := rpc.responders[res.Rid]; ok {
   166  		if responder.Handler != nil {
   167  			if err := responder.Handler(res); err != nil {
   168  				return err
   169  			}
   170  		}
   171  	} else {
   172  		return errors.New("Received unsolicited response, ignored")
   173  	}
   174  	return nil
   175  }
   176  
   177  // SpawnReaderRoutine spawnes a goroutine that actively read from the socket.
   178  // This function returns two channels. The first one is the channel that
   179  // send the content from the socket, and the second channel send an error
   180  // object if there is one.
   181  func (rpc *RPCCore) SpawnReaderRoutine() (chan []byte, chan error) {
   182  
   183  	go func() {
   184  		for {
   185  			buf := make([]byte, bufferSize)
   186  			n, err := rpc.Conn.Read(buf)
   187  			if err != nil {
   188  				rpc.readErrChan <- err
   189  				return
   190  			}
   191  			rpc.readChan <- buf[:n]
   192  		}
   193  	}()
   194  
   195  	return rpc.readChan, rpc.readErrChan
   196  }
   197  
   198  // StopConn stops the connection and terminates the reader goroutine.
   199  func (rpc *RPCCore) StopConn() {
   200  	rpc.Conn.Close()
   201  
   202  	time.Sleep(200 * time.Millisecond)
   203  
   204  	// Drain rpc.readChan and rpc.readErrChan so that the reader goroutine can
   205  	// exit.
   206  	for {
   207  		select {
   208  		case <-rpc.readChan:
   209  		case <-rpc.readErrChan:
   210  		default:
   211  			return
   212  		}
   213  	}
   214  }
   215  
   216  // ParseMessage parses a single JSON string into a Message object.
   217  func (rpc *RPCCore) ParseMessage(msgJSON string) (Message, error) {
   218  	var req Request
   219  	var res Response
   220  
   221  	err := json.Unmarshal([]byte(msgJSON), &req)
   222  	if err != nil || len(req.Name) == 0 {
   223  		err = json.Unmarshal([]byte(msgJSON), &res)
   224  		if err != nil {
   225  			err = errors.New("mal-formed JSON request, ignored")
   226  		} else {
   227  			return &res, nil
   228  		}
   229  	} else {
   230  		return &req, nil
   231  	}
   232  
   233  	return nil, err
   234  }
   235  
   236  // ParseRequests parses a buffer from SpawnReaderRoutine into Request objects.
   237  // The response message is automatically handled by the RPCCore itrpc by
   238  // invoking the corresponding response handler.
   239  func (rpc *RPCCore) ParseRequests(buffer string, single bool) []*Request {
   240  	var reqs []*Request
   241  	var msgsJSON []string
   242  
   243  	rpc.ReadBuffer += buffer
   244  	if single {
   245  		idx := strings.Index(rpc.ReadBuffer, messageSeparator)
   246  		if idx == -1 {
   247  			return nil
   248  		}
   249  		msgsJSON = []string{rpc.ReadBuffer[:idx]}
   250  		rpc.ReadBuffer = rpc.ReadBuffer[idx+2:]
   251  	} else {
   252  		msgs := strings.Split(rpc.ReadBuffer, messageSeparator)
   253  		if len(msgs) == 1 {
   254  			return nil
   255  		}
   256  		rpc.ReadBuffer = msgs[len(msgs)-1]
   257  		msgsJSON = msgs[:len(msgs)-1]
   258  	}
   259  
   260  	for _, msgJSON := range msgsJSON {
   261  		if debugRPC {
   262  			log.Printf("<----- " + msgJSON)
   263  		}
   264  		if msg, err := rpc.ParseMessage(msgJSON); err != nil {
   265  			log.Printf("Message parse failed: %s\n", err)
   266  			continue
   267  		} else {
   268  			switch m := msg.(type) {
   269  			case *Request:
   270  				reqs = append(reqs, m)
   271  			case *Response:
   272  				err := rpc.handleResponse(m)
   273  				if err != nil {
   274  					log.Printf("Response error: %s\n", err)
   275  				}
   276  			}
   277  		}
   278  	}
   279  	return reqs
   280  }
   281  
   282  // ScanForTimeoutRequests scans for timeout requests.
   283  func (rpc *RPCCore) ScanForTimeoutRequests() error {
   284  	for rid, res := range rpc.responders {
   285  		if time.Now().Unix()-res.RequestTime > res.Timeout {
   286  			if res.Handler != nil {
   287  				if err := res.Handler(nil); err != nil {
   288  					delete(rpc.responders, rid)
   289  					return err
   290  				}
   291  			} else {
   292  				log.Printf("Request %s timeout\n", rid)
   293  			}
   294  			delete(rpc.responders, rid)
   295  		}
   296  	}
   297  	return nil
   298  }
   299  
   300  // ClearRequests clear all the requests.
   301  func (rpc *RPCCore) ClearRequests() {
   302  	rpc.responders = make(map[string]Responder)
   303  }