github.com/moqsien/xraycore@v1.8.5/proxy/vless/outbound/outbound.go (about)

     1  package outbound
     2  
     3  //go:generate go run github.com/moqsien/xraycore/common/errors/errorgen
     4  
     5  import (
     6  	"bytes"
     7  	"context"
     8  	gotls "crypto/tls"
     9  	"reflect"
    10  	"syscall"
    11  	"time"
    12  	"unsafe"
    13  
    14  	utls "github.com/refraction-networking/utls"
    15  	"github.com/moqsien/xraycore/common"
    16  	"github.com/moqsien/xraycore/common/buf"
    17  	"github.com/moqsien/xraycore/common/net"
    18  	"github.com/moqsien/xraycore/common/protocol"
    19  	"github.com/moqsien/xraycore/common/retry"
    20  	"github.com/moqsien/xraycore/common/session"
    21  	"github.com/moqsien/xraycore/common/signal"
    22  	"github.com/moqsien/xraycore/common/task"
    23  	"github.com/moqsien/xraycore/common/xudp"
    24  	"github.com/moqsien/xraycore/core"
    25  	"github.com/moqsien/xraycore/features/policy"
    26  	"github.com/moqsien/xraycore/features/stats"
    27  	"github.com/moqsien/xraycore/proxy/vless"
    28  	"github.com/moqsien/xraycore/proxy/vless/encoding"
    29  	"github.com/moqsien/xraycore/transport"
    30  	"github.com/moqsien/xraycore/transport/internet"
    31  	"github.com/moqsien/xraycore/transport/internet/reality"
    32  	"github.com/moqsien/xraycore/transport/internet/stat"
    33  	"github.com/moqsien/xraycore/transport/internet/tls"
    34  )
    35  
    36  func init() {
    37  	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
    38  		return New(ctx, config.(*Config))
    39  	}))
    40  }
    41  
    42  // Handler is an outbound connection handler for VLess protocol.
    43  type Handler struct {
    44  	serverList    *protocol.ServerList
    45  	serverPicker  protocol.ServerPicker
    46  	policyManager policy.Manager
    47  	cone          bool
    48  }
    49  
    50  // New creates a new VLess outbound handler.
    51  func New(ctx context.Context, config *Config) (*Handler, error) {
    52  	serverList := protocol.NewServerList()
    53  	for _, rec := range config.Vnext {
    54  		s, err := protocol.NewServerSpecFromPB(rec)
    55  		if err != nil {
    56  			return nil, newError("failed to parse server spec").Base(err).AtError()
    57  		}
    58  		serverList.AddServer(s)
    59  	}
    60  
    61  	v := core.MustFromContext(ctx)
    62  	handler := &Handler{
    63  		serverList:    serverList,
    64  		serverPicker:  protocol.NewRoundRobinServerPicker(serverList),
    65  		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
    66  		cone:          ctx.Value("cone").(bool),
    67  	}
    68  
    69  	return handler, nil
    70  }
    71  
    72  // Process implements proxy.Outbound.Process().
    73  func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
    74  	var rec *protocol.ServerSpec
    75  	var conn stat.Connection
    76  
    77  	if err := retry.ExponentialBackoff(5, 200).On(func() error {
    78  		rec = h.serverPicker.PickServer()
    79  		var err error
    80  		conn, err = dialer.Dial(ctx, rec.Destination())
    81  		if err != nil {
    82  			return err
    83  		}
    84  		return nil
    85  	}); err != nil {
    86  		return newError("failed to find an available destination").Base(err).AtWarning()
    87  	}
    88  	defer conn.Close()
    89  
    90  	iConn := conn
    91  	statConn, ok := iConn.(*stat.CounterConnection)
    92  	if ok {
    93  		iConn = statConn.Connection
    94  	}
    95  
    96  	outbound := session.OutboundFromContext(ctx)
    97  	if outbound == nil || !outbound.Target.IsValid() {
    98  		return newError("target not specified").AtError()
    99  	}
   100  
   101  	target := outbound.Target
   102  	newError("tunneling request to ", target, " via ", rec.Destination().NetAddr()).AtInfo().WriteToLog(session.ExportIDToError(ctx))
   103  
   104  	command := protocol.RequestCommandTCP
   105  	if target.Network == net.Network_UDP {
   106  		command = protocol.RequestCommandUDP
   107  	}
   108  	if target.Address.Family().IsDomain() && target.Address.Domain() == "v1.mux.cool" {
   109  		command = protocol.RequestCommandMux
   110  	}
   111  
   112  	request := &protocol.RequestHeader{
   113  		Version: encoding.Version,
   114  		User:    rec.PickUser(),
   115  		Command: command,
   116  		Address: target.Address,
   117  		Port:    target.Port,
   118  	}
   119  
   120  	account := request.User.Account.(*vless.MemoryAccount)
   121  
   122  	requestAddons := &encoding.Addons{
   123  		Flow: account.Flow,
   124  	}
   125  
   126  	var netConn net.Conn
   127  	var rawConn syscall.RawConn
   128  	var input *bytes.Reader
   129  	var rawInput *bytes.Buffer
   130  	allowUDP443 := false
   131  	switch requestAddons.Flow {
   132  	case vless.XRV + "-udp443":
   133  		allowUDP443 = true
   134  		requestAddons.Flow = requestAddons.Flow[:16]
   135  		fallthrough
   136  	case vless.XRV:
   137  		switch request.Command {
   138  		case protocol.RequestCommandUDP:
   139  			if !allowUDP443 && request.Port == 443 {
   140  				return newError("XTLS rejected UDP/443 traffic").AtInfo()
   141  			}
   142  			requestAddons.Flow = ""
   143  		case protocol.RequestCommandMux:
   144  			fallthrough // let server break Mux connections that contain TCP requests
   145  		case protocol.RequestCommandTCP:
   146  			var t reflect.Type
   147  			var p uintptr
   148  			if tlsConn, ok := iConn.(*tls.Conn); ok {
   149  				netConn = tlsConn.NetConn()
   150  				t = reflect.TypeOf(tlsConn.Conn).Elem()
   151  				p = uintptr(unsafe.Pointer(tlsConn.Conn))
   152  			} else if utlsConn, ok := iConn.(*tls.UConn); ok {
   153  				netConn = utlsConn.NetConn()
   154  				t = reflect.TypeOf(utlsConn.Conn).Elem()
   155  				p = uintptr(unsafe.Pointer(utlsConn.Conn))
   156  			} else if realityConn, ok := iConn.(*reality.UConn); ok {
   157  				netConn = realityConn.NetConn()
   158  				t = reflect.TypeOf(realityConn.Conn).Elem()
   159  				p = uintptr(unsafe.Pointer(realityConn.Conn))
   160  			} else {
   161  				return newError("XTLS only supports TLS and REALITY directly for now.").AtWarning()
   162  			}
   163  			if sc, ok := netConn.(syscall.Conn); ok {
   164  				rawConn, _ = sc.SyscallConn()
   165  			}
   166  			i, _ := t.FieldByName("input")
   167  			r, _ := t.FieldByName("rawInput")
   168  			input = (*bytes.Reader)(unsafe.Pointer(p + i.Offset))
   169  			rawInput = (*bytes.Buffer)(unsafe.Pointer(p + r.Offset))
   170  		}
   171  	}
   172  
   173  	var newCtx context.Context
   174  	var newCancel context.CancelFunc
   175  	if session.TimeoutOnlyFromContext(ctx) {
   176  		newCtx, newCancel = context.WithCancel(context.Background())
   177  	}
   178  
   179  	sessionPolicy := h.policyManager.ForLevel(request.User.Level)
   180  	ctx, cancel := context.WithCancel(ctx)
   181  	timer := signal.CancelAfterInactivity(ctx, func() {
   182  		cancel()
   183  		if newCancel != nil {
   184  			newCancel()
   185  		}
   186  	}, sessionPolicy.Timeouts.ConnectionIdle)
   187  
   188  	clientReader := link.Reader // .(*pipe.Reader)
   189  	clientWriter := link.Writer // .(*pipe.Writer)
   190  	enableXtls := false
   191  	isTLS12orAbove := false
   192  	isTLS := false
   193  	var cipher uint16 = 0
   194  	var remainingServerHello int32 = -1
   195  	numberOfPacketToFilter := 8
   196  
   197  	if request.Command == protocol.RequestCommandUDP && h.cone && request.Port != 53 && request.Port != 443 {
   198  		request.Command = protocol.RequestCommandMux
   199  		request.Address = net.DomainAddress("v1.mux.cool")
   200  		request.Port = net.Port(666)
   201  	}
   202  
   203  	postRequest := func() error {
   204  		defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
   205  
   206  		bufferWriter := buf.NewBufferedWriter(buf.NewWriter(conn))
   207  		if err := encoding.EncodeRequestHeader(bufferWriter, request, requestAddons); err != nil {
   208  			return newError("failed to encode request header").Base(err).AtWarning()
   209  		}
   210  
   211  		// default: serverWriter := bufferWriter
   212  		serverWriter := encoding.EncodeBodyAddons(bufferWriter, request, requestAddons)
   213  		if request.Command == protocol.RequestCommandMux && request.Port == 666 {
   214  			serverWriter = xudp.NewPacketWriter(serverWriter, target, xudp.GetGlobalID(ctx))
   215  		}
   216  		userUUID := account.ID.Bytes()
   217  		timeoutReader, ok := clientReader.(buf.TimeoutReader)
   218  		if ok {
   219  			multiBuffer, err1 := timeoutReader.ReadMultiBufferTimeout(time.Millisecond * 500)
   220  			if err1 == nil {
   221  				if requestAddons.Flow == vless.XRV {
   222  					encoding.XtlsFilterTls(multiBuffer, &numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello, ctx)
   223  					multiBuffer = encoding.ReshapeMultiBuffer(ctx, multiBuffer)
   224  					for i, b := range multiBuffer {
   225  						multiBuffer[i] = encoding.XtlsPadding(b, encoding.CommandPaddingContinue, &userUUID, isTLS, ctx)
   226  					}
   227  				}
   228  				if err := serverWriter.WriteMultiBuffer(multiBuffer); err != nil {
   229  					return err // ...
   230  				}
   231  			} else if err1 != buf.ErrReadTimeout {
   232  				return err1
   233  			} else if requestAddons.Flow == vless.XRV {
   234  				mb := make(buf.MultiBuffer, 1)
   235  				mb[0] = encoding.XtlsPadding(nil, encoding.CommandPaddingContinue, &userUUID, true, ctx) // we do a long padding to hide vless header
   236  				newError("Insert padding with empty content to camouflage VLESS header ", mb.Len()).WriteToLog(session.ExportIDToError(ctx))
   237  				if err := serverWriter.WriteMultiBuffer(mb); err != nil {
   238  					return err
   239  				}
   240  			}
   241  		} else {
   242  			newError("Reader is not timeout reader, will send out vless header separately from first payload").AtDebug().WriteToLog(session.ExportIDToError(ctx))
   243  		}
   244  		// Flush; bufferWriter.WriteMultiBufer now is bufferWriter.writer.WriteMultiBuffer
   245  		if err := bufferWriter.SetBuffered(false); err != nil {
   246  			return newError("failed to write A request payload").Base(err).AtWarning()
   247  		}
   248  
   249  		var err error
   250  		if requestAddons.Flow == vless.XRV {
   251  			if tlsConn, ok := iConn.(*tls.Conn); ok {
   252  				if tlsConn.ConnectionState().Version != gotls.VersionTLS13 {
   253  					return newError(`failed to use `+requestAddons.Flow+`, found outer tls version `, tlsConn.ConnectionState().Version).AtWarning()
   254  				}
   255  			} else if utlsConn, ok := iConn.(*tls.UConn); ok {
   256  				if utlsConn.ConnectionState().Version != utls.VersionTLS13 {
   257  					return newError(`failed to use `+requestAddons.Flow+`, found outer tls version `, utlsConn.ConnectionState().Version).AtWarning()
   258  				}
   259  			}
   260  			var counter stats.Counter
   261  			if statConn != nil {
   262  				counter = statConn.WriteCounter
   263  			}
   264  			err = encoding.XtlsWrite(clientReader, serverWriter, timer, netConn, counter, ctx, &numberOfPacketToFilter,
   265  				&enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello)
   266  		} else {
   267  			// from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBufer
   268  			err = buf.Copy(clientReader, serverWriter, buf.UpdateActivity(timer))
   269  		}
   270  		if err != nil {
   271  			return newError("failed to transfer request payload").Base(err).AtInfo()
   272  		}
   273  
   274  		// Indicates the end of request payload.
   275  		switch requestAddons.Flow {
   276  		default:
   277  		}
   278  		return nil
   279  	}
   280  
   281  	getResponse := func() error {
   282  		defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
   283  
   284  		responseAddons, err := encoding.DecodeResponseHeader(conn, request)
   285  		if err != nil {
   286  			return newError("failed to decode response header").Base(err).AtInfo()
   287  		}
   288  
   289  		// default: serverReader := buf.NewReader(conn)
   290  		serverReader := encoding.DecodeBodyAddons(conn, request, responseAddons)
   291  		if request.Command == protocol.RequestCommandMux && request.Port == 666 {
   292  			serverReader = xudp.NewPacketReader(conn)
   293  		}
   294  
   295  		if requestAddons.Flow == vless.XRV {
   296  			var counter stats.Counter
   297  			if statConn != nil {
   298  				counter = statConn.ReadCounter
   299  			}
   300  			err = encoding.XtlsRead(serverReader, clientWriter, timer, netConn, rawConn, input, rawInput, counter, ctx, account.ID.Bytes(),
   301  				&numberOfPacketToFilter, &enableXtls, &isTLS12orAbove, &isTLS, &cipher, &remainingServerHello)
   302  		} else {
   303  			// from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBufer
   304  			err = buf.Copy(serverReader, clientWriter, buf.UpdateActivity(timer))
   305  		}
   306  
   307  		if err != nil {
   308  			return newError("failed to transfer response payload").Base(err).AtInfo()
   309  		}
   310  
   311  		return nil
   312  	}
   313  
   314  	if newCtx != nil {
   315  		ctx = newCtx
   316  	}
   317  
   318  	if err := task.Run(ctx, postRequest, task.OnSuccess(getResponse, task.Close(clientWriter))); err != nil {
   319  		return newError("connection ends").Base(err).AtInfo()
   320  	}
   321  
   322  	return nil
   323  }