github.com/powerman/golang-tools@v0.1.11-0.20220410185822-5ad214d8d803/internal/jsonrpc2/conn.go (about)

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package jsonrpc2
     6  
     7  import (
     8  	"context"
     9  	"encoding/json"
    10  	"fmt"
    11  	"sync"
    12  	"sync/atomic"
    13  
    14  	"github.com/powerman/golang-tools/internal/event"
    15  	"github.com/powerman/golang-tools/internal/event/label"
    16  	"github.com/powerman/golang-tools/internal/lsp/debug/tag"
    17  )
    18  
    19  // Conn is the common interface to jsonrpc clients and servers.
    20  // Conn is bidirectional; it does not have a designated server or client end.
    21  // It manages the jsonrpc2 protocol, connecting responses back to their calls.
    22  type Conn interface {
    23  	// Call invokes the target method and waits for a response.
    24  	// The params will be marshaled to JSON before sending over the wire, and will
    25  	// be handed to the method invoked.
    26  	// The response will be unmarshaled from JSON into the result.
    27  	// The id returned will be unique from this connection, and can be used for
    28  	// logging or tracking.
    29  	Call(ctx context.Context, method string, params, result interface{}) (ID, error)
    30  
    31  	// Notify invokes the target method but does not wait for a response.
    32  	// The params will be marshaled to JSON before sending over the wire, and will
    33  	// be handed to the method invoked.
    34  	Notify(ctx context.Context, method string, params interface{}) error
    35  
    36  	// Go starts a goroutine to handle the connection.
    37  	// It must be called exactly once for each Conn.
    38  	// It returns immediately.
    39  	// You must block on Done() to wait for the connection to shut down.
    40  	// This is a temporary measure, this should be started automatically in the
    41  	// future.
    42  	Go(ctx context.Context, handler Handler)
    43  
    44  	// Close closes the connection and it's underlying stream.
    45  	// It does not wait for the close to complete, use the Done() channel for
    46  	// that.
    47  	Close() error
    48  
    49  	// Done returns a channel that will be closed when the processing goroutine
    50  	// has terminated, which will happen if Close() is called or an underlying
    51  	// stream is closed.
    52  	Done() <-chan struct{}
    53  
    54  	// Err returns an error if there was one from within the processing goroutine.
    55  	// If err returns non nil, the connection will be already closed or closing.
    56  	Err() error
    57  }
    58  
    59  type conn struct {
    60  	seq       int64      // must only be accessed using atomic operations
    61  	writeMu   sync.Mutex // protects writes to the stream
    62  	stream    Stream
    63  	pendingMu sync.Mutex // protects the pending map
    64  	pending   map[ID]chan *Response
    65  
    66  	done chan struct{}
    67  	err  atomic.Value
    68  }
    69  
    70  // NewConn creates a new connection object around the supplied stream.
    71  func NewConn(s Stream) Conn {
    72  	conn := &conn{
    73  		stream:  s,
    74  		pending: make(map[ID]chan *Response),
    75  		done:    make(chan struct{}),
    76  	}
    77  	return conn
    78  }
    79  
    80  func (c *conn) Notify(ctx context.Context, method string, params interface{}) (err error) {
    81  	notify, err := NewNotification(method, params)
    82  	if err != nil {
    83  		return fmt.Errorf("marshaling notify parameters: %v", err)
    84  	}
    85  	ctx, done := event.Start(ctx, method,
    86  		tag.Method.Of(method),
    87  		tag.RPCDirection.Of(tag.Outbound),
    88  	)
    89  	defer func() {
    90  		recordStatus(ctx, err)
    91  		done()
    92  	}()
    93  
    94  	event.Metric(ctx, tag.Started.Of(1))
    95  	n, err := c.write(ctx, notify)
    96  	event.Metric(ctx, tag.SentBytes.Of(n))
    97  	return err
    98  }
    99  
   100  func (c *conn) Call(ctx context.Context, method string, params, result interface{}) (_ ID, err error) {
   101  	// generate a new request identifier
   102  	id := ID{number: atomic.AddInt64(&c.seq, 1)}
   103  	call, err := NewCall(id, method, params)
   104  	if err != nil {
   105  		return id, fmt.Errorf("marshaling call parameters: %v", err)
   106  	}
   107  	ctx, done := event.Start(ctx, method,
   108  		tag.Method.Of(method),
   109  		tag.RPCDirection.Of(tag.Outbound),
   110  		tag.RPCID.Of(fmt.Sprintf("%q", id)),
   111  	)
   112  	defer func() {
   113  		recordStatus(ctx, err)
   114  		done()
   115  	}()
   116  	event.Metric(ctx, tag.Started.Of(1))
   117  	// We have to add ourselves to the pending map before we send, otherwise we
   118  	// are racing the response. Also add a buffer to rchan, so that if we get a
   119  	// wire response between the time this call is cancelled and id is deleted
   120  	// from c.pending, the send to rchan will not block.
   121  	rchan := make(chan *Response, 1)
   122  	c.pendingMu.Lock()
   123  	c.pending[id] = rchan
   124  	c.pendingMu.Unlock()
   125  	defer func() {
   126  		c.pendingMu.Lock()
   127  		delete(c.pending, id)
   128  		c.pendingMu.Unlock()
   129  	}()
   130  	// now we are ready to send
   131  	n, err := c.write(ctx, call)
   132  	event.Metric(ctx, tag.SentBytes.Of(n))
   133  	if err != nil {
   134  		// sending failed, we will never get a response, so don't leave it pending
   135  		return id, err
   136  	}
   137  	// now wait for the response
   138  	select {
   139  	case response := <-rchan:
   140  		// is it an error response?
   141  		if response.err != nil {
   142  			return id, response.err
   143  		}
   144  		if result == nil || len(response.result) == 0 {
   145  			return id, nil
   146  		}
   147  		if err := json.Unmarshal(response.result, result); err != nil {
   148  			return id, fmt.Errorf("unmarshaling result: %v", err)
   149  		}
   150  		return id, nil
   151  	case <-ctx.Done():
   152  		return id, ctx.Err()
   153  	}
   154  }
   155  
   156  func (c *conn) replier(req Request, spanDone func()) Replier {
   157  	return func(ctx context.Context, result interface{}, err error) error {
   158  		defer func() {
   159  			recordStatus(ctx, err)
   160  			spanDone()
   161  		}()
   162  		call, ok := req.(*Call)
   163  		if !ok {
   164  			// request was a notify, no need to respond
   165  			return nil
   166  		}
   167  		response, err := NewResponse(call.id, result, err)
   168  		if err != nil {
   169  			return err
   170  		}
   171  		n, err := c.write(ctx, response)
   172  		event.Metric(ctx, tag.SentBytes.Of(n))
   173  		if err != nil {
   174  			// TODO(iancottrell): if a stream write fails, we really need to shut down
   175  			// the whole stream
   176  			return err
   177  		}
   178  		return nil
   179  	}
   180  }
   181  
   182  func (c *conn) write(ctx context.Context, msg Message) (int64, error) {
   183  	c.writeMu.Lock()
   184  	defer c.writeMu.Unlock()
   185  	return c.stream.Write(ctx, msg)
   186  }
   187  
   188  func (c *conn) Go(ctx context.Context, handler Handler) {
   189  	go c.run(ctx, handler)
   190  }
   191  
   192  func (c *conn) run(ctx context.Context, handler Handler) {
   193  	defer close(c.done)
   194  	for {
   195  		// get the next message
   196  		msg, n, err := c.stream.Read(ctx)
   197  		if err != nil {
   198  			// The stream failed, we cannot continue.
   199  			c.fail(err)
   200  			return
   201  		}
   202  		switch msg := msg.(type) {
   203  		case Request:
   204  			labels := []label.Label{
   205  				tag.Method.Of(msg.Method()),
   206  				tag.RPCDirection.Of(tag.Inbound),
   207  				{}, // reserved for ID if present
   208  			}
   209  			if call, ok := msg.(*Call); ok {
   210  				labels[len(labels)-1] = tag.RPCID.Of(fmt.Sprintf("%q", call.ID()))
   211  			} else {
   212  				labels = labels[:len(labels)-1]
   213  			}
   214  			reqCtx, spanDone := event.Start(ctx, msg.Method(), labels...)
   215  			event.Metric(reqCtx,
   216  				tag.Started.Of(1),
   217  				tag.ReceivedBytes.Of(n))
   218  			if err := handler(reqCtx, c.replier(msg, spanDone), msg); err != nil {
   219  				// delivery failed, not much we can do
   220  				event.Error(reqCtx, "jsonrpc2 message delivery failed", err)
   221  			}
   222  		case *Response:
   223  			// If method is not set, this should be a response, in which case we must
   224  			// have an id to send the response back to the caller.
   225  			c.pendingMu.Lock()
   226  			rchan, ok := c.pending[msg.id]
   227  			c.pendingMu.Unlock()
   228  			if ok {
   229  				rchan <- msg
   230  			}
   231  		}
   232  	}
   233  }
   234  
   235  func (c *conn) Close() error {
   236  	return c.stream.Close()
   237  }
   238  
   239  func (c *conn) Done() <-chan struct{} {
   240  	return c.done
   241  }
   242  
   243  func (c *conn) Err() error {
   244  	if err := c.err.Load(); err != nil {
   245  		return err.(error)
   246  	}
   247  	return nil
   248  }
   249  
   250  // fail sets a failure condition on the stream and closes it.
   251  func (c *conn) fail(err error) {
   252  	c.err.Store(err)
   253  	c.stream.Close()
   254  }
   255  
   256  func recordStatus(ctx context.Context, err error) {
   257  	if err != nil {
   258  		event.Label(ctx, tag.StatusCode.Of("ERROR"))
   259  	} else {
   260  		event.Label(ctx, tag.StatusCode.Of("OK"))
   261  	}
   262  }