github.com/codingeasygo/util@v0.0.0-20231206062002-1ce2f004b7d9/proxy/ws/ws.go (about) 1 package ws 2 3 import ( 4 "context" 5 "fmt" 6 "net" 7 "net/http" 8 "net/url" 9 "strings" 10 "sync" 11 12 "github.com/codingeasygo/util/xio" 13 "github.com/codingeasygo/util/xnet" 14 "golang.org/x/net/websocket" 15 ) 16 17 type ContextKey string 18 19 type Server struct { 20 *websocket.Server 21 BufferSize int 22 Dialer xio.PiperDialer 23 waiter sync.WaitGroup 24 listners map[net.Listener]string 25 } 26 27 func NewServer() (server *Server) { 28 server = &Server{ 29 BufferSize: 32 * 1024, 30 Dialer: xio.PiperDialerF(xio.DialNetPiper), 31 waiter: sync.WaitGroup{}, 32 listners: map[net.Listener]string{}, 33 } 34 server.Server = &websocket.Server{Handler: server.handler} 35 return 36 } 37 38 // Run will listen tcp on address and accept to ProcConn 39 func (s *Server) loopAccept(l net.Listener) (err error) { 40 defer s.waiter.Done() 41 http.Serve(l, s) 42 return 43 } 44 45 // Run will listen tcp on address and sync accept to ProcConn 46 func (s *Server) Run(addr string) (err error) { 47 listener, err := net.Listen("tcp", addr) 48 if err == nil { 49 s.listners[listener] = addr 50 InfoLog("Server listen http proxy on %v", addr) 51 s.waiter.Add(1) 52 err = s.loopAccept(listener) 53 } 54 return 55 } 56 57 // Start will listen tcp on address and async accept to ProcConn 58 func (s *Server) Start(network, addr string) (listener net.Listener, err error) { 59 listener, err = net.Listen(network, addr) 60 if err == nil { 61 s.listners[listener] = addr 62 InfoLog("Server listen http proxy on %v", addr) 63 s.waiter.Add(1) 64 go s.loopAccept(listener) 65 } 66 return 67 } 68 69 // Stop will stop listener and wait loop stop 70 func (s *Server) Stop() (err error) { 71 for listener, addr := range s.listners { 72 err = listener.Close() 73 delete(s.listners, listener) 74 InfoLog("Server http proxy listener on %v is stopped by %v", addr, err) 75 } 76 s.waiter.Wait() 77 return 78 } 79 80 func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { 81 req.ParseForm() 82 uri := req.Form.Get("_uri") 83 if len(uri) < 1 { 84 w.WriteHeader(http.StatusBadRequest) 85 fmt.Fprintf(w, "_uri is required") 86 return 87 } 88 raw, err := s.Dialer.DialPiper(uri, s.BufferSize) 89 if err != nil { 90 InfoLog("Server dial to %v fail with %v", uri, err) 91 w.WriteHeader(http.StatusBadGateway) 92 fmt.Fprintf(w, "%v", err) 93 return 94 } 95 newReq := context.WithValue(req.Context(), ContextKey("upstream"), []interface{}{raw, uri}) 96 s.Server.ServeHTTP(w, req.WithContext(newReq)) 97 } 98 99 func (s *Server) handler(conn *websocket.Conn) { 100 defer conn.Close() 101 req := conn.Request() 102 upstream := req.Context().Value(ContextKey("upstream")).([]interface{}) 103 raw, uri := upstream[0].(xio.Piper), upstream[1].(string) 104 DebugLog("Server start forward %v to %v", req.RemoteAddr, uri) 105 err := raw.PipeConn(conn, uri) 106 DebugLog("Server forward %v to %v is done with %v", req.RemoteAddr, uri, err) 107 } 108 109 // Dial will dial connection by proxy server 110 func Dial(proxy, uri string) (conn net.Conn, err error) { 111 dialer := xnet.NewWebsocketDialer() 112 targetURI := proxy 113 if strings.Contains(proxy, "?") { 114 targetURI += fmt.Sprintf("&_uri=%v", url.QueryEscape(uri)) 115 } else { 116 targetURI += fmt.Sprintf("?_uri=%v", url.QueryEscape(uri)) 117 } 118 raw, err := dialer.Dial(targetURI) 119 if err == nil { 120 conn = raw.(net.Conn) 121 } 122 return 123 }