github.com/xmplusdev/xray-core@v1.8.10/proxy/freedom/freedom.go (about)

     1  package freedom
     2  
     3  //go:generate go run github.com/xmplusdev/xray-core/common/errors/errorgen
     4  
     5  import (
     6  	"context"
     7  	"crypto/rand"
     8  	"io"
     9  	"math/big"
    10  	"time"
    11  
    12  	"github.com/pires/go-proxyproto"
    13  	"github.com/xmplusdev/xray-core/common"
    14  	"github.com/xmplusdev/xray-core/common/buf"
    15  	"github.com/xmplusdev/xray-core/common/dice"
    16  	"github.com/xmplusdev/xray-core/common/net"
    17  	"github.com/xmplusdev/xray-core/common/platform"
    18  	"github.com/xmplusdev/xray-core/common/retry"
    19  	"github.com/xmplusdev/xray-core/common/session"
    20  	"github.com/xmplusdev/xray-core/common/signal"
    21  	"github.com/xmplusdev/xray-core/common/task"
    22  	"github.com/xmplusdev/xray-core/core"
    23  	"github.com/xmplusdev/xray-core/features/dns"
    24  	"github.com/xmplusdev/xray-core/features/policy"
    25  	"github.com/xmplusdev/xray-core/features/stats"
    26  	"github.com/xmplusdev/xray-core/proxy"
    27  	"github.com/xmplusdev/xray-core/transport"
    28  	"github.com/xmplusdev/xray-core/transport/internet"
    29  	"github.com/xmplusdev/xray-core/transport/internet/stat"
    30  )
    31  
    32  var useSplice bool
    33  
    34  func init() {
    35  	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
    36  		h := new(Handler)
    37  		if err := core.RequireFeatures(ctx, func(pm policy.Manager, d dns.Client) error {
    38  			return h.Init(config.(*Config), pm, d)
    39  		}); err != nil {
    40  			return nil, err
    41  		}
    42  		return h, nil
    43  	}))
    44  	const defaultFlagValue = "NOT_DEFINED_AT_ALL"
    45  	value := platform.NewEnvFlag(platform.UseFreedomSplice).GetValue(func() string { return defaultFlagValue })
    46  	switch value {
    47  	case defaultFlagValue, "auto", "enable":
    48  		useSplice = true
    49  	}
    50  }
    51  
    52  // Handler handles Freedom connections.
    53  type Handler struct {
    54  	policyManager policy.Manager
    55  	dns           dns.Client
    56  	config        *Config
    57  }
    58  
    59  // Init initializes the Handler with necessary parameters.
    60  func (h *Handler) Init(config *Config, pm policy.Manager, d dns.Client) error {
    61  	h.config = config
    62  	h.policyManager = pm
    63  	h.dns = d
    64  
    65  	return nil
    66  }
    67  
    68  func (h *Handler) policy() policy.Session {
    69  	p := h.policyManager.ForLevel(h.config.UserLevel)
    70  	if h.config.Timeout > 0 && h.config.UserLevel == 0 {
    71  		p.Timeouts.ConnectionIdle = time.Duration(h.config.Timeout) * time.Second
    72  	}
    73  	return p
    74  }
    75  
    76  func (h *Handler) resolveIP(ctx context.Context, domain string, localAddr net.Address) net.Address {
    77  	ips, err := h.dns.LookupIP(domain, dns.IPOption{
    78  		IPv4Enable: (localAddr == nil || localAddr.Family().IsIPv4()) && h.config.preferIP4(),
    79  		IPv6Enable: (localAddr == nil || localAddr.Family().IsIPv6()) && h.config.preferIP6(),
    80  	})
    81  	{ // Resolve fallback
    82  		if (len(ips) == 0 || err != nil) && h.config.hasFallback() && localAddr == nil {
    83  			ips, err = h.dns.LookupIP(domain, dns.IPOption{
    84  				IPv4Enable: h.config.fallbackIP4(),
    85  				IPv6Enable: h.config.fallbackIP6(),
    86  			})
    87  		}
    88  	}
    89  	if err != nil {
    90  		newError("failed to get IP address for domain ", domain).Base(err).WriteToLog(session.ExportIDToError(ctx))
    91  	}
    92  	if len(ips) == 0 {
    93  		return nil
    94  	}
    95  	return net.IPAddress(ips[dice.Roll(len(ips))])
    96  }
    97  
    98  func isValidAddress(addr *net.IPOrDomain) bool {
    99  	if addr == nil {
   100  		return false
   101  	}
   102  
   103  	a := addr.AsAddress()
   104  	return a != net.AnyIP
   105  }
   106  
   107  // Process implements proxy.Outbound.
   108  func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
   109  	outbound := session.OutboundFromContext(ctx)
   110  	if outbound == nil || !outbound.Target.IsValid() {
   111  		return newError("target not specified.")
   112  	}
   113  	outbound.Name = "freedom"
   114  	inbound := session.InboundFromContext(ctx)
   115  	if inbound != nil {
   116  		inbound.SetCanSpliceCopy(1)
   117  	}
   118  	destination := outbound.Target
   119  	UDPOverride := net.UDPDestination(nil, 0)
   120  	if h.config.DestinationOverride != nil {
   121  		server := h.config.DestinationOverride.Server
   122  		if isValidAddress(server.Address) {
   123  			destination.Address = server.Address.AsAddress()
   124  			UDPOverride.Address = destination.Address
   125  		}
   126  		if server.Port != 0 {
   127  			destination.Port = net.Port(server.Port)
   128  			UDPOverride.Port = destination.Port
   129  		}
   130  	}
   131  
   132  	input := link.Reader
   133  	output := link.Writer
   134  
   135  	var conn stat.Connection
   136  	err := retry.ExponentialBackoff(5, 100).On(func() error {
   137  		dialDest := destination
   138  		if h.config.hasStrategy() && dialDest.Address.Family().IsDomain() {
   139  			ip := h.resolveIP(ctx, dialDest.Address.Domain(), dialer.Address())
   140  			if ip != nil {
   141  				dialDest = net.Destination{
   142  					Network: dialDest.Network,
   143  					Address: ip,
   144  					Port:    dialDest.Port,
   145  				}
   146  				newError("dialing to ", dialDest).WriteToLog(session.ExportIDToError(ctx))
   147  			} else if h.config.forceIP() {
   148  				return dns.ErrEmptyResponse
   149  			}
   150  		}
   151  
   152  		rawConn, err := dialer.Dial(ctx, dialDest)
   153  		if err != nil {
   154  			return err
   155  		}
   156  
   157  		if h.config.ProxyProtocol > 0 && h.config.ProxyProtocol <= 2 {
   158  			version := byte(h.config.ProxyProtocol)
   159  			srcAddr := inbound.Source.RawNetAddr()
   160  			dstAddr := rawConn.RemoteAddr()
   161  			header := proxyproto.HeaderProxyFromAddrs(version, srcAddr, dstAddr)
   162  			if _, err = header.WriteTo(rawConn); err != nil {
   163  				rawConn.Close()
   164  				return err
   165  			}
   166  		}
   167  
   168  		conn = rawConn
   169  		return nil
   170  	})
   171  	if err != nil {
   172  		return newError("failed to open connection to ", destination).Base(err)
   173  	}
   174  	defer conn.Close()
   175  	newError("connection opened to ", destination, ", local endpoint ", conn.LocalAddr(), ", remote endpoint ", conn.RemoteAddr()).WriteToLog(session.ExportIDToError(ctx))
   176  
   177  	var newCtx context.Context
   178  	var newCancel context.CancelFunc
   179  	if session.TimeoutOnlyFromContext(ctx) {
   180  		newCtx, newCancel = context.WithCancel(context.Background())
   181  	}
   182  
   183  	plcy := h.policy()
   184  	ctx, cancel := context.WithCancel(ctx)
   185  	timer := signal.CancelAfterInactivity(ctx, func() {
   186  		cancel()
   187  		if newCancel != nil {
   188  			newCancel()
   189  		}
   190  	}, plcy.Timeouts.ConnectionIdle)
   191  
   192  	requestDone := func() error {
   193  		defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
   194  
   195  		var writer buf.Writer
   196  		if destination.Network == net.Network_TCP {
   197  			if h.config.Fragment != nil {
   198  				newError("FRAGMENT", h.config.Fragment.PacketsFrom, h.config.Fragment.PacketsTo, h.config.Fragment.LengthMin, h.config.Fragment.LengthMax,
   199  					h.config.Fragment.IntervalMin, h.config.Fragment.IntervalMax).AtDebug().WriteToLog(session.ExportIDToError(ctx))
   200  				writer = buf.NewWriter(&FragmentWriter{
   201  					fragment: h.config.Fragment,
   202  					writer:   conn,
   203  				})
   204  			} else {
   205  				writer = buf.NewWriter(conn)
   206  			}
   207  		} else {
   208  			writer = NewPacketWriter(conn, h, ctx, UDPOverride)
   209  		}
   210  
   211  		if err := buf.Copy(input, writer, buf.UpdateActivity(timer)); err != nil {
   212  			return newError("failed to process request").Base(err)
   213  		}
   214  
   215  		return nil
   216  	}
   217  
   218  	responseDone := func() error {
   219  		defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
   220  		if destination.Network == net.Network_TCP {
   221  			var writeConn net.Conn
   222  			if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Conn != nil && useSplice {
   223  				writeConn = inbound.Conn
   224  			}
   225  			return proxy.CopyRawConnIfExist(ctx, conn, writeConn, link.Writer, timer)
   226  		}
   227  		reader := NewPacketReader(conn, UDPOverride)
   228  		if err := buf.Copy(reader, output, buf.UpdateActivity(timer)); err != nil {
   229  			return newError("failed to process response").Base(err)
   230  		}
   231  		return nil
   232  	}
   233  
   234  	if newCtx != nil {
   235  		ctx = newCtx
   236  	}
   237  
   238  	if err := task.Run(ctx, requestDone, task.OnSuccess(responseDone, task.Close(output))); err != nil {
   239  		return newError("connection ends").Base(err)
   240  	}
   241  
   242  	return nil
   243  }
   244  
   245  func NewPacketReader(conn net.Conn, UDPOverride net.Destination) buf.Reader {
   246  	iConn := conn
   247  	statConn, ok := iConn.(*stat.CounterConnection)
   248  	if ok {
   249  		iConn = statConn.Connection
   250  	}
   251  	var counter stats.Counter
   252  	if statConn != nil {
   253  		counter = statConn.ReadCounter
   254  	}
   255  	if c, ok := iConn.(*internet.PacketConnWrapper); ok && UDPOverride.Address == nil && UDPOverride.Port == 0 {
   256  		return &PacketReader{
   257  			PacketConnWrapper: c,
   258  			Counter:           counter,
   259  		}
   260  	}
   261  	return &buf.PacketReader{Reader: conn}
   262  }
   263  
   264  type PacketReader struct {
   265  	*internet.PacketConnWrapper
   266  	stats.Counter
   267  }
   268  
   269  func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
   270  	b := buf.New()
   271  	b.Resize(0, buf.Size)
   272  	n, d, err := r.PacketConnWrapper.ReadFrom(b.Bytes())
   273  	if err != nil {
   274  		b.Release()
   275  		return nil, err
   276  	}
   277  	b.Resize(0, int32(n))
   278  	b.UDP = &net.Destination{
   279  		Address: net.IPAddress(d.(*net.UDPAddr).IP),
   280  		Port:    net.Port(d.(*net.UDPAddr).Port),
   281  		Network: net.Network_UDP,
   282  	}
   283  	if r.Counter != nil {
   284  		r.Counter.Add(int64(n))
   285  	}
   286  	return buf.MultiBuffer{b}, nil
   287  }
   288  
   289  func NewPacketWriter(conn net.Conn, h *Handler, ctx context.Context, UDPOverride net.Destination) buf.Writer {
   290  	iConn := conn
   291  	statConn, ok := iConn.(*stat.CounterConnection)
   292  	if ok {
   293  		iConn = statConn.Connection
   294  	}
   295  	var counter stats.Counter
   296  	if statConn != nil {
   297  		counter = statConn.WriteCounter
   298  	}
   299  	if c, ok := iConn.(*internet.PacketConnWrapper); ok {
   300  		return &PacketWriter{
   301  			PacketConnWrapper: c,
   302  			Counter:           counter,
   303  			Handler:           h,
   304  			Context:           ctx,
   305  			UDPOverride:       UDPOverride,
   306  		}
   307  	}
   308  	return &buf.SequentialWriter{Writer: conn}
   309  }
   310  
   311  type PacketWriter struct {
   312  	*internet.PacketConnWrapper
   313  	stats.Counter
   314  	*Handler
   315  	context.Context
   316  	UDPOverride net.Destination
   317  }
   318  
   319  func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
   320  	for {
   321  		mb2, b := buf.SplitFirst(mb)
   322  		mb = mb2
   323  		if b == nil {
   324  			break
   325  		}
   326  		var n int
   327  		var err error
   328  		if b.UDP != nil {
   329  			if w.UDPOverride.Address != nil {
   330  				b.UDP.Address = w.UDPOverride.Address
   331  			}
   332  			if w.UDPOverride.Port != 0 {
   333  				b.UDP.Port = w.UDPOverride.Port
   334  			}
   335  			if w.Handler.config.hasStrategy() && b.UDP.Address.Family().IsDomain() {
   336  				ip := w.Handler.resolveIP(w.Context, b.UDP.Address.Domain(), nil)
   337  				if ip != nil {
   338  					b.UDP.Address = ip
   339  				}
   340  			}
   341  			destAddr, _ := net.ResolveUDPAddr("udp", b.UDP.NetAddr())
   342  			if destAddr == nil {
   343  				b.Release()
   344  				continue
   345  			}
   346  			n, err = w.PacketConnWrapper.WriteTo(b.Bytes(), destAddr)
   347  		} else {
   348  			n, err = w.PacketConnWrapper.Write(b.Bytes())
   349  		}
   350  		b.Release()
   351  		if err != nil {
   352  			buf.ReleaseMulti(mb)
   353  			return err
   354  		}
   355  		if w.Counter != nil {
   356  			w.Counter.Add(int64(n))
   357  		}
   358  	}
   359  	return nil
   360  }
   361  
   362  type FragmentWriter struct {
   363  	fragment *Fragment
   364  	writer   io.Writer
   365  	count    uint64
   366  }
   367  
   368  func (f *FragmentWriter) Write(b []byte) (int, error) {
   369  	f.count++
   370  
   371  	if f.fragment.PacketsFrom == 0 && f.fragment.PacketsTo == 1 {
   372  		if f.count != 1 || len(b) <= 5 || b[0] != 22 {
   373  			return f.writer.Write(b)
   374  		}
   375  		recordLen := 5 + ((int(b[3]) << 8) | int(b[4]))
   376  		if len(b) < recordLen { // maybe already fragmented somehow
   377  			return f.writer.Write(b)
   378  		}
   379  		data := b[5:recordLen]
   380  		buf := make([]byte, 1024)
   381  		for from := 0; ; {
   382  			to := from + int(randBetween(int64(f.fragment.LengthMin), int64(f.fragment.LengthMax)))
   383  			if to > len(data) {
   384  				to = len(data)
   385  			}
   386  			copy(buf[:3], b)
   387  			copy(buf[5:], data[from:to])
   388  			l := to - from
   389  			from = to
   390  			buf[3] = byte(l >> 8)
   391  			buf[4] = byte(l)
   392  			_, err := f.writer.Write(buf[:5+l])
   393  			time.Sleep(time.Duration(randBetween(int64(f.fragment.IntervalMin), int64(f.fragment.IntervalMax))) * time.Millisecond)
   394  			if err != nil {
   395  				return 0, err
   396  			}
   397  			if from == len(data) {
   398  				if len(b) > recordLen {
   399  					n, err := f.writer.Write(b[recordLen:])
   400  					if err != nil {
   401  						return recordLen + n, err
   402  					}
   403  				}
   404  				return len(b), nil
   405  			}
   406  		}
   407  	}
   408  
   409  	if f.fragment.PacketsFrom != 0 && (f.count < f.fragment.PacketsFrom || f.count > f.fragment.PacketsTo) {
   410  		return f.writer.Write(b)
   411  	}
   412  	for from := 0; ; {
   413  		to := from + int(randBetween(int64(f.fragment.LengthMin), int64(f.fragment.LengthMax)))
   414  		if to > len(b) {
   415  			to = len(b)
   416  		}
   417  		n, err := f.writer.Write(b[from:to])
   418  		from += n
   419  		time.Sleep(time.Duration(randBetween(int64(f.fragment.IntervalMin), int64(f.fragment.IntervalMax))) * time.Millisecond)
   420  		if err != nil {
   421  			return from, err
   422  		}
   423  		if from >= len(b) {
   424  			return from, nil
   425  		}
   426  	}
   427  }
   428  
   429  // stolen from github.com/xmplusdev/xray-core/transport/internet/reality
   430  func randBetween(left int64, right int64) int64 {
   431  	if left == right {
   432  		return left
   433  	}
   434  	bigInt, _ := rand.Int(rand.Reader, big.NewInt(right-left))
   435  	return left + bigInt.Int64()
   436  }