github.com/xtls/xray-core@v1.8.3/proxy/freedom/freedom.go (about)

     1  package freedom
     2  
     3  //go:generate go run github.com/xtls/xray-core/common/errors/errorgen
     4  
     5  import (
     6  	"context"
     7  	"crypto/rand"
     8  	"encoding/binary"
     9  	"io"
    10  	"math/big"
    11  	"time"
    12  
    13  	"github.com/xtls/xray-core/common"
    14  	"github.com/xtls/xray-core/common/buf"
    15  	"github.com/xtls/xray-core/common/dice"
    16  	"github.com/xtls/xray-core/common/errors"
    17  	"github.com/xtls/xray-core/common/net"
    18  	"github.com/xtls/xray-core/common/retry"
    19  	"github.com/xtls/xray-core/common/session"
    20  	"github.com/xtls/xray-core/common/signal"
    21  	"github.com/xtls/xray-core/common/task"
    22  	"github.com/xtls/xray-core/core"
    23  	"github.com/xtls/xray-core/features/dns"
    24  	"github.com/xtls/xray-core/features/policy"
    25  	"github.com/xtls/xray-core/features/stats"
    26  	"github.com/xtls/xray-core/transport"
    27  	"github.com/xtls/xray-core/transport/internet"
    28  	"github.com/xtls/xray-core/transport/internet/stat"
    29  )
    30  
    31  func init() {
    32  	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
    33  		h := new(Handler)
    34  		if err := core.RequireFeatures(ctx, func(pm policy.Manager, d dns.Client) error {
    35  			return h.Init(config.(*Config), pm, d)
    36  		}); err != nil {
    37  			return nil, err
    38  		}
    39  		return h, nil
    40  	}))
    41  }
    42  
    43  // Handler handles Freedom connections.
    44  type Handler struct {
    45  	policyManager policy.Manager
    46  	dns           dns.Client
    47  	config        *Config
    48  }
    49  
    50  // Init initializes the Handler with necessary parameters.
    51  func (h *Handler) Init(config *Config, pm policy.Manager, d dns.Client) error {
    52  	h.config = config
    53  	h.policyManager = pm
    54  	h.dns = d
    55  
    56  	return nil
    57  }
    58  
    59  func (h *Handler) policy() policy.Session {
    60  	p := h.policyManager.ForLevel(h.config.UserLevel)
    61  	if h.config.Timeout > 0 && h.config.UserLevel == 0 {
    62  		p.Timeouts.ConnectionIdle = time.Duration(h.config.Timeout) * time.Second
    63  	}
    64  	return p
    65  }
    66  
    67  func (h *Handler) resolveIP(ctx context.Context, domain string, localAddr net.Address) net.Address {
    68  	var option dns.IPOption = dns.IPOption{
    69  		IPv4Enable: true,
    70  		IPv6Enable: true,
    71  		FakeEnable: false,
    72  	}
    73  	if h.config.DomainStrategy == Config_USE_IP4 || (localAddr != nil && localAddr.Family().IsIPv4()) {
    74  		option = dns.IPOption{
    75  			IPv4Enable: true,
    76  			IPv6Enable: false,
    77  			FakeEnable: false,
    78  		}
    79  	} else if h.config.DomainStrategy == Config_USE_IP6 || (localAddr != nil && localAddr.Family().IsIPv6()) {
    80  		option = dns.IPOption{
    81  			IPv4Enable: false,
    82  			IPv6Enable: true,
    83  			FakeEnable: false,
    84  		}
    85  	}
    86  
    87  	ips, err := h.dns.LookupIP(domain, option)
    88  	if err != nil {
    89  		newError("failed to get IP address for domain ", domain).Base(err).WriteToLog(session.ExportIDToError(ctx))
    90  	}
    91  	if len(ips) == 0 {
    92  		return nil
    93  	}
    94  	return net.IPAddress(ips[dice.Roll(len(ips))])
    95  }
    96  
    97  func isValidAddress(addr *net.IPOrDomain) bool {
    98  	if addr == nil {
    99  		return false
   100  	}
   101  
   102  	a := addr.AsAddress()
   103  	return a != net.AnyIP
   104  }
   105  
   106  // Process implements proxy.Outbound.
   107  func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
   108  	outbound := session.OutboundFromContext(ctx)
   109  	if outbound == nil || !outbound.Target.IsValid() {
   110  		return newError("target not specified.")
   111  	}
   112  	destination := outbound.Target
   113  	UDPOverride := net.UDPDestination(nil, 0)
   114  	if h.config.DestinationOverride != nil {
   115  		server := h.config.DestinationOverride.Server
   116  		if isValidAddress(server.Address) {
   117  			destination.Address = server.Address.AsAddress()
   118  			UDPOverride.Address = destination.Address
   119  		}
   120  		if server.Port != 0 {
   121  			destination.Port = net.Port(server.Port)
   122  			UDPOverride.Port = destination.Port
   123  		}
   124  	}
   125  
   126  	input := link.Reader
   127  	output := link.Writer
   128  
   129  	var conn stat.Connection
   130  	err := retry.ExponentialBackoff(5, 100).On(func() error {
   131  		dialDest := destination
   132  		if h.config.useIP() && dialDest.Address.Family().IsDomain() {
   133  			ip := h.resolveIP(ctx, dialDest.Address.Domain(), dialer.Address())
   134  			if ip != nil {
   135  				dialDest = net.Destination{
   136  					Network: dialDest.Network,
   137  					Address: ip,
   138  					Port:    dialDest.Port,
   139  				}
   140  				newError("dialing to ", dialDest).WriteToLog(session.ExportIDToError(ctx))
   141  			}
   142  		}
   143  
   144  		rawConn, err := dialer.Dial(ctx, dialDest)
   145  		if err != nil {
   146  			return err
   147  		}
   148  		conn = rawConn
   149  		return nil
   150  	})
   151  	if err != nil {
   152  		return newError("failed to open connection to ", destination).Base(err)
   153  	}
   154  	defer conn.Close()
   155  	newError("connection opened to ", destination, ", local endpoint ", conn.LocalAddr(), ", remote endpoint ", conn.RemoteAddr()).WriteToLog(session.ExportIDToError(ctx))
   156  
   157  	var newCtx context.Context
   158  	var newCancel context.CancelFunc
   159  	if session.TimeoutOnlyFromContext(ctx) {
   160  		newCtx, newCancel = context.WithCancel(context.Background())
   161  	}
   162  
   163  	plcy := h.policy()
   164  	ctx, cancel := context.WithCancel(ctx)
   165  	timer := signal.CancelAfterInactivity(ctx, func() {
   166  		cancel()
   167  		if newCancel != nil {
   168  			newCancel()
   169  		}
   170  	}, plcy.Timeouts.ConnectionIdle)
   171  
   172  	requestDone := func() error {
   173  		defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
   174  
   175  		var writer buf.Writer
   176  		if destination.Network == net.Network_TCP {
   177  			if h.config.Fragment != nil {
   178  				if h.config.Fragment.StartPacket == 0 && h.config.Fragment.EndPacket == 1 {
   179  					newError("FRAGMENT", int(h.config.Fragment.MaxLength)).WriteToLog(session.ExportIDToError(ctx))
   180  					writer = buf.NewWriter(
   181  						&FragmentedClientHelloConn{
   182  							Conn:        conn,
   183  							maxLength:   int(h.config.Fragment.MaxLength),
   184  							minInterval: time.Duration(h.config.Fragment.MinInterval) * time.Millisecond,
   185  							maxInterval: time.Duration(h.config.Fragment.MaxInterval) * time.Millisecond,
   186  						})
   187  				} else {
   188  					writer = buf.NewWriter(
   189  						&FragmentWriter{
   190  							Writer:      conn,
   191  							minLength:   int(h.config.Fragment.MinLength),
   192  							maxLength:   int(h.config.Fragment.MaxLength),
   193  							minInterval: time.Duration(h.config.Fragment.MinInterval) * time.Millisecond,
   194  							maxInterval: time.Duration(h.config.Fragment.MaxInterval) * time.Millisecond,
   195  							startPacket: int(h.config.Fragment.StartPacket),
   196  							endPacket:   int(h.config.Fragment.EndPacket),
   197  							PacketCount: 0,
   198  						})
   199  				}
   200  			} else {
   201  				writer = buf.NewWriter(conn)
   202  			}
   203  		} else {
   204  			writer = NewPacketWriter(conn, h, ctx, UDPOverride)
   205  		}
   206  
   207  		if err := buf.Copy(input, writer, buf.UpdateActivity(timer)); err != nil {
   208  			return newError("failed to process request").Base(err)
   209  		}
   210  
   211  		return nil
   212  	}
   213  
   214  	responseDone := func() error {
   215  		defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
   216  
   217  		var reader buf.Reader
   218  		if destination.Network == net.Network_TCP {
   219  			reader = buf.NewReader(conn)
   220  		} else {
   221  			reader = NewPacketReader(conn, UDPOverride)
   222  		}
   223  		if err := buf.Copy(reader, output, buf.UpdateActivity(timer)); err != nil {
   224  			return newError("failed to process response").Base(err)
   225  		}
   226  
   227  		return nil
   228  	}
   229  
   230  	if newCtx != nil {
   231  		ctx = newCtx
   232  	}
   233  
   234  	if err := task.Run(ctx, requestDone, task.OnSuccess(responseDone, task.Close(output))); err != nil {
   235  		return newError("connection ends").Base(err)
   236  	}
   237  
   238  	return nil
   239  }
   240  
   241  func NewPacketReader(conn net.Conn, UDPOverride net.Destination) buf.Reader {
   242  	iConn := conn
   243  	statConn, ok := iConn.(*stat.CounterConnection)
   244  	if ok {
   245  		iConn = statConn.Connection
   246  	}
   247  	var counter stats.Counter
   248  	if statConn != nil {
   249  		counter = statConn.ReadCounter
   250  	}
   251  	if c, ok := iConn.(*internet.PacketConnWrapper); ok && UDPOverride.Address == nil && UDPOverride.Port == 0 {
   252  		return &PacketReader{
   253  			PacketConnWrapper: c,
   254  			Counter:           counter,
   255  		}
   256  	}
   257  	return &buf.PacketReader{Reader: conn}
   258  }
   259  
   260  type PacketReader struct {
   261  	*internet.PacketConnWrapper
   262  	stats.Counter
   263  }
   264  
   265  func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
   266  	b := buf.New()
   267  	b.Resize(0, buf.Size)
   268  	n, d, err := r.PacketConnWrapper.ReadFrom(b.Bytes())
   269  	if err != nil {
   270  		b.Release()
   271  		return nil, err
   272  	}
   273  	b.Resize(0, int32(n))
   274  	b.UDP = &net.Destination{
   275  		Address: net.IPAddress(d.(*net.UDPAddr).IP),
   276  		Port:    net.Port(d.(*net.UDPAddr).Port),
   277  		Network: net.Network_UDP,
   278  	}
   279  	if r.Counter != nil {
   280  		r.Counter.Add(int64(n))
   281  	}
   282  	return buf.MultiBuffer{b}, nil
   283  }
   284  
   285  func NewPacketWriter(conn net.Conn, h *Handler, ctx context.Context, UDPOverride net.Destination) buf.Writer {
   286  	iConn := conn
   287  	statConn, ok := iConn.(*stat.CounterConnection)
   288  	if ok {
   289  		iConn = statConn.Connection
   290  	}
   291  	var counter stats.Counter
   292  	if statConn != nil {
   293  		counter = statConn.WriteCounter
   294  	}
   295  	if c, ok := iConn.(*internet.PacketConnWrapper); ok {
   296  		return &PacketWriter{
   297  			PacketConnWrapper: c,
   298  			Counter:           counter,
   299  			Handler:           h,
   300  			Context:           ctx,
   301  			UDPOverride:       UDPOverride,
   302  		}
   303  	}
   304  	return &buf.SequentialWriter{Writer: conn}
   305  }
   306  
   307  type PacketWriter struct {
   308  	*internet.PacketConnWrapper
   309  	stats.Counter
   310  	*Handler
   311  	context.Context
   312  	UDPOverride net.Destination
   313  }
   314  
   315  func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
   316  	for {
   317  		mb2, b := buf.SplitFirst(mb)
   318  		mb = mb2
   319  		if b == nil {
   320  			break
   321  		}
   322  		var n int
   323  		var err error
   324  		if b.UDP != nil {
   325  			if w.UDPOverride.Address != nil {
   326  				b.UDP.Address = w.UDPOverride.Address
   327  			}
   328  			if w.UDPOverride.Port != 0 {
   329  				b.UDP.Port = w.UDPOverride.Port
   330  			}
   331  			if w.Handler.config.useIP() && b.UDP.Address.Family().IsDomain() {
   332  				ip := w.Handler.resolveIP(w.Context, b.UDP.Address.Domain(), nil)
   333  				if ip != nil {
   334  					b.UDP.Address = ip
   335  				}
   336  			}
   337  			destAddr, _ := net.ResolveUDPAddr("udp", b.UDP.NetAddr())
   338  			if destAddr == nil {
   339  				b.Release()
   340  				continue
   341  			}
   342  			n, err = w.PacketConnWrapper.WriteTo(b.Bytes(), destAddr)
   343  		} else {
   344  			n, err = w.PacketConnWrapper.Write(b.Bytes())
   345  		}
   346  		b.Release()
   347  		if err != nil {
   348  			buf.ReleaseMulti(mb)
   349  			return err
   350  		}
   351  		if w.Counter != nil {
   352  			w.Counter.Add(int64(n))
   353  		}
   354  	}
   355  	return nil
   356  }
   357  
   358  type FragmentWriter struct {
   359  	io.Writer
   360  	minLength   int
   361  	maxLength   int
   362  	minInterval time.Duration
   363  	maxInterval time.Duration
   364  	startPacket int
   365  	endPacket   int
   366  	PacketCount int
   367  }
   368  
   369  func (w *FragmentWriter) Write(buf []byte) (int, error) {
   370  	w.PacketCount += 1
   371  	if (w.startPacket != 0 && (w.PacketCount < w.startPacket || w.PacketCount > w.endPacket)) || len(buf) <= w.minLength {
   372  		return w.Writer.Write(buf)
   373  	}
   374  
   375  	nTotal := 0
   376  	for {
   377  		randomBytesTo := int(randBetween(int64(w.minLength), int64(w.maxLength))) + nTotal
   378  		if randomBytesTo > len(buf) {
   379  			randomBytesTo = len(buf)
   380  		}
   381  		n, err := w.Writer.Write(buf[nTotal:randomBytesTo])
   382  		if err != nil {
   383  			return nTotal + n, err
   384  		}
   385  		nTotal += n
   386  
   387  		if nTotal >= len(buf) {
   388  			return nTotal, nil
   389  		}
   390  
   391  		randomInterval := randBetween(int64(w.minInterval), int64(w.maxInterval))
   392  		time.Sleep(time.Duration(randomInterval))
   393  	}
   394  }
   395  
   396  // stolen from github.com/xtls/xray-core/transport/internet/reality
   397  func randBetween(left int64, right int64) int64 {
   398  	if left == right {
   399  		return left
   400  	}
   401  	bigInt, _ := rand.Int(rand.Reader, big.NewInt(right-left))
   402  	return left + bigInt.Int64()
   403  }
   404  
   405  type FragmentedClientHelloConn struct {
   406  	net.Conn
   407  	PacketCount int
   408  	minLength   int
   409  	maxLength   int
   410  	minInterval time.Duration
   411  	maxInterval time.Duration
   412  }
   413  
   414  func (c *FragmentedClientHelloConn) Write(b []byte) (n int, err error) {
   415  	if len(b) >= 5 && b[0] == 22 && c.PacketCount == 0 {
   416  		n, err = sendFragmentedClientHello(c, b, c.minLength, c.maxLength)
   417  
   418  		if err == nil {
   419  			c.PacketCount++
   420  			return n, err
   421  		}
   422  	}
   423  
   424  	return c.Conn.Write(b)
   425  }
   426  
   427  func sendFragmentedClientHello(conn *FragmentedClientHelloConn, clientHello []byte, minFragmentSize, maxFragmentSize int) (n int, err error) {
   428  	if len(clientHello) < 5 || clientHello[0] != 22 {
   429  		return 0, errors.New("not a valid TLS ClientHello message")
   430  	}
   431  
   432  	clientHelloLen := (int(clientHello[3]) << 8) | int(clientHello[4])
   433  
   434  	clientHelloData := clientHello[5:]
   435  	for i := 0; i < clientHelloLen; {
   436  		fragmentEnd := i + int(randBetween(int64(minFragmentSize), int64(maxFragmentSize)))
   437  		if fragmentEnd > clientHelloLen {
   438  			fragmentEnd = clientHelloLen
   439  		}
   440  
   441  		fragment := clientHelloData[i:fragmentEnd]
   442  		i = fragmentEnd
   443  
   444  		err = writeFragmentedRecord(conn, 22, fragment, clientHello)
   445  		if err != nil {
   446  			return 0, err
   447  		}
   448  	}
   449  
   450  	return len(clientHello), nil
   451  }
   452  
   453  func writeFragmentedRecord(c *FragmentedClientHelloConn, contentType uint8, data []byte, clientHello []byte) error {
   454  	header := make([]byte, 5)
   455  	header[0] = byte(clientHello[0])
   456  
   457  	tlsVersion := (int(clientHello[1]) << 8) | int(clientHello[2])
   458  	binary.BigEndian.PutUint16(header[1:], uint16(tlsVersion))
   459  
   460  	binary.BigEndian.PutUint16(header[3:], uint16(len(data)))
   461  	_, err := c.Conn.Write(append(header, data...))
   462  	randomInterval := randBetween(int64(c.minInterval), int64(c.maxInterval))
   463  	time.Sleep(time.Duration(randomInterval))
   464  
   465  	return err
   466  }