github.com/binbinly/pkg@v0.0.11-0.20240321014439-f4fbf666eb0f/transport/ws/ws.go (about) 1 package ws 2 3 import ( 4 "context" 5 "errors" 6 "log" 7 "net" 8 "net/http" 9 "net/url" 10 "strconv" 11 12 "github.com/binbinly/pkg/logger" 13 "github.com/binbinly/pkg/util" 14 "github.com/gorilla/websocket" 15 "github.com/rs/xid" 16 "github.com/zhenjl/cityhash" 17 ) 18 19 var ( 20 // ErrConnNotFound 连接未找到 21 ErrConnNotFound = errors.New("connection not found") 22 // ErrConnNotFinish 连接未完成,不可以发送消息 23 ErrConnNotFinish = errors.New("connection not finish when send msg") 24 ) 25 26 type ConnHandlerFunc func(cid uint64, conn Connection) error 27 28 // Server is a simple micro server abstraction 29 type Server interface { 30 // Init Initialise options 31 Init(...Option) 32 // Options Retrieve the options 33 Options() *Options 34 // Start the server 35 Start(ctx context.Context) error 36 // Stop the server 37 Stop(ctx context.Context) error 38 // Endpoint return a real address to registry endpoint. 39 Endpoint() (*url.URL, error) 40 // GetManager 所有连接管理 41 GetManager(cid uint64) *Manager 42 // Range 遍历所有连接 43 Range(f ConnHandlerFunc) 44 // Total 服务器连接总数 45 Total() int 46 } 47 48 // wsServer 基础服务 49 type wsServer struct { 50 managers []*Manager 51 handler *Handler 52 opts *Options 53 lis net.Listener 54 endpoint *url.URL 55 upgrader *websocket.Upgrader 56 } 57 58 // NewServer 实例化websocket服务器 59 func NewServer() Server { 60 return &wsServer{ 61 opts: defOptions, 62 } 63 } 64 65 // Options 服务选项 66 func (s *wsServer) Options() *Options { 67 return s.opts 68 } 69 70 // Init 初始化 71 func (s *wsServer) Init(opts ...Option) { 72 for _, o := range opts { 73 o(s.opts) 74 } 75 if s.opts.ID == "" { 76 s.opts.ID = xid.New().String() 77 } 78 //初始化连接管理器 79 s.managers = make([]*Manager, s.opts.ManagerSize) 80 for i := 0; i < s.opts.ManagerSize; i++ { 81 s.managers[i] = NewManager() 82 } 83 //初始化消息处理器 84 s.handler = NewHandler(s.opts.WorkerPoolSize, s.opts.Router) 85 s.upgrader = &websocket.Upgrader{ 86 ReadBufferSize: s.opts.ReadBufferSize, 87 WriteBufferSize: s.opts.WriteBufferSize, 88 CheckOrigin: func(r *http.Request) bool { return true }, 89 } 90 } 91 92 // Start 启动服务器 93 func (s *wsServer) Start(ctx context.Context) error { 94 // 启动worker工作池机制 95 s.handler.Init(s.opts.MaxWorkerTaskLen) 96 return s.Listen() 97 } 98 99 // Stop 关闭服务器 100 func (s *wsServer) Stop(ctx context.Context) error { 101 log.Print("[Websocket] server is stopping") 102 103 // 先关闭监听新连接,再关闭当前所有连接 104 err := s.lis.Close() 105 for _, manager := range s.managers { 106 manager.Clear() 107 } 108 109 return err 110 } 111 112 // Listen websocket连接监听 113 func (s *wsServer) Listen() error { 114 var cid uint64 = 1 115 lis, err := net.Listen("tcp", s.opts.Addr) 116 if err != nil { 117 return err 118 } 119 s.lis = lis 120 121 if _, err = s.Endpoint(); err != nil { 122 return err 123 } 124 http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 125 //设置服务器最大连接控制,如果超过最大连接,则拒绝 126 if s.Total() >= s.Options().MaxConn { 127 logger.Warn("[ws.start] connection size limit") 128 return 129 } 130 // 如果需要 websocket 认证请设置认证信息 131 uid := 0 132 if s.Options().OnConnAuth != nil { 133 var ok bool 134 if uid, ok = s.Options().OnConnAuth(r, s.opts.ID, cid); !ok { 135 w.WriteHeader(401) 136 return 137 } 138 } 139 // 判断 header 里面是有子协议 140 if len(r.Header.Get("Sec-Websocket-Protocol")) > 0 { 141 s.upgrader.Subprotocols = websocket.Subprotocols(r) 142 } 143 // 升级成 websocket 连接 144 c, err := s.upgrader.Upgrade(w, r, nil) 145 if err != nil { 146 w.WriteHeader(500) 147 return 148 } 149 150 conn := NewConnect(s, c, cid, uid) 151 // 添加连接至管理器 152 s.GetManager(cid).Add(conn) 153 conn.Start() 154 cid++ 155 }) 156 log.Printf("[Websocket] server is listening on: %s", lis.Addr().String()) 157 if err = http.Serve(lis, nil); !errors.Is(err, http.ErrServerClosed) { 158 return err 159 } 160 161 return nil 162 } 163 164 // Endpoint return a real address to registry endpoint. 165 // examples: http://127.0.0.1:8080 166 func (s *wsServer) Endpoint() (*url.URL, error) { 167 addr, err := util.Extract(s.opts.Addr, s.lis) 168 if err != nil { 169 return nil, err 170 } 171 s.endpoint = &url.URL{Scheme: "http", Host: addr} 172 return s.endpoint, nil 173 } 174 175 // GetManager 获取当前连接的管理器 176 func (s *wsServer) GetManager(cid uint64) *Manager { 177 str := strconv.FormatUint(cid, 10) 178 idx := cityhash.CityHash32([]byte(str), uint32(len(str))) % uint32(s.opts.ManagerSize) 179 return s.managers[idx] 180 } 181 182 // Range 遍历所有连接 183 func (s *wsServer) Range(f ConnHandlerFunc) { 184 for _, manager := range s.managers { 185 _ = manager.Range(f) 186 } 187 } 188 189 // Total 当前服务器的总连接数 190 func (s *wsServer) Total() int { 191 var c int 192 for _, manager := range s.managers { 193 c += manager.Len() 194 } 195 return c 196 }