github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/websocket/server.go (about)

     1  package websocket
     2  
     3  import (
     4  	"context"
     5  	"encoding/base64"
     6  	"errors"
     7  	"log/slog"
     8  	"net"
     9  	"net/http"
    10  
    11  	"github.com/Asutorufa/yuhaiin/pkg/log"
    12  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
    13  	websocket "github.com/Asutorufa/yuhaiin/pkg/net/proxy/websocket/x"
    14  	"github.com/Asutorufa/yuhaiin/pkg/protos/config/listener"
    15  	"github.com/Asutorufa/yuhaiin/pkg/utils/pool"
    16  )
    17  
    18  type Server struct {
    19  	net.Listener
    20  	server   *http.Server
    21  	connChan chan net.Conn
    22  	closeCtx context.Context
    23  	close    context.CancelFunc
    24  }
    25  
    26  func init() {
    27  	listener.RegisterTransport(NewServer)
    28  }
    29  
    30  func NewServer(c *listener.Transport_Websocket) func(netapi.Listener) (netapi.Listener, error) {
    31  	return func(ii netapi.Listener) (netapi.Listener, error) {
    32  		lis, err := ii.Stream(context.TODO())
    33  		if err != nil {
    34  			return nil, err
    35  		}
    36  		return netapi.PatchStream(newServer(lis), ii), nil
    37  	}
    38  }
    39  
    40  func newServer(lis net.Listener) *Server {
    41  	ctx, cancel := context.WithCancel(context.Background())
    42  	s := &Server{
    43  		Listener: lis,
    44  		connChan: make(chan net.Conn, 20),
    45  		closeCtx: ctx,
    46  		close:    cancel,
    47  	}
    48  	s.server = &http.Server{Handler: s}
    49  
    50  	go func() {
    51  		defer s.Close()
    52  		log.IfErr("websocket serve", func() error { return s.server.Serve(lis) })
    53  	}()
    54  
    55  	return s
    56  }
    57  
    58  func (s *Server) Close() error {
    59  	var err error
    60  	s.close()
    61  	err = s.server.Close()
    62  	if er := s.Listener.Close(); er != nil {
    63  		err = errors.Join(err, er)
    64  	}
    65  
    66  	return err
    67  }
    68  
    69  func (s *Server) Accept() (net.Conn, error) {
    70  	select {
    71  	case conn := <-s.connChan:
    72  		return conn, nil
    73  	case <-s.closeCtx.Done():
    74  		return nil, net.ErrClosed
    75  	}
    76  }
    77  
    78  func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
    79  	var earlyData []*pool.Bytes
    80  	wsconn, err := websocket.NewServerConn(w, req, func(r *websocket.Request) error {
    81  		if r.Request.Header.Get("early_data") == "base64" {
    82  
    83  			buf := pool.GetBytesBuffer(base64.RawStdEncoding.DecodedLen(len(r.SecWebSocketKey)))
    84  			n, err := base64.RawStdEncoding.Decode(buf.Bytes(), []byte(r.SecWebSocketKey))
    85  			if err != nil {
    86  				return err
    87  			}
    88  
    89  			buf.Refactor(0, n)
    90  
    91  			earlyData = append(earlyData, buf)
    92  
    93  			r.Header = http.Header{}
    94  			r.Header.Add("early_data", "true")
    95  		}
    96  
    97  		return nil
    98  	})
    99  	if err != nil {
   100  		log.Error("new websocket server conn failed", slog.Any("from", req.RemoteAddr), slog.Any("err", err))
   101  		return
   102  	}
   103  
   104  	select {
   105  	case <-s.closeCtx.Done():
   106  		_ = wsconn.Close()
   107  	case s.connChan <- netapi.NewPrefixBytesConn(wsconn, earlyData...):
   108  	}
   109  }