github.com/eagleql/xray-core@v1.4.4/transport/internet/grpc/encoding/hunkconn.go (about)

     1  package encoding
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"net"
     7  
     8  	"google.golang.org/grpc/peer"
     9  
    10  	"github.com/eagleql/xray-core/common/buf"
    11  	"github.com/eagleql/xray-core/common/net/cnc"
    12  	"github.com/eagleql/xray-core/common/signal/done"
    13  )
    14  
    15  type HunkConn interface {
    16  	Context() context.Context
    17  	Send(*Hunk) error
    18  	Recv() (*Hunk, error)
    19  	SendMsg(m interface{}) error
    20  	RecvMsg(m interface{}) error
    21  }
    22  
    23  type StreamCloser interface {
    24  	CloseSend() error
    25  }
    26  
    27  type HunkReaderWriter struct {
    28  	hc     HunkConn
    29  	cancel context.CancelFunc
    30  	done   *done.Instance
    31  
    32  	buf   []byte
    33  	index int
    34  }
    35  
    36  func NewHunkReadWriter(hc HunkConn, cancel context.CancelFunc) *HunkReaderWriter {
    37  	return &HunkReaderWriter{hc, cancel, done.New(), nil, 0}
    38  }
    39  
    40  func NewHunkConn(hc HunkConn, cancel context.CancelFunc) net.Conn {
    41  	var rAddr net.Addr
    42  	pr, ok := peer.FromContext(hc.Context())
    43  	if ok {
    44  		rAddr = pr.Addr
    45  	} else {
    46  		rAddr = &net.TCPAddr{
    47  			IP:   []byte{0, 0, 0, 0},
    48  			Port: 0,
    49  		}
    50  	}
    51  
    52  	wrc := NewHunkReadWriter(hc, cancel)
    53  	return cnc.NewConnection(
    54  		cnc.ConnectionInput(wrc),
    55  		cnc.ConnectionOutput(wrc),
    56  		cnc.ConnectionOnClose(wrc),
    57  		cnc.ConnectionRemoteAddr(rAddr),
    58  	)
    59  }
    60  
    61  func (h *HunkReaderWriter) forceFetch() error {
    62  	hunk, err := h.hc.Recv()
    63  	if err != nil {
    64  		if err == io.EOF {
    65  			return err
    66  		}
    67  
    68  		return newError("failed to fetch hunk from gRPC tunnel").Base(err)
    69  	}
    70  
    71  	h.buf = hunk.Data
    72  	h.index = 0
    73  
    74  	return nil
    75  }
    76  
    77  func (h *HunkReaderWriter) Read(buf []byte) (int, error) {
    78  	if h.done.Done() {
    79  		return 0, io.EOF
    80  	}
    81  
    82  	if h.index >= len(h.buf) {
    83  		if err := h.forceFetch(); err != nil {
    84  			return 0, err
    85  		}
    86  	}
    87  	n := copy(buf, h.buf[h.index:])
    88  	h.index += n
    89  
    90  	return n, nil
    91  }
    92  
    93  func (h *HunkReaderWriter) ReadMultiBuffer() (buf.MultiBuffer, error) {
    94  	if h.done.Done() {
    95  		return nil, io.EOF
    96  	}
    97  	if h.index >= len(h.buf) {
    98  		if err := h.forceFetch(); err != nil {
    99  			return nil, err
   100  		}
   101  	}
   102  
   103  	if cap(h.buf) >= buf.Size {
   104  		b := h.buf
   105  		h.index = len(h.buf)
   106  		return buf.MultiBuffer{buf.NewExisted(b)}, nil
   107  	}
   108  
   109  	b := buf.New()
   110  	_, err := b.ReadFrom(h)
   111  	if err != nil {
   112  		return nil, err
   113  	}
   114  	return buf.MultiBuffer{b}, nil
   115  }
   116  
   117  func (h *HunkReaderWriter) Write(buf []byte) (int, error) {
   118  	if h.done.Done() {
   119  		return 0, io.ErrClosedPipe
   120  	}
   121  
   122  	err := h.hc.Send(&Hunk{Data: buf[:]})
   123  	if err != nil {
   124  		return 0, newError("failed to send data over gRPC tunnel").Base(err)
   125  	}
   126  	return len(buf), nil
   127  }
   128  
   129  func (h *HunkReaderWriter) Close() error {
   130  	if h.cancel != nil {
   131  		h.cancel()
   132  	}
   133  	if sc, match := h.hc.(StreamCloser); match {
   134  		return sc.CloseSend()
   135  	}
   136  
   137  	return h.done.Close()
   138  }