github.com/kelleygo/clashcore@v1.0.2/transport/vless/conn.go (about)

     1  package vless
     2  
     3  import (
     4  	"encoding/binary"
     5  	"errors"
     6  	"io"
     7  	"net"
     8  	"sync"
     9  
    10  	"github.com/kelleygo/clashcore/common/buf"
    11  	N "github.com/kelleygo/clashcore/common/net"
    12  	"github.com/kelleygo/clashcore/transport/vless/vision"
    13  
    14  	"github.com/gofrs/uuid/v5"
    15  	"google.golang.org/protobuf/proto"
    16  )
    17  
    18  type Conn struct {
    19  	N.ExtendedWriter
    20  	N.ExtendedReader
    21  	net.Conn
    22  	dst      *DstAddr
    23  	id       *uuid.UUID
    24  	addons   *Addons
    25  	received bool
    26  
    27  	handshakeMutex sync.Mutex
    28  	needHandshake  bool
    29  	err            error
    30  }
    31  
    32  func (vc *Conn) Read(b []byte) (int, error) {
    33  	if vc.received {
    34  		return vc.ExtendedReader.Read(b)
    35  	}
    36  
    37  	if err := vc.recvResponse(); err != nil {
    38  		return 0, err
    39  	}
    40  	vc.received = true
    41  	return vc.ExtendedReader.Read(b)
    42  }
    43  
    44  func (vc *Conn) ReadBuffer(buffer *buf.Buffer) error {
    45  	if vc.received {
    46  		return vc.ExtendedReader.ReadBuffer(buffer)
    47  	}
    48  
    49  	if err := vc.recvResponse(); err != nil {
    50  		return err
    51  	}
    52  	vc.received = true
    53  	return vc.ExtendedReader.ReadBuffer(buffer)
    54  }
    55  
    56  func (vc *Conn) Write(p []byte) (int, error) {
    57  	if vc.needHandshake {
    58  		vc.handshakeMutex.Lock()
    59  		if vc.needHandshake {
    60  			vc.needHandshake = false
    61  			if vc.sendRequest(p) {
    62  				vc.handshakeMutex.Unlock()
    63  				if vc.err != nil {
    64  					return 0, vc.err
    65  				}
    66  				return len(p), vc.err
    67  			}
    68  			if vc.err != nil {
    69  				vc.handshakeMutex.Unlock()
    70  				return 0, vc.err
    71  			}
    72  		}
    73  		vc.handshakeMutex.Unlock()
    74  	}
    75  
    76  	return vc.ExtendedWriter.Write(p)
    77  }
    78  
    79  func (vc *Conn) WriteBuffer(buffer *buf.Buffer) error {
    80  	if vc.needHandshake {
    81  		vc.handshakeMutex.Lock()
    82  		if vc.needHandshake {
    83  			vc.needHandshake = false
    84  			if vc.sendRequest(buffer.Bytes()) {
    85  				vc.handshakeMutex.Unlock()
    86  				return vc.err
    87  			}
    88  			if vc.err != nil {
    89  				vc.handshakeMutex.Unlock()
    90  				return vc.err
    91  			}
    92  		}
    93  		vc.handshakeMutex.Unlock()
    94  	}
    95  
    96  	return vc.ExtendedWriter.WriteBuffer(buffer)
    97  }
    98  
    99  func (vc *Conn) sendRequest(p []byte) bool {
   100  	var addonsBytes []byte
   101  	if vc.addons != nil {
   102  		addonsBytes, vc.err = proto.Marshal(vc.addons)
   103  		if vc.err != nil {
   104  			return true
   105  		}
   106  	}
   107  
   108  	var buffer *buf.Buffer
   109  	if vc.IsXTLSVisionEnabled() {
   110  		buffer = buf.New()
   111  		defer buffer.Release()
   112  	} else {
   113  		requestLen := 1  // protocol version
   114  		requestLen += 16 // UUID
   115  		requestLen += 1  // addons length
   116  		requestLen += len(addonsBytes)
   117  		requestLen += 1 // command
   118  		if !vc.dst.Mux {
   119  			requestLen += 2 // port
   120  			requestLen += 1 // addr type
   121  			requestLen += len(vc.dst.Addr)
   122  		}
   123  		requestLen += len(p)
   124  
   125  		buffer = buf.NewSize(requestLen)
   126  		defer buffer.Release()
   127  	}
   128  
   129  	buf.Must(
   130  		buffer.WriteByte(Version),              // protocol version
   131  		buf.Error(buffer.Write(vc.id.Bytes())), // 16 bytes of uuid
   132  		buffer.WriteByte(byte(len(addonsBytes))),
   133  		buf.Error(buffer.Write(addonsBytes)),
   134  	)
   135  
   136  	if vc.dst.Mux {
   137  		buf.Must(buffer.WriteByte(CommandMux))
   138  	} else {
   139  		if vc.dst.UDP {
   140  			buf.Must(buffer.WriteByte(CommandUDP))
   141  		} else {
   142  			buf.Must(buffer.WriteByte(CommandTCP))
   143  		}
   144  
   145  		binary.BigEndian.PutUint16(buffer.Extend(2), vc.dst.Port)
   146  		buf.Must(
   147  			buffer.WriteByte(vc.dst.AddrType),
   148  			buf.Error(buffer.Write(vc.dst.Addr)),
   149  		)
   150  	}
   151  
   152  	buf.Must(buf.Error(buffer.Write(p)))
   153  
   154  	_, vc.err = vc.ExtendedWriter.Write(buffer.Bytes())
   155  	return true
   156  }
   157  
   158  func (vc *Conn) recvResponse() error {
   159  	var buffer [2]byte
   160  	_, vc.err = io.ReadFull(vc.ExtendedReader, buffer[:])
   161  	if vc.err != nil {
   162  		return vc.err
   163  	}
   164  
   165  	if buffer[0] != Version {
   166  		return errors.New("unexpected response version")
   167  	}
   168  
   169  	length := int64(buffer[1])
   170  	if length != 0 { // addon data length > 0
   171  		io.CopyN(io.Discard, vc.ExtendedReader, length) // just discard
   172  	}
   173  
   174  	return nil
   175  }
   176  
   177  func (vc *Conn) Upstream() any {
   178  	return vc.Conn
   179  }
   180  
   181  func (vc *Conn) NeedHandshake() bool {
   182  	return vc.needHandshake
   183  }
   184  
   185  func (vc *Conn) IsXTLSVisionEnabled() bool {
   186  	return vc.addons != nil && vc.addons.Flow == XRV
   187  }
   188  
   189  // newConn return a Conn instance
   190  func newConn(conn net.Conn, client *Client, dst *DstAddr) (net.Conn, error) {
   191  	c := &Conn{
   192  		ExtendedReader: N.NewExtendedReader(conn),
   193  		ExtendedWriter: N.NewExtendedWriter(conn),
   194  		Conn:           conn,
   195  		id:             client.uuid,
   196  		dst:            dst,
   197  		needHandshake:  true,
   198  	}
   199  
   200  	if client.Addons != nil {
   201  		switch client.Addons.Flow {
   202  		case XRV:
   203  			visionConn, err := vision.NewConn(c, c.id)
   204  			if err != nil {
   205  				return nil, err
   206  			}
   207  			c.addons = client.Addons
   208  			return visionConn, nil
   209  		}
   210  	}
   211  
   212  	return c, nil
   213  }