github.com/v2fly/v2ray-core/v5@v5.16.2-0.20240507031116-8191faa6e095/transport/internet/http/hub.go (about)

     1  package http
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"net/http"
     7  	"strings"
     8  	"time"
     9  
    10  	"golang.org/x/net/http2"
    11  	"golang.org/x/net/http2/h2c"
    12  
    13  	"github.com/v2fly/v2ray-core/v5/common"
    14  	"github.com/v2fly/v2ray-core/v5/common/net"
    15  	http_proto "github.com/v2fly/v2ray-core/v5/common/protocol/http"
    16  	"github.com/v2fly/v2ray-core/v5/common/serial"
    17  	"github.com/v2fly/v2ray-core/v5/common/session"
    18  	"github.com/v2fly/v2ray-core/v5/common/signal/done"
    19  	"github.com/v2fly/v2ray-core/v5/transport/internet"
    20  	"github.com/v2fly/v2ray-core/v5/transport/internet/tls"
    21  )
    22  
    23  type Listener struct {
    24  	server  *http.Server
    25  	handler internet.ConnHandler
    26  	local   net.Addr
    27  	config  *Config
    28  }
    29  
    30  func (l *Listener) Addr() net.Addr {
    31  	return l.local
    32  }
    33  
    34  func (l *Listener) Close() error {
    35  	return l.server.Close()
    36  }
    37  
    38  type flushWriter struct {
    39  	w io.Writer
    40  	d *done.Instance
    41  }
    42  
    43  func (fw flushWriter) Write(p []byte) (n int, err error) {
    44  	if fw.d.Done() {
    45  		return 0, io.ErrClosedPipe
    46  	}
    47  
    48  	n, err = fw.w.Write(p)
    49  	if f, ok := fw.w.(http.Flusher); ok {
    50  		f.Flush()
    51  	}
    52  	return
    53  }
    54  
    55  func (l *Listener) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
    56  	host := request.Host
    57  	if len(l.config.Host) != 0 && !l.config.isValidHost(host) {
    58  		writer.WriteHeader(404)
    59  		return
    60  	}
    61  	path := l.config.getNormalizedPath()
    62  	if !strings.HasPrefix(request.URL.Path, path) {
    63  		writer.WriteHeader(404)
    64  		return
    65  	}
    66  
    67  	writer.Header().Set("Cache-Control", "no-store")
    68  
    69  	for _, httpHeader := range l.config.Header {
    70  		for _, httpHeaderValue := range httpHeader.Value {
    71  			writer.Header().Set(httpHeader.Name, httpHeaderValue)
    72  		}
    73  	}
    74  
    75  	writer.WriteHeader(200)
    76  	if f, ok := writer.(http.Flusher); ok {
    77  		f.Flush()
    78  	}
    79  
    80  	remoteAddr := l.Addr()
    81  	dest, err := net.ParseDestination(request.RemoteAddr)
    82  	if err != nil {
    83  		newError("failed to parse request remote addr: ", request.RemoteAddr).Base(err).WriteToLog()
    84  	} else {
    85  		remoteAddr = &net.TCPAddr{
    86  			IP:   dest.Address.IP(),
    87  			Port: int(dest.Port),
    88  		}
    89  	}
    90  
    91  	forwardedAddress := http_proto.ParseXForwardedFor(request.Header)
    92  	if len(forwardedAddress) > 0 && forwardedAddress[0].Family().IsIP() {
    93  		remoteAddr = &net.TCPAddr{
    94  			IP:   forwardedAddress[0].IP(),
    95  			Port: 0,
    96  		}
    97  	}
    98  
    99  	done := done.New()
   100  	conn := net.NewConnection(
   101  		net.ConnectionOutput(request.Body),
   102  		net.ConnectionInput(flushWriter{w: writer, d: done}),
   103  		net.ConnectionOnClose(common.ChainedClosable{done, request.Body}),
   104  		net.ConnectionLocalAddr(l.Addr()),
   105  		net.ConnectionRemoteAddr(remoteAddr),
   106  	)
   107  	l.handler(conn)
   108  	<-done.Wait()
   109  }
   110  
   111  func Listen(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, handler internet.ConnHandler) (internet.Listener, error) {
   112  	httpSettings := streamSettings.ProtocolSettings.(*Config)
   113  	var listener *Listener
   114  	if port == net.Port(0) { // unix
   115  		listener = &Listener{
   116  			handler: handler,
   117  			local: &net.UnixAddr{
   118  				Name: address.Domain(),
   119  				Net:  "unix",
   120  			},
   121  			config: httpSettings,
   122  		}
   123  	} else { // tcp
   124  		listener = &Listener{
   125  			handler: handler,
   126  			local: &net.TCPAddr{
   127  				IP:   address.IP(),
   128  				Port: int(port),
   129  			},
   130  			config: httpSettings,
   131  		}
   132  	}
   133  
   134  	var server *http.Server
   135  	config := tls.ConfigFromStreamSettings(streamSettings)
   136  	if config == nil {
   137  		h2s := &http2.Server{}
   138  
   139  		server = &http.Server{
   140  			Addr:              serial.Concat(address, ":", port),
   141  			Handler:           h2c.NewHandler(listener, h2s),
   142  			ReadHeaderTimeout: time.Second * 4,
   143  		}
   144  	} else {
   145  		server = &http.Server{
   146  			Addr:              serial.Concat(address, ":", port),
   147  			TLSConfig:         config.GetTLSConfig(tls.WithNextProto("h2")),
   148  			Handler:           listener,
   149  			ReadHeaderTimeout: time.Second * 4,
   150  		}
   151  	}
   152  
   153  	if streamSettings.SocketSettings != nil && streamSettings.SocketSettings.AcceptProxyProtocol {
   154  		newError("accepting PROXY protocol").AtWarning().WriteToLog(session.ExportIDToError(ctx))
   155  	}
   156  
   157  	listener.server = server
   158  	go func() {
   159  		var streamListener net.Listener
   160  		var err error
   161  		if port == net.Port(0) { // unix
   162  			streamListener, err = internet.ListenSystem(ctx, &net.UnixAddr{
   163  				Name: address.Domain(),
   164  				Net:  "unix",
   165  			}, streamSettings.SocketSettings)
   166  			if err != nil {
   167  				newError("failed to listen on ", address).Base(err).AtError().WriteToLog(session.ExportIDToError(ctx))
   168  				return
   169  			}
   170  		} else { // tcp
   171  			streamListener, err = internet.ListenSystem(ctx, &net.TCPAddr{
   172  				IP:   address.IP(),
   173  				Port: int(port),
   174  			}, streamSettings.SocketSettings)
   175  			if err != nil {
   176  				newError("failed to listen on ", address, ":", port).Base(err).AtError().WriteToLog(session.ExportIDToError(ctx))
   177  				return
   178  			}
   179  		}
   180  
   181  		if config == nil {
   182  			err = server.Serve(streamListener)
   183  			if err != nil {
   184  				newError("stopping serving H2C").Base(err).WriteToLog(session.ExportIDToError(ctx))
   185  			}
   186  		} else {
   187  			err = server.ServeTLS(streamListener, "", "")
   188  			if err != nil {
   189  				newError("stopping serving TLS").Base(err).WriteToLog(session.ExportIDToError(ctx))
   190  			}
   191  		}
   192  	}()
   193  
   194  	return listener, nil
   195  }
   196  
   197  func init() {
   198  	common.Must(internet.RegisterTransportListener(protocolName, Listen))
   199  }