golang.org/x/tools@v0.21.1-0.20240520172518-788d39e776b1/internal/jsonrpc2_v2/frame.go (about)

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package jsonrpc2
     6  
     7  import (
     8  	"bufio"
     9  	"context"
    10  	"encoding/json"
    11  	"fmt"
    12  	"io"
    13  	"strconv"
    14  	"strings"
    15  )
    16  
    17  // Reader abstracts the transport mechanics from the JSON RPC protocol.
    18  // A Conn reads messages from the reader it was provided on construction,
    19  // and assumes that each call to Read fully transfers a single message,
    20  // or returns an error.
    21  // A reader is not safe for concurrent use, it is expected it will be used by
    22  // a single Conn in a safe manner.
    23  type Reader interface {
    24  	// Read gets the next message from the stream.
    25  	Read(context.Context) (Message, int64, error)
    26  }
    27  
    28  // Writer abstracts the transport mechanics from the JSON RPC protocol.
    29  // A Conn writes messages using the writer it was provided on construction,
    30  // and assumes that each call to Write fully transfers a single message,
    31  // or returns an error.
    32  // A writer is not safe for concurrent use, it is expected it will be used by
    33  // a single Conn in a safe manner.
    34  type Writer interface {
    35  	// Write sends a message to the stream.
    36  	Write(context.Context, Message) (int64, error)
    37  }
    38  
    39  // Framer wraps low level byte readers and writers into jsonrpc2 message
    40  // readers and writers.
    41  // It is responsible for the framing and encoding of messages into wire form.
    42  type Framer interface {
    43  	// Reader wraps a byte reader into a message reader.
    44  	Reader(rw io.Reader) Reader
    45  	// Writer wraps a byte writer into a message writer.
    46  	Writer(rw io.Writer) Writer
    47  }
    48  
    49  // RawFramer returns a new Framer.
    50  // The messages are sent with no wrapping, and rely on json decode consistency
    51  // to determine message boundaries.
    52  func RawFramer() Framer { return rawFramer{} }
    53  
    54  type rawFramer struct{}
    55  type rawReader struct{ in *json.Decoder }
    56  type rawWriter struct{ out io.Writer }
    57  
    58  func (rawFramer) Reader(rw io.Reader) Reader {
    59  	return &rawReader{in: json.NewDecoder(rw)}
    60  }
    61  
    62  func (rawFramer) Writer(rw io.Writer) Writer {
    63  	return &rawWriter{out: rw}
    64  }
    65  
    66  func (r *rawReader) Read(ctx context.Context) (Message, int64, error) {
    67  	select {
    68  	case <-ctx.Done():
    69  		return nil, 0, ctx.Err()
    70  	default:
    71  	}
    72  	var raw json.RawMessage
    73  	if err := r.in.Decode(&raw); err != nil {
    74  		return nil, 0, err
    75  	}
    76  	msg, err := DecodeMessage(raw)
    77  	return msg, int64(len(raw)), err
    78  }
    79  
    80  func (w *rawWriter) Write(ctx context.Context, msg Message) (int64, error) {
    81  	select {
    82  	case <-ctx.Done():
    83  		return 0, ctx.Err()
    84  	default:
    85  	}
    86  	data, err := EncodeMessage(msg)
    87  	if err != nil {
    88  		return 0, fmt.Errorf("marshaling message: %v", err)
    89  	}
    90  	n, err := w.out.Write(data)
    91  	return int64(n), err
    92  }
    93  
    94  // HeaderFramer returns a new Framer.
    95  // The messages are sent with HTTP content length and MIME type headers.
    96  // This is the format used by LSP and others.
    97  func HeaderFramer() Framer { return headerFramer{} }
    98  
    99  type headerFramer struct{}
   100  type headerReader struct{ in *bufio.Reader }
   101  type headerWriter struct{ out io.Writer }
   102  
   103  func (headerFramer) Reader(rw io.Reader) Reader {
   104  	return &headerReader{in: bufio.NewReader(rw)}
   105  }
   106  
   107  func (headerFramer) Writer(rw io.Writer) Writer {
   108  	return &headerWriter{out: rw}
   109  }
   110  
   111  func (r *headerReader) Read(ctx context.Context) (Message, int64, error) {
   112  	select {
   113  	case <-ctx.Done():
   114  		return nil, 0, ctx.Err()
   115  	default:
   116  	}
   117  	var total, length int64
   118  	// read the header, stop on the first empty line
   119  	for {
   120  		line, err := r.in.ReadString('\n')
   121  		total += int64(len(line))
   122  		if err != nil {
   123  			if err == io.EOF {
   124  				if total == 0 {
   125  					return nil, 0, io.EOF
   126  				}
   127  				err = io.ErrUnexpectedEOF
   128  			}
   129  			return nil, total, fmt.Errorf("failed reading header line: %w", err)
   130  		}
   131  		line = strings.TrimSpace(line)
   132  		// check we have a header line
   133  		if line == "" {
   134  			break
   135  		}
   136  		colon := strings.IndexRune(line, ':')
   137  		if colon < 0 {
   138  			return nil, total, fmt.Errorf("invalid header line %q", line)
   139  		}
   140  		name, value := line[:colon], strings.TrimSpace(line[colon+1:])
   141  		switch name {
   142  		case "Content-Length":
   143  			if length, err = strconv.ParseInt(value, 10, 32); err != nil {
   144  				return nil, total, fmt.Errorf("failed parsing Content-Length: %v", value)
   145  			}
   146  			if length <= 0 {
   147  				return nil, total, fmt.Errorf("invalid Content-Length: %v", length)
   148  			}
   149  		default:
   150  			// ignoring unknown headers
   151  		}
   152  	}
   153  	if length == 0 {
   154  		return nil, total, fmt.Errorf("missing Content-Length header")
   155  	}
   156  	data := make([]byte, length)
   157  	n, err := io.ReadFull(r.in, data)
   158  	total += int64(n)
   159  	if err != nil {
   160  		return nil, total, err
   161  	}
   162  	msg, err := DecodeMessage(data)
   163  	return msg, total, err
   164  }
   165  
   166  func (w *headerWriter) Write(ctx context.Context, msg Message) (int64, error) {
   167  	select {
   168  	case <-ctx.Done():
   169  		return 0, ctx.Err()
   170  	default:
   171  	}
   172  	data, err := EncodeMessage(msg)
   173  	if err != nil {
   174  		return 0, fmt.Errorf("marshaling message: %v", err)
   175  	}
   176  	n, err := fmt.Fprintf(w.out, "Content-Length: %v\r\n\r\n", len(data))
   177  	total := int64(n)
   178  	if err == nil {
   179  		n, err = w.out.Write(data)
   180  		total += int64(n)
   181  	}
   182  	return total, err
   183  }