github.com/ipfans/trojan-go@v0.11.0/tunnel/websocket/server.go (about)

     1  package websocket
     2  
     3  import (
     4  	"bufio"
     5  	"context"
     6  	"math/rand"
     7  	"net"
     8  	"net/http"
     9  	"strings"
    10  	"time"
    11  
    12  	"golang.org/x/net/websocket"
    13  
    14  	"github.com/ipfans/trojan-go/common"
    15  	"github.com/ipfans/trojan-go/config"
    16  	"github.com/ipfans/trojan-go/log"
    17  	"github.com/ipfans/trojan-go/redirector"
    18  	"github.com/ipfans/trojan-go/tunnel"
    19  )
    20  
    21  // Fake response writer
    22  // Websocket ServeHTTP method uses Hijack method to get the ReadWriter
    23  type fakeHTTPResponseWriter struct {
    24  	http.Hijacker
    25  	http.ResponseWriter
    26  
    27  	ReadWriter *bufio.ReadWriter
    28  	Conn       net.Conn
    29  }
    30  
    31  func (w *fakeHTTPResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
    32  	return w.Conn, w.ReadWriter, nil
    33  }
    34  
    35  type Server struct {
    36  	underlay  tunnel.Server
    37  	hostname  string
    38  	path      string
    39  	enabled   bool
    40  	redirAddr net.Addr
    41  	redir     *redirector.Redirector
    42  	ctx       context.Context
    43  	cancel    context.CancelFunc
    44  	timeout   time.Duration
    45  }
    46  
    47  func (s *Server) Close() error {
    48  	s.cancel()
    49  	return s.underlay.Close()
    50  }
    51  
    52  func (s *Server) AcceptConn(tunnel.Tunnel) (tunnel.Conn, error) {
    53  	conn, err := s.underlay.AcceptConn(&Tunnel{})
    54  	if err != nil {
    55  		return nil, common.NewError("websocket failed to accept connection from underlying server")
    56  	}
    57  	if !s.enabled {
    58  		s.redir.Redirect(&redirector.Redirection{
    59  			InboundConn: conn,
    60  			RedirectTo:  s.redirAddr,
    61  		})
    62  		return nil, common.NewError("websocket is disabled. redirecting http request from " + conn.RemoteAddr().String())
    63  	}
    64  	rewindConn := common.NewRewindConn(conn)
    65  	rewindConn.SetBufferSize(512)
    66  	defer rewindConn.StopBuffering()
    67  	rw := bufio.NewReadWriter(bufio.NewReader(rewindConn), bufio.NewWriter(rewindConn))
    68  	req, err := http.ReadRequest(rw.Reader)
    69  	if err != nil {
    70  		log.Debug("invalid http request")
    71  		rewindConn.Rewind()
    72  		rewindConn.StopBuffering()
    73  		s.redir.Redirect(&redirector.Redirection{
    74  			InboundConn: rewindConn,
    75  			RedirectTo:  s.redirAddr,
    76  		})
    77  		return nil, common.NewError("not a valid http request: " + conn.RemoteAddr().String()).Base(err)
    78  	}
    79  	if strings.ToLower(req.Header.Get("Upgrade")) != "websocket" || req.URL.Path != s.path {
    80  		log.Debug("invalid http websocket handshake request")
    81  		rewindConn.Rewind()
    82  		rewindConn.StopBuffering()
    83  		s.redir.Redirect(&redirector.Redirection{
    84  			InboundConn: rewindConn,
    85  			RedirectTo:  s.redirAddr,
    86  		})
    87  		return nil, common.NewError("not a valid websocket handshake request: " + conn.RemoteAddr().String()).Base(err)
    88  	}
    89  
    90  	handshake := make(chan struct{})
    91  
    92  	url := "wss://" + s.hostname + s.path
    93  	origin := "https://" + s.hostname
    94  	wsConfig, err := websocket.NewConfig(url, origin)
    95  	if err != nil {
    96  		return nil, common.NewError("failed to create websocket config").Base(err)
    97  	}
    98  	var wsConn *websocket.Conn
    99  	ctx, cancel := context.WithCancel(s.ctx)
   100  
   101  	wsServer := websocket.Server{
   102  		Config: *wsConfig,
   103  		Handler: func(conn *websocket.Conn) {
   104  			wsConn = conn                              // store the websocket after handshaking
   105  			wsConn.PayloadType = websocket.BinaryFrame // treat it as a binary websocket
   106  
   107  			log.Debug("websocket obtained")
   108  			handshake <- struct{}{}
   109  			// this function SHOULD NOT return unless the connection is ended
   110  			// or the websocket will be closed by ServeHTTP method
   111  			<-ctx.Done()
   112  			log.Debug("websocket closed")
   113  		},
   114  		Handshake: func(wsConfig *websocket.Config, httpRequest *http.Request) error {
   115  			log.Debug("websocket url", httpRequest.URL, "origin", httpRequest.Header.Get("Origin"))
   116  			return nil
   117  		},
   118  	}
   119  
   120  	respWriter := &fakeHTTPResponseWriter{
   121  		Conn:       conn,
   122  		ReadWriter: rw,
   123  	}
   124  	go wsServer.ServeHTTP(respWriter, req)
   125  
   126  	select {
   127  	case <-handshake:
   128  	case <-time.After(s.timeout):
   129  	}
   130  
   131  	if wsConn == nil {
   132  		cancel()
   133  		return nil, common.NewError("websocket failed to handshake")
   134  	}
   135  
   136  	return &InboundConn{
   137  		OutboundConn: OutboundConn{
   138  			tcpConn: conn,
   139  			Conn:    wsConn,
   140  		},
   141  		ctx:    ctx,
   142  		cancel: cancel,
   143  	}, nil
   144  }
   145  
   146  func (s *Server) AcceptPacket(tunnel.Tunnel) (tunnel.PacketConn, error) {
   147  	return nil, common.NewError("not supported")
   148  }
   149  
   150  func NewServer(ctx context.Context, underlay tunnel.Server) (*Server, error) {
   151  	cfg := config.FromContext(ctx, Name).(*Config)
   152  	if cfg.Websocket.Enabled {
   153  		if !strings.HasPrefix(cfg.Websocket.Path, "/") {
   154  			return nil, common.NewError("websocket path must start with \"/\"")
   155  		}
   156  	}
   157  	if cfg.RemoteHost == "" {
   158  		log.Warn("empty websocket redirection hostname")
   159  		cfg.RemoteHost = cfg.Websocket.Host
   160  	}
   161  	if cfg.RemotePort == 0 {
   162  		log.Warn("empty websocket redirection port")
   163  		cfg.RemotePort = 80
   164  	}
   165  	ctx, cancel := context.WithCancel(ctx)
   166  	log.Debug("websocket server created")
   167  	return &Server{
   168  		enabled:   cfg.Websocket.Enabled,
   169  		hostname:  cfg.Websocket.Host,
   170  		path:      cfg.Websocket.Path,
   171  		ctx:       ctx,
   172  		cancel:    cancel,
   173  		underlay:  underlay,
   174  		timeout:   time.Second * time.Duration(rand.Intn(10)+5),
   175  		redir:     redirector.NewRedirector(ctx),
   176  		redirAddr: tunnel.NewAddressFromHostPort("tcp", cfg.RemoteHost, cfg.RemotePort),
   177  	}, nil
   178  }