github.com/xraypb/xray-core@v1.6.6/proxy/freedom/freedom.go (about)

     1  package freedom
     2  
     3  //go:generate go run github.com/xraypb/xray-core/common/errors/errorgen
     4  
     5  import (
     6  	"context"
     7  	"time"
     8  
     9  	"github.com/xraypb/xray-core/common"
    10  	"github.com/xraypb/xray-core/common/buf"
    11  	"github.com/xraypb/xray-core/common/dice"
    12  	"github.com/xraypb/xray-core/common/net"
    13  	"github.com/xraypb/xray-core/common/retry"
    14  	"github.com/xraypb/xray-core/common/session"
    15  	"github.com/xraypb/xray-core/common/signal"
    16  	"github.com/xraypb/xray-core/common/task"
    17  	"github.com/xraypb/xray-core/core"
    18  	"github.com/xraypb/xray-core/features/dns"
    19  	"github.com/xraypb/xray-core/features/policy"
    20  	"github.com/xraypb/xray-core/features/stats"
    21  	"github.com/xraypb/xray-core/transport"
    22  	"github.com/xraypb/xray-core/transport/internet"
    23  	"github.com/xraypb/xray-core/transport/internet/stat"
    24  )
    25  
    26  func init() {
    27  	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
    28  		h := new(Handler)
    29  		if err := core.RequireFeatures(ctx, func(pm policy.Manager, d dns.Client) error {
    30  			return h.Init(config.(*Config), pm, d)
    31  		}); err != nil {
    32  			return nil, err
    33  		}
    34  		return h, nil
    35  	}))
    36  }
    37  
    38  // Handler handles Freedom connections.
    39  type Handler struct {
    40  	policyManager policy.Manager
    41  	dns           dns.Client
    42  	config        *Config
    43  }
    44  
    45  // Init initializes the Handler with necessary parameters.
    46  func (h *Handler) Init(config *Config, pm policy.Manager, d dns.Client) error {
    47  	h.config = config
    48  	h.policyManager = pm
    49  	h.dns = d
    50  
    51  	return nil
    52  }
    53  
    54  func (h *Handler) policy() policy.Session {
    55  	p := h.policyManager.ForLevel(h.config.UserLevel)
    56  	if h.config.Timeout > 0 && h.config.UserLevel == 0 {
    57  		p.Timeouts.ConnectionIdle = time.Duration(h.config.Timeout) * time.Second
    58  	}
    59  	return p
    60  }
    61  
    62  func (h *Handler) resolveIP(ctx context.Context, domain string, localAddr net.Address) net.Address {
    63  	var option dns.IPOption = dns.IPOption{
    64  		IPv4Enable: true,
    65  		IPv6Enable: true,
    66  		FakeEnable: false,
    67  	}
    68  	if h.config.DomainStrategy == Config_USE_IP4 || (localAddr != nil && localAddr.Family().IsIPv4()) {
    69  		option = dns.IPOption{
    70  			IPv4Enable: true,
    71  			IPv6Enable: false,
    72  			FakeEnable: false,
    73  		}
    74  	} else if h.config.DomainStrategy == Config_USE_IP6 || (localAddr != nil && localAddr.Family().IsIPv6()) {
    75  		option = dns.IPOption{
    76  			IPv4Enable: false,
    77  			IPv6Enable: true,
    78  			FakeEnable: false,
    79  		}
    80  	}
    81  
    82  	ips, err := h.dns.LookupIP(domain, option)
    83  	if err != nil {
    84  		newError("failed to get IP address for domain ", domain).Base(err).WriteToLog(session.ExportIDToError(ctx))
    85  	}
    86  	if len(ips) == 0 {
    87  		return nil
    88  	}
    89  	return net.IPAddress(ips[dice.Roll(len(ips))])
    90  }
    91  
    92  func isValidAddress(addr *net.IPOrDomain) bool {
    93  	if addr == nil {
    94  		return false
    95  	}
    96  
    97  	a := addr.AsAddress()
    98  	return a != net.AnyIP
    99  }
   100  
   101  // Process implements proxy.Outbound.
   102  func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
   103  	outbound := session.OutboundFromContext(ctx)
   104  	if outbound == nil || !outbound.Target.IsValid() {
   105  		return newError("target not specified.")
   106  	}
   107  	destination := outbound.Target
   108  	UDPOverride := net.UDPDestination(nil, 0)
   109  	if h.config.DestinationOverride != nil {
   110  		server := h.config.DestinationOverride.Server
   111  		if isValidAddress(server.Address) {
   112  			destination.Address = server.Address.AsAddress()
   113  			UDPOverride.Address = destination.Address
   114  		}
   115  		if server.Port != 0 {
   116  			destination.Port = net.Port(server.Port)
   117  			UDPOverride.Port = destination.Port
   118  		}
   119  	}
   120  	newError("opening connection to ", destination).WriteToLog(session.ExportIDToError(ctx))
   121  
   122  	input := link.Reader
   123  	output := link.Writer
   124  
   125  	var conn stat.Connection
   126  	err := retry.ExponentialBackoff(5, 100).On(func() error {
   127  		dialDest := destination
   128  		if h.config.useIP() && dialDest.Address.Family().IsDomain() {
   129  			ip := h.resolveIP(ctx, dialDest.Address.Domain(), dialer.Address())
   130  			if ip != nil {
   131  				dialDest = net.Destination{
   132  					Network: dialDest.Network,
   133  					Address: ip,
   134  					Port:    dialDest.Port,
   135  				}
   136  				newError("dialing to ", dialDest).WriteToLog(session.ExportIDToError(ctx))
   137  			}
   138  		}
   139  
   140  		rawConn, err := dialer.Dial(ctx, dialDest)
   141  		if err != nil {
   142  			return err
   143  		}
   144  		conn = rawConn
   145  		return nil
   146  	})
   147  	if err != nil {
   148  		return newError("failed to open connection to ", destination).Base(err)
   149  	}
   150  	defer conn.Close()
   151  
   152  	plcy := h.policy()
   153  	ctx, cancel := context.WithCancel(ctx)
   154  	timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle)
   155  
   156  	requestDone := func() error {
   157  		defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
   158  
   159  		var writer buf.Writer
   160  		if destination.Network == net.Network_TCP {
   161  			writer = buf.NewWriter(conn)
   162  		} else {
   163  			writer = NewPacketWriter(conn, h, ctx, UDPOverride)
   164  		}
   165  
   166  		if err := buf.Copy(input, writer, buf.UpdateActivity(timer)); err != nil {
   167  			return newError("failed to process request").Base(err)
   168  		}
   169  
   170  		return nil
   171  	}
   172  
   173  	responseDone := func() error {
   174  		defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
   175  
   176  		var reader buf.Reader
   177  		if destination.Network == net.Network_TCP {
   178  			reader = buf.NewReader(conn)
   179  		} else {
   180  			reader = NewPacketReader(conn, UDPOverride)
   181  		}
   182  		if err := buf.Copy(reader, output, buf.UpdateActivity(timer)); err != nil {
   183  			return newError("failed to process response").Base(err)
   184  		}
   185  
   186  		return nil
   187  	}
   188  
   189  	if err := task.Run(ctx, requestDone, task.OnSuccess(responseDone, task.Close(output))); err != nil {
   190  		return newError("connection ends").Base(err)
   191  	}
   192  
   193  	return nil
   194  }
   195  
   196  func NewPacketReader(conn net.Conn, UDPOverride net.Destination) buf.Reader {
   197  	iConn := conn
   198  	statConn, ok := iConn.(*stat.CounterConnection)
   199  	if ok {
   200  		iConn = statConn.Connection
   201  	}
   202  	var counter stats.Counter
   203  	if statConn != nil {
   204  		counter = statConn.ReadCounter
   205  	}
   206  	if c, ok := iConn.(*internet.PacketConnWrapper); ok && UDPOverride.Address == nil && UDPOverride.Port == 0 {
   207  		return &PacketReader{
   208  			PacketConnWrapper: c,
   209  			Counter:           counter,
   210  		}
   211  	}
   212  	return &buf.PacketReader{Reader: conn}
   213  }
   214  
   215  type PacketReader struct {
   216  	*internet.PacketConnWrapper
   217  	stats.Counter
   218  }
   219  
   220  func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
   221  	b := buf.New()
   222  	b.Resize(0, buf.Size)
   223  	n, d, err := r.PacketConnWrapper.ReadFrom(b.Bytes())
   224  	if err != nil {
   225  		b.Release()
   226  		return nil, err
   227  	}
   228  	b.Resize(0, int32(n))
   229  	b.UDP = &net.Destination{
   230  		Address: net.IPAddress(d.(*net.UDPAddr).IP),
   231  		Port:    net.Port(d.(*net.UDPAddr).Port),
   232  		Network: net.Network_UDP,
   233  	}
   234  	if r.Counter != nil {
   235  		r.Counter.Add(int64(n))
   236  	}
   237  	return buf.MultiBuffer{b}, nil
   238  }
   239  
   240  func NewPacketWriter(conn net.Conn, h *Handler, ctx context.Context, UDPOverride net.Destination) buf.Writer {
   241  	iConn := conn
   242  	statConn, ok := iConn.(*stat.CounterConnection)
   243  	if ok {
   244  		iConn = statConn.Connection
   245  	}
   246  	var counter stats.Counter
   247  	if statConn != nil {
   248  		counter = statConn.WriteCounter
   249  	}
   250  	if c, ok := iConn.(*internet.PacketConnWrapper); ok {
   251  		return &PacketWriter{
   252  			PacketConnWrapper: c,
   253  			Counter:           counter,
   254  			Handler:           h,
   255  			Context:           ctx,
   256  			UDPOverride:       UDPOverride,
   257  		}
   258  	}
   259  	return &buf.SequentialWriter{Writer: conn}
   260  }
   261  
   262  type PacketWriter struct {
   263  	*internet.PacketConnWrapper
   264  	stats.Counter
   265  	*Handler
   266  	context.Context
   267  	UDPOverride net.Destination
   268  }
   269  
   270  func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
   271  	for {
   272  		mb2, b := buf.SplitFirst(mb)
   273  		mb = mb2
   274  		if b == nil {
   275  			break
   276  		}
   277  		var n int
   278  		var err error
   279  		if b.UDP != nil {
   280  			if w.UDPOverride.Address != nil {
   281  				b.UDP.Address = w.UDPOverride.Address
   282  			}
   283  			if w.UDPOverride.Port != 0 {
   284  				b.UDP.Port = w.UDPOverride.Port
   285  			}
   286  			if w.Handler.config.useIP() && b.UDP.Address.Family().IsDomain() {
   287  				ip := w.Handler.resolveIP(w.Context, b.UDP.Address.Domain(), nil)
   288  				if ip != nil {
   289  					b.UDP.Address = ip
   290  				}
   291  			}
   292  			destAddr, _ := net.ResolveUDPAddr("udp", b.UDP.NetAddr())
   293  			if destAddr == nil {
   294  				b.Release()
   295  				continue
   296  			}
   297  			n, err = w.PacketConnWrapper.WriteTo(b.Bytes(), destAddr)
   298  		} else {
   299  			n, err = w.PacketConnWrapper.Write(b.Bytes())
   300  		}
   301  		b.Release()
   302  		if err != nil {
   303  			buf.ReleaseMulti(mb)
   304  			return err
   305  		}
   306  		if w.Counter != nil {
   307  			w.Counter.Add(int64(n))
   308  		}
   309  	}
   310  	return nil
   311  }