github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/websocket/server.go (about) 1 package websocket 2 3 import ( 4 "context" 5 "encoding/base64" 6 "errors" 7 "log/slog" 8 "net" 9 "net/http" 10 11 "github.com/Asutorufa/yuhaiin/pkg/log" 12 "github.com/Asutorufa/yuhaiin/pkg/net/netapi" 13 websocket "github.com/Asutorufa/yuhaiin/pkg/net/proxy/websocket/x" 14 "github.com/Asutorufa/yuhaiin/pkg/protos/config/listener" 15 "github.com/Asutorufa/yuhaiin/pkg/utils/pool" 16 ) 17 18 type Server struct { 19 net.Listener 20 server *http.Server 21 connChan chan net.Conn 22 closeCtx context.Context 23 close context.CancelFunc 24 } 25 26 func init() { 27 listener.RegisterTransport(NewServer) 28 } 29 30 func NewServer(c *listener.Transport_Websocket) func(netapi.Listener) (netapi.Listener, error) { 31 return func(ii netapi.Listener) (netapi.Listener, error) { 32 lis, err := ii.Stream(context.TODO()) 33 if err != nil { 34 return nil, err 35 } 36 return netapi.PatchStream(newServer(lis), ii), nil 37 } 38 } 39 40 func newServer(lis net.Listener) *Server { 41 ctx, cancel := context.WithCancel(context.Background()) 42 s := &Server{ 43 Listener: lis, 44 connChan: make(chan net.Conn, 20), 45 closeCtx: ctx, 46 close: cancel, 47 } 48 s.server = &http.Server{Handler: s} 49 50 go func() { 51 defer s.Close() 52 log.IfErr("websocket serve", func() error { return s.server.Serve(lis) }) 53 }() 54 55 return s 56 } 57 58 func (s *Server) Close() error { 59 var err error 60 s.close() 61 err = s.server.Close() 62 if er := s.Listener.Close(); er != nil { 63 err = errors.Join(err, er) 64 } 65 66 return err 67 } 68 69 func (s *Server) Accept() (net.Conn, error) { 70 select { 71 case conn := <-s.connChan: 72 return conn, nil 73 case <-s.closeCtx.Done(): 74 return nil, net.ErrClosed 75 } 76 } 77 78 func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { 79 var earlyData []*pool.Bytes 80 wsconn, err := websocket.NewServerConn(w, req, func(r *websocket.Request) error { 81 if r.Request.Header.Get("early_data") == "base64" { 82 83 buf := pool.GetBytesBuffer(base64.RawStdEncoding.DecodedLen(len(r.SecWebSocketKey))) 84 n, err := base64.RawStdEncoding.Decode(buf.Bytes(), []byte(r.SecWebSocketKey)) 85 if err != nil { 86 return err 87 } 88 89 buf.Refactor(0, n) 90 91 earlyData = append(earlyData, buf) 92 93 r.Header = http.Header{} 94 r.Header.Add("early_data", "true") 95 } 96 97 return nil 98 }) 99 if err != nil { 100 log.Error("new websocket server conn failed", slog.Any("from", req.RemoteAddr), slog.Any("err", err)) 101 return 102 } 103 104 select { 105 case <-s.closeCtx.Done(): 106 _ = wsconn.Close() 107 case s.connChan <- netapi.NewPrefixBytesConn(wsconn, earlyData...): 108 } 109 }