github.com/ipfans/trojan-go@v0.11.0/tunnel/websocket/server.go (about) 1 package websocket 2 3 import ( 4 "bufio" 5 "context" 6 "math/rand" 7 "net" 8 "net/http" 9 "strings" 10 "time" 11 12 "golang.org/x/net/websocket" 13 14 "github.com/ipfans/trojan-go/common" 15 "github.com/ipfans/trojan-go/config" 16 "github.com/ipfans/trojan-go/log" 17 "github.com/ipfans/trojan-go/redirector" 18 "github.com/ipfans/trojan-go/tunnel" 19 ) 20 21 // Fake response writer 22 // Websocket ServeHTTP method uses Hijack method to get the ReadWriter 23 type fakeHTTPResponseWriter struct { 24 http.Hijacker 25 http.ResponseWriter 26 27 ReadWriter *bufio.ReadWriter 28 Conn net.Conn 29 } 30 31 func (w *fakeHTTPResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { 32 return w.Conn, w.ReadWriter, nil 33 } 34 35 type Server struct { 36 underlay tunnel.Server 37 hostname string 38 path string 39 enabled bool 40 redirAddr net.Addr 41 redir *redirector.Redirector 42 ctx context.Context 43 cancel context.CancelFunc 44 timeout time.Duration 45 } 46 47 func (s *Server) Close() error { 48 s.cancel() 49 return s.underlay.Close() 50 } 51 52 func (s *Server) AcceptConn(tunnel.Tunnel) (tunnel.Conn, error) { 53 conn, err := s.underlay.AcceptConn(&Tunnel{}) 54 if err != nil { 55 return nil, common.NewError("websocket failed to accept connection from underlying server") 56 } 57 if !s.enabled { 58 s.redir.Redirect(&redirector.Redirection{ 59 InboundConn: conn, 60 RedirectTo: s.redirAddr, 61 }) 62 return nil, common.NewError("websocket is disabled. redirecting http request from " + conn.RemoteAddr().String()) 63 } 64 rewindConn := common.NewRewindConn(conn) 65 rewindConn.SetBufferSize(512) 66 defer rewindConn.StopBuffering() 67 rw := bufio.NewReadWriter(bufio.NewReader(rewindConn), bufio.NewWriter(rewindConn)) 68 req, err := http.ReadRequest(rw.Reader) 69 if err != nil { 70 log.Debug("invalid http request") 71 rewindConn.Rewind() 72 rewindConn.StopBuffering() 73 s.redir.Redirect(&redirector.Redirection{ 74 InboundConn: rewindConn, 75 RedirectTo: s.redirAddr, 76 }) 77 return nil, common.NewError("not a valid http request: " + conn.RemoteAddr().String()).Base(err) 78 } 79 if strings.ToLower(req.Header.Get("Upgrade")) != "websocket" || req.URL.Path != s.path { 80 log.Debug("invalid http websocket handshake request") 81 rewindConn.Rewind() 82 rewindConn.StopBuffering() 83 s.redir.Redirect(&redirector.Redirection{ 84 InboundConn: rewindConn, 85 RedirectTo: s.redirAddr, 86 }) 87 return nil, common.NewError("not a valid websocket handshake request: " + conn.RemoteAddr().String()).Base(err) 88 } 89 90 handshake := make(chan struct{}) 91 92 url := "wss://" + s.hostname + s.path 93 origin := "https://" + s.hostname 94 wsConfig, err := websocket.NewConfig(url, origin) 95 if err != nil { 96 return nil, common.NewError("failed to create websocket config").Base(err) 97 } 98 var wsConn *websocket.Conn 99 ctx, cancel := context.WithCancel(s.ctx) 100 101 wsServer := websocket.Server{ 102 Config: *wsConfig, 103 Handler: func(conn *websocket.Conn) { 104 wsConn = conn // store the websocket after handshaking 105 wsConn.PayloadType = websocket.BinaryFrame // treat it as a binary websocket 106 107 log.Debug("websocket obtained") 108 handshake <- struct{}{} 109 // this function SHOULD NOT return unless the connection is ended 110 // or the websocket will be closed by ServeHTTP method 111 <-ctx.Done() 112 log.Debug("websocket closed") 113 }, 114 Handshake: func(wsConfig *websocket.Config, httpRequest *http.Request) error { 115 log.Debug("websocket url", httpRequest.URL, "origin", httpRequest.Header.Get("Origin")) 116 return nil 117 }, 118 } 119 120 respWriter := &fakeHTTPResponseWriter{ 121 Conn: conn, 122 ReadWriter: rw, 123 } 124 go wsServer.ServeHTTP(respWriter, req) 125 126 select { 127 case <-handshake: 128 case <-time.After(s.timeout): 129 } 130 131 if wsConn == nil { 132 cancel() 133 return nil, common.NewError("websocket failed to handshake") 134 } 135 136 return &InboundConn{ 137 OutboundConn: OutboundConn{ 138 tcpConn: conn, 139 Conn: wsConn, 140 }, 141 ctx: ctx, 142 cancel: cancel, 143 }, nil 144 } 145 146 func (s *Server) AcceptPacket(tunnel.Tunnel) (tunnel.PacketConn, error) { 147 return nil, common.NewError("not supported") 148 } 149 150 func NewServer(ctx context.Context, underlay tunnel.Server) (*Server, error) { 151 cfg := config.FromContext(ctx, Name).(*Config) 152 if cfg.Websocket.Enabled { 153 if !strings.HasPrefix(cfg.Websocket.Path, "/") { 154 return nil, common.NewError("websocket path must start with \"/\"") 155 } 156 } 157 if cfg.RemoteHost == "" { 158 log.Warn("empty websocket redirection hostname") 159 cfg.RemoteHost = cfg.Websocket.Host 160 } 161 if cfg.RemotePort == 0 { 162 log.Warn("empty websocket redirection port") 163 cfg.RemotePort = 80 164 } 165 ctx, cancel := context.WithCancel(ctx) 166 log.Debug("websocket server created") 167 return &Server{ 168 enabled: cfg.Websocket.Enabled, 169 hostname: cfg.Websocket.Host, 170 path: cfg.Websocket.Path, 171 ctx: ctx, 172 cancel: cancel, 173 underlay: underlay, 174 timeout: time.Second * time.Duration(rand.Intn(10)+5), 175 redir: redirector.NewRedirector(ctx), 176 redirAddr: tunnel.NewAddressFromHostPort("tcp", cfg.RemoteHost, cfg.RemotePort), 177 }, nil 178 }