github.com/yunabe/lgo@v0.0.0-20190709125917-42c42d410fdf/jupyter/gojupyterscaffold/shelsocket.go (about) 1 package gojupyterscaffold 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "sync" 8 9 zmq "github.com/pebbe/zmq4" 10 ) 11 12 var ( 13 errLoopEnd = errors.New("end of loop") 14 // emptyMetadata is set to Metadata in sendDisplayData if metadata is missing. 15 emptyMetadata = make(map[string]interface{}) 16 ) 17 18 type contextAndCancel struct { 19 ctx context.Context 20 cancel func() 21 } 22 23 type iopubSocket struct { 24 socket *zmq.Socket 25 mutex *sync.Mutex 26 hmacKey []byte 27 serverCtx context.Context 28 ongoing map[*contextAndCancel]bool 29 } 30 31 func newIOPubSocket(serverCtx context.Context, zmqCtx *zmq.Context, cinfo *connectionInfo) (*iopubSocket, error) { 32 iopub, err := zmqCtx.NewSocket(zmq.PUB) 33 if err != nil { 34 return nil, fmt.Errorf("Failed to open iopub socket: %v", err) 35 } 36 if err := iopub.Bind(cinfo.getAddr(cinfo.IOPubPort)); err != nil { 37 return nil, fmt.Errorf("Failed to bind shell socket: %v", err) 38 } 39 return &iopubSocket{ 40 socket: iopub, 41 mutex: &sync.Mutex{}, 42 hmacKey: []byte(cinfo.Key), 43 serverCtx: serverCtx, 44 ongoing: make(map[*contextAndCancel]bool), 45 }, nil 46 } 47 48 func (s *iopubSocket) close() error { 49 return s.socket.Close() 50 } 51 52 func (s *iopubSocket) addOngoingContext() *contextAndCancel { 53 ctx, cancel := context.WithCancel(s.serverCtx) 54 ctxCancel := &contextAndCancel{ctx: ctx, cancel: cancel} 55 s.mutex.Lock() 56 defer s.mutex.Unlock() 57 s.ongoing[ctxCancel] = true 58 return ctxCancel 59 } 60 61 func (s *iopubSocket) removeOngoingContext(ctx *contextAndCancel) { 62 s.mutex.Lock() 63 defer s.mutex.Unlock() 64 delete(s.ongoing, ctx) 65 } 66 67 func (s *iopubSocket) cancelOngoings() { 68 s.mutex.Lock() 69 defer s.mutex.Unlock() 70 for e := range s.ongoing { 71 e.cancel() 72 } 73 } 74 75 func (s *iopubSocket) WithOngoingContext(f func(ctx context.Context) error, parent *message) (err error) { 76 // We may want to call addOngoingContext in the goroutine for zmq loop. 77 // TODO: Reconsider this deeply. 78 ctxCancel := s.addOngoingContext() 79 if err := s.publishStatus("busy", parent); err != nil { 80 return err 81 } 82 defer func() { 83 s.removeOngoingContext(ctxCancel) 84 if ierr := s.publishStatus("idle", parent); ierr != nil && err == nil { 85 err = ierr 86 } 87 }() 88 return f(ctxCancel.ctx) 89 } 90 91 func (s *iopubSocket) sendMessage(msg *message) error { 92 s.mutex.Lock() 93 defer s.mutex.Unlock() 94 return msg.Send(s.socket, s.hmacKey) 95 } 96 97 func (s *iopubSocket) publishStatus(status string, parent *message) error { 98 var msg message 99 // TODO: Change the format of Identity to kernel.<uuid>.MsgType. 100 // http://jupyter-client.readthedocs.io/en/latest/messaging.html#the-wire-protocol 101 msg.Identity = [][]byte{[]byte("status")} 102 msg.Header.MsgType = "status" 103 msg.Header.Version = "5.2" 104 msg.Header.Username = "username" 105 msg.Header.MsgID = genMsgID() 106 msg.ParentHeader = parent.Header 107 msg.Content = &struct { 108 ExecutionState string `json:"execution_state"` 109 }{ 110 ExecutionState: status, 111 } 112 return s.sendMessage(&msg) 113 } 114 115 // http://jupyter-client.readthedocs.io/en/latest/messaging.html#streams-stdout-stderr-etc 116 func (s *iopubSocket) sendStream(name, text string, parent *message) { 117 var msg message 118 msg.Identity = [][]byte{[]byte("stream")} 119 msg.Header.MsgType = "stream" 120 msg.Header.Version = "5.2" 121 msg.Header.Username = "username" 122 msg.Header.MsgID = genMsgID() 123 124 msg.ParentHeader = parent.Header 125 msg.Content = &struct { 126 Name string `json:"name"` 127 Text string `json:"text"` 128 }{ 129 Name: name, 130 Text: text, 131 } 132 if err := s.sendMessage(&msg); err != nil { 133 logger.Errorf("Failed to send stream: %v", err) 134 } 135 } 136 137 func (s *iopubSocket) sendDisplayData(data *DisplayData, parent *message, update bool) { 138 var msg message 139 msgType := "display_data" 140 if update { 141 if data.Transient["display_id"] == nil { 142 logger.Warning("update_display_data with no display_id") 143 } 144 msgType = "update_display_data" 145 } 146 if data.Metadata == nil { 147 var copy DisplayData 148 copy = *data 149 copy.Metadata = emptyMetadata 150 data = © 151 } 152 msg.Identity = [][]byte{[]byte(msgType)} 153 msg.Header.MsgType = msgType 154 msg.Header.Version = "5.2" 155 msg.Header.Username = "username" 156 msg.Header.MsgID = genMsgID() 157 msg.ParentHeader = parent.Header 158 msg.Content = data 159 if err := s.sendMessage(&msg); err != nil { 160 logger.Errorf("Failed to send stream: %v", err) 161 } 162 } 163 164 type shellSocket struct { 165 name string 166 hmacKey []byte 167 socket *zmq.Socket 168 resultPush *zmq.Socket 169 resultPushMux sync.Mutex 170 resultPull *zmq.Socket 171 iopub *iopubSocket 172 173 handlers RequestHandlers 174 ctx context.Context 175 cancelCtx func() 176 177 execQueue *executeQueue 178 } 179 180 func newShellSocket(serverCtx context.Context, zmqCtx *zmq.Context, name string, cinfo *connectionInfo, iopub *iopubSocket, handlers RequestHandlers, cancelCtx func(), execQueue *executeQueue) (*shellSocket, error) { 181 var routerAddr string 182 if name == "shell" { 183 routerAddr = cinfo.getAddr(cinfo.ShellPort) 184 } else if name == "control" { 185 routerAddr = cinfo.getAddr(cinfo.ControlPort) 186 } else { 187 return nil, fmt.Errorf("Unknown shell socket name: %q", name) 188 } 189 190 sock, err := zmqCtx.NewSocket(zmq.ROUTER) 191 if err != nil { 192 return nil, fmt.Errorf("Failed to open %s socket: %v", name, err) 193 } 194 if err := sock.Bind(routerAddr); err != nil { 195 return nil, fmt.Errorf("Failed to bind %s socket: %v", name, err) 196 } 197 resultPush, err := zmqCtx.NewSocket(zmq.PUSH) 198 if err != nil { 199 return nil, err 200 } 201 inprocAddr := fmt.Sprintf("inproc://result-for-%s-socket", name) 202 if err := resultPush.Bind(inprocAddr); err != nil { 203 return nil, err 204 } 205 resultPull, err := zmqCtx.NewSocket(zmq.PULL) 206 if err != nil { 207 return nil, err 208 } 209 if err := resultPull.Connect(inprocAddr); err != nil { 210 return nil, err 211 } 212 return &shellSocket{ 213 name: name, 214 hmacKey: []byte(cinfo.Key), 215 socket: sock, 216 resultPush: resultPush, 217 resultPull: resultPull, 218 iopub: iopub, 219 handlers: handlers, 220 ctx: serverCtx, 221 cancelCtx: cancelCtx, 222 execQueue: execQueue, 223 }, nil 224 } 225 226 func (s *shellSocket) close() (err error) { 227 if cerr := s.socket.Close(); cerr != nil { 228 err = cerr 229 } 230 if cerr := s.resultPush.Close(); cerr != nil { 231 err = cerr 232 } 233 if cerr := s.resultPull.Close(); cerr != nil { 234 err = cerr 235 } 236 return 237 } 238 239 // pushResult sends a message to shellSocket so that it will be sent to the client. 240 // This method is goroutine-safe. 241 func (s *shellSocket) pushResult(msg *message) error { 242 s.resultPushMux.Lock() 243 defer s.resultPushMux.Unlock() 244 return msg.Send(s.resultPush, s.hmacKey) 245 } 246 247 // notifyLoopEnd notifies the end of the loop to the goroutine in loop(). 248 func (s *shellSocket) notifyLoopEnd() error { 249 s.resultPushMux.Lock() 250 defer s.resultPushMux.Unlock() 251 // Notes: 252 // You need to send at least one message with SendMessage. 253 // You can not use a zero-length messsages to notify the end of loop because 254 // the zero-length messages are not sent to the receiver. 255 _, err := s.resultPush.SendMessage("END_OF_LOOP") 256 return err 257 } 258 259 func (s *shellSocket) loop() { 260 poller := zmq.NewPoller() 261 poller.Add(s.socket, zmq.POLLIN) 262 poller.Add(s.resultPull, zmq.POLLIN) 263 loop: 264 for { 265 polled, err := poller.Poll(-1) 266 if isEINTR(err) { 267 // It seems like poller.Poll sometimes return EINTR when a signal is sent 268 // even if a signal handler for SIGINT is registered. 269 logger.Info("zmq.Poll was interrupted") 270 continue 271 } 272 if err != nil { 273 logger.Errorf("Poll on %s socket failed: %v", s.name, err) 274 continue 275 } 276 for _, p := range polled { 277 switch p.Socket { 278 case s.socket: 279 if err := s.handleMessages(); err != nil { 280 logger.Errorf("Failed to handle a message on %s socket: %v", s.name, err) 281 } 282 case s.resultPull: 283 err := s.handleResultPull() 284 if err == errLoopEnd { 285 logger.Infof("Exiting polling loop for %s", s.name) 286 break loop 287 } 288 if err != nil { 289 logger.Infof("Failed to handle a message on the result socket of %s: %v", s.name, err) 290 } 291 default: 292 panic(errors.New("zmq.Poll returned an unexpected socket")) 293 } 294 } 295 } 296 } 297 298 func (s *shellSocket) sendKernelInfo(req *message) error { 299 return s.iopub.WithOngoingContext(func(_ context.Context) error { 300 var info KernelInfo 301 info = s.handlers.HandleKernelInfo() 302 res := newMessageWithParent(req) 303 304 // https://github.com/jupyter/notebook/blob/master/notebook/services/kernels/handlers.py#L174 305 res.Header.MsgType = "kernel_info_reply" 306 res.Content = &info 307 return res.Send(s.socket, s.hmacKey) 308 }, req) 309 } 310 311 func (s *shellSocket) handleMessages() error { 312 msgs, err := s.socket.RecvMessageBytes(0) 313 if err != nil { 314 return fmt.Errorf("Failed to receive data from %s: %v", s.name, err) 315 } 316 var msg message 317 err = msg.Unmarshal(msgs, s.hmacKey) 318 if err != nil { 319 return fmt.Errorf("Failed to unmarshal messages from %s: %v", s.name, err) 320 } 321 logger.Warningf("MsgType in %s: %q", s.name, msg.Header.MsgType) 322 switch typ := msg.Header.MsgType; typ { 323 case "kernel_info_request": 324 if err := s.sendKernelInfo(&msg); err != nil { 325 logger.Errorf("Failed to handle kernel_info_request: %v", err) 326 } 327 case "shutdown_request": 328 logger.Info("received shutdown_request.") 329 s.cancelCtx() 330 // TODO: Send shutdown_reply 331 case "execute_request": 332 s.execQueue.push(&msg, s) 333 case "complete_request": 334 go func() { 335 reply := s.handlers.HandleComplete(msg.Content.(*CompleteRequest)) 336 if reply == nil { 337 reply = &CompleteReply{ 338 Status: "ok", 339 } 340 } 341 if reply.Status == "ok" && reply.Matches == nil { 342 // matches must not be null because `jupyter console` can not accept null for matches as of 2017/12. 343 // https://goo.gl/QRd5rG 344 reply.Matches = make([]string, 0) 345 } 346 res := newMessageWithParent(&msg) 347 res.Header.MsgType = "complete_reply" 348 res.Content = reply 349 s.pushResult(res) 350 }() 351 case "inspect_request": 352 go func() { 353 reply := s.handlers.HandleInspect(msg.Content.(*InspectRequest)) 354 if reply == nil { 355 reply = &InspectReply{ 356 Status: "ok", 357 Found: false, 358 } 359 } 360 res := newMessageWithParent(&msg) 361 res.Header.MsgType = "inspect_reply" 362 res.Content = reply 363 s.pushResult(res) 364 }() 365 case "is_complete_request": 366 go func() { 367 reply := s.handlers.HandleIsComplete(msg.Content.(*IsCompleteRequest)) 368 if reply == nil { 369 reply = &IsCompleteReply{Status: "unknown"} 370 } 371 res := newMessageWithParent(&msg) 372 res.Header.MsgType = "is_complete_reply" 373 res.Content = reply 374 s.pushResult(res) 375 }() 376 case "gofmt_request": 377 go func() { 378 res := newMessageWithParent(&msg) 379 res.Header.MsgType = "gofmt_reply" 380 reply, err := s.handlers.HandleGoFmt(msg.Content.(*GoFmtRequest)) 381 if err != nil { 382 res.Content = &errorReply{ 383 Status: "error", 384 Ename: "error", 385 Evalue: err.Error(), 386 } 387 } else { 388 res.Content = reply 389 } 390 s.pushResult(res) 391 }() 392 default: 393 logger.Warningf("Unsupported MsgType in %s: %q", s.name, typ) 394 } 395 return nil 396 } 397 398 // Forward a message on result_pull to socket. 399 func (s *shellSocket) handleResultPull() error { 400 msgs, err := s.resultPull.RecvMessageBytes(0) 401 if err != nil { 402 return err 403 } 404 if len(msgs) == 1 && string(msgs[0]) == "END_OF_LOOP" { 405 return errLoopEnd 406 } 407 // For some reasons, execute_reply is not handled correctly 408 // unless we unmarshal and marshal msgs rather than just forwarding them. 409 var msg message 410 if err := msg.Unmarshal(msgs, s.hmacKey); err != nil { 411 return err 412 } 413 return msg.Send(s.socket, s.hmacKey) 414 }