github.com/xmplusdev/xmcore@v1.8.11-0.20240412132628-5518b55526af/transport/internet/grpc/encoding/multiconn.go (about)

     1  package encoding
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"net"
     7  
     8  	"github.com/xmplusdev/xmcore/common/buf"
     9  	xnet "github.com/xmplusdev/xmcore/common/net"
    10  	"github.com/xmplusdev/xmcore/common/net/cnc"
    11  	"github.com/xmplusdev/xmcore/common/signal/done"
    12  	"google.golang.org/grpc/metadata"
    13  	"google.golang.org/grpc/peer"
    14  )
    15  
    16  type MultiHunkConn interface {
    17  	Context() context.Context
    18  	Send(*MultiHunk) error
    19  	Recv() (*MultiHunk, error)
    20  	SendMsg(m interface{}) error
    21  	RecvMsg(m interface{}) error
    22  }
    23  
    24  type MultiHunkReaderWriter struct {
    25  	hc     MultiHunkConn
    26  	cancel context.CancelFunc
    27  	done   *done.Instance
    28  
    29  	buf [][]byte
    30  }
    31  
    32  func NewMultiHunkReadWriter(hc MultiHunkConn, cancel context.CancelFunc) *MultiHunkReaderWriter {
    33  	return &MultiHunkReaderWriter{hc, cancel, done.New(), nil}
    34  }
    35  
    36  func NewMultiHunkConn(hc MultiHunkConn, cancel context.CancelFunc) net.Conn {
    37  	var rAddr net.Addr
    38  	pr, ok := peer.FromContext(hc.Context())
    39  	if ok {
    40  		rAddr = pr.Addr
    41  	} else {
    42  		rAddr = &net.TCPAddr{
    43  			IP:   []byte{0, 0, 0, 0},
    44  			Port: 0,
    45  		}
    46  	}
    47  
    48  	md, ok := metadata.FromIncomingContext(hc.Context())
    49  	if ok {
    50  		header := md.Get("x-real-ip")
    51  		if len(header) > 0 {
    52  			realip := xnet.ParseAddress(header[0])
    53  			if realip.Family().IsIP() {
    54  				rAddr = &net.TCPAddr{
    55  					IP:   realip.IP(),
    56  					Port: 0,
    57  				}
    58  			}
    59  		}
    60  	}
    61  	wrc := NewMultiHunkReadWriter(hc, cancel)
    62  	return cnc.NewConnection(
    63  		cnc.ConnectionInputMulti(wrc),
    64  		cnc.ConnectionOutputMulti(wrc),
    65  		cnc.ConnectionOnClose(wrc),
    66  		cnc.ConnectionRemoteAddr(rAddr),
    67  	)
    68  }
    69  
    70  func (h *MultiHunkReaderWriter) forceFetch() error {
    71  	hunk, err := h.hc.Recv()
    72  	if err != nil {
    73  		if err == io.EOF {
    74  			return err
    75  		}
    76  
    77  		return newError("failed to fetch hunk from gRPC tunnel").Base(err)
    78  	}
    79  
    80  	h.buf = hunk.Data
    81  
    82  	return nil
    83  }
    84  
    85  func (h *MultiHunkReaderWriter) ReadMultiBuffer() (buf.MultiBuffer, error) {
    86  	if h.done.Done() {
    87  		return nil, io.EOF
    88  	}
    89  
    90  	if err := h.forceFetch(); err != nil {
    91  		return nil, err
    92  	}
    93  
    94  	mb := make(buf.MultiBuffer, 0, len(h.buf))
    95  	for _, b := range h.buf {
    96  		if len(b) == 0 {
    97  			continue
    98  		}
    99  
   100  		if cap(b) >= buf.Size {
   101  			mb = append(mb, buf.NewExisted(b))
   102  		} else {
   103  			nb := buf.New()
   104  			nb.Extend(int32(len(b)))
   105  			copy(nb.Bytes(), b)
   106  
   107  			mb = append(mb, nb)
   108  		}
   109  
   110  	}
   111  	return mb, nil
   112  }
   113  
   114  func (h *MultiHunkReaderWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
   115  	defer buf.ReleaseMulti(mb)
   116  	if h.done.Done() {
   117  		return io.ErrClosedPipe
   118  	}
   119  
   120  	hunks := make([][]byte, 0, len(mb))
   121  
   122  	for _, b := range mb {
   123  		if b.Len() > 0 {
   124  			hunks = append(hunks, b.Bytes())
   125  		}
   126  	}
   127  
   128  	err := h.hc.Send(&MultiHunk{Data: hunks})
   129  	if err != nil {
   130  		return err
   131  	}
   132  	return nil
   133  }
   134  
   135  func (h *MultiHunkReaderWriter) Close() error {
   136  	if h.cancel != nil {
   137  		h.cancel()
   138  	}
   139  	if sc, match := h.hc.(StreamCloser); match {
   140  		return sc.CloseSend()
   141  	}
   142  
   143  	return h.done.Close()
   144  }