github.com/EagleQL/Xray-core@v1.4.3/proxy/freedom/freedom.go (about)

     1  package freedom
     2  
     3  //go:generate go run github.com/xtls/xray-core/common/errors/errorgen
     4  
     5  import (
     6  	"context"
     7  	"time"
     8  
     9  	"github.com/xtls/xray-core/common"
    10  	"github.com/xtls/xray-core/common/buf"
    11  	"github.com/xtls/xray-core/common/dice"
    12  	"github.com/xtls/xray-core/common/net"
    13  	"github.com/xtls/xray-core/common/retry"
    14  	"github.com/xtls/xray-core/common/session"
    15  	"github.com/xtls/xray-core/common/signal"
    16  	"github.com/xtls/xray-core/common/task"
    17  	"github.com/xtls/xray-core/core"
    18  	"github.com/xtls/xray-core/features/dns"
    19  	"github.com/xtls/xray-core/features/policy"
    20  	"github.com/xtls/xray-core/features/stats"
    21  	"github.com/xtls/xray-core/transport"
    22  	"github.com/xtls/xray-core/transport/internet"
    23  )
    24  
    25  func init() {
    26  	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
    27  		h := new(Handler)
    28  		if err := core.RequireFeatures(ctx, func(pm policy.Manager, d dns.Client) error {
    29  			return h.Init(config.(*Config), pm, d)
    30  		}); err != nil {
    31  			return nil, err
    32  		}
    33  		return h, nil
    34  	}))
    35  }
    36  
    37  // Handler handles Freedom connections.
    38  type Handler struct {
    39  	policyManager policy.Manager
    40  	dns           dns.Client
    41  	config        *Config
    42  }
    43  
    44  // Init initializes the Handler with necessary parameters.
    45  func (h *Handler) Init(config *Config, pm policy.Manager, d dns.Client) error {
    46  	h.config = config
    47  	h.policyManager = pm
    48  	h.dns = d
    49  
    50  	return nil
    51  }
    52  
    53  func (h *Handler) policy() policy.Session {
    54  	p := h.policyManager.ForLevel(h.config.UserLevel)
    55  	if h.config.Timeout > 0 && h.config.UserLevel == 0 {
    56  		p.Timeouts.ConnectionIdle = time.Duration(h.config.Timeout) * time.Second
    57  	}
    58  	return p
    59  }
    60  
    61  func (h *Handler) resolveIP(ctx context.Context, domain string, localAddr net.Address) net.Address {
    62  	var option dns.IPOption = dns.IPOption{
    63  		IPv4Enable: true,
    64  		IPv6Enable: true,
    65  		FakeEnable: false,
    66  	}
    67  	if h.config.DomainStrategy == Config_USE_IP4 || (localAddr != nil && localAddr.Family().IsIPv4()) {
    68  		option = dns.IPOption{
    69  			IPv4Enable: true,
    70  			IPv6Enable: false,
    71  			FakeEnable: false,
    72  		}
    73  	} else if h.config.DomainStrategy == Config_USE_IP6 || (localAddr != nil && localAddr.Family().IsIPv6()) {
    74  		option = dns.IPOption{
    75  			IPv4Enable: false,
    76  			IPv6Enable: true,
    77  			FakeEnable: false,
    78  		}
    79  	}
    80  
    81  	ips, err := h.dns.LookupIP(domain, option)
    82  	if err != nil {
    83  		newError("failed to get IP address for domain ", domain).Base(err).WriteToLog(session.ExportIDToError(ctx))
    84  	}
    85  	if len(ips) == 0 {
    86  		return nil
    87  	}
    88  	return net.IPAddress(ips[dice.Roll(len(ips))])
    89  }
    90  
    91  func isValidAddress(addr *net.IPOrDomain) bool {
    92  	if addr == nil {
    93  		return false
    94  	}
    95  
    96  	a := addr.AsAddress()
    97  	return a != net.AnyIP
    98  }
    99  
   100  // Process implements proxy.Outbound.
   101  func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
   102  	outbound := session.OutboundFromContext(ctx)
   103  	if outbound == nil || !outbound.Target.IsValid() {
   104  		return newError("target not specified.")
   105  	}
   106  	destination := outbound.Target
   107  	UDPOverride := net.UDPDestination(nil, 0)
   108  	if h.config.DestinationOverride != nil {
   109  		server := h.config.DestinationOverride.Server
   110  		if isValidAddress(server.Address) {
   111  			destination.Address = server.Address.AsAddress()
   112  			UDPOverride.Address = destination.Address
   113  		}
   114  		if server.Port != 0 {
   115  			destination.Port = net.Port(server.Port)
   116  			UDPOverride.Port = destination.Port
   117  		}
   118  	}
   119  	newError("opening connection to ", destination).WriteToLog(session.ExportIDToError(ctx))
   120  
   121  	input := link.Reader
   122  	output := link.Writer
   123  
   124  	var conn internet.Connection
   125  	err := retry.ExponentialBackoff(5, 100).On(func() error {
   126  		dialDest := destination
   127  		if h.config.useIP() && dialDest.Address.Family().IsDomain() {
   128  			ip := h.resolveIP(ctx, dialDest.Address.Domain(), dialer.Address())
   129  			if ip != nil {
   130  				dialDest = net.Destination{
   131  					Network: dialDest.Network,
   132  					Address: ip,
   133  					Port:    dialDest.Port,
   134  				}
   135  				newError("dialing to ", dialDest).WriteToLog(session.ExportIDToError(ctx))
   136  			}
   137  		}
   138  
   139  		rawConn, err := dialer.Dial(ctx, dialDest)
   140  		if err != nil {
   141  			return err
   142  		}
   143  		conn = rawConn
   144  		return nil
   145  	})
   146  	if err != nil {
   147  		return newError("failed to open connection to ", destination).Base(err)
   148  	}
   149  	defer conn.Close()
   150  
   151  	plcy := h.policy()
   152  	ctx, cancel := context.WithCancel(ctx)
   153  	timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle)
   154  
   155  	requestDone := func() error {
   156  		defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
   157  
   158  		var writer buf.Writer
   159  		if destination.Network == net.Network_TCP {
   160  			writer = buf.NewWriter(conn)
   161  		} else {
   162  			writer = NewPacketWriter(conn, h, ctx, UDPOverride)
   163  		}
   164  
   165  		if err := buf.Copy(input, writer, buf.UpdateActivity(timer)); err != nil {
   166  			return newError("failed to process request").Base(err)
   167  		}
   168  
   169  		return nil
   170  	}
   171  
   172  	responseDone := func() error {
   173  		defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
   174  
   175  		var reader buf.Reader
   176  		if destination.Network == net.Network_TCP {
   177  			reader = buf.NewReader(conn)
   178  		} else {
   179  			reader = NewPacketReader(conn, UDPOverride)
   180  		}
   181  		if err := buf.Copy(reader, output, buf.UpdateActivity(timer)); err != nil {
   182  			return newError("failed to process response").Base(err)
   183  		}
   184  
   185  		return nil
   186  	}
   187  
   188  	if err := task.Run(ctx, requestDone, task.OnSuccess(responseDone, task.Close(output))); err != nil {
   189  		return newError("connection ends").Base(err)
   190  	}
   191  
   192  	return nil
   193  }
   194  
   195  func NewPacketReader(conn net.Conn, UDPOverride net.Destination) buf.Reader {
   196  	iConn := conn
   197  	statConn, ok := iConn.(*internet.StatCouterConnection)
   198  	if ok {
   199  		iConn = statConn.Connection
   200  	}
   201  	var counter stats.Counter
   202  	if statConn != nil {
   203  		counter = statConn.ReadCounter
   204  	}
   205  	if c, ok := iConn.(*internet.PacketConnWrapper); ok && UDPOverride.Address == nil && UDPOverride.Port == 0 {
   206  		return &PacketReader{
   207  			PacketConnWrapper: c,
   208  			Counter:           counter,
   209  		}
   210  	}
   211  	return &buf.PacketReader{Reader: conn}
   212  }
   213  
   214  type PacketReader struct {
   215  	*internet.PacketConnWrapper
   216  	stats.Counter
   217  }
   218  
   219  func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
   220  	b := buf.New()
   221  	b.Resize(0, buf.Size)
   222  	n, d, err := r.PacketConnWrapper.ReadFrom(b.Bytes())
   223  	if err != nil {
   224  		b.Release()
   225  		return nil, err
   226  	}
   227  	b.Resize(0, int32(n))
   228  	b.UDP = &net.Destination{
   229  		Address: net.IPAddress(d.(*net.UDPAddr).IP),
   230  		Port:    net.Port(d.(*net.UDPAddr).Port),
   231  		Network: net.Network_UDP,
   232  	}
   233  	if r.Counter != nil {
   234  		r.Counter.Add(int64(n))
   235  	}
   236  	return buf.MultiBuffer{b}, nil
   237  }
   238  
   239  func NewPacketWriter(conn net.Conn, h *Handler, ctx context.Context, UDPOverride net.Destination) buf.Writer {
   240  	iConn := conn
   241  	statConn, ok := iConn.(*internet.StatCouterConnection)
   242  	if ok {
   243  		iConn = statConn.Connection
   244  	}
   245  	var counter stats.Counter
   246  	if statConn != nil {
   247  		counter = statConn.WriteCounter
   248  	}
   249  	if c, ok := iConn.(*internet.PacketConnWrapper); ok {
   250  		return &PacketWriter{
   251  			PacketConnWrapper: c,
   252  			Counter:           counter,
   253  			Handler:           h,
   254  			Context:           ctx,
   255  			UDPOverride:       UDPOverride,
   256  		}
   257  	}
   258  	return &buf.SequentialWriter{Writer: conn}
   259  }
   260  
   261  type PacketWriter struct {
   262  	*internet.PacketConnWrapper
   263  	stats.Counter
   264  	*Handler
   265  	context.Context
   266  	UDPOverride net.Destination
   267  }
   268  
   269  func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
   270  	for {
   271  		mb2, b := buf.SplitFirst(mb)
   272  		mb = mb2
   273  		if b == nil {
   274  			break
   275  		}
   276  		var n int
   277  		var err error
   278  		if b.UDP != nil {
   279  			if w.UDPOverride.Address != nil {
   280  				b.UDP.Address = w.UDPOverride.Address
   281  			}
   282  			if w.UDPOverride.Port != 0 {
   283  				b.UDP.Port = w.UDPOverride.Port
   284  			}
   285  			if w.Handler.config.useIP() && b.UDP.Address.Family().IsDomain() {
   286  				ip := w.Handler.resolveIP(w.Context, b.UDP.Address.Domain(), nil)
   287  				if ip != nil {
   288  					b.UDP.Address = ip
   289  				}
   290  			}
   291  			destAddr, _ := net.ResolveUDPAddr("udp", b.UDP.NetAddr())
   292  			if destAddr == nil {
   293  				b.Release()
   294  				continue
   295  			}
   296  			n, err = w.PacketConnWrapper.WriteTo(b.Bytes(), destAddr)
   297  		} else {
   298  			n, err = w.PacketConnWrapper.Write(b.Bytes())
   299  		}
   300  		b.Release()
   301  		if err != nil {
   302  			buf.ReleaseMulti(mb)
   303  			return err
   304  		}
   305  		if w.Counter != nil {
   306  			w.Counter.Add(int64(n))
   307  		}
   308  	}
   309  	return nil
   310  }