github.com/codingeasygo/util@v0.0.0-20231206062002-1ce2f004b7d9/proxy/ws/ws.go (about)

     1  package ws
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net"
     7  	"net/http"
     8  	"net/url"
     9  	"strings"
    10  	"sync"
    11  
    12  	"github.com/codingeasygo/util/xio"
    13  	"github.com/codingeasygo/util/xnet"
    14  	"golang.org/x/net/websocket"
    15  )
    16  
    17  type ContextKey string
    18  
    19  type Server struct {
    20  	*websocket.Server
    21  	BufferSize int
    22  	Dialer     xio.PiperDialer
    23  	waiter     sync.WaitGroup
    24  	listners   map[net.Listener]string
    25  }
    26  
    27  func NewServer() (server *Server) {
    28  	server = &Server{
    29  		BufferSize: 32 * 1024,
    30  		Dialer:     xio.PiperDialerF(xio.DialNetPiper),
    31  		waiter:     sync.WaitGroup{},
    32  		listners:   map[net.Listener]string{},
    33  	}
    34  	server.Server = &websocket.Server{Handler: server.handler}
    35  	return
    36  }
    37  
    38  // Run will listen tcp on address and accept to ProcConn
    39  func (s *Server) loopAccept(l net.Listener) (err error) {
    40  	defer s.waiter.Done()
    41  	http.Serve(l, s)
    42  	return
    43  }
    44  
    45  // Run will listen tcp on address and sync accept to ProcConn
    46  func (s *Server) Run(addr string) (err error) {
    47  	listener, err := net.Listen("tcp", addr)
    48  	if err == nil {
    49  		s.listners[listener] = addr
    50  		InfoLog("Server listen http proxy on %v", addr)
    51  		s.waiter.Add(1)
    52  		err = s.loopAccept(listener)
    53  	}
    54  	return
    55  }
    56  
    57  // Start will listen tcp on address and async accept to ProcConn
    58  func (s *Server) Start(network, addr string) (listener net.Listener, err error) {
    59  	listener, err = net.Listen(network, addr)
    60  	if err == nil {
    61  		s.listners[listener] = addr
    62  		InfoLog("Server listen http proxy on %v", addr)
    63  		s.waiter.Add(1)
    64  		go s.loopAccept(listener)
    65  	}
    66  	return
    67  }
    68  
    69  // Stop will stop listener and wait loop stop
    70  func (s *Server) Stop() (err error) {
    71  	for listener, addr := range s.listners {
    72  		err = listener.Close()
    73  		delete(s.listners, listener)
    74  		InfoLog("Server http proxy listener on %v is stopped by %v", addr, err)
    75  	}
    76  	s.waiter.Wait()
    77  	return
    78  }
    79  
    80  func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
    81  	req.ParseForm()
    82  	uri := req.Form.Get("_uri")
    83  	if len(uri) < 1 {
    84  		w.WriteHeader(http.StatusBadRequest)
    85  		fmt.Fprintf(w, "_uri is required")
    86  		return
    87  	}
    88  	raw, err := s.Dialer.DialPiper(uri, s.BufferSize)
    89  	if err != nil {
    90  		InfoLog("Server dial to %v fail with %v", uri, err)
    91  		w.WriteHeader(http.StatusBadGateway)
    92  		fmt.Fprintf(w, "%v", err)
    93  		return
    94  	}
    95  	newReq := context.WithValue(req.Context(), ContextKey("upstream"), []interface{}{raw, uri})
    96  	s.Server.ServeHTTP(w, req.WithContext(newReq))
    97  }
    98  
    99  func (s *Server) handler(conn *websocket.Conn) {
   100  	defer conn.Close()
   101  	req := conn.Request()
   102  	upstream := req.Context().Value(ContextKey("upstream")).([]interface{})
   103  	raw, uri := upstream[0].(xio.Piper), upstream[1].(string)
   104  	DebugLog("Server start forward %v to %v", req.RemoteAddr, uri)
   105  	err := raw.PipeConn(conn, uri)
   106  	DebugLog("Server forward %v to %v is done with %v", req.RemoteAddr, uri, err)
   107  }
   108  
   109  // Dial will dial connection by proxy server
   110  func Dial(proxy, uri string) (conn net.Conn, err error) {
   111  	dialer := xnet.NewWebsocketDialer()
   112  	targetURI := proxy
   113  	if strings.Contains(proxy, "?") {
   114  		targetURI += fmt.Sprintf("&_uri=%v", url.QueryEscape(uri))
   115  	} else {
   116  		targetURI += fmt.Sprintf("?_uri=%v", url.QueryEscape(uri))
   117  	}
   118  	raw, err := dialer.Dial(targetURI)
   119  	if err == nil {
   120  		conn = raw.(net.Conn)
   121  	}
   122  	return
   123  }