github.com/xraypb/Xray-core@v1.8.1/transport/internet/http/hub.go (about)

     1  package http
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"net/http"
     7  	"strings"
     8  	"time"
     9  
    10  	"github.com/xraypb/Xray-core/common"
    11  	"github.com/xraypb/Xray-core/common/net"
    12  	"github.com/xraypb/Xray-core/common/net/cnc"
    13  	http_proto "github.com/xraypb/Xray-core/common/protocol/http"
    14  	"github.com/xraypb/Xray-core/common/serial"
    15  	"github.com/xraypb/Xray-core/common/session"
    16  	"github.com/xraypb/Xray-core/common/signal/done"
    17  	"github.com/xraypb/Xray-core/transport/internet"
    18  	"github.com/xraypb/Xray-core/transport/internet/reality"
    19  	"github.com/xraypb/Xray-core/transport/internet/tls"
    20  	goreality "github.com/xtls/reality"
    21  	"golang.org/x/net/http2"
    22  	"golang.org/x/net/http2/h2c"
    23  )
    24  
    25  type Listener struct {
    26  	server  *http.Server
    27  	handler internet.ConnHandler
    28  	local   net.Addr
    29  	config  *Config
    30  	locker  *internet.FileLocker // for unix domain socket
    31  }
    32  
    33  func (l *Listener) Addr() net.Addr {
    34  	return l.local
    35  }
    36  
    37  func (l *Listener) Close() error {
    38  	if l.locker != nil {
    39  		l.locker.Release()
    40  	}
    41  	return l.server.Close()
    42  }
    43  
    44  type flushWriter struct {
    45  	w io.Writer
    46  	d *done.Instance
    47  }
    48  
    49  func (fw flushWriter) Write(p []byte) (n int, err error) {
    50  	if fw.d.Done() {
    51  		return 0, io.ErrClosedPipe
    52  	}
    53  
    54  	defer func() {
    55  		if recover() != nil {
    56  			fw.d.Close()
    57  			err = io.ErrClosedPipe
    58  		}
    59  	}()
    60  
    61  	n, err = fw.w.Write(p)
    62  	if f, ok := fw.w.(http.Flusher); ok && err == nil {
    63  		f.Flush()
    64  	}
    65  	return
    66  }
    67  
    68  func (l *Listener) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
    69  	host := request.Host
    70  	if !l.config.isValidHost(host) {
    71  		writer.WriteHeader(404)
    72  		return
    73  	}
    74  	path := l.config.getNormalizedPath()
    75  	if !strings.HasPrefix(request.URL.Path, path) {
    76  		writer.WriteHeader(404)
    77  		return
    78  	}
    79  
    80  	writer.Header().Set("Cache-Control", "no-store")
    81  
    82  	for _, httpHeader := range l.config.Header {
    83  		for _, httpHeaderValue := range httpHeader.Value {
    84  			writer.Header().Set(httpHeader.Name, httpHeaderValue)
    85  		}
    86  	}
    87  
    88  	writer.WriteHeader(200)
    89  	if f, ok := writer.(http.Flusher); ok {
    90  		f.Flush()
    91  	}
    92  
    93  	remoteAddr := l.Addr()
    94  	dest, err := net.ParseDestination(request.RemoteAddr)
    95  	if err != nil {
    96  		newError("failed to parse request remote addr: ", request.RemoteAddr).Base(err).WriteToLog()
    97  	} else {
    98  		remoteAddr = &net.TCPAddr{
    99  			IP:   dest.Address.IP(),
   100  			Port: int(dest.Port),
   101  		}
   102  	}
   103  
   104  	forwardedAddress := http_proto.ParseXForwardedFor(request.Header)
   105  	if len(forwardedAddress) > 0 && forwardedAddress[0].Family().IsIP() {
   106  		remoteAddr = &net.TCPAddr{
   107  			IP:   forwardedAddress[0].IP(),
   108  			Port: 0,
   109  		}
   110  	}
   111  
   112  	done := done.New()
   113  	conn := cnc.NewConnection(
   114  		cnc.ConnectionOutput(request.Body),
   115  		cnc.ConnectionInput(flushWriter{w: writer, d: done}),
   116  		cnc.ConnectionOnClose(common.ChainedClosable{done, request.Body}),
   117  		cnc.ConnectionLocalAddr(l.Addr()),
   118  		cnc.ConnectionRemoteAddr(remoteAddr),
   119  	)
   120  	l.handler(conn)
   121  	<-done.Wait()
   122  }
   123  
   124  func Listen(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, handler internet.ConnHandler) (internet.Listener, error) {
   125  	httpSettings := streamSettings.ProtocolSettings.(*Config)
   126  	var listener *Listener
   127  	if port == net.Port(0) { // unix
   128  		listener = &Listener{
   129  			handler: handler,
   130  			local: &net.UnixAddr{
   131  				Name: address.Domain(),
   132  				Net:  "unix",
   133  			},
   134  			config: httpSettings,
   135  		}
   136  	} else { // tcp
   137  		listener = &Listener{
   138  			handler: handler,
   139  			local: &net.TCPAddr{
   140  				IP:   address.IP(),
   141  				Port: int(port),
   142  			},
   143  			config: httpSettings,
   144  		}
   145  	}
   146  
   147  	var server *http.Server
   148  	config := tls.ConfigFromStreamSettings(streamSettings)
   149  	if config == nil {
   150  		h2s := &http2.Server{}
   151  
   152  		server = &http.Server{
   153  			Addr:              serial.Concat(address, ":", port),
   154  			Handler:           h2c.NewHandler(listener, h2s),
   155  			ReadHeaderTimeout: time.Second * 4,
   156  		}
   157  	} else {
   158  		server = &http.Server{
   159  			Addr:              serial.Concat(address, ":", port),
   160  			TLSConfig:         config.GetTLSConfig(tls.WithNextProto("h2")),
   161  			Handler:           listener,
   162  			ReadHeaderTimeout: time.Second * 4,
   163  		}
   164  	}
   165  
   166  	if streamSettings.SocketSettings != nil && streamSettings.SocketSettings.AcceptProxyProtocol {
   167  		newError("accepting PROXY protocol").AtWarning().WriteToLog(session.ExportIDToError(ctx))
   168  	}
   169  
   170  	listener.server = server
   171  	go func() {
   172  		var streamListener net.Listener
   173  		var err error
   174  		if port == net.Port(0) { // unix
   175  			streamListener, err = internet.ListenSystem(ctx, &net.UnixAddr{
   176  				Name: address.Domain(),
   177  				Net:  "unix",
   178  			}, streamSettings.SocketSettings)
   179  			if err != nil {
   180  				newError("failed to listen on ", address).Base(err).AtError().WriteToLog(session.ExportIDToError(ctx))
   181  				return
   182  			}
   183  			locker := ctx.Value(address.Domain())
   184  			if locker != nil {
   185  				listener.locker = locker.(*internet.FileLocker)
   186  			}
   187  		} else { // tcp
   188  			streamListener, err = internet.ListenSystem(ctx, &net.TCPAddr{
   189  				IP:   address.IP(),
   190  				Port: int(port),
   191  			}, streamSettings.SocketSettings)
   192  			if err != nil {
   193  				newError("failed to listen on ", address, ":", port).Base(err).AtError().WriteToLog(session.ExportIDToError(ctx))
   194  				return
   195  			}
   196  		}
   197  
   198  		if config == nil {
   199  			if config := reality.ConfigFromStreamSettings(streamSettings); config != nil {
   200  				streamListener = goreality.NewListener(streamListener, config.GetREALITYConfig())
   201  			}
   202  			err = server.Serve(streamListener)
   203  			if err != nil {
   204  				newError("stopping serving H2C or REALITY H2").Base(err).WriteToLog(session.ExportIDToError(ctx))
   205  			}
   206  		} else {
   207  			err = server.ServeTLS(streamListener, "", "")
   208  			if err != nil {
   209  				newError("stopping serving TLS H2").Base(err).WriteToLog(session.ExportIDToError(ctx))
   210  			}
   211  		}
   212  	}()
   213  
   214  	return listener, nil
   215  }
   216  
   217  func init() {
   218  	common.Must(internet.RegisterTransportListener(protocolName, Listen))
   219  }