github.com/nyan233/littlerpc@v0.4.6-0.20230316182519-0c8d5c48abaf/core/common/transport/nbio_websocket.go (about)

     1  package transport
     2  
     3  import (
     4  	"errors"
     5  	"github.com/lesismal/nbio/nbhttp"
     6  	"github.com/lesismal/nbio/nbhttp/websocket"
     7  	"net/http"
     8  	"net/url"
     9  	"sync/atomic"
    10  	"time"
    11  	"unsafe"
    12  )
    13  
    14  const (
    15  	wsUrl = "/LittleRpc-WebSocket"
    16  )
    17  
    18  // NBioWebSocketEngine 不设置错误处理回调函数则采用默认回调
    19  // 默认函数中遇到错误就会panic,所以不期望panic的话一定要设置错误处理回调
    20  type NBioWebSocketEngine struct {
    21  	started  int32
    22  	closed   int32
    23  	wsEngine *nbhttp.Engine
    24  	onMsg    func(conn ConnAdapter, bytes []byte)
    25  	onClose  func(conn ConnAdapter, err error)
    26  	onOpen   func(conn ConnAdapter)
    27  	onErr    func(err error)
    28  }
    29  
    30  func NewNBioWebsocketClient() ClientBuilder {
    31  	return &NBioWebSocketEngine{
    32  		wsEngine: nbhttp.NewEngine(nbhttp.Config{
    33  			NPoller: 1,
    34  			NParser: 1,
    35  		}),
    36  		onMsg: func(conn ConnAdapter, bytes []byte) {
    37  			return
    38  		},
    39  		onOpen: func(conn ConnAdapter) {
    40  			return
    41  		},
    42  		onClose: func(conn ConnAdapter, err error) {
    43  			return
    44  		},
    45  		onErr: func(err error) {
    46  			return
    47  		},
    48  	}
    49  }
    50  
    51  func NewNBioWebsocketServer(config NetworkServerConfig) ServerBuilder {
    52  	nConfig := nbhttp.Config{}
    53  	nConfig.Name = "LittleRpc-Server-WebSocket"
    54  	nConfig.Network = "tcp"
    55  	nConfig.ReleaseWebsocketPayload = true
    56  	nConfig.ReadBufferSize = ReadBufferSize
    57  	nConfig.MaxWriteBufferSize = MaxWriteBufferSize
    58  	nConfig.Addrs = config.Addrs
    59  	server := &NBioWebSocketEngine{wsEngine: nbhttp.NewEngine(nConfig)}
    60  	// set default function
    61  	server.onErr = func(err error) {
    62  		panic(interface{}(err))
    63  	}
    64  	server.onOpen = func(conn ConnAdapter) {
    65  		return
    66  	}
    67  	server.onMsg = func(conn ConnAdapter, bytes []byte) {
    68  		return
    69  	}
    70  	server.onClose = func(conn ConnAdapter, err error) {
    71  		return
    72  	}
    73  	return server
    74  }
    75  
    76  func (engine *NBioWebSocketEngine) NewConn(config NetworkClientConfig) (ConnAdapter, error) {
    77  	dialer := &websocket.Dialer{
    78  		Engine: engine.wsEngine,
    79  		Upgrader: func() *websocket.Upgrader {
    80  			u := websocket.NewUpgrader()
    81  			u.OnMessage(func(conn *websocket.Conn, messageType websocket.MessageType, bytes []byte) {
    82  				engine.onMsg((*WsConnAdapter)(unsafe.Pointer(conn)), bytes)
    83  			})
    84  			u.OnOpen(func(conn *websocket.Conn) {
    85  				engine.onOpen((*WsConnAdapter)(unsafe.Pointer(conn)))
    86  			})
    87  			u.OnClose(func(conn *websocket.Conn, err error) {
    88  				engine.onClose((*WsConnAdapter)(unsafe.Pointer(conn)), err)
    89  			})
    90  			return u
    91  		}(),
    92  		DialTimeout: time.Second * 5,
    93  	}
    94  	u := url.URL{
    95  		Scheme: "wss",
    96  		Host:   config.ServerAddr,
    97  		Path:   wsUrl,
    98  	}
    99  	if config.TLSPriPem == nil {
   100  		u.Scheme = "ws"
   101  	}
   102  	conn, _, err := dialer.Dial(u.String(), nil)
   103  	if err != nil {
   104  		return nil, err
   105  	}
   106  	return (*WsConnAdapter)(unsafe.Pointer(conn)), nil
   107  }
   108  
   109  func (engine *NBioWebSocketEngine) Server() ServerEngine {
   110  	return engine
   111  }
   112  
   113  func (engine *NBioWebSocketEngine) Client() ClientEngine {
   114  	return engine
   115  }
   116  
   117  func (engine *NBioWebSocketEngine) EventDriveInter() EventDriveInter {
   118  	return engine
   119  }
   120  
   121  func (engine *NBioWebSocketEngine) OnRead(f func(conn ConnAdapter)) {
   122  	return
   123  }
   124  
   125  func (engine *NBioWebSocketEngine) OnMessage(f func(conn ConnAdapter, data []byte)) {
   126  	engine.onMsg = f
   127  }
   128  
   129  func (engine *NBioWebSocketEngine) OnOpen(f func(conn ConnAdapter)) {
   130  	engine.onOpen = f
   131  }
   132  
   133  func (engine *NBioWebSocketEngine) OnClose(f func(conn ConnAdapter, err error)) {
   134  	engine.onClose = f
   135  }
   136  
   137  func (engine *NBioWebSocketEngine) Start() error {
   138  	if !atomic.CompareAndSwapInt32(&engine.started, 0, 1) {
   139  		return errors.New("wsEngine already started")
   140  	}
   141  	mux := &http.ServeMux{}
   142  	mux.HandleFunc(wsUrl, func(writer http.ResponseWriter, request *http.Request) {
   143  		ws := websocket.NewUpgrader()
   144  		ws.OnMessage(func(conn *websocket.Conn, messageType websocket.MessageType, bytes []byte) {
   145  			engine.onMsg((*WsConnAdapter)(unsafe.Pointer(conn)), bytes)
   146  		})
   147  		ws.OnClose(func(conn *websocket.Conn, err error) {
   148  			engine.onClose((*WsConnAdapter)(unsafe.Pointer(conn)), err)
   149  		})
   150  		ws.OnOpen(func(conn *websocket.Conn) {
   151  			engine.onOpen((*WsConnAdapter)(unsafe.Pointer(conn)))
   152  		})
   153  		// 从Http升级到WebSocket
   154  		conn, err := ws.Upgrade(writer, request, nil)
   155  		if err != nil {
   156  			engine.onErr(err)
   157  		}
   158  		wsConn := conn.(*websocket.Conn)
   159  		_ = wsConn
   160  	})
   161  	engine.wsEngine.Handler = mux
   162  	return engine.wsEngine.Start()
   163  }
   164  
   165  func (engine *NBioWebSocketEngine) Stop() error {
   166  	if !atomic.CompareAndSwapInt32(&engine.closed, 0, 1) {
   167  		return errors.New("wsEngine already closed")
   168  	}
   169  	engine.wsEngine.Stop()
   170  	return nil
   171  }
   172  
   173  type WsConnAdapter struct {
   174  	websocket.Conn
   175  }
   176  
   177  func (w *WsConnAdapter) Write(b []byte) (n int, err error) {
   178  	err = w.WriteMessage(websocket.BinaryMessage, b)
   179  	if err != nil {
   180  		return -1, err
   181  	}
   182  	return len(b), nil
   183  }
   184  
   185  func (w *WsConnAdapter) SetSource(s interface{}) {
   186  	w.SetSession(s)
   187  }
   188  
   189  func (w *WsConnAdapter) Source() interface{} {
   190  	return w.Session()
   191  }