github.com/xraypb/xray-core@v1.6.6/proxy/shadowsocks_2022/outbound.go (about)

     1  package shadowsocks_2022
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"runtime"
     7  	"time"
     8  
     9  	shadowsocks "github.com/sagernet/sing-shadowsocks"
    10  	"github.com/sagernet/sing-shadowsocks/shadowaead_2022"
    11  	C "github.com/sagernet/sing/common"
    12  	B "github.com/sagernet/sing/common/buf"
    13  	"github.com/sagernet/sing/common/bufio"
    14  	M "github.com/sagernet/sing/common/metadata"
    15  	N "github.com/sagernet/sing/common/network"
    16  	"github.com/sagernet/sing/common/uot"
    17  	"github.com/xraypb/xray-core/common"
    18  	"github.com/xraypb/xray-core/common/buf"
    19  	"github.com/xraypb/xray-core/common/net"
    20  	"github.com/xraypb/xray-core/common/session"
    21  	"github.com/xraypb/xray-core/transport"
    22  	"github.com/xraypb/xray-core/transport/internet"
    23  )
    24  
    25  func init() {
    26  	common.Must(common.RegisterConfig((*ClientConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
    27  		return NewClient(ctx, config.(*ClientConfig))
    28  	}))
    29  }
    30  
    31  type Outbound struct {
    32  	ctx    context.Context
    33  	server net.Destination
    34  	method shadowsocks.Method
    35  	uot    bool
    36  }
    37  
    38  func NewClient(ctx context.Context, config *ClientConfig) (*Outbound, error) {
    39  	o := &Outbound{
    40  		ctx: ctx,
    41  		server: net.Destination{
    42  			Address: config.Address.AsAddress(),
    43  			Port:    net.Port(config.Port),
    44  			Network: net.Network_TCP,
    45  		},
    46  		uot: config.UdpOverTcp,
    47  	}
    48  	if C.Contains(shadowaead_2022.List, config.Method) {
    49  		if config.Key == "" {
    50  			return nil, newError("missing psk")
    51  		}
    52  		method, err := shadowaead_2022.NewWithPassword(config.Method, config.Key)
    53  		if err != nil {
    54  			return nil, newError("create method").Base(err)
    55  		}
    56  		o.method = method
    57  	} else {
    58  		return nil, newError("unknown method ", config.Method)
    59  	}
    60  	return o, nil
    61  }
    62  
    63  func (o *Outbound) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
    64  	var inboundConn net.Conn
    65  	inbound := session.InboundFromContext(ctx)
    66  	if inbound != nil {
    67  		inboundConn = inbound.Conn
    68  	}
    69  
    70  	outbound := session.OutboundFromContext(ctx)
    71  	if outbound == nil || !outbound.Target.IsValid() {
    72  		return newError("target not specified")
    73  	}
    74  	destination := outbound.Target
    75  	network := destination.Network
    76  
    77  	newError("tunneling request to ", destination, " via ", o.server.NetAddr()).WriteToLog(session.ExportIDToError(ctx))
    78  
    79  	serverDestination := o.server
    80  	if o.uot {
    81  		serverDestination.Network = net.Network_TCP
    82  	} else {
    83  		serverDestination.Network = network
    84  	}
    85  	connection, err := dialer.Dial(ctx, serverDestination)
    86  	if err != nil {
    87  		return newError("failed to connect to server").Base(err)
    88  	}
    89  
    90  	if network == net.Network_TCP {
    91  		serverConn := o.method.DialEarlyConn(connection, toSocksaddr(destination))
    92  		var handshake bool
    93  		if timeoutReader, isTimeoutReader := link.Reader.(buf.TimeoutReader); isTimeoutReader {
    94  			mb, err := timeoutReader.ReadMultiBufferTimeout(time.Millisecond * 100)
    95  			if err != nil && err != buf.ErrNotTimeoutReader && err != buf.ErrReadTimeout {
    96  				return newError("read payload").Base(err)
    97  			}
    98  			_payload := B.StackNew()
    99  			payload := C.Dup(_payload)
   100  			defer payload.Release()
   101  			for {
   102  				payload.FullReset()
   103  				nb, n := buf.SplitBytes(mb, payload.FreeBytes())
   104  				if n > 0 {
   105  					payload.Truncate(n)
   106  					_, err = serverConn.Write(payload.Bytes())
   107  					if err != nil {
   108  						return newError("write payload").Base(err)
   109  					}
   110  					handshake = true
   111  				}
   112  				if nb.IsEmpty() {
   113  					break
   114  				} else {
   115  					mb = nb
   116  				}
   117  			}
   118  			runtime.KeepAlive(_payload)
   119  		}
   120  		if !handshake {
   121  			_, err = serverConn.Write(nil)
   122  			if err != nil {
   123  				return newError("client handshake").Base(err)
   124  			}
   125  		}
   126  		conn := &pipeConnWrapper{
   127  			W:    link.Writer,
   128  			Conn: inboundConn,
   129  		}
   130  		if ir, ok := link.Reader.(io.Reader); ok {
   131  			conn.R = ir
   132  		} else {
   133  			conn.R = &buf.BufferedReader{Reader: link.Reader}
   134  		}
   135  
   136  		return returnError(bufio.CopyConn(ctx, conn, serverConn))
   137  	} else {
   138  		var packetConn N.PacketConn
   139  		if pc, isPacketConn := inboundConn.(N.PacketConn); isPacketConn {
   140  			packetConn = pc
   141  		} else if nc, isNetPacket := inboundConn.(net.PacketConn); isNetPacket {
   142  			packetConn = bufio.NewPacketConn(nc)
   143  		} else {
   144  			packetConn = &packetConnWrapper{
   145  				Reader: link.Reader,
   146  				Writer: link.Writer,
   147  				Conn:   inboundConn,
   148  				Dest:   destination,
   149  			}
   150  		}
   151  
   152  		if o.uot {
   153  			serverConn := o.method.DialEarlyConn(connection, M.Socksaddr{Fqdn: uot.UOTMagicAddress})
   154  			return returnError(bufio.CopyPacketConn(ctx, packetConn, uot.NewClientConn(serverConn)))
   155  		} else {
   156  			serverConn := o.method.DialPacketConn(connection)
   157  			return returnError(bufio.CopyPacketConn(ctx, packetConn, serverConn))
   158  		}
   159  	}
   160  }