github.com/moqsien/xraycore@v1.8.5/proxy/freedom/freedom.go (about)

     1  package freedom
     2  
     3  //go:generate go run github.com/moqsien/xraycore/common/errors/errorgen
     4  
     5  import (
     6  	"context"
     7  	"crypto/rand"
     8  	"io"
     9  	"math/big"
    10  	"time"
    11  
    12  	"github.com/moqsien/xraycore/common"
    13  	"github.com/moqsien/xraycore/common/buf"
    14  	"github.com/moqsien/xraycore/common/dice"
    15  	"github.com/moqsien/xraycore/common/net"
    16  	"github.com/moqsien/xraycore/common/retry"
    17  	"github.com/moqsien/xraycore/common/session"
    18  	"github.com/moqsien/xraycore/common/signal"
    19  	"github.com/moqsien/xraycore/common/task"
    20  	"github.com/moqsien/xraycore/core"
    21  	"github.com/moqsien/xraycore/features/dns"
    22  	"github.com/moqsien/xraycore/features/policy"
    23  	"github.com/moqsien/xraycore/features/stats"
    24  	"github.com/moqsien/xraycore/transport"
    25  	"github.com/moqsien/xraycore/transport/internet"
    26  	"github.com/moqsien/xraycore/transport/internet/stat"
    27  )
    28  
    29  func init() {
    30  	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
    31  		h := new(Handler)
    32  		if err := core.RequireFeatures(ctx, func(pm policy.Manager, d dns.Client) error {
    33  			return h.Init(config.(*Config), pm, d)
    34  		}); err != nil {
    35  			return nil, err
    36  		}
    37  		return h, nil
    38  	}))
    39  }
    40  
    41  // Handler handles Freedom connections.
    42  type Handler struct {
    43  	policyManager policy.Manager
    44  	dns           dns.Client
    45  	config        *Config
    46  }
    47  
    48  // Init initializes the Handler with necessary parameters.
    49  func (h *Handler) Init(config *Config, pm policy.Manager, d dns.Client) error {
    50  	h.config = config
    51  	h.policyManager = pm
    52  	h.dns = d
    53  
    54  	return nil
    55  }
    56  
    57  func (h *Handler) policy() policy.Session {
    58  	p := h.policyManager.ForLevel(h.config.UserLevel)
    59  	if h.config.Timeout > 0 && h.config.UserLevel == 0 {
    60  		p.Timeouts.ConnectionIdle = time.Duration(h.config.Timeout) * time.Second
    61  	}
    62  	return p
    63  }
    64  
    65  func (h *Handler) resolveIP(ctx context.Context, domain string, localAddr net.Address) net.Address {
    66  	var option dns.IPOption = dns.IPOption{
    67  		IPv4Enable: true,
    68  		IPv6Enable: true,
    69  		FakeEnable: false,
    70  	}
    71  	if h.config.DomainStrategy == Config_USE_IP4 || (localAddr != nil && localAddr.Family().IsIPv4()) {
    72  		option = dns.IPOption{
    73  			IPv4Enable: true,
    74  			IPv6Enable: false,
    75  			FakeEnable: false,
    76  		}
    77  	} else if h.config.DomainStrategy == Config_USE_IP6 || (localAddr != nil && localAddr.Family().IsIPv6()) {
    78  		option = dns.IPOption{
    79  			IPv4Enable: false,
    80  			IPv6Enable: true,
    81  			FakeEnable: false,
    82  		}
    83  	}
    84  
    85  	ips, err := h.dns.LookupIP(domain, option)
    86  	if err != nil {
    87  		newError("failed to get IP address for domain ", domain).Base(err).WriteToLog(session.ExportIDToError(ctx))
    88  	}
    89  	if len(ips) == 0 {
    90  		return nil
    91  	}
    92  	return net.IPAddress(ips[dice.Roll(len(ips))])
    93  }
    94  
    95  func isValidAddress(addr *net.IPOrDomain) bool {
    96  	if addr == nil {
    97  		return false
    98  	}
    99  
   100  	a := addr.AsAddress()
   101  	return a != net.AnyIP
   102  }
   103  
   104  // Process implements proxy.Outbound.
   105  func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
   106  	outbound := session.OutboundFromContext(ctx)
   107  	if outbound == nil || !outbound.Target.IsValid() {
   108  		return newError("target not specified.")
   109  	}
   110  	destination := outbound.Target
   111  	UDPOverride := net.UDPDestination(nil, 0)
   112  	if h.config.DestinationOverride != nil {
   113  		server := h.config.DestinationOverride.Server
   114  		if isValidAddress(server.Address) {
   115  			destination.Address = server.Address.AsAddress()
   116  			UDPOverride.Address = destination.Address
   117  		}
   118  		if server.Port != 0 {
   119  			destination.Port = net.Port(server.Port)
   120  			UDPOverride.Port = destination.Port
   121  		}
   122  	}
   123  
   124  	input := link.Reader
   125  	output := link.Writer
   126  
   127  	var conn stat.Connection
   128  	err := retry.ExponentialBackoff(5, 100).On(func() error {
   129  		dialDest := destination
   130  		if h.config.useIP() && dialDest.Address.Family().IsDomain() {
   131  			ip := h.resolveIP(ctx, dialDest.Address.Domain(), dialer.Address())
   132  			if ip != nil {
   133  				dialDest = net.Destination{
   134  					Network: dialDest.Network,
   135  					Address: ip,
   136  					Port:    dialDest.Port,
   137  				}
   138  				newError("dialing to ", dialDest).WriteToLog(session.ExportIDToError(ctx))
   139  			}
   140  		}
   141  
   142  		rawConn, err := dialer.Dial(ctx, dialDest)
   143  		if err != nil {
   144  			return err
   145  		}
   146  		conn = rawConn
   147  		return nil
   148  	})
   149  	if err != nil {
   150  		return newError("failed to open connection to ", destination).Base(err)
   151  	}
   152  	defer conn.Close()
   153  	newError("connection opened to ", destination, ", local endpoint ", conn.LocalAddr(), ", remote endpoint ", conn.RemoteAddr()).WriteToLog(session.ExportIDToError(ctx))
   154  
   155  	var newCtx context.Context
   156  	var newCancel context.CancelFunc
   157  	if session.TimeoutOnlyFromContext(ctx) {
   158  		newCtx, newCancel = context.WithCancel(context.Background())
   159  	}
   160  
   161  	plcy := h.policy()
   162  	ctx, cancel := context.WithCancel(ctx)
   163  	timer := signal.CancelAfterInactivity(ctx, func() {
   164  		cancel()
   165  		if newCancel != nil {
   166  			newCancel()
   167  		}
   168  	}, plcy.Timeouts.ConnectionIdle)
   169  
   170  	requestDone := func() error {
   171  		defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
   172  
   173  		var writer buf.Writer
   174  		if destination.Network == net.Network_TCP {
   175  			if h.config.Fragment != nil {
   176  				newError("FRAGMENT", h.config.Fragment.PacketsFrom, h.config.Fragment.PacketsTo, h.config.Fragment.LengthMin, h.config.Fragment.LengthMax,
   177  					h.config.Fragment.IntervalMin, h.config.Fragment.IntervalMax).AtDebug().WriteToLog(session.ExportIDToError(ctx))
   178  				writer = buf.NewWriter(&FragmentWriter{
   179  					fragment: h.config.Fragment,
   180  					writer:   conn,
   181  				})
   182  			} else {
   183  				writer = buf.NewWriter(conn)
   184  			}
   185  		} else {
   186  			writer = NewPacketWriter(conn, h, ctx, UDPOverride)
   187  		}
   188  
   189  		if err := buf.Copy(input, writer, buf.UpdateActivity(timer)); err != nil {
   190  			return newError("failed to process request").Base(err)
   191  		}
   192  
   193  		return nil
   194  	}
   195  
   196  	responseDone := func() error {
   197  		defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
   198  
   199  		var reader buf.Reader
   200  		if destination.Network == net.Network_TCP {
   201  			reader = buf.NewReader(conn)
   202  		} else {
   203  			reader = NewPacketReader(conn, UDPOverride)
   204  		}
   205  		if err := buf.Copy(reader, output, buf.UpdateActivity(timer)); err != nil {
   206  			return newError("failed to process response").Base(err)
   207  		}
   208  
   209  		return nil
   210  	}
   211  
   212  	if newCtx != nil {
   213  		ctx = newCtx
   214  	}
   215  
   216  	if err := task.Run(ctx, requestDone, task.OnSuccess(responseDone, task.Close(output))); err != nil {
   217  		return newError("connection ends").Base(err)
   218  	}
   219  
   220  	return nil
   221  }
   222  
   223  func NewPacketReader(conn net.Conn, UDPOverride net.Destination) buf.Reader {
   224  	iConn := conn
   225  	statConn, ok := iConn.(*stat.CounterConnection)
   226  	if ok {
   227  		iConn = statConn.Connection
   228  	}
   229  	var counter stats.Counter
   230  	if statConn != nil {
   231  		counter = statConn.ReadCounter
   232  	}
   233  	if c, ok := iConn.(*internet.PacketConnWrapper); ok && UDPOverride.Address == nil && UDPOverride.Port == 0 {
   234  		return &PacketReader{
   235  			PacketConnWrapper: c,
   236  			Counter:           counter,
   237  		}
   238  	}
   239  	return &buf.PacketReader{Reader: conn}
   240  }
   241  
   242  type PacketReader struct {
   243  	*internet.PacketConnWrapper
   244  	stats.Counter
   245  }
   246  
   247  func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
   248  	b := buf.New()
   249  	b.Resize(0, buf.Size)
   250  	n, d, err := r.PacketConnWrapper.ReadFrom(b.Bytes())
   251  	if err != nil {
   252  		b.Release()
   253  		return nil, err
   254  	}
   255  	b.Resize(0, int32(n))
   256  	b.UDP = &net.Destination{
   257  		Address: net.IPAddress(d.(*net.UDPAddr).IP),
   258  		Port:    net.Port(d.(*net.UDPAddr).Port),
   259  		Network: net.Network_UDP,
   260  	}
   261  	if r.Counter != nil {
   262  		r.Counter.Add(int64(n))
   263  	}
   264  	return buf.MultiBuffer{b}, nil
   265  }
   266  
   267  func NewPacketWriter(conn net.Conn, h *Handler, ctx context.Context, UDPOverride net.Destination) buf.Writer {
   268  	iConn := conn
   269  	statConn, ok := iConn.(*stat.CounterConnection)
   270  	if ok {
   271  		iConn = statConn.Connection
   272  	}
   273  	var counter stats.Counter
   274  	if statConn != nil {
   275  		counter = statConn.WriteCounter
   276  	}
   277  	if c, ok := iConn.(*internet.PacketConnWrapper); ok {
   278  		return &PacketWriter{
   279  			PacketConnWrapper: c,
   280  			Counter:           counter,
   281  			Handler:           h,
   282  			Context:           ctx,
   283  			UDPOverride:       UDPOverride,
   284  		}
   285  	}
   286  	return &buf.SequentialWriter{Writer: conn}
   287  }
   288  
   289  type PacketWriter struct {
   290  	*internet.PacketConnWrapper
   291  	stats.Counter
   292  	*Handler
   293  	context.Context
   294  	UDPOverride net.Destination
   295  }
   296  
   297  func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
   298  	for {
   299  		mb2, b := buf.SplitFirst(mb)
   300  		mb = mb2
   301  		if b == nil {
   302  			break
   303  		}
   304  		var n int
   305  		var err error
   306  		if b.UDP != nil {
   307  			if w.UDPOverride.Address != nil {
   308  				b.UDP.Address = w.UDPOverride.Address
   309  			}
   310  			if w.UDPOverride.Port != 0 {
   311  				b.UDP.Port = w.UDPOverride.Port
   312  			}
   313  			if w.Handler.config.useIP() && b.UDP.Address.Family().IsDomain() {
   314  				ip := w.Handler.resolveIP(w.Context, b.UDP.Address.Domain(), nil)
   315  				if ip != nil {
   316  					b.UDP.Address = ip
   317  				}
   318  			}
   319  			destAddr, _ := net.ResolveUDPAddr("udp", b.UDP.NetAddr())
   320  			if destAddr == nil {
   321  				b.Release()
   322  				continue
   323  			}
   324  			n, err = w.PacketConnWrapper.WriteTo(b.Bytes(), destAddr)
   325  		} else {
   326  			n, err = w.PacketConnWrapper.Write(b.Bytes())
   327  		}
   328  		b.Release()
   329  		if err != nil {
   330  			buf.ReleaseMulti(mb)
   331  			return err
   332  		}
   333  		if w.Counter != nil {
   334  			w.Counter.Add(int64(n))
   335  		}
   336  	}
   337  	return nil
   338  }
   339  
   340  type FragmentWriter struct {
   341  	fragment *Fragment
   342  	writer   io.Writer
   343  	count    uint64
   344  }
   345  
   346  func (f *FragmentWriter) Write(b []byte) (int, error) {
   347  	f.count++
   348  
   349  	if f.fragment.PacketsFrom == 0 && f.fragment.PacketsTo == 1 {
   350  		if f.count != 1 || len(b) <= 5 || b[0] != 22 {
   351  			return f.writer.Write(b)
   352  		}
   353  		recordLen := 5 + ((int(b[3]) << 8) | int(b[4]))
   354  		data := b[5:recordLen]
   355  		buf := make([]byte, 1024)
   356  		for from := 0; ; {
   357  			to := from + int(randBetween(int64(f.fragment.LengthMin), int64(f.fragment.LengthMax)))
   358  			if to > len(data) {
   359  				to = len(data)
   360  			}
   361  			copy(buf[:3], b)
   362  			copy(buf[5:], data[from:to])
   363  			l := to - from
   364  			from = to
   365  			buf[3] = byte(l >> 8)
   366  			buf[4] = byte(l)
   367  			_, err := f.writer.Write(buf[:5+l])
   368  			time.Sleep(time.Duration(randBetween(int64(f.fragment.IntervalMin), int64(f.fragment.IntervalMax))) * time.Millisecond)
   369  			if err != nil {
   370  				return 0, err
   371  			}
   372  			if from == len(data) {
   373  				if len(b) > recordLen {
   374  					n, err := f.writer.Write(b[recordLen:])
   375  					if err != nil {
   376  						return recordLen + n, err
   377  					}
   378  				}
   379  				return len(b), nil
   380  			}
   381  		}
   382  	}
   383  
   384  	if f.fragment.PacketsFrom != 0 && (f.count < f.fragment.PacketsFrom || f.count > f.fragment.PacketsTo) {
   385  		return f.writer.Write(b)
   386  	}
   387  	for from := 0; ; {
   388  		to := from + int(randBetween(int64(f.fragment.LengthMin), int64(f.fragment.LengthMax)))
   389  		if to > len(b) {
   390  			to = len(b)
   391  		}
   392  		n, err := f.writer.Write(b[from:to])
   393  		from += n
   394  		time.Sleep(time.Duration(randBetween(int64(f.fragment.IntervalMin), int64(f.fragment.IntervalMax))) * time.Millisecond)
   395  		if err != nil {
   396  			return from, err
   397  		}
   398  		if from >= len(b) {
   399  			return from, nil
   400  		}
   401  	}
   402  }
   403  
   404  // stolen from github.com/moqsien/xraycore/transport/internet/reality
   405  func randBetween(left int64, right int64) int64 {
   406  	if left == right {
   407  		return left
   408  	}
   409  	bigInt, _ := rand.Int(rand.Reader, big.NewInt(right-left))
   410  	return left + bigInt.Int64()
   411  }