github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/transport/v2raywebsocket/server.go (about)

     1  package v2raywebsocket
     2  
     3  import (
     4  	"context"
     5  	"encoding/base64"
     6  	"net"
     7  	"net/http"
     8  	"os"
     9  	"strings"
    10  
    11  	"github.com/inazumav/sing-box/adapter"
    12  	"github.com/inazumav/sing-box/common/tls"
    13  	C "github.com/inazumav/sing-box/constant"
    14  	"github.com/inazumav/sing-box/option"
    15  	"github.com/inazumav/sing-box/transport/v2rayhttp"
    16  	"github.com/sagernet/sing/common"
    17  	"github.com/sagernet/sing/common/buf"
    18  	"github.com/sagernet/sing/common/bufio"
    19  	E "github.com/sagernet/sing/common/exceptions"
    20  	M "github.com/sagernet/sing/common/metadata"
    21  	N "github.com/sagernet/sing/common/network"
    22  	aTLS "github.com/sagernet/sing/common/tls"
    23  	sHttp "github.com/sagernet/sing/protocol/http"
    24  	"github.com/sagernet/websocket"
    25  )
    26  
    27  var _ adapter.V2RayServerTransport = (*Server)(nil)
    28  
    29  type Server struct {
    30  	ctx                 context.Context
    31  	tlsConfig           tls.ServerConfig
    32  	handler             adapter.V2RayServerTransportHandler
    33  	httpServer          *http.Server
    34  	path                string
    35  	maxEarlyData        uint32
    36  	earlyDataHeaderName string
    37  }
    38  
    39  func NewServer(ctx context.Context, options option.V2RayWebsocketOptions, tlsConfig tls.ServerConfig, handler adapter.V2RayServerTransportHandler) (*Server, error) {
    40  	server := &Server{
    41  		ctx:                 ctx,
    42  		tlsConfig:           tlsConfig,
    43  		handler:             handler,
    44  		path:                options.Path,
    45  		maxEarlyData:        options.MaxEarlyData,
    46  		earlyDataHeaderName: options.EarlyDataHeaderName,
    47  	}
    48  	if !strings.HasPrefix(server.path, "/") {
    49  		server.path = "/" + server.path
    50  	}
    51  	server.httpServer = &http.Server{
    52  		Handler:           server,
    53  		ReadHeaderTimeout: C.TCPTimeout,
    54  		MaxHeaderBytes:    http.DefaultMaxHeaderBytes,
    55  		BaseContext: func(net.Listener) context.Context {
    56  			return ctx
    57  		},
    58  	}
    59  	return server, nil
    60  }
    61  
    62  var upgrader = websocket.Upgrader{
    63  	HandshakeTimeout: C.TCPTimeout,
    64  	CheckOrigin: func(r *http.Request) bool {
    65  		return true
    66  	},
    67  }
    68  
    69  func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
    70  	if s.maxEarlyData == 0 || s.earlyDataHeaderName != "" {
    71  		if request.URL.Path != s.path {
    72  			s.fallbackRequest(request.Context(), writer, request, http.StatusNotFound, E.New("bad path: ", request.URL.Path))
    73  			return
    74  		}
    75  	}
    76  	var (
    77  		earlyData []byte
    78  		err       error
    79  		conn      net.Conn
    80  	)
    81  	if s.earlyDataHeaderName == "" {
    82  		if strings.HasPrefix(request.URL.RequestURI(), s.path) {
    83  			earlyDataStr := request.URL.RequestURI()[len(s.path):]
    84  			earlyData, err = base64.RawURLEncoding.DecodeString(earlyDataStr)
    85  		} else {
    86  			s.fallbackRequest(request.Context(), writer, request, http.StatusNotFound, E.New("bad path: ", request.URL.Path))
    87  			return
    88  		}
    89  	} else {
    90  		earlyDataStr := request.Header.Get(s.earlyDataHeaderName)
    91  		if earlyDataStr != "" {
    92  			earlyData, err = base64.RawURLEncoding.DecodeString(earlyDataStr)
    93  		}
    94  	}
    95  	if err != nil {
    96  		s.fallbackRequest(request.Context(), writer, request, http.StatusBadRequest, E.Cause(err, "decode early data"))
    97  		return
    98  	}
    99  	wsConn, err := upgrader.Upgrade(writer, request, nil)
   100  	if err != nil {
   101  		s.fallbackRequest(request.Context(), writer, request, 0, E.Cause(err, "upgrade websocket connection"))
   102  		return
   103  	}
   104  	var metadata M.Metadata
   105  	metadata.Source = sHttp.SourceAddress(request)
   106  	conn = NewServerConn(wsConn, metadata.Source.TCPAddr())
   107  	if len(earlyData) > 0 {
   108  		conn = bufio.NewCachedConn(conn, buf.As(earlyData))
   109  	}
   110  	s.handler.NewConnection(request.Context(), conn, metadata)
   111  }
   112  
   113  func (s *Server) fallbackRequest(ctx context.Context, writer http.ResponseWriter, request *http.Request, statusCode int, err error) {
   114  	conn := v2rayhttp.NewHTTPConn(request.Body, writer)
   115  	fErr := s.handler.FallbackConnection(ctx, &conn, M.Metadata{})
   116  	if fErr == nil {
   117  		return
   118  	} else if fErr == os.ErrInvalid {
   119  		fErr = nil
   120  	}
   121  	if statusCode > 0 {
   122  		writer.WriteHeader(statusCode)
   123  	}
   124  	s.handler.NewError(request.Context(), E.Cause(E.Errors(err, E.Cause(fErr, "fallback connection")), "process connection from ", request.RemoteAddr))
   125  }
   126  
   127  func (s *Server) Network() []string {
   128  	return []string{N.NetworkTCP}
   129  }
   130  
   131  func (s *Server) Serve(listener net.Listener) error {
   132  	if s.tlsConfig != nil {
   133  		listener = aTLS.NewListener(listener, s.tlsConfig)
   134  	}
   135  	return s.httpServer.Serve(listener)
   136  }
   137  
   138  func (s *Server) ServePacket(listener net.PacketConn) error {
   139  	return os.ErrInvalid
   140  }
   141  
   142  func (s *Server) Close() error {
   143  	return common.Close(common.PtrOrNil(s.httpServer))
   144  }