github.com/v2fly/tools@v0.100.0/internal/jsonrpc2_v2/frame.go (about)

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package jsonrpc2
     6  
     7  import (
     8  	"bufio"
     9  	"context"
    10  	"encoding/json"
    11  	"fmt"
    12  	"io"
    13  	"strconv"
    14  	"strings"
    15  
    16  	errors "golang.org/x/xerrors"
    17  )
    18  
    19  // Reader abstracts the transport mechanics from the JSON RPC protocol.
    20  // A Conn reads messages from the reader it was provided on construction,
    21  // and assumes that each call to Read fully transfers a single message,
    22  // or returns an error.
    23  // A reader is not safe for concurrent use, it is expected it will be used by
    24  // a single Conn in a safe manner.
    25  type Reader interface {
    26  	// Read gets the next message from the stream.
    27  	Read(context.Context) (Message, int64, error)
    28  }
    29  
    30  // Writer abstracts the transport mechanics from the JSON RPC protocol.
    31  // A Conn writes messages using the writer it was provided on construction,
    32  // and assumes that each call to Write fully transfers a single message,
    33  // or returns an error.
    34  // A writer is not safe for concurrent use, it is expected it will be used by
    35  // a single Conn in a safe manner.
    36  type Writer interface {
    37  	// Write sends a message to the stream.
    38  	Write(context.Context, Message) (int64, error)
    39  }
    40  
    41  // Framer wraps low level byte readers and writers into jsonrpc2 message
    42  // readers and writers.
    43  // It is responsible for the framing and encoding of messages into wire form.
    44  type Framer interface {
    45  	// Reader wraps a byte reader into a message reader.
    46  	Reader(rw io.Reader) Reader
    47  	// Writer wraps a byte writer into a message writer.
    48  	Writer(rw io.Writer) Writer
    49  }
    50  
    51  // RawFramer returns a new Framer.
    52  // The messages are sent with no wrapping, and rely on json decode consistency
    53  // to determine message boundaries.
    54  func RawFramer() Framer { return rawFramer{} }
    55  
    56  type rawFramer struct{}
    57  type rawReader struct{ in *json.Decoder }
    58  type rawWriter struct{ out io.Writer }
    59  
    60  func (rawFramer) Reader(rw io.Reader) Reader {
    61  	return &rawReader{in: json.NewDecoder(rw)}
    62  }
    63  
    64  func (rawFramer) Writer(rw io.Writer) Writer {
    65  	return &rawWriter{out: rw}
    66  }
    67  
    68  func (r *rawReader) Read(ctx context.Context) (Message, int64, error) {
    69  	select {
    70  	case <-ctx.Done():
    71  		return nil, 0, ctx.Err()
    72  	default:
    73  	}
    74  	var raw json.RawMessage
    75  	if err := r.in.Decode(&raw); err != nil {
    76  		return nil, 0, err
    77  	}
    78  	msg, err := DecodeMessage(raw)
    79  	return msg, int64(len(raw)), err
    80  }
    81  
    82  func (w *rawWriter) Write(ctx context.Context, msg Message) (int64, error) {
    83  	select {
    84  	case <-ctx.Done():
    85  		return 0, ctx.Err()
    86  	default:
    87  	}
    88  	data, err := EncodeMessage(msg)
    89  	if err != nil {
    90  		return 0, errors.Errorf("marshaling message: %v", err)
    91  	}
    92  	n, err := w.out.Write(data)
    93  	return int64(n), err
    94  }
    95  
    96  // HeaderFramer returns a new Framer.
    97  // The messages are sent with HTTP content length and MIME type headers.
    98  // This is the format used by LSP and others.
    99  func HeaderFramer() Framer { return headerFramer{} }
   100  
   101  type headerFramer struct{}
   102  type headerReader struct{ in *bufio.Reader }
   103  type headerWriter struct{ out io.Writer }
   104  
   105  func (headerFramer) Reader(rw io.Reader) Reader {
   106  	return &headerReader{in: bufio.NewReader(rw)}
   107  }
   108  
   109  func (headerFramer) Writer(rw io.Writer) Writer {
   110  	return &headerWriter{out: rw}
   111  }
   112  
   113  func (r *headerReader) Read(ctx context.Context) (Message, int64, error) {
   114  	select {
   115  	case <-ctx.Done():
   116  		return nil, 0, ctx.Err()
   117  	default:
   118  	}
   119  	var total, length int64
   120  	// read the header, stop on the first empty line
   121  	for {
   122  		line, err := r.in.ReadString('\n')
   123  		total += int64(len(line))
   124  		if err != nil {
   125  			return nil, total, errors.Errorf("failed reading header line: %w", err)
   126  		}
   127  		line = strings.TrimSpace(line)
   128  		// check we have a header line
   129  		if line == "" {
   130  			break
   131  		}
   132  		colon := strings.IndexRune(line, ':')
   133  		if colon < 0 {
   134  			return nil, total, errors.Errorf("invalid header line %q", line)
   135  		}
   136  		name, value := line[:colon], strings.TrimSpace(line[colon+1:])
   137  		switch name {
   138  		case "Content-Length":
   139  			if length, err = strconv.ParseInt(value, 10, 32); err != nil {
   140  				return nil, total, errors.Errorf("failed parsing Content-Length: %v", value)
   141  			}
   142  			if length <= 0 {
   143  				return nil, total, errors.Errorf("invalid Content-Length: %v", length)
   144  			}
   145  		default:
   146  			// ignoring unknown headers
   147  		}
   148  	}
   149  	if length == 0 {
   150  		return nil, total, errors.Errorf("missing Content-Length header")
   151  	}
   152  	data := make([]byte, length)
   153  	n, err := io.ReadFull(r.in, data)
   154  	total += int64(n)
   155  	if err != nil {
   156  		return nil, total, err
   157  	}
   158  	msg, err := DecodeMessage(data)
   159  	return msg, total, err
   160  }
   161  
   162  func (w *headerWriter) Write(ctx context.Context, msg Message) (int64, error) {
   163  	select {
   164  	case <-ctx.Done():
   165  		return 0, ctx.Err()
   166  	default:
   167  	}
   168  	data, err := EncodeMessage(msg)
   169  	if err != nil {
   170  		return 0, errors.Errorf("marshaling message: %v", err)
   171  	}
   172  	n, err := fmt.Fprintf(w.out, "Content-Length: %v\r\n\r\n", len(data))
   173  	total := int64(n)
   174  	if err == nil {
   175  		n, err = w.out.Write(data)
   176  		total += int64(n)
   177  	}
   178  	return total, err
   179  }