github.com/v2fly/v2ray-core/v5@v5.16.2-0.20240507031116-8191faa6e095/transport/internet/websocket/connection.go (about)

     1  package websocket
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"net"
     7  	"time"
     8  
     9  	"github.com/gorilla/websocket"
    10  
    11  	"github.com/v2fly/v2ray-core/v5/common/buf"
    12  	"github.com/v2fly/v2ray-core/v5/common/errors"
    13  	"github.com/v2fly/v2ray-core/v5/common/serial"
    14  )
    15  
    16  var _ buf.Writer = (*connection)(nil)
    17  
    18  // connection is a wrapper for net.Conn over WebSocket connection.
    19  type connection struct {
    20  	conn       *websocket.Conn
    21  	reader     io.Reader
    22  	remoteAddr net.Addr
    23  
    24  	shouldWait        bool
    25  	delayedDialFinish context.Context
    26  	finishedDial      context.CancelFunc
    27  	dialer            DelayedDialer
    28  }
    29  
    30  type DelayedDialer interface {
    31  	Dial(earlyData []byte) (*websocket.Conn, error)
    32  }
    33  
    34  func newConnection(conn *websocket.Conn, remoteAddr net.Addr) *connection {
    35  	return &connection{
    36  		conn:       conn,
    37  		remoteAddr: remoteAddr,
    38  	}
    39  }
    40  
    41  func newConnectionWithEarlyData(conn *websocket.Conn, remoteAddr net.Addr, earlyData io.Reader) *connection {
    42  	return &connection{
    43  		conn:       conn,
    44  		remoteAddr: remoteAddr,
    45  		reader:     earlyData,
    46  	}
    47  }
    48  
    49  func newConnectionWithDelayedDial(dialer DelayedDialer) *connection {
    50  	delayedDialContext, cancelFunc := context.WithCancel(context.Background())
    51  	return &connection{
    52  		shouldWait:        true,
    53  		delayedDialFinish: delayedDialContext,
    54  		finishedDial:      cancelFunc,
    55  		dialer:            dialer,
    56  	}
    57  }
    58  
    59  func newRelayedConnectionWithDelayedDial(dialer DelayedDialerForwarded) *connectionForwarder {
    60  	delayedDialContext, cancelFunc := context.WithCancel(context.Background())
    61  	return &connectionForwarder{
    62  		shouldWait:        true,
    63  		delayedDialFinish: delayedDialContext,
    64  		finishedDial:      cancelFunc,
    65  		dialer:            dialer,
    66  	}
    67  }
    68  
    69  func newRelayedConnection(conn io.ReadWriteCloser) *connectionForwarder {
    70  	return &connectionForwarder{
    71  		ReadWriteCloser: conn,
    72  		shouldWait:      false,
    73  	}
    74  }
    75  
    76  // Read implements net.Conn.Read()
    77  func (c *connection) Read(b []byte) (int, error) {
    78  	for {
    79  		reader, err := c.getReader()
    80  		if err != nil {
    81  			return 0, err
    82  		}
    83  
    84  		nBytes, err := reader.Read(b)
    85  		if errors.Cause(err) == io.EOF {
    86  			c.reader = nil
    87  			continue
    88  		}
    89  		return nBytes, err
    90  	}
    91  }
    92  
    93  func (c *connection) getReader() (io.Reader, error) {
    94  	if c.shouldWait {
    95  		<-c.delayedDialFinish.Done()
    96  		if c.conn == nil {
    97  			return nil, newError("unable to read delayed dial websocket connection as it do not exist")
    98  		}
    99  	}
   100  	if c.reader != nil {
   101  		return c.reader, nil
   102  	}
   103  
   104  	_, reader, err := c.conn.NextReader()
   105  	if err != nil {
   106  		return nil, err
   107  	}
   108  	c.reader = reader
   109  	return reader, nil
   110  }
   111  
   112  // Write implements io.Writer.
   113  func (c *connection) Write(b []byte) (int, error) {
   114  	if c.shouldWait {
   115  		var err error
   116  		c.conn, err = c.dialer.Dial(b)
   117  		c.finishedDial()
   118  		if err != nil {
   119  			return 0, newError("Unable to proceed with delayed write").Base(err)
   120  		}
   121  		c.remoteAddr = c.conn.RemoteAddr()
   122  		c.shouldWait = false
   123  		return len(b), nil
   124  	}
   125  	if err := c.conn.WriteMessage(websocket.BinaryMessage, b); err != nil {
   126  		return 0, err
   127  	}
   128  	return len(b), nil
   129  }
   130  
   131  func (c *connection) WriteMultiBuffer(mb buf.MultiBuffer) error {
   132  	mb = buf.Compact(mb)
   133  	mb, err := buf.WriteMultiBuffer(c, mb)
   134  	buf.ReleaseMulti(mb)
   135  	return err
   136  }
   137  
   138  func (c *connection) Close() error {
   139  	if c.shouldWait {
   140  		<-c.delayedDialFinish.Done()
   141  		if c.conn == nil {
   142  			return newError("unable to close delayed dial websocket connection as it do not exist")
   143  		}
   144  	}
   145  	var errors []interface{}
   146  	if err := c.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second*5)); err != nil {
   147  		errors = append(errors, err)
   148  	}
   149  	if err := c.conn.Close(); err != nil {
   150  		errors = append(errors, err)
   151  	}
   152  	if len(errors) > 0 {
   153  		return newError("failed to close connection").Base(newError(serial.Concat(errors...)))
   154  	}
   155  	return nil
   156  }
   157  
   158  func (c *connection) LocalAddr() net.Addr {
   159  	if c.shouldWait {
   160  		<-c.delayedDialFinish.Done()
   161  		if c.conn == nil {
   162  			newError("websocket transport is not materialized when LocalAddr() is called").AtWarning().WriteToLog()
   163  			return &net.UnixAddr{
   164  				Name: "@placeholder",
   165  				Net:  "unix",
   166  			}
   167  		}
   168  	}
   169  	return c.conn.LocalAddr()
   170  }
   171  
   172  func (c *connection) RemoteAddr() net.Addr {
   173  	return c.remoteAddr
   174  }
   175  
   176  func (c *connection) SetDeadline(t time.Time) error {
   177  	if err := c.SetReadDeadline(t); err != nil {
   178  		return err
   179  	}
   180  	return c.SetWriteDeadline(t)
   181  }
   182  
   183  func (c *connection) SetReadDeadline(t time.Time) error {
   184  	if c.shouldWait {
   185  		<-c.delayedDialFinish.Done()
   186  		if c.conn == nil {
   187  			newError("websocket transport is not materialized when SetReadDeadline() is called").AtWarning().WriteToLog()
   188  			return nil
   189  		}
   190  	}
   191  	return c.conn.SetReadDeadline(t)
   192  }
   193  
   194  func (c *connection) SetWriteDeadline(t time.Time) error {
   195  	if c.shouldWait {
   196  		<-c.delayedDialFinish.Done()
   197  		if c.conn == nil {
   198  			newError("websocket transport is not materialized when SetWriteDeadline() is called").AtWarning().WriteToLog()
   199  			return nil
   200  		}
   201  	}
   202  	return c.conn.SetWriteDeadline(t)
   203  }