github.com/sagernet/sing-box@v1.9.0-rc.20/transport/v2raywebsocket/server.go (about)

     1  package v2raywebsocket
     2  
     3  import (
     4  	"context"
     5  	"encoding/base64"
     6  	"net"
     7  	"net/http"
     8  	"os"
     9  	"strings"
    10  
    11  	"github.com/sagernet/sing-box/adapter"
    12  	"github.com/sagernet/sing-box/common/tls"
    13  	C "github.com/sagernet/sing-box/constant"
    14  	"github.com/sagernet/sing-box/option"
    15  	"github.com/sagernet/sing/common"
    16  	"github.com/sagernet/sing/common/buf"
    17  	"github.com/sagernet/sing/common/bufio"
    18  	E "github.com/sagernet/sing/common/exceptions"
    19  	M "github.com/sagernet/sing/common/metadata"
    20  	N "github.com/sagernet/sing/common/network"
    21  	aTLS "github.com/sagernet/sing/common/tls"
    22  	sHttp "github.com/sagernet/sing/protocol/http"
    23  	"github.com/sagernet/ws"
    24  )
    25  
    26  var _ adapter.V2RayServerTransport = (*Server)(nil)
    27  
    28  type Server struct {
    29  	ctx                 context.Context
    30  	tlsConfig           tls.ServerConfig
    31  	handler             adapter.V2RayServerTransportHandler
    32  	httpServer          *http.Server
    33  	path                string
    34  	maxEarlyData        uint32
    35  	earlyDataHeaderName string
    36  	upgrader            ws.HTTPUpgrader
    37  }
    38  
    39  func NewServer(ctx context.Context, options option.V2RayWebsocketOptions, tlsConfig tls.ServerConfig, handler adapter.V2RayServerTransportHandler) (*Server, error) {
    40  	server := &Server{
    41  		ctx:                 ctx,
    42  		tlsConfig:           tlsConfig,
    43  		handler:             handler,
    44  		path:                options.Path,
    45  		maxEarlyData:        options.MaxEarlyData,
    46  		earlyDataHeaderName: options.EarlyDataHeaderName,
    47  		upgrader: ws.HTTPUpgrader{
    48  			Timeout: C.TCPTimeout,
    49  			Header:  options.Headers.Build(),
    50  		},
    51  	}
    52  	if !strings.HasPrefix(server.path, "/") {
    53  		server.path = "/" + server.path
    54  	}
    55  	server.httpServer = &http.Server{
    56  		Handler:           server,
    57  		ReadHeaderTimeout: C.TCPTimeout,
    58  		MaxHeaderBytes:    http.DefaultMaxHeaderBytes,
    59  		BaseContext: func(net.Listener) context.Context {
    60  			return ctx
    61  		},
    62  	}
    63  	return server, nil
    64  }
    65  
    66  func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
    67  	if s.maxEarlyData == 0 || s.earlyDataHeaderName != "" {
    68  		if request.URL.Path != s.path {
    69  			s.invalidRequest(writer, request, http.StatusNotFound, E.New("bad path: ", request.URL.Path))
    70  			return
    71  		}
    72  	}
    73  	var (
    74  		earlyData []byte
    75  		err       error
    76  		conn      net.Conn
    77  	)
    78  	if s.earlyDataHeaderName == "" {
    79  		if strings.HasPrefix(request.URL.RequestURI(), s.path) {
    80  			earlyDataStr := request.URL.RequestURI()[len(s.path):]
    81  			earlyData, err = base64.RawURLEncoding.DecodeString(earlyDataStr)
    82  		} else {
    83  			s.invalidRequest(writer, request, http.StatusNotFound, E.New("bad path: ", request.URL.Path))
    84  			return
    85  		}
    86  	} else {
    87  		if request.URL.Path != s.path {
    88  			s.invalidRequest(writer, request, http.StatusNotFound, E.New("bad path: ", request.URL.Path))
    89  			return
    90  		}
    91  		earlyDataStr := request.Header.Get(s.earlyDataHeaderName)
    92  		if earlyDataStr != "" {
    93  			earlyData, err = base64.RawURLEncoding.DecodeString(earlyDataStr)
    94  		}
    95  	}
    96  	if err != nil {
    97  		s.invalidRequest(writer, request, http.StatusBadRequest, E.Cause(err, "decode early data"))
    98  		return
    99  	}
   100  	wsConn, _, _, err := ws.UpgradeHTTP(request, writer)
   101  	if err != nil {
   102  		s.invalidRequest(writer, request, 0, E.Cause(err, "upgrade websocket connection"))
   103  		return
   104  	}
   105  	var metadata M.Metadata
   106  	metadata.Source = sHttp.SourceAddress(request)
   107  	conn = NewConn(wsConn, metadata.Source.TCPAddr(), ws.StateServerSide)
   108  	if len(earlyData) > 0 {
   109  		conn = bufio.NewCachedConn(conn, buf.As(earlyData))
   110  	}
   111  	s.handler.NewConnection(request.Context(), conn, metadata)
   112  }
   113  
   114  func (s *Server) invalidRequest(writer http.ResponseWriter, request *http.Request, statusCode int, err error) {
   115  	if statusCode > 0 {
   116  		writer.WriteHeader(statusCode)
   117  	}
   118  	s.handler.NewError(request.Context(), E.Cause(err, "process connection from ", request.RemoteAddr))
   119  }
   120  
   121  func (s *Server) Network() []string {
   122  	return []string{N.NetworkTCP}
   123  }
   124  
   125  func (s *Server) Serve(listener net.Listener) error {
   126  	if s.tlsConfig != nil {
   127  		listener = aTLS.NewListener(listener, s.tlsConfig)
   128  	}
   129  	return s.httpServer.Serve(listener)
   130  }
   131  
   132  func (s *Server) ServePacket(listener net.PacketConn) error {
   133  	return os.ErrInvalid
   134  }
   135  
   136  func (s *Server) Close() error {
   137  	return common.Close(common.PtrOrNil(s.httpServer))
   138  }