github.com/EagleQL/Xray-core@v1.4.3/transport/internet/http/hub.go (about)

     1  package http
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"net/http"
     7  	"strings"
     8  	"time"
     9  
    10  	"golang.org/x/net/http2"
    11  	"golang.org/x/net/http2/h2c"
    12  
    13  	"github.com/xtls/xray-core/common"
    14  	"github.com/xtls/xray-core/common/net"
    15  	"github.com/xtls/xray-core/common/net/cnc"
    16  	http_proto "github.com/xtls/xray-core/common/protocol/http"
    17  	"github.com/xtls/xray-core/common/serial"
    18  	"github.com/xtls/xray-core/common/session"
    19  	"github.com/xtls/xray-core/common/signal/done"
    20  	"github.com/xtls/xray-core/transport/internet"
    21  	"github.com/xtls/xray-core/transport/internet/tls"
    22  )
    23  
    24  type Listener struct {
    25  	server  *http.Server
    26  	handler internet.ConnHandler
    27  	local   net.Addr
    28  	config  *Config
    29  	locker  *internet.FileLocker // for unix domain socket
    30  }
    31  
    32  func (l *Listener) Addr() net.Addr {
    33  	return l.local
    34  }
    35  
    36  func (l *Listener) Close() error {
    37  	if l.locker != nil {
    38  		l.locker.Release()
    39  	}
    40  	return l.server.Close()
    41  }
    42  
    43  type flushWriter struct {
    44  	w io.Writer
    45  	d *done.Instance
    46  }
    47  
    48  func (fw flushWriter) Write(p []byte) (n int, err error) {
    49  	if fw.d.Done() {
    50  		return 0, io.ErrClosedPipe
    51  	}
    52  
    53  	n, err = fw.w.Write(p)
    54  	if f, ok := fw.w.(http.Flusher); ok && err == nil {
    55  		f.Flush()
    56  	}
    57  	return
    58  }
    59  
    60  func (l *Listener) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
    61  	host := request.Host
    62  	if !l.config.isValidHost(host) {
    63  		writer.WriteHeader(404)
    64  		return
    65  	}
    66  	path := l.config.getNormalizedPath()
    67  	if !strings.HasPrefix(request.URL.Path, path) {
    68  		writer.WriteHeader(404)
    69  		return
    70  	}
    71  
    72  	writer.Header().Set("Cache-Control", "no-store")
    73  	writer.WriteHeader(200)
    74  	if f, ok := writer.(http.Flusher); ok {
    75  		f.Flush()
    76  	}
    77  
    78  	remoteAddr := l.Addr()
    79  	dest, err := net.ParseDestination(request.RemoteAddr)
    80  	if err != nil {
    81  		newError("failed to parse request remote addr: ", request.RemoteAddr).Base(err).WriteToLog()
    82  	} else {
    83  		remoteAddr = &net.TCPAddr{
    84  			IP:   dest.Address.IP(),
    85  			Port: int(dest.Port),
    86  		}
    87  	}
    88  
    89  	forwardedAddress := http_proto.ParseXForwardedFor(request.Header)
    90  	if len(forwardedAddress) > 0 && forwardedAddress[0].Family().IsIP() {
    91  		remoteAddr = &net.TCPAddr{
    92  			IP:   forwardedAddress[0].IP(),
    93  			Port: 0,
    94  		}
    95  	}
    96  
    97  	done := done.New()
    98  	conn := cnc.NewConnection(
    99  		cnc.ConnectionOutput(request.Body),
   100  		cnc.ConnectionInput(flushWriter{w: writer, d: done}),
   101  		cnc.ConnectionOnClose(common.ChainedClosable{done, request.Body}),
   102  		cnc.ConnectionLocalAddr(l.Addr()),
   103  		cnc.ConnectionRemoteAddr(remoteAddr),
   104  	)
   105  	l.handler(conn)
   106  	<-done.Wait()
   107  }
   108  
   109  func Listen(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, handler internet.ConnHandler) (internet.Listener, error) {
   110  	httpSettings := streamSettings.ProtocolSettings.(*Config)
   111  	var listener *Listener
   112  	if port == net.Port(0) { // unix
   113  		listener = &Listener{
   114  			handler: handler,
   115  			local: &net.UnixAddr{
   116  				Name: address.Domain(),
   117  				Net:  "unix",
   118  			},
   119  			config: httpSettings,
   120  		}
   121  	} else { // tcp
   122  		listener = &Listener{
   123  			handler: handler,
   124  			local: &net.TCPAddr{
   125  				IP:   address.IP(),
   126  				Port: int(port),
   127  			},
   128  			config: httpSettings,
   129  		}
   130  	}
   131  
   132  	var server *http.Server
   133  	config := tls.ConfigFromStreamSettings(streamSettings)
   134  	if config == nil {
   135  		h2s := &http2.Server{}
   136  
   137  		server = &http.Server{
   138  			Addr:              serial.Concat(address, ":", port),
   139  			Handler:           h2c.NewHandler(listener, h2s),
   140  			ReadHeaderTimeout: time.Second * 4,
   141  		}
   142  	} else {
   143  		server = &http.Server{
   144  			Addr:              serial.Concat(address, ":", port),
   145  			TLSConfig:         config.GetTLSConfig(tls.WithNextProto("h2")),
   146  			Handler:           listener,
   147  			ReadHeaderTimeout: time.Second * 4,
   148  		}
   149  	}
   150  
   151  	if streamSettings.SocketSettings != nil && streamSettings.SocketSettings.AcceptProxyProtocol {
   152  		newError("accepting PROXY protocol").AtWarning().WriteToLog(session.ExportIDToError(ctx))
   153  	}
   154  
   155  	listener.server = server
   156  	go func() {
   157  		var streamListener net.Listener
   158  		var err error
   159  		if port == net.Port(0) { // unix
   160  			streamListener, err = internet.ListenSystem(ctx, &net.UnixAddr{
   161  				Name: address.Domain(),
   162  				Net:  "unix",
   163  			}, streamSettings.SocketSettings)
   164  			if err != nil {
   165  				newError("failed to listen on ", address).Base(err).AtError().WriteToLog(session.ExportIDToError(ctx))
   166  				return
   167  			}
   168  			locker := ctx.Value(address.Domain())
   169  			if locker != nil {
   170  				listener.locker = locker.(*internet.FileLocker)
   171  			}
   172  		} else { // tcp
   173  			streamListener, err = internet.ListenSystem(ctx, &net.TCPAddr{
   174  				IP:   address.IP(),
   175  				Port: int(port),
   176  			}, streamSettings.SocketSettings)
   177  			if err != nil {
   178  				newError("failed to listen on ", address, ":", port).Base(err).AtError().WriteToLog(session.ExportIDToError(ctx))
   179  				return
   180  			}
   181  		}
   182  
   183  		if config == nil {
   184  			err = server.Serve(streamListener)
   185  			if err != nil {
   186  				newError("stopping serving H2C").Base(err).WriteToLog(session.ExportIDToError(ctx))
   187  			}
   188  		} else {
   189  			err = server.ServeTLS(streamListener, "", "")
   190  			if err != nil {
   191  				newError("stopping serving TLS").Base(err).WriteToLog(session.ExportIDToError(ctx))
   192  			}
   193  		}
   194  	}()
   195  
   196  	return listener, nil
   197  }
   198  
   199  func init() {
   200  	common.Must(internet.RegisterTransportListener(protocolName, Listen))
   201  }