github.com/xmplusdev/xray-core@v1.8.10/transport/internet/http/hub.go (about)

     1  package http
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"net/http"
     7  	"strings"
     8  	"time"
     9  
    10  	goreality "github.com/xtls/reality"
    11  	"github.com/xmplusdev/xray-core/common"
    12  	"github.com/xmplusdev/xray-core/common/net"
    13  	"github.com/xmplusdev/xray-core/common/net/cnc"
    14  	http_proto "github.com/xmplusdev/xray-core/common/protocol/http"
    15  	"github.com/xmplusdev/xray-core/common/serial"
    16  	"github.com/xmplusdev/xray-core/common/session"
    17  	"github.com/xmplusdev/xray-core/common/signal/done"
    18  	"github.com/xmplusdev/xray-core/transport/internet"
    19  	"github.com/xmplusdev/xray-core/transport/internet/reality"
    20  	"github.com/xmplusdev/xray-core/transport/internet/tls"
    21  	"golang.org/x/net/http2"
    22  	"golang.org/x/net/http2/h2c"
    23  )
    24  
    25  type Listener struct {
    26  	server  *http.Server
    27  	handler internet.ConnHandler
    28  	local   net.Addr
    29  	config  *Config
    30  }
    31  
    32  func (l *Listener) Addr() net.Addr {
    33  	return l.local
    34  }
    35  
    36  func (l *Listener) Close() error {
    37  	return l.server.Close()
    38  }
    39  
    40  type flushWriter struct {
    41  	w io.Writer
    42  	d *done.Instance
    43  }
    44  
    45  func (fw flushWriter) Write(p []byte) (n int, err error) {
    46  	if fw.d.Done() {
    47  		return 0, io.ErrClosedPipe
    48  	}
    49  
    50  	defer func() {
    51  		if recover() != nil {
    52  			fw.d.Close()
    53  			err = io.ErrClosedPipe
    54  		}
    55  	}()
    56  
    57  	n, err = fw.w.Write(p)
    58  	if f, ok := fw.w.(http.Flusher); ok && err == nil {
    59  		f.Flush()
    60  	}
    61  	return
    62  }
    63  
    64  func (l *Listener) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
    65  	host := request.Host
    66  	if !l.config.isValidHost(host) {
    67  		writer.WriteHeader(404)
    68  		return
    69  	}
    70  	path := l.config.getNormalizedPath()
    71  	if !strings.HasPrefix(request.URL.Path, path) {
    72  		writer.WriteHeader(404)
    73  		return
    74  	}
    75  
    76  	writer.Header().Set("Cache-Control", "no-store")
    77  
    78  	for _, httpHeader := range l.config.Header {
    79  		for _, httpHeaderValue := range httpHeader.Value {
    80  			writer.Header().Set(httpHeader.Name, httpHeaderValue)
    81  		}
    82  	}
    83  
    84  	writer.WriteHeader(200)
    85  	if f, ok := writer.(http.Flusher); ok {
    86  		f.Flush()
    87  	}
    88  
    89  	remoteAddr := l.Addr()
    90  	dest, err := net.ParseDestination(request.RemoteAddr)
    91  	if err != nil {
    92  		newError("failed to parse request remote addr: ", request.RemoteAddr).Base(err).WriteToLog()
    93  	} else {
    94  		remoteAddr = &net.TCPAddr{
    95  			IP:   dest.Address.IP(),
    96  			Port: int(dest.Port),
    97  		}
    98  	}
    99  
   100  	forwardedAddress := http_proto.ParseXForwardedFor(request.Header)
   101  	if len(forwardedAddress) > 0 && forwardedAddress[0].Family().IsIP() {
   102  		remoteAddr = &net.TCPAddr{
   103  			IP:   forwardedAddress[0].IP(),
   104  			Port: 0,
   105  		}
   106  	}
   107  
   108  	done := done.New()
   109  	conn := cnc.NewConnection(
   110  		cnc.ConnectionOutput(request.Body),
   111  		cnc.ConnectionInput(flushWriter{w: writer, d: done}),
   112  		cnc.ConnectionOnClose(common.ChainedClosable{done, request.Body}),
   113  		cnc.ConnectionLocalAddr(l.Addr()),
   114  		cnc.ConnectionRemoteAddr(remoteAddr),
   115  	)
   116  	l.handler(conn)
   117  	<-done.Wait()
   118  }
   119  
   120  func Listen(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, handler internet.ConnHandler) (internet.Listener, error) {
   121  	httpSettings := streamSettings.ProtocolSettings.(*Config)
   122  	var listener *Listener
   123  	if port == net.Port(0) { // unix
   124  		listener = &Listener{
   125  			handler: handler,
   126  			local: &net.UnixAddr{
   127  				Name: address.Domain(),
   128  				Net:  "unix",
   129  			},
   130  			config: httpSettings,
   131  		}
   132  	} else { // tcp
   133  		listener = &Listener{
   134  			handler: handler,
   135  			local: &net.TCPAddr{
   136  				IP:   address.IP(),
   137  				Port: int(port),
   138  			},
   139  			config: httpSettings,
   140  		}
   141  	}
   142  
   143  	var server *http.Server
   144  	config := tls.ConfigFromStreamSettings(streamSettings)
   145  	if config == nil {
   146  		h2s := &http2.Server{}
   147  
   148  		server = &http.Server{
   149  			Addr:              serial.Concat(address, ":", port),
   150  			Handler:           h2c.NewHandler(listener, h2s),
   151  			ReadHeaderTimeout: time.Second * 4,
   152  		}
   153  	} else {
   154  		server = &http.Server{
   155  			Addr:              serial.Concat(address, ":", port),
   156  			TLSConfig:         config.GetTLSConfig(tls.WithNextProto("h2")),
   157  			Handler:           listener,
   158  			ReadHeaderTimeout: time.Second * 4,
   159  		}
   160  	}
   161  
   162  	if streamSettings.SocketSettings != nil && streamSettings.SocketSettings.AcceptProxyProtocol {
   163  		newError("accepting PROXY protocol").AtWarning().WriteToLog(session.ExportIDToError(ctx))
   164  	}
   165  
   166  	listener.server = server
   167  	go func() {
   168  		var streamListener net.Listener
   169  		var err error
   170  		if port == net.Port(0) { // unix
   171  			streamListener, err = internet.ListenSystem(ctx, &net.UnixAddr{
   172  				Name: address.Domain(),
   173  				Net:  "unix",
   174  			}, streamSettings.SocketSettings)
   175  			if err != nil {
   176  				newError("failed to listen on ", address).Base(err).AtError().WriteToLog(session.ExportIDToError(ctx))
   177  				return
   178  			}
   179  		} else { // tcp
   180  			streamListener, err = internet.ListenSystem(ctx, &net.TCPAddr{
   181  				IP:   address.IP(),
   182  				Port: int(port),
   183  			}, streamSettings.SocketSettings)
   184  			if err != nil {
   185  				newError("failed to listen on ", address, ":", port).Base(err).AtError().WriteToLog(session.ExportIDToError(ctx))
   186  				return
   187  			}
   188  		}
   189  
   190  		if config == nil {
   191  			if config := reality.ConfigFromStreamSettings(streamSettings); config != nil {
   192  				streamListener = goreality.NewListener(streamListener, config.GetREALITYConfig())
   193  			}
   194  			err = server.Serve(streamListener)
   195  			if err != nil {
   196  				newError("stopping serving H2C or REALITY H2").Base(err).WriteToLog(session.ExportIDToError(ctx))
   197  			}
   198  		} else {
   199  			err = server.ServeTLS(streamListener, "", "")
   200  			if err != nil {
   201  				newError("stopping serving TLS H2").Base(err).WriteToLog(session.ExportIDToError(ctx))
   202  			}
   203  		}
   204  	}()
   205  
   206  	return listener, nil
   207  }
   208  
   209  func init() {
   210  	common.Must(internet.RegisterTransportListener(protocolName, Listen))
   211  }