github.com/xraypb/Xray-core@v1.8.1/proxy/freedom/freedom.go (about)

     1  package freedom
     2  
     3  //go:generate go run github.com/xraypb/Xray-core/common/errors/errorgen
     4  
     5  import (
     6  	"context"
     7  	"time"
     8  
     9  	"github.com/xraypb/Xray-core/common"
    10  	"github.com/xraypb/Xray-core/common/buf"
    11  	"github.com/xraypb/Xray-core/common/dice"
    12  	"github.com/xraypb/Xray-core/common/net"
    13  	"github.com/xraypb/Xray-core/common/retry"
    14  	"github.com/xraypb/Xray-core/common/session"
    15  	"github.com/xraypb/Xray-core/common/signal"
    16  	"github.com/xraypb/Xray-core/common/task"
    17  	"github.com/xraypb/Xray-core/core"
    18  	"github.com/xraypb/Xray-core/features/dns"
    19  	"github.com/xraypb/Xray-core/features/policy"
    20  	"github.com/xraypb/Xray-core/features/stats"
    21  	"github.com/xraypb/Xray-core/transport"
    22  	"github.com/xraypb/Xray-core/transport/internet"
    23  	"github.com/xraypb/Xray-core/transport/internet/stat"
    24  )
    25  
    26  func init() {
    27  	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
    28  		h := new(Handler)
    29  		if err := core.RequireFeatures(ctx, func(pm policy.Manager, d dns.Client) error {
    30  			return h.Init(config.(*Config), pm, d)
    31  		}); err != nil {
    32  			return nil, err
    33  		}
    34  		return h, nil
    35  	}))
    36  }
    37  
    38  // Handler handles Freedom connections.
    39  type Handler struct {
    40  	policyManager policy.Manager
    41  	dns           dns.Client
    42  	config        *Config
    43  }
    44  
    45  // Init initializes the Handler with necessary parameters.
    46  func (h *Handler) Init(config *Config, pm policy.Manager, d dns.Client) error {
    47  	h.config = config
    48  	h.policyManager = pm
    49  	h.dns = d
    50  
    51  	return nil
    52  }
    53  
    54  func (h *Handler) policy() policy.Session {
    55  	p := h.policyManager.ForLevel(h.config.UserLevel)
    56  	if h.config.Timeout > 0 && h.config.UserLevel == 0 {
    57  		p.Timeouts.ConnectionIdle = time.Duration(h.config.Timeout) * time.Second
    58  	}
    59  	return p
    60  }
    61  
    62  func (h *Handler) resolveIP(ctx context.Context, domain string, localAddr net.Address) net.Address {
    63  	var option dns.IPOption = dns.IPOption{
    64  		IPv4Enable: true,
    65  		IPv6Enable: true,
    66  		FakeEnable: false,
    67  	}
    68  	if h.config.DomainStrategy == Config_USE_IP4 || (localAddr != nil && localAddr.Family().IsIPv4()) {
    69  		option = dns.IPOption{
    70  			IPv4Enable: true,
    71  			IPv6Enable: false,
    72  			FakeEnable: false,
    73  		}
    74  	} else if h.config.DomainStrategy == Config_USE_IP6 || (localAddr != nil && localAddr.Family().IsIPv6()) {
    75  		option = dns.IPOption{
    76  			IPv4Enable: false,
    77  			IPv6Enable: true,
    78  			FakeEnable: false,
    79  		}
    80  	}
    81  
    82  	ips, err := h.dns.LookupIP(domain, option)
    83  	if err != nil {
    84  		newError("failed to get IP address for domain ", domain).Base(err).WriteToLog(session.ExportIDToError(ctx))
    85  	}
    86  	if len(ips) == 0 {
    87  		return nil
    88  	}
    89  	return net.IPAddress(ips[dice.Roll(len(ips))])
    90  }
    91  
    92  func isValidAddress(addr *net.IPOrDomain) bool {
    93  	if addr == nil {
    94  		return false
    95  	}
    96  
    97  	a := addr.AsAddress()
    98  	return a != net.AnyIP
    99  }
   100  
   101  // Process implements proxy.Outbound.
   102  func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
   103  	outbound := session.OutboundFromContext(ctx)
   104  	if outbound == nil || !outbound.Target.IsValid() {
   105  		return newError("target not specified.")
   106  	}
   107  	destination := outbound.Target
   108  	UDPOverride := net.UDPDestination(nil, 0)
   109  	if h.config.DestinationOverride != nil {
   110  		server := h.config.DestinationOverride.Server
   111  		if isValidAddress(server.Address) {
   112  			destination.Address = server.Address.AsAddress()
   113  			UDPOverride.Address = destination.Address
   114  		}
   115  		if server.Port != 0 {
   116  			destination.Port = net.Port(server.Port)
   117  			UDPOverride.Port = destination.Port
   118  		}
   119  	}
   120  	newError("opening connection to ", destination).WriteToLog(session.ExportIDToError(ctx))
   121  
   122  	input := link.Reader
   123  	output := link.Writer
   124  
   125  	var conn stat.Connection
   126  	err := retry.ExponentialBackoff(5, 100).On(func() error {
   127  		dialDest := destination
   128  		if h.config.useIP() && dialDest.Address.Family().IsDomain() {
   129  			ip := h.resolveIP(ctx, dialDest.Address.Domain(), dialer.Address())
   130  			if ip != nil {
   131  				dialDest = net.Destination{
   132  					Network: dialDest.Network,
   133  					Address: ip,
   134  					Port:    dialDest.Port,
   135  				}
   136  				newError("dialing to ", dialDest).WriteToLog(session.ExportIDToError(ctx))
   137  			}
   138  		}
   139  
   140  		rawConn, err := dialer.Dial(ctx, dialDest)
   141  		if err != nil {
   142  			return err
   143  		}
   144  		conn = rawConn
   145  		return nil
   146  	})
   147  	if err != nil {
   148  		return newError("failed to open connection to ", destination).Base(err)
   149  	}
   150  	defer conn.Close()
   151  
   152  	var newCtx context.Context
   153  	var newCancel context.CancelFunc
   154  	if session.TimeoutOnlyFromContext(ctx) {
   155  		newCtx, newCancel = context.WithCancel(context.Background())
   156  	}
   157  
   158  	plcy := h.policy()
   159  	ctx, cancel := context.WithCancel(ctx)
   160  	timer := signal.CancelAfterInactivity(ctx, func() {
   161  		cancel()
   162  		if newCancel != nil {
   163  			newCancel()
   164  		}
   165  	}, plcy.Timeouts.ConnectionIdle)
   166  
   167  	requestDone := func() error {
   168  		defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
   169  
   170  		var writer buf.Writer
   171  		if destination.Network == net.Network_TCP {
   172  			writer = buf.NewWriter(conn)
   173  		} else {
   174  			writer = NewPacketWriter(conn, h, ctx, UDPOverride)
   175  		}
   176  
   177  		if err := buf.Copy(input, writer, buf.UpdateActivity(timer)); err != nil {
   178  			return newError("failed to process request").Base(err)
   179  		}
   180  
   181  		return nil
   182  	}
   183  
   184  	responseDone := func() error {
   185  		defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
   186  
   187  		var reader buf.Reader
   188  		if destination.Network == net.Network_TCP {
   189  			reader = buf.NewReader(conn)
   190  		} else {
   191  			reader = NewPacketReader(conn, UDPOverride)
   192  		}
   193  		if err := buf.Copy(reader, output, buf.UpdateActivity(timer)); err != nil {
   194  			return newError("failed to process response").Base(err)
   195  		}
   196  
   197  		return nil
   198  	}
   199  
   200  	if newCtx != nil {
   201  		ctx = newCtx
   202  	}
   203  
   204  	if err := task.Run(ctx, requestDone, task.OnSuccess(responseDone, task.Close(output))); err != nil {
   205  		return newError("connection ends").Base(err)
   206  	}
   207  
   208  	return nil
   209  }
   210  
   211  func NewPacketReader(conn net.Conn, UDPOverride net.Destination) buf.Reader {
   212  	iConn := conn
   213  	statConn, ok := iConn.(*stat.CounterConnection)
   214  	if ok {
   215  		iConn = statConn.Connection
   216  	}
   217  	var counter stats.Counter
   218  	if statConn != nil {
   219  		counter = statConn.ReadCounter
   220  	}
   221  	if c, ok := iConn.(*internet.PacketConnWrapper); ok && UDPOverride.Address == nil && UDPOverride.Port == 0 {
   222  		return &PacketReader{
   223  			PacketConnWrapper: c,
   224  			Counter:           counter,
   225  		}
   226  	}
   227  	return &buf.PacketReader{Reader: conn}
   228  }
   229  
   230  type PacketReader struct {
   231  	*internet.PacketConnWrapper
   232  	stats.Counter
   233  }
   234  
   235  func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
   236  	b := buf.New()
   237  	b.Resize(0, buf.Size)
   238  	n, d, err := r.PacketConnWrapper.ReadFrom(b.Bytes())
   239  	if err != nil {
   240  		b.Release()
   241  		return nil, err
   242  	}
   243  	b.Resize(0, int32(n))
   244  	b.UDP = &net.Destination{
   245  		Address: net.IPAddress(d.(*net.UDPAddr).IP),
   246  		Port:    net.Port(d.(*net.UDPAddr).Port),
   247  		Network: net.Network_UDP,
   248  	}
   249  	if r.Counter != nil {
   250  		r.Counter.Add(int64(n))
   251  	}
   252  	return buf.MultiBuffer{b}, nil
   253  }
   254  
   255  func NewPacketWriter(conn net.Conn, h *Handler, ctx context.Context, UDPOverride net.Destination) buf.Writer {
   256  	iConn := conn
   257  	statConn, ok := iConn.(*stat.CounterConnection)
   258  	if ok {
   259  		iConn = statConn.Connection
   260  	}
   261  	var counter stats.Counter
   262  	if statConn != nil {
   263  		counter = statConn.WriteCounter
   264  	}
   265  	if c, ok := iConn.(*internet.PacketConnWrapper); ok {
   266  		return &PacketWriter{
   267  			PacketConnWrapper: c,
   268  			Counter:           counter,
   269  			Handler:           h,
   270  			Context:           ctx,
   271  			UDPOverride:       UDPOverride,
   272  		}
   273  	}
   274  	return &buf.SequentialWriter{Writer: conn}
   275  }
   276  
   277  type PacketWriter struct {
   278  	*internet.PacketConnWrapper
   279  	stats.Counter
   280  	*Handler
   281  	context.Context
   282  	UDPOverride net.Destination
   283  }
   284  
   285  func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
   286  	for {
   287  		mb2, b := buf.SplitFirst(mb)
   288  		mb = mb2
   289  		if b == nil {
   290  			break
   291  		}
   292  		var n int
   293  		var err error
   294  		if b.UDP != nil {
   295  			if w.UDPOverride.Address != nil {
   296  				b.UDP.Address = w.UDPOverride.Address
   297  			}
   298  			if w.UDPOverride.Port != 0 {
   299  				b.UDP.Port = w.UDPOverride.Port
   300  			}
   301  			if w.Handler.config.useIP() && b.UDP.Address.Family().IsDomain() {
   302  				ip := w.Handler.resolveIP(w.Context, b.UDP.Address.Domain(), nil)
   303  				if ip != nil {
   304  					b.UDP.Address = ip
   305  				}
   306  			}
   307  			destAddr, _ := net.ResolveUDPAddr("udp", b.UDP.NetAddr())
   308  			if destAddr == nil {
   309  				b.Release()
   310  				continue
   311  			}
   312  			n, err = w.PacketConnWrapper.WriteTo(b.Bytes(), destAddr)
   313  		} else {
   314  			n, err = w.PacketConnWrapper.Write(b.Bytes())
   315  		}
   316  		b.Release()
   317  		if err != nil {
   318  			buf.ReleaseMulti(mb)
   319  			return err
   320  		}
   321  		if w.Counter != nil {
   322  			w.Counter.Add(int64(n))
   323  		}
   324  	}
   325  	return nil
   326  }