github.com/xraypb/xray-core@v1.6.6/transport/internet/http/hub.go (about)

     1  package http
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"net/http"
     7  	"strings"
     8  	"time"
     9  
    10  	"github.com/xraypb/xray-core/common"
    11  	"github.com/xraypb/xray-core/common/net"
    12  	"github.com/xraypb/xray-core/common/net/cnc"
    13  	http_proto "github.com/xraypb/xray-core/common/protocol/http"
    14  	"github.com/xraypb/xray-core/common/serial"
    15  	"github.com/xraypb/xray-core/common/session"
    16  	"github.com/xraypb/xray-core/common/signal/done"
    17  	"github.com/xraypb/xray-core/transport/internet"
    18  	"github.com/xraypb/xray-core/transport/internet/tls"
    19  	"golang.org/x/net/http2"
    20  	"golang.org/x/net/http2/h2c"
    21  )
    22  
    23  type Listener struct {
    24  	server  *http.Server
    25  	handler internet.ConnHandler
    26  	local   net.Addr
    27  	config  *Config
    28  	locker  *internet.FileLocker // for unix domain socket
    29  }
    30  
    31  func (l *Listener) Addr() net.Addr {
    32  	return l.local
    33  }
    34  
    35  func (l *Listener) Close() error {
    36  	if l.locker != nil {
    37  		l.locker.Release()
    38  	}
    39  	return l.server.Close()
    40  }
    41  
    42  type flushWriter struct {
    43  	w io.Writer
    44  	d *done.Instance
    45  }
    46  
    47  func (fw flushWriter) Write(p []byte) (n int, err error) {
    48  	if fw.d.Done() {
    49  		return 0, io.ErrClosedPipe
    50  	}
    51  
    52  	n, err = fw.w.Write(p)
    53  	if f, ok := fw.w.(http.Flusher); ok && err == nil {
    54  		f.Flush()
    55  	}
    56  	return
    57  }
    58  
    59  func (l *Listener) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
    60  	host := request.Host
    61  	if !l.config.isValidHost(host) {
    62  		writer.WriteHeader(404)
    63  		return
    64  	}
    65  	path := l.config.getNormalizedPath()
    66  	if !strings.HasPrefix(request.URL.Path, path) {
    67  		writer.WriteHeader(404)
    68  		return
    69  	}
    70  
    71  	writer.Header().Set("Cache-Control", "no-store")
    72  
    73  	for _, httpHeader := range l.config.Header {
    74  		for _, httpHeaderValue := range httpHeader.Value {
    75  			writer.Header().Set(httpHeader.Name, httpHeaderValue)
    76  		}
    77  	}
    78  
    79  	writer.WriteHeader(200)
    80  	if f, ok := writer.(http.Flusher); ok {
    81  		f.Flush()
    82  	}
    83  
    84  	remoteAddr := l.Addr()
    85  	dest, err := net.ParseDestination(request.RemoteAddr)
    86  	if err != nil {
    87  		newError("failed to parse request remote addr: ", request.RemoteAddr).Base(err).WriteToLog()
    88  	} else {
    89  		remoteAddr = &net.TCPAddr{
    90  			IP:   dest.Address.IP(),
    91  			Port: int(dest.Port),
    92  		}
    93  	}
    94  
    95  	forwardedAddress := http_proto.ParseXForwardedFor(request.Header)
    96  	if len(forwardedAddress) > 0 && forwardedAddress[0].Family().IsIP() {
    97  		remoteAddr = &net.TCPAddr{
    98  			IP:   forwardedAddress[0].IP(),
    99  			Port: 0,
   100  		}
   101  	}
   102  
   103  	done := done.New()
   104  	conn := cnc.NewConnection(
   105  		cnc.ConnectionOutput(request.Body),
   106  		cnc.ConnectionInput(flushWriter{w: writer, d: done}),
   107  		cnc.ConnectionOnClose(common.ChainedClosable{done, request.Body}),
   108  		cnc.ConnectionLocalAddr(l.Addr()),
   109  		cnc.ConnectionRemoteAddr(remoteAddr),
   110  	)
   111  	l.handler(conn)
   112  	<-done.Wait()
   113  }
   114  
   115  func Listen(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, handler internet.ConnHandler) (internet.Listener, error) {
   116  	httpSettings := streamSettings.ProtocolSettings.(*Config)
   117  	var listener *Listener
   118  	if port == net.Port(0) { // unix
   119  		listener = &Listener{
   120  			handler: handler,
   121  			local: &net.UnixAddr{
   122  				Name: address.Domain(),
   123  				Net:  "unix",
   124  			},
   125  			config: httpSettings,
   126  		}
   127  	} else { // tcp
   128  		listener = &Listener{
   129  			handler: handler,
   130  			local: &net.TCPAddr{
   131  				IP:   address.IP(),
   132  				Port: int(port),
   133  			},
   134  			config: httpSettings,
   135  		}
   136  	}
   137  
   138  	var server *http.Server
   139  	config := tls.ConfigFromStreamSettings(streamSettings)
   140  	if config == nil {
   141  		h2s := &http2.Server{}
   142  
   143  		server = &http.Server{
   144  			Addr:              serial.Concat(address, ":", port),
   145  			Handler:           h2c.NewHandler(listener, h2s),
   146  			ReadHeaderTimeout: time.Second * 4,
   147  		}
   148  	} else {
   149  		server = &http.Server{
   150  			Addr:              serial.Concat(address, ":", port),
   151  			TLSConfig:         config.GetTLSConfig(tls.WithNextProto("h2")),
   152  			Handler:           listener,
   153  			ReadHeaderTimeout: time.Second * 4,
   154  		}
   155  	}
   156  
   157  	if streamSettings.SocketSettings != nil && streamSettings.SocketSettings.AcceptProxyProtocol {
   158  		newError("accepting PROXY protocol").AtWarning().WriteToLog(session.ExportIDToError(ctx))
   159  	}
   160  
   161  	listener.server = server
   162  	go func() {
   163  		var streamListener net.Listener
   164  		var err error
   165  		if port == net.Port(0) { // unix
   166  			streamListener, err = internet.ListenSystem(ctx, &net.UnixAddr{
   167  				Name: address.Domain(),
   168  				Net:  "unix",
   169  			}, streamSettings.SocketSettings)
   170  			if err != nil {
   171  				newError("failed to listen on ", address).Base(err).AtError().WriteToLog(session.ExportIDToError(ctx))
   172  				return
   173  			}
   174  			locker := ctx.Value(address.Domain())
   175  			if locker != nil {
   176  				listener.locker = locker.(*internet.FileLocker)
   177  			}
   178  		} else { // tcp
   179  			streamListener, err = internet.ListenSystem(ctx, &net.TCPAddr{
   180  				IP:   address.IP(),
   181  				Port: int(port),
   182  			}, streamSettings.SocketSettings)
   183  			if err != nil {
   184  				newError("failed to listen on ", address, ":", port).Base(err).AtError().WriteToLog(session.ExportIDToError(ctx))
   185  				return
   186  			}
   187  		}
   188  
   189  		if config == nil {
   190  			err = server.Serve(streamListener)
   191  			if err != nil {
   192  				newError("stopping serving H2C").Base(err).WriteToLog(session.ExportIDToError(ctx))
   193  			}
   194  		} else {
   195  			err = server.ServeTLS(streamListener, "", "")
   196  			if err != nil {
   197  				newError("stopping serving TLS").Base(err).WriteToLog(session.ExportIDToError(ctx))
   198  			}
   199  		}
   200  	}()
   201  
   202  	return listener, nil
   203  }
   204  
   205  func init() {
   206  	common.Must(internet.RegisterTransportListener(protocolName, Listen))
   207  }