github.com/kelleygo/clashcore@v1.0.2/transport/vless/vision/conn.go (about)

     1  package vision
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/subtle"
     6  	gotls "crypto/tls"
     7  	"encoding/binary"
     8  	"fmt"
     9  	"io"
    10  	"net"
    11  
    12  	"github.com/kelleygo/clashcore/common/buf"
    13  	N "github.com/kelleygo/clashcore/common/net"
    14  	"github.com/kelleygo/clashcore/log"
    15  
    16  	"github.com/gofrs/uuid/v5"
    17  	utls "github.com/sagernet/utls"
    18  )
    19  
    20  var (
    21  	_ N.ExtendedConn = (*Conn)(nil)
    22  )
    23  
    24  type Conn struct {
    25  	net.Conn
    26  	N.ExtendedReader
    27  	N.ExtendedWriter
    28  	upstream net.Conn
    29  	userUUID *uuid.UUID
    30  
    31  	tlsConn  net.Conn
    32  	input    *bytes.Reader
    33  	rawInput *bytes.Buffer
    34  
    35  	needHandshake              bool
    36  	packetsToFilter            int
    37  	isTLS                      bool
    38  	isTLS12orAbove             bool
    39  	enableXTLS                 bool
    40  	cipher                     uint16
    41  	remainingServerHello       uint16
    42  	readRemainingContent       int
    43  	readRemainingPadding       int
    44  	readProcess                bool
    45  	readFilterUUID             bool
    46  	readLastCommand            byte
    47  	writeFilterApplicationData bool
    48  	writeDirect                bool
    49  }
    50  
    51  func (vc *Conn) Read(b []byte) (int, error) {
    52  	if vc.readProcess {
    53  		buffer := buf.With(b)
    54  		err := vc.ReadBuffer(buffer)
    55  		return buffer.Len(), err
    56  	}
    57  	return vc.ExtendedReader.Read(b)
    58  }
    59  
    60  func (vc *Conn) ReadBuffer(buffer *buf.Buffer) error {
    61  	toRead := buffer.FreeBytes()
    62  	if vc.readRemainingContent > 0 {
    63  		if vc.readRemainingContent < buffer.FreeLen() {
    64  			toRead = toRead[:vc.readRemainingContent]
    65  		}
    66  		n, err := vc.ExtendedReader.Read(toRead)
    67  		buffer.Truncate(n)
    68  		vc.readRemainingContent -= n
    69  		vc.FilterTLS(toRead)
    70  		return err
    71  	}
    72  	if vc.readRemainingPadding > 0 {
    73  		_, err := io.CopyN(io.Discard, vc.ExtendedReader, int64(vc.readRemainingPadding))
    74  		if err != nil {
    75  			return err
    76  		}
    77  		vc.readRemainingPadding = 0
    78  	}
    79  	if vc.readProcess {
    80  		switch vc.readLastCommand {
    81  		case commandPaddingContinue:
    82  			//if vc.isTLS || vc.packetsToFilter > 0 {
    83  			headerUUIDLen := 0
    84  			if vc.readFilterUUID {
    85  				headerUUIDLen = uuid.Size
    86  			}
    87  			var header []byte
    88  			if need := headerUUIDLen + PaddingHeaderLen - uuid.Size; buffer.FreeLen() < need {
    89  				header = make([]byte, need)
    90  			} else {
    91  				header = buffer.FreeBytes()[:need]
    92  			}
    93  			_, err := io.ReadFull(vc.ExtendedReader, header)
    94  			if err != nil {
    95  				return err
    96  			}
    97  			if vc.readFilterUUID {
    98  				vc.readFilterUUID = false
    99  				if subtle.ConstantTimeCompare(vc.userUUID.Bytes(), header[:uuid.Size]) != 1 {
   100  					err = fmt.Errorf("XTLS Vision server responded unknown UUID: %s",
   101  						uuid.FromBytesOrNil(header[:uuid.Size]).String())
   102  					log.Errorln(err.Error())
   103  					return err
   104  				}
   105  				header = header[uuid.Size:]
   106  			}
   107  			vc.readRemainingPadding = int(binary.BigEndian.Uint16(header[3:]))
   108  			vc.readRemainingContent = int(binary.BigEndian.Uint16(header[1:]))
   109  			vc.readLastCommand = header[0]
   110  			log.Debugln("XTLS Vision read padding: command=%d, payloadLen=%d, paddingLen=%d",
   111  				vc.readLastCommand, vc.readRemainingContent, vc.readRemainingPadding)
   112  			return vc.ReadBuffer(buffer)
   113  			//}
   114  		case commandPaddingEnd:
   115  			vc.readProcess = false
   116  			return vc.ReadBuffer(buffer)
   117  		case commandPaddingDirect:
   118  			needReturn := false
   119  			if vc.input != nil {
   120  				_, err := buffer.ReadFrom(vc.input)
   121  				if err != nil {
   122  					return err
   123  				}
   124  				if vc.input.Len() == 0 {
   125  					needReturn = true
   126  					vc.input = nil
   127  				} else { // buffer is full
   128  					return nil
   129  				}
   130  			}
   131  			if vc.rawInput != nil {
   132  				_, err := buffer.ReadFrom(vc.rawInput)
   133  				if err != nil {
   134  					return err
   135  				}
   136  				needReturn = true
   137  				if vc.rawInput.Len() == 0 {
   138  					vc.rawInput = nil
   139  				}
   140  			}
   141  			if vc.input == nil && vc.rawInput == nil {
   142  				vc.readProcess = false
   143  				vc.ExtendedReader = N.NewExtendedReader(vc.Conn)
   144  				log.Debugln("XTLS Vision direct read start")
   145  			}
   146  			if needReturn {
   147  				return nil
   148  			}
   149  		default:
   150  			err := fmt.Errorf("XTLS Vision read unknown command: %d", vc.readLastCommand)
   151  			log.Debugln(err.Error())
   152  			return err
   153  		}
   154  	}
   155  	return vc.ExtendedReader.ReadBuffer(buffer)
   156  }
   157  
   158  func (vc *Conn) Write(p []byte) (int, error) {
   159  	if vc.writeFilterApplicationData {
   160  		return N.WriteBuffer(vc, buf.As(p))
   161  	}
   162  	return vc.ExtendedWriter.Write(p)
   163  }
   164  
   165  func (vc *Conn) WriteBuffer(buffer *buf.Buffer) (err error) {
   166  	if vc.needHandshake {
   167  		vc.needHandshake = false
   168  		if buffer.IsEmpty() {
   169  			ApplyPadding(buffer, commandPaddingContinue, vc.userUUID, false)
   170  		} else {
   171  			vc.FilterTLS(buffer.Bytes())
   172  			ApplyPadding(buffer, commandPaddingContinue, vc.userUUID, vc.isTLS)
   173  		}
   174  		err = vc.ExtendedWriter.WriteBuffer(buffer)
   175  		if err != nil {
   176  			buffer.Release()
   177  			return err
   178  		}
   179  		switch underlying := vc.tlsConn.(type) {
   180  		case *gotls.Conn:
   181  			if underlying.ConnectionState().Version != gotls.VersionTLS13 {
   182  				buffer.Release()
   183  				return ErrNotTLS13
   184  			}
   185  		case *utls.UConn:
   186  			if underlying.ConnectionState().Version != utls.VersionTLS13 {
   187  				buffer.Release()
   188  				return ErrNotTLS13
   189  			}
   190  		}
   191  		vc.tlsConn = nil
   192  		return nil
   193  	}
   194  
   195  	if vc.writeFilterApplicationData {
   196  		buffer2 := ReshapeBuffer(buffer)
   197  		defer buffer2.Release()
   198  		vc.FilterTLS(buffer.Bytes())
   199  		command := commandPaddingContinue
   200  		if !vc.isTLS {
   201  			command = commandPaddingEnd
   202  
   203  			// disable XTLS
   204  			//vc.readProcess = false
   205  			vc.writeFilterApplicationData = false
   206  			vc.packetsToFilter = 0
   207  		} else if buffer.Len() > 6 && bytes.Equal(buffer.To(3), tlsApplicationDataStart) || vc.packetsToFilter <= 0 {
   208  			command = commandPaddingEnd
   209  			if vc.enableXTLS {
   210  				command = commandPaddingDirect
   211  				vc.writeDirect = true
   212  			}
   213  			vc.writeFilterApplicationData = false
   214  		}
   215  		ApplyPadding(buffer, command, nil, vc.isTLS)
   216  		err = vc.ExtendedWriter.WriteBuffer(buffer)
   217  		if err != nil {
   218  			return err
   219  		}
   220  		if vc.writeDirect {
   221  			vc.ExtendedWriter = N.NewExtendedWriter(vc.Conn)
   222  			log.Debugln("XTLS Vision direct write start")
   223  			//time.Sleep(5 * time.Millisecond)
   224  		}
   225  		if buffer2 != nil {
   226  			if vc.writeDirect || !vc.isTLS {
   227  				return vc.ExtendedWriter.WriteBuffer(buffer2)
   228  			}
   229  			vc.FilterTLS(buffer2.Bytes())
   230  			command = commandPaddingContinue
   231  			if buffer2.Len() > 6 && bytes.Equal(buffer2.To(3), tlsApplicationDataStart) || vc.packetsToFilter <= 0 {
   232  				command = commandPaddingEnd
   233  				if vc.enableXTLS {
   234  					command = commandPaddingDirect
   235  					vc.writeDirect = true
   236  				}
   237  				vc.writeFilterApplicationData = false
   238  			}
   239  			ApplyPadding(buffer2, command, nil, vc.isTLS)
   240  			err = vc.ExtendedWriter.WriteBuffer(buffer2)
   241  			if vc.writeDirect {
   242  				vc.ExtendedWriter = N.NewExtendedWriter(vc.Conn)
   243  				log.Debugln("XTLS Vision direct write start")
   244  				//time.Sleep(10 * time.Millisecond)
   245  			}
   246  		}
   247  		return err
   248  	}
   249  	/*if vc.writeDirect {
   250  		log.Debugln("XTLS Vision Direct write, payloadLen=%d", buffer.Len())
   251  	}*/
   252  	return vc.ExtendedWriter.WriteBuffer(buffer)
   253  }
   254  
   255  func (vc *Conn) FrontHeadroom() int {
   256  	if vc.readFilterUUID {
   257  		return PaddingHeaderLen
   258  	}
   259  	return PaddingHeaderLen - uuid.Size
   260  }
   261  
   262  func (vc *Conn) RearHeadroom() int {
   263  	return 500 + 900
   264  }
   265  
   266  func (vc *Conn) NeedHandshake() bool {
   267  	return vc.needHandshake
   268  }
   269  
   270  func (vc *Conn) Upstream() any {
   271  	if vc.writeDirect ||
   272  		vc.readLastCommand == commandPaddingDirect {
   273  		return vc.Conn
   274  	}
   275  	return vc.upstream
   276  }
   277  
   278  func (vc *Conn) ReaderPossiblyReplaceable() bool {
   279  	return vc.readProcess
   280  }
   281  
   282  func (vc *Conn) ReaderReplaceable() bool {
   283  	if !vc.readProcess &&
   284  		vc.readLastCommand == commandPaddingDirect {
   285  		return true
   286  	}
   287  	return false
   288  }
   289  
   290  func (vc *Conn) WriterPossiblyReplaceable() bool {
   291  	return vc.writeFilterApplicationData
   292  }
   293  
   294  func (vc *Conn) WriterReplaceable() bool {
   295  	if vc.writeDirect {
   296  		return true
   297  	}
   298  	return false
   299  }