github.com/april1989/origin-go-tools@v0.0.32/internal/jsonrpc2/stream.go (about)

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package jsonrpc2
     6  
     7  import (
     8  	"bufio"
     9  	"context"
    10  	"encoding/json"
    11  	"fmt"
    12  	"io"
    13  	"net"
    14  	"strconv"
    15  	"strings"
    16  )
    17  
    18  // Stream abstracts the transport mechanics from the JSON RPC protocol.
    19  // A Conn reads and writes messages using the stream it was provided on
    20  // construction, and assumes that each call to Read or Write fully transfers
    21  // a single message, or returns an error.
    22  // A stream is not safe for concurrent use, it is expected it will be used by
    23  // a single Conn in a safe manner.
    24  type Stream interface {
    25  	// Read gets the next message from the stream.
    26  	Read(context.Context) (Message, int64, error)
    27  	// Write sends a message to the stream.
    28  	Write(context.Context, Message) (int64, error)
    29  	// Close closes the connection.
    30  	// Any blocked Read or Write operations will be unblocked and return errors.
    31  	Close() error
    32  }
    33  
    34  // Framer wraps a network connection up into a Stream.
    35  // It is responsible for the framing and encoding of messages into wire form.
    36  // NewRawStream and NewHeaderStream are implementations of a Framer.
    37  type Framer func(conn net.Conn) Stream
    38  
    39  // NewRawStream returns a Stream built on top of a net.Conn.
    40  // The messages are sent with no wrapping, and rely on json decode consistency
    41  // to determine message boundaries.
    42  func NewRawStream(conn net.Conn) Stream {
    43  	return &rawStream{
    44  		conn: conn,
    45  		in:   json.NewDecoder(conn),
    46  	}
    47  }
    48  
    49  type rawStream struct {
    50  	conn net.Conn
    51  	in   *json.Decoder
    52  }
    53  
    54  func (s *rawStream) Read(ctx context.Context) (Message, int64, error) {
    55  	select {
    56  	case <-ctx.Done():
    57  		return nil, 0, ctx.Err()
    58  	default:
    59  	}
    60  	var raw json.RawMessage
    61  	if err := s.in.Decode(&raw); err != nil {
    62  		return nil, 0, err
    63  	}
    64  	msg, err := DecodeMessage(raw)
    65  	return msg, int64(len(raw)), err
    66  }
    67  
    68  func (s *rawStream) Write(ctx context.Context, msg Message) (int64, error) {
    69  	select {
    70  	case <-ctx.Done():
    71  		return 0, ctx.Err()
    72  	default:
    73  	}
    74  	data, err := json.Marshal(msg)
    75  	if err != nil {
    76  		return 0, fmt.Errorf("marshaling message: %v", err)
    77  	}
    78  	n, err := s.conn.Write(data)
    79  	return int64(n), err
    80  }
    81  
    82  func (s *rawStream) Close() error {
    83  	return s.conn.Close()
    84  }
    85  
    86  // NewHeaderStream returns a Stream built on top of a net.Conn.
    87  // The messages are sent with HTTP content length and MIME type headers.
    88  // This is the format used by LSP and others.
    89  func NewHeaderStream(conn net.Conn) Stream {
    90  	return &headerStream{
    91  		conn: conn,
    92  		in:   bufio.NewReader(conn),
    93  	}
    94  }
    95  
    96  type headerStream struct {
    97  	conn net.Conn
    98  	in   *bufio.Reader
    99  }
   100  
   101  func (s *headerStream) Read(ctx context.Context) (Message, int64, error) {
   102  	select {
   103  	case <-ctx.Done():
   104  		return nil, 0, ctx.Err()
   105  	default:
   106  	}
   107  	var total, length int64
   108  	// read the header, stop on the first empty line
   109  	for {
   110  		line, err := s.in.ReadString('\n')
   111  		total += int64(len(line))
   112  		if err != nil {
   113  			return nil, total, fmt.Errorf("failed reading header line: %w", err)
   114  		}
   115  		line = strings.TrimSpace(line)
   116  		// check we have a header line
   117  		if line == "" {
   118  			break
   119  		}
   120  		colon := strings.IndexRune(line, ':')
   121  		if colon < 0 {
   122  			return nil, total, fmt.Errorf("invalid header line %q", line)
   123  		}
   124  		name, value := line[:colon], strings.TrimSpace(line[colon+1:])
   125  		switch name {
   126  		case "Content-Length":
   127  			if length, err = strconv.ParseInt(value, 10, 32); err != nil {
   128  				return nil, total, fmt.Errorf("failed parsing Content-Length: %v", value)
   129  			}
   130  			if length <= 0 {
   131  				return nil, total, fmt.Errorf("invalid Content-Length: %v", length)
   132  			}
   133  		default:
   134  			// ignoring unknown headers
   135  		}
   136  	}
   137  	if length == 0 {
   138  		return nil, total, fmt.Errorf("missing Content-Length header")
   139  	}
   140  	data := make([]byte, length)
   141  	if _, err := io.ReadFull(s.in, data); err != nil {
   142  		return nil, total, err
   143  	}
   144  	total += length
   145  	msg, err := DecodeMessage(data)
   146  	return msg, total, err
   147  }
   148  
   149  func (s *headerStream) Write(ctx context.Context, msg Message) (int64, error) {
   150  	select {
   151  	case <-ctx.Done():
   152  		return 0, ctx.Err()
   153  	default:
   154  	}
   155  	data, err := json.Marshal(msg)
   156  	if err != nil {
   157  		return 0, fmt.Errorf("marshaling message: %v", err)
   158  	}
   159  	n, err := fmt.Fprintf(s.conn, "Content-Length: %v\r\n\r\n", len(data))
   160  	total := int64(n)
   161  	if err == nil {
   162  		n, err = s.conn.Write(data)
   163  		total += int64(n)
   164  	}
   165  	return total, err
   166  }
   167  
   168  func (s *headerStream) Close() error {
   169  	return s.conn.Close()
   170  }