github.com/imannamdari/v2ray-core/v5@v5.0.5/transport/internet/tcp/hub.go (about)

     1  package tcp
     2  
     3  import (
     4  	"context"
     5  	gotls "crypto/tls"
     6  	"strings"
     7  	"time"
     8  
     9  	"github.com/imannamdari/v2ray-core/v5/common"
    10  	"github.com/imannamdari/v2ray-core/v5/common/net"
    11  	"github.com/imannamdari/v2ray-core/v5/common/serial"
    12  	"github.com/imannamdari/v2ray-core/v5/common/session"
    13  	"github.com/imannamdari/v2ray-core/v5/transport/internet"
    14  	"github.com/imannamdari/v2ray-core/v5/transport/internet/tls"
    15  )
    16  
    17  // Listener is an internet.Listener that listens for TCP connections.
    18  type Listener struct {
    19  	listener   net.Listener
    20  	tlsConfig  *gotls.Config
    21  	authConfig internet.ConnectionAuthenticator
    22  	config     *Config
    23  	addConn    internet.ConnHandler
    24  	locker     *internet.FileLocker // for unix domain socket
    25  }
    26  
    27  // ListenTCP creates a new Listener based on configurations.
    28  func ListenTCP(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, handler internet.ConnHandler) (internet.Listener, error) {
    29  	l := &Listener{
    30  		addConn: handler,
    31  	}
    32  	tcpSettings := streamSettings.ProtocolSettings.(*Config)
    33  	l.config = tcpSettings
    34  	if l.config != nil {
    35  		if streamSettings.SocketSettings == nil {
    36  			streamSettings.SocketSettings = &internet.SocketConfig{}
    37  		}
    38  		streamSettings.SocketSettings.AcceptProxyProtocol = l.config.AcceptProxyProtocol
    39  	}
    40  	var listener net.Listener
    41  	var err error
    42  	if address.Family().IsDomain() {
    43  		listener, err = internet.ListenSystem(ctx, &net.UnixAddr{
    44  			Name: address.Domain(),
    45  			Net:  "unix",
    46  		}, streamSettings.SocketSettings)
    47  		if err != nil {
    48  			return nil, newError("failed to listen Unix Domain Socket on ", address).Base(err)
    49  		}
    50  		newError("listening Unix Domain Socket on ", address).WriteToLog(session.ExportIDToError(ctx))
    51  		locker := ctx.Value(address.Domain())
    52  		if locker != nil {
    53  			l.locker = locker.(*internet.FileLocker)
    54  		}
    55  	} else {
    56  		listener, err = internet.ListenSystem(ctx, &net.TCPAddr{
    57  			IP:   address.IP(),
    58  			Port: int(port),
    59  		}, streamSettings.SocketSettings)
    60  		if err != nil {
    61  			return nil, newError("failed to listen TCP on ", address, ":", port).Base(err)
    62  		}
    63  		newError("listening TCP on ", address, ":", port).WriteToLog(session.ExportIDToError(ctx))
    64  	}
    65  
    66  	if streamSettings.SocketSettings != nil && streamSettings.SocketSettings.AcceptProxyProtocol {
    67  		newError("accepting PROXY protocol").AtWarning().WriteToLog(session.ExportIDToError(ctx))
    68  	}
    69  
    70  	l.listener = listener
    71  
    72  	if config := tls.ConfigFromStreamSettings(streamSettings); config != nil {
    73  		l.tlsConfig = config.GetTLSConfig()
    74  	}
    75  
    76  	if tcpSettings.HeaderSettings != nil {
    77  		headerConfig, err := serial.GetInstanceOf(tcpSettings.HeaderSettings)
    78  		if err != nil {
    79  			return nil, newError("invalid header settings").Base(err).AtError()
    80  		}
    81  		auth, err := internet.CreateConnectionAuthenticator(headerConfig)
    82  		if err != nil {
    83  			return nil, newError("invalid header settings.").Base(err).AtError()
    84  		}
    85  		l.authConfig = auth
    86  	}
    87  
    88  	go l.keepAccepting()
    89  	return l, nil
    90  }
    91  
    92  func (v *Listener) keepAccepting() {
    93  	for {
    94  		conn, err := v.listener.Accept()
    95  		if err != nil {
    96  			errStr := err.Error()
    97  			if strings.Contains(errStr, "closed") {
    98  				break
    99  			}
   100  			newError("failed to accepted raw connections").Base(err).AtWarning().WriteToLog()
   101  			if strings.Contains(errStr, "too many") {
   102  				time.Sleep(time.Millisecond * 500)
   103  			}
   104  			continue
   105  		}
   106  
   107  		if v.tlsConfig != nil {
   108  			conn = tls.Server(conn, v.tlsConfig)
   109  		}
   110  		if v.authConfig != nil {
   111  			conn = v.authConfig.Server(conn)
   112  		}
   113  
   114  		v.addConn(internet.Connection(conn))
   115  	}
   116  }
   117  
   118  // Addr implements internet.Listener.Addr.
   119  func (v *Listener) Addr() net.Addr {
   120  	return v.listener.Addr()
   121  }
   122  
   123  // Close implements internet.Listener.Close.
   124  func (v *Listener) Close() error {
   125  	if v.locker != nil {
   126  		v.locker.Release()
   127  	}
   128  	return v.listener.Close()
   129  }
   130  
   131  func init() {
   132  	common.Must(internet.RegisterTransportListener(protocolName, ListenTCP))
   133  }