github.com/Uhtred009/v2ray-core-1@v4.31.2+incompatible/transport/internet/websocket/hub.go (about)

     1  // +build !confonly
     2  
     3  package websocket
     4  
     5  import (
     6  	"context"
     7  	"crypto/tls"
     8  	"net/http"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/gorilla/websocket"
    13  	"github.com/pires/go-proxyproto"
    14  
    15  	"v2ray.com/core/common"
    16  	"v2ray.com/core/common/net"
    17  	http_proto "v2ray.com/core/common/protocol/http"
    18  	"v2ray.com/core/common/session"
    19  	"v2ray.com/core/transport/internet"
    20  	v2tls "v2ray.com/core/transport/internet/tls"
    21  )
    22  
    23  type requestHandler struct {
    24  	path string
    25  	ln   *Listener
    26  }
    27  
    28  var upgrader = &websocket.Upgrader{
    29  	ReadBufferSize:   4 * 1024,
    30  	WriteBufferSize:  4 * 1024,
    31  	HandshakeTimeout: time.Second * 4,
    32  	CheckOrigin: func(r *http.Request) bool {
    33  		return true
    34  	},
    35  }
    36  
    37  func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
    38  	if request.URL.Path != h.path {
    39  		writer.WriteHeader(http.StatusNotFound)
    40  		return
    41  	}
    42  	conn, err := upgrader.Upgrade(writer, request, nil)
    43  	if err != nil {
    44  		newError("failed to convert to WebSocket connection").Base(err).WriteToLog()
    45  		return
    46  	}
    47  
    48  	forwardedAddrs := http_proto.ParseXForwardedFor(request.Header)
    49  	remoteAddr := conn.RemoteAddr()
    50  	if len(forwardedAddrs) > 0 && forwardedAddrs[0].Family().IsIP() {
    51  		remoteAddr.(*net.TCPAddr).IP = forwardedAddrs[0].IP()
    52  	}
    53  
    54  	h.ln.addConn(newConnection(conn, remoteAddr))
    55  }
    56  
    57  type Listener struct {
    58  	sync.Mutex
    59  	server   http.Server
    60  	listener net.Listener
    61  	config   *Config
    62  	addConn  internet.ConnHandler
    63  }
    64  
    65  func ListenWS(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) {
    66  	listener, err := internet.ListenSystem(ctx, &net.TCPAddr{
    67  		IP:   address.IP(),
    68  		Port: int(port),
    69  	}, streamSettings.SocketSettings)
    70  	if err != nil {
    71  		return nil, newError("failed to listen TCP(for WS) on", address, ":", port).Base(err)
    72  	}
    73  	newError("listening TCP(for WS) on ", address, ":", port).WriteToLog(session.ExportIDToError(ctx))
    74  
    75  	wsSettings := streamSettings.ProtocolSettings.(*Config)
    76  
    77  	if wsSettings.AcceptProxyProtocol {
    78  		policyFunc := func(upstream net.Addr) (proxyproto.Policy, error) { return proxyproto.REQUIRE, nil }
    79  		listener = &proxyproto.Listener{Listener: listener, Policy: policyFunc}
    80  		newError("accepting PROXY protocol").AtWarning().WriteToLog(session.ExportIDToError(ctx))
    81  	}
    82  
    83  	if config := v2tls.ConfigFromStreamSettings(streamSettings); config != nil {
    84  		if tlsConfig := config.GetTLSConfig(); tlsConfig != nil {
    85  			listener = tls.NewListener(listener, tlsConfig)
    86  		}
    87  	}
    88  
    89  	l := &Listener{
    90  		config:   wsSettings,
    91  		addConn:  addConn,
    92  		listener: listener,
    93  	}
    94  
    95  	l.server = http.Server{
    96  		Handler: &requestHandler{
    97  			path: wsSettings.GetNormalizedPath(),
    98  			ln:   l,
    99  		},
   100  		ReadHeaderTimeout: time.Second * 4,
   101  		MaxHeaderBytes:    2048,
   102  	}
   103  
   104  	go func() {
   105  		if err := l.server.Serve(l.listener); err != nil {
   106  			newError("failed to serve http for WebSocket").Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx))
   107  		}
   108  	}()
   109  
   110  	return l, err
   111  }
   112  
   113  // Addr implements net.Listener.Addr().
   114  func (ln *Listener) Addr() net.Addr {
   115  	return ln.listener.Addr()
   116  }
   117  
   118  // Close implements net.Listener.Close().
   119  func (ln *Listener) Close() error {
   120  	return ln.listener.Close()
   121  }
   122  
   123  func init() {
   124  	common.Must(internet.RegisterTransportListener(protocolName, ListenWS))
   125  }