github.com/binbinly/pkg@v0.0.11-0.20240321014439-f4fbf666eb0f/transport/ws/ws.go (about)

     1  package ws
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"log"
     7  	"net"
     8  	"net/http"
     9  	"net/url"
    10  	"strconv"
    11  
    12  	"github.com/binbinly/pkg/logger"
    13  	"github.com/binbinly/pkg/util"
    14  	"github.com/gorilla/websocket"
    15  	"github.com/rs/xid"
    16  	"github.com/zhenjl/cityhash"
    17  )
    18  
    19  var (
    20  	// ErrConnNotFound 连接未找到
    21  	ErrConnNotFound = errors.New("connection not found")
    22  	// ErrConnNotFinish 连接未完成,不可以发送消息
    23  	ErrConnNotFinish = errors.New("connection not finish when send msg")
    24  )
    25  
    26  type ConnHandlerFunc func(cid uint64, conn Connection) error
    27  
    28  // Server is a simple micro server abstraction
    29  type Server interface {
    30  	// Init Initialise options
    31  	Init(...Option)
    32  	// Options Retrieve the options
    33  	Options() *Options
    34  	// Start the server
    35  	Start(ctx context.Context) error
    36  	// Stop the server
    37  	Stop(ctx context.Context) error
    38  	// Endpoint return a real address to registry endpoint.
    39  	Endpoint() (*url.URL, error)
    40  	// GetManager 所有连接管理
    41  	GetManager(cid uint64) *Manager
    42  	// Range 遍历所有连接
    43  	Range(f ConnHandlerFunc)
    44  	// Total 服务器连接总数
    45  	Total() int
    46  }
    47  
    48  // wsServer 基础服务
    49  type wsServer struct {
    50  	managers []*Manager
    51  	handler  *Handler
    52  	opts     *Options
    53  	lis      net.Listener
    54  	endpoint *url.URL
    55  	upgrader *websocket.Upgrader
    56  }
    57  
    58  // NewServer 实例化websocket服务器
    59  func NewServer() Server {
    60  	return &wsServer{
    61  		opts: defOptions,
    62  	}
    63  }
    64  
    65  // Options 服务选项
    66  func (s *wsServer) Options() *Options {
    67  	return s.opts
    68  }
    69  
    70  // Init 初始化
    71  func (s *wsServer) Init(opts ...Option) {
    72  	for _, o := range opts {
    73  		o(s.opts)
    74  	}
    75  	if s.opts.ID == "" {
    76  		s.opts.ID = xid.New().String()
    77  	}
    78  	//初始化连接管理器
    79  	s.managers = make([]*Manager, s.opts.ManagerSize)
    80  	for i := 0; i < s.opts.ManagerSize; i++ {
    81  		s.managers[i] = NewManager()
    82  	}
    83  	//初始化消息处理器
    84  	s.handler = NewHandler(s.opts.WorkerPoolSize, s.opts.Router)
    85  	s.upgrader = &websocket.Upgrader{
    86  		ReadBufferSize:  s.opts.ReadBufferSize,
    87  		WriteBufferSize: s.opts.WriteBufferSize,
    88  		CheckOrigin:     func(r *http.Request) bool { return true },
    89  	}
    90  }
    91  
    92  // Start 启动服务器
    93  func (s *wsServer) Start(ctx context.Context) error {
    94  	// 启动worker工作池机制
    95  	s.handler.Init(s.opts.MaxWorkerTaskLen)
    96  	return s.Listen()
    97  }
    98  
    99  // Stop 关闭服务器
   100  func (s *wsServer) Stop(ctx context.Context) error {
   101  	log.Print("[Websocket] server is stopping")
   102  
   103  	// 先关闭监听新连接,再关闭当前所有连接
   104  	err := s.lis.Close()
   105  	for _, manager := range s.managers {
   106  		manager.Clear()
   107  	}
   108  
   109  	return err
   110  }
   111  
   112  // Listen websocket连接监听
   113  func (s *wsServer) Listen() error {
   114  	var cid uint64 = 1
   115  	lis, err := net.Listen("tcp", s.opts.Addr)
   116  	if err != nil {
   117  		return err
   118  	}
   119  	s.lis = lis
   120  
   121  	if _, err = s.Endpoint(); err != nil {
   122  		return err
   123  	}
   124  	http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
   125  		//设置服务器最大连接控制,如果超过最大连接,则拒绝
   126  		if s.Total() >= s.Options().MaxConn {
   127  			logger.Warn("[ws.start] connection size limit")
   128  			return
   129  		}
   130  		// 如果需要 websocket 认证请设置认证信息
   131  		uid := 0
   132  		if s.Options().OnConnAuth != nil {
   133  			var ok bool
   134  			if uid, ok = s.Options().OnConnAuth(r, s.opts.ID, cid); !ok {
   135  				w.WriteHeader(401)
   136  				return
   137  			}
   138  		}
   139  		// 判断 header 里面是有子协议
   140  		if len(r.Header.Get("Sec-Websocket-Protocol")) > 0 {
   141  			s.upgrader.Subprotocols = websocket.Subprotocols(r)
   142  		}
   143  		// 升级成 websocket 连接
   144  		c, err := s.upgrader.Upgrade(w, r, nil)
   145  		if err != nil {
   146  			w.WriteHeader(500)
   147  			return
   148  		}
   149  
   150  		conn := NewConnect(s, c, cid, uid)
   151  		// 添加连接至管理器
   152  		s.GetManager(cid).Add(conn)
   153  		conn.Start()
   154  		cid++
   155  	})
   156  	log.Printf("[Websocket] server is listening on: %s", lis.Addr().String())
   157  	if err = http.Serve(lis, nil); !errors.Is(err, http.ErrServerClosed) {
   158  		return err
   159  	}
   160  
   161  	return nil
   162  }
   163  
   164  // Endpoint return a real address to registry endpoint.
   165  // examples: http://127.0.0.1:8080
   166  func (s *wsServer) Endpoint() (*url.URL, error) {
   167  	addr, err := util.Extract(s.opts.Addr, s.lis)
   168  	if err != nil {
   169  		return nil, err
   170  	}
   171  	s.endpoint = &url.URL{Scheme: "http", Host: addr}
   172  	return s.endpoint, nil
   173  }
   174  
   175  // GetManager 获取当前连接的管理器
   176  func (s *wsServer) GetManager(cid uint64) *Manager {
   177  	str := strconv.FormatUint(cid, 10)
   178  	idx := cityhash.CityHash32([]byte(str), uint32(len(str))) % uint32(s.opts.ManagerSize)
   179  	return s.managers[idx]
   180  }
   181  
   182  // Range 遍历所有连接
   183  func (s *wsServer) Range(f ConnHandlerFunc) {
   184  	for _, manager := range s.managers {
   185  		_ = manager.Range(f)
   186  	}
   187  }
   188  
   189  // Total 当前服务器的总连接数
   190  func (s *wsServer) Total() int {
   191  	var c int
   192  	for _, manager := range s.managers {
   193  		c += manager.Len()
   194  	}
   195  	return c
   196  }