github.com/xmplusdev/xmcore@v1.8.11-0.20240412132628-5518b55526af/transport/internet/tcp/hub.go (about)

     1  package tcp
     2  
     3  import (
     4  	"context"
     5  	gotls "crypto/tls"
     6  	"strings"
     7  	"time"
     8  
     9  	goreality "github.com/xtls/reality"
    10  	"github.com/xmplusdev/xmcore/common"
    11  	"github.com/xmplusdev/xmcore/common/net"
    12  	"github.com/xmplusdev/xmcore/common/session"
    13  	"github.com/xmplusdev/xmcore/transport/internet"
    14  	"github.com/xmplusdev/xmcore/transport/internet/reality"
    15  	"github.com/xmplusdev/xmcore/transport/internet/stat"
    16  	"github.com/xmplusdev/xmcore/transport/internet/tls"
    17  )
    18  
    19  // Listener is an internet.Listener that listens for TCP connections.
    20  type Listener struct {
    21  	listener      net.Listener
    22  	tlsConfig     *gotls.Config
    23  	realityConfig *goreality.Config
    24  	authConfig    internet.ConnectionAuthenticator
    25  	config        *Config
    26  	addConn       internet.ConnHandler
    27  }
    28  
    29  // ListenTCP creates a new Listener based on configurations.
    30  func ListenTCP(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, handler internet.ConnHandler) (internet.Listener, error) {
    31  	l := &Listener{
    32  		addConn: handler,
    33  	}
    34  	tcpSettings := streamSettings.ProtocolSettings.(*Config)
    35  	l.config = tcpSettings
    36  	if l.config != nil {
    37  		if streamSettings.SocketSettings == nil {
    38  			streamSettings.SocketSettings = &internet.SocketConfig{}
    39  		}
    40  		streamSettings.SocketSettings.AcceptProxyProtocol = l.config.AcceptProxyProtocol || streamSettings.SocketSettings.AcceptProxyProtocol
    41  	}
    42  	var listener net.Listener
    43  	var err error
    44  	if port == net.Port(0) { // unix
    45  		listener, err = internet.ListenSystem(ctx, &net.UnixAddr{
    46  			Name: address.Domain(),
    47  			Net:  "unix",
    48  		}, streamSettings.SocketSettings)
    49  		if err != nil {
    50  			return nil, newError("failed to listen Unix Domain Socket on ", address).Base(err)
    51  		}
    52  		newError("listening Unix Domain Socket on ", address).WriteToLog(session.ExportIDToError(ctx))
    53  	} else {
    54  		listener, err = internet.ListenSystem(ctx, &net.TCPAddr{
    55  			IP:   address.IP(),
    56  			Port: int(port),
    57  		}, streamSettings.SocketSettings)
    58  		if err != nil {
    59  			return nil, newError("failed to listen TCP on ", address, ":", port).Base(err)
    60  		}
    61  		newError("listening TCP on ", address, ":", port).WriteToLog(session.ExportIDToError(ctx))
    62  	}
    63  
    64  	if streamSettings.SocketSettings != nil && streamSettings.SocketSettings.AcceptProxyProtocol {
    65  		newError("accepting PROXY protocol").AtWarning().WriteToLog(session.ExportIDToError(ctx))
    66  	}
    67  
    68  	l.listener = listener
    69  
    70  	if config := tls.ConfigFromStreamSettings(streamSettings); config != nil {
    71  		l.tlsConfig = config.GetTLSConfig()
    72  	}
    73  	if config := reality.ConfigFromStreamSettings(streamSettings); config != nil {
    74  		l.realityConfig = config.GetREALITYConfig()
    75  	}
    76  
    77  	if tcpSettings.HeaderSettings != nil {
    78  		headerConfig, err := tcpSettings.HeaderSettings.GetInstance()
    79  		if err != nil {
    80  			return nil, newError("invalid header settings").Base(err).AtError()
    81  		}
    82  		auth, err := internet.CreateConnectionAuthenticator(headerConfig)
    83  		if err != nil {
    84  			return nil, newError("invalid header settings.").Base(err).AtError()
    85  		}
    86  		l.authConfig = auth
    87  	}
    88  
    89  	go l.keepAccepting()
    90  	return l, nil
    91  }
    92  
    93  func (v *Listener) keepAccepting() {
    94  	for {
    95  		conn, err := v.listener.Accept()
    96  		if err != nil {
    97  			errStr := err.Error()
    98  			if strings.Contains(errStr, "closed") {
    99  				break
   100  			}
   101  			newError("failed to accepted raw connections").Base(err).AtWarning().WriteToLog()
   102  			if strings.Contains(errStr, "too many") {
   103  				time.Sleep(time.Millisecond * 500)
   104  			}
   105  			continue
   106  		}
   107  		go func() {
   108  			if v.tlsConfig != nil {
   109  				conn = tls.Server(conn, v.tlsConfig)
   110  			} else if v.realityConfig != nil {
   111  				if conn, err = reality.Server(conn, v.realityConfig); err != nil {
   112  					newError(err).AtInfo().WriteToLog()
   113  					return
   114  				}
   115  			}
   116  			if v.authConfig != nil {
   117  				conn = v.authConfig.Server(conn)
   118  			}
   119  			v.addConn(stat.Connection(conn))
   120  		}()
   121  	}
   122  }
   123  
   124  // Addr implements internet.Listener.Addr.
   125  func (v *Listener) Addr() net.Addr {
   126  	return v.listener.Addr()
   127  }
   128  
   129  // Close implements internet.Listener.Close.
   130  func (v *Listener) Close() error {
   131  	return v.listener.Close()
   132  }
   133  
   134  func init() {
   135  	common.Must(internet.RegisterTransportListener(protocolName, ListenTCP))
   136  }