github.com/xmplusdev/xray-core@v1.8.10/proxy/freedom/freedom.go (about) 1 package freedom 2 3 //go:generate go run github.com/xmplusdev/xray-core/common/errors/errorgen 4 5 import ( 6 "context" 7 "crypto/rand" 8 "io" 9 "math/big" 10 "time" 11 12 "github.com/pires/go-proxyproto" 13 "github.com/xmplusdev/xray-core/common" 14 "github.com/xmplusdev/xray-core/common/buf" 15 "github.com/xmplusdev/xray-core/common/dice" 16 "github.com/xmplusdev/xray-core/common/net" 17 "github.com/xmplusdev/xray-core/common/platform" 18 "github.com/xmplusdev/xray-core/common/retry" 19 "github.com/xmplusdev/xray-core/common/session" 20 "github.com/xmplusdev/xray-core/common/signal" 21 "github.com/xmplusdev/xray-core/common/task" 22 "github.com/xmplusdev/xray-core/core" 23 "github.com/xmplusdev/xray-core/features/dns" 24 "github.com/xmplusdev/xray-core/features/policy" 25 "github.com/xmplusdev/xray-core/features/stats" 26 "github.com/xmplusdev/xray-core/proxy" 27 "github.com/xmplusdev/xray-core/transport" 28 "github.com/xmplusdev/xray-core/transport/internet" 29 "github.com/xmplusdev/xray-core/transport/internet/stat" 30 ) 31 32 var useSplice bool 33 34 func init() { 35 common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { 36 h := new(Handler) 37 if err := core.RequireFeatures(ctx, func(pm policy.Manager, d dns.Client) error { 38 return h.Init(config.(*Config), pm, d) 39 }); err != nil { 40 return nil, err 41 } 42 return h, nil 43 })) 44 const defaultFlagValue = "NOT_DEFINED_AT_ALL" 45 value := platform.NewEnvFlag(platform.UseFreedomSplice).GetValue(func() string { return defaultFlagValue }) 46 switch value { 47 case defaultFlagValue, "auto", "enable": 48 useSplice = true 49 } 50 } 51 52 // Handler handles Freedom connections. 53 type Handler struct { 54 policyManager policy.Manager 55 dns dns.Client 56 config *Config 57 } 58 59 // Init initializes the Handler with necessary parameters. 60 func (h *Handler) Init(config *Config, pm policy.Manager, d dns.Client) error { 61 h.config = config 62 h.policyManager = pm 63 h.dns = d 64 65 return nil 66 } 67 68 func (h *Handler) policy() policy.Session { 69 p := h.policyManager.ForLevel(h.config.UserLevel) 70 if h.config.Timeout > 0 && h.config.UserLevel == 0 { 71 p.Timeouts.ConnectionIdle = time.Duration(h.config.Timeout) * time.Second 72 } 73 return p 74 } 75 76 func (h *Handler) resolveIP(ctx context.Context, domain string, localAddr net.Address) net.Address { 77 ips, err := h.dns.LookupIP(domain, dns.IPOption{ 78 IPv4Enable: (localAddr == nil || localAddr.Family().IsIPv4()) && h.config.preferIP4(), 79 IPv6Enable: (localAddr == nil || localAddr.Family().IsIPv6()) && h.config.preferIP6(), 80 }) 81 { // Resolve fallback 82 if (len(ips) == 0 || err != nil) && h.config.hasFallback() && localAddr == nil { 83 ips, err = h.dns.LookupIP(domain, dns.IPOption{ 84 IPv4Enable: h.config.fallbackIP4(), 85 IPv6Enable: h.config.fallbackIP6(), 86 }) 87 } 88 } 89 if err != nil { 90 newError("failed to get IP address for domain ", domain).Base(err).WriteToLog(session.ExportIDToError(ctx)) 91 } 92 if len(ips) == 0 { 93 return nil 94 } 95 return net.IPAddress(ips[dice.Roll(len(ips))]) 96 } 97 98 func isValidAddress(addr *net.IPOrDomain) bool { 99 if addr == nil { 100 return false 101 } 102 103 a := addr.AsAddress() 104 return a != net.AnyIP 105 } 106 107 // Process implements proxy.Outbound. 108 func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { 109 outbound := session.OutboundFromContext(ctx) 110 if outbound == nil || !outbound.Target.IsValid() { 111 return newError("target not specified.") 112 } 113 outbound.Name = "freedom" 114 inbound := session.InboundFromContext(ctx) 115 if inbound != nil { 116 inbound.SetCanSpliceCopy(1) 117 } 118 destination := outbound.Target 119 UDPOverride := net.UDPDestination(nil, 0) 120 if h.config.DestinationOverride != nil { 121 server := h.config.DestinationOverride.Server 122 if isValidAddress(server.Address) { 123 destination.Address = server.Address.AsAddress() 124 UDPOverride.Address = destination.Address 125 } 126 if server.Port != 0 { 127 destination.Port = net.Port(server.Port) 128 UDPOverride.Port = destination.Port 129 } 130 } 131 132 input := link.Reader 133 output := link.Writer 134 135 var conn stat.Connection 136 err := retry.ExponentialBackoff(5, 100).On(func() error { 137 dialDest := destination 138 if h.config.hasStrategy() && dialDest.Address.Family().IsDomain() { 139 ip := h.resolveIP(ctx, dialDest.Address.Domain(), dialer.Address()) 140 if ip != nil { 141 dialDest = net.Destination{ 142 Network: dialDest.Network, 143 Address: ip, 144 Port: dialDest.Port, 145 } 146 newError("dialing to ", dialDest).WriteToLog(session.ExportIDToError(ctx)) 147 } else if h.config.forceIP() { 148 return dns.ErrEmptyResponse 149 } 150 } 151 152 rawConn, err := dialer.Dial(ctx, dialDest) 153 if err != nil { 154 return err 155 } 156 157 if h.config.ProxyProtocol > 0 && h.config.ProxyProtocol <= 2 { 158 version := byte(h.config.ProxyProtocol) 159 srcAddr := inbound.Source.RawNetAddr() 160 dstAddr := rawConn.RemoteAddr() 161 header := proxyproto.HeaderProxyFromAddrs(version, srcAddr, dstAddr) 162 if _, err = header.WriteTo(rawConn); err != nil { 163 rawConn.Close() 164 return err 165 } 166 } 167 168 conn = rawConn 169 return nil 170 }) 171 if err != nil { 172 return newError("failed to open connection to ", destination).Base(err) 173 } 174 defer conn.Close() 175 newError("connection opened to ", destination, ", local endpoint ", conn.LocalAddr(), ", remote endpoint ", conn.RemoteAddr()).WriteToLog(session.ExportIDToError(ctx)) 176 177 var newCtx context.Context 178 var newCancel context.CancelFunc 179 if session.TimeoutOnlyFromContext(ctx) { 180 newCtx, newCancel = context.WithCancel(context.Background()) 181 } 182 183 plcy := h.policy() 184 ctx, cancel := context.WithCancel(ctx) 185 timer := signal.CancelAfterInactivity(ctx, func() { 186 cancel() 187 if newCancel != nil { 188 newCancel() 189 } 190 }, plcy.Timeouts.ConnectionIdle) 191 192 requestDone := func() error { 193 defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly) 194 195 var writer buf.Writer 196 if destination.Network == net.Network_TCP { 197 if h.config.Fragment != nil { 198 newError("FRAGMENT", h.config.Fragment.PacketsFrom, h.config.Fragment.PacketsTo, h.config.Fragment.LengthMin, h.config.Fragment.LengthMax, 199 h.config.Fragment.IntervalMin, h.config.Fragment.IntervalMax).AtDebug().WriteToLog(session.ExportIDToError(ctx)) 200 writer = buf.NewWriter(&FragmentWriter{ 201 fragment: h.config.Fragment, 202 writer: conn, 203 }) 204 } else { 205 writer = buf.NewWriter(conn) 206 } 207 } else { 208 writer = NewPacketWriter(conn, h, ctx, UDPOverride) 209 } 210 211 if err := buf.Copy(input, writer, buf.UpdateActivity(timer)); err != nil { 212 return newError("failed to process request").Base(err) 213 } 214 215 return nil 216 } 217 218 responseDone := func() error { 219 defer timer.SetTimeout(plcy.Timeouts.UplinkOnly) 220 if destination.Network == net.Network_TCP { 221 var writeConn net.Conn 222 if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Conn != nil && useSplice { 223 writeConn = inbound.Conn 224 } 225 return proxy.CopyRawConnIfExist(ctx, conn, writeConn, link.Writer, timer) 226 } 227 reader := NewPacketReader(conn, UDPOverride) 228 if err := buf.Copy(reader, output, buf.UpdateActivity(timer)); err != nil { 229 return newError("failed to process response").Base(err) 230 } 231 return nil 232 } 233 234 if newCtx != nil { 235 ctx = newCtx 236 } 237 238 if err := task.Run(ctx, requestDone, task.OnSuccess(responseDone, task.Close(output))); err != nil { 239 return newError("connection ends").Base(err) 240 } 241 242 return nil 243 } 244 245 func NewPacketReader(conn net.Conn, UDPOverride net.Destination) buf.Reader { 246 iConn := conn 247 statConn, ok := iConn.(*stat.CounterConnection) 248 if ok { 249 iConn = statConn.Connection 250 } 251 var counter stats.Counter 252 if statConn != nil { 253 counter = statConn.ReadCounter 254 } 255 if c, ok := iConn.(*internet.PacketConnWrapper); ok && UDPOverride.Address == nil && UDPOverride.Port == 0 { 256 return &PacketReader{ 257 PacketConnWrapper: c, 258 Counter: counter, 259 } 260 } 261 return &buf.PacketReader{Reader: conn} 262 } 263 264 type PacketReader struct { 265 *internet.PacketConnWrapper 266 stats.Counter 267 } 268 269 func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) { 270 b := buf.New() 271 b.Resize(0, buf.Size) 272 n, d, err := r.PacketConnWrapper.ReadFrom(b.Bytes()) 273 if err != nil { 274 b.Release() 275 return nil, err 276 } 277 b.Resize(0, int32(n)) 278 b.UDP = &net.Destination{ 279 Address: net.IPAddress(d.(*net.UDPAddr).IP), 280 Port: net.Port(d.(*net.UDPAddr).Port), 281 Network: net.Network_UDP, 282 } 283 if r.Counter != nil { 284 r.Counter.Add(int64(n)) 285 } 286 return buf.MultiBuffer{b}, nil 287 } 288 289 func NewPacketWriter(conn net.Conn, h *Handler, ctx context.Context, UDPOverride net.Destination) buf.Writer { 290 iConn := conn 291 statConn, ok := iConn.(*stat.CounterConnection) 292 if ok { 293 iConn = statConn.Connection 294 } 295 var counter stats.Counter 296 if statConn != nil { 297 counter = statConn.WriteCounter 298 } 299 if c, ok := iConn.(*internet.PacketConnWrapper); ok { 300 return &PacketWriter{ 301 PacketConnWrapper: c, 302 Counter: counter, 303 Handler: h, 304 Context: ctx, 305 UDPOverride: UDPOverride, 306 } 307 } 308 return &buf.SequentialWriter{Writer: conn} 309 } 310 311 type PacketWriter struct { 312 *internet.PacketConnWrapper 313 stats.Counter 314 *Handler 315 context.Context 316 UDPOverride net.Destination 317 } 318 319 func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { 320 for { 321 mb2, b := buf.SplitFirst(mb) 322 mb = mb2 323 if b == nil { 324 break 325 } 326 var n int 327 var err error 328 if b.UDP != nil { 329 if w.UDPOverride.Address != nil { 330 b.UDP.Address = w.UDPOverride.Address 331 } 332 if w.UDPOverride.Port != 0 { 333 b.UDP.Port = w.UDPOverride.Port 334 } 335 if w.Handler.config.hasStrategy() && b.UDP.Address.Family().IsDomain() { 336 ip := w.Handler.resolveIP(w.Context, b.UDP.Address.Domain(), nil) 337 if ip != nil { 338 b.UDP.Address = ip 339 } 340 } 341 destAddr, _ := net.ResolveUDPAddr("udp", b.UDP.NetAddr()) 342 if destAddr == nil { 343 b.Release() 344 continue 345 } 346 n, err = w.PacketConnWrapper.WriteTo(b.Bytes(), destAddr) 347 } else { 348 n, err = w.PacketConnWrapper.Write(b.Bytes()) 349 } 350 b.Release() 351 if err != nil { 352 buf.ReleaseMulti(mb) 353 return err 354 } 355 if w.Counter != nil { 356 w.Counter.Add(int64(n)) 357 } 358 } 359 return nil 360 } 361 362 type FragmentWriter struct { 363 fragment *Fragment 364 writer io.Writer 365 count uint64 366 } 367 368 func (f *FragmentWriter) Write(b []byte) (int, error) { 369 f.count++ 370 371 if f.fragment.PacketsFrom == 0 && f.fragment.PacketsTo == 1 { 372 if f.count != 1 || len(b) <= 5 || b[0] != 22 { 373 return f.writer.Write(b) 374 } 375 recordLen := 5 + ((int(b[3]) << 8) | int(b[4])) 376 if len(b) < recordLen { // maybe already fragmented somehow 377 return f.writer.Write(b) 378 } 379 data := b[5:recordLen] 380 buf := make([]byte, 1024) 381 for from := 0; ; { 382 to := from + int(randBetween(int64(f.fragment.LengthMin), int64(f.fragment.LengthMax))) 383 if to > len(data) { 384 to = len(data) 385 } 386 copy(buf[:3], b) 387 copy(buf[5:], data[from:to]) 388 l := to - from 389 from = to 390 buf[3] = byte(l >> 8) 391 buf[4] = byte(l) 392 _, err := f.writer.Write(buf[:5+l]) 393 time.Sleep(time.Duration(randBetween(int64(f.fragment.IntervalMin), int64(f.fragment.IntervalMax))) * time.Millisecond) 394 if err != nil { 395 return 0, err 396 } 397 if from == len(data) { 398 if len(b) > recordLen { 399 n, err := f.writer.Write(b[recordLen:]) 400 if err != nil { 401 return recordLen + n, err 402 } 403 } 404 return len(b), nil 405 } 406 } 407 } 408 409 if f.fragment.PacketsFrom != 0 && (f.count < f.fragment.PacketsFrom || f.count > f.fragment.PacketsTo) { 410 return f.writer.Write(b) 411 } 412 for from := 0; ; { 413 to := from + int(randBetween(int64(f.fragment.LengthMin), int64(f.fragment.LengthMax))) 414 if to > len(b) { 415 to = len(b) 416 } 417 n, err := f.writer.Write(b[from:to]) 418 from += n 419 time.Sleep(time.Duration(randBetween(int64(f.fragment.IntervalMin), int64(f.fragment.IntervalMax))) * time.Millisecond) 420 if err != nil { 421 return from, err 422 } 423 if from >= len(b) { 424 return from, nil 425 } 426 } 427 } 428 429 // stolen from github.com/xmplusdev/xray-core/transport/internet/reality 430 func randBetween(left int64, right int64) int64 { 431 if left == right { 432 return left 433 } 434 bigInt, _ := rand.Int(rand.Reader, big.NewInt(right-left)) 435 return left + bigInt.Int64() 436 }