github.com/yunabe/lgo@v0.0.0-20190709125917-42c42d410fdf/jupyter/gojupyterscaffold/shelsocket.go (about)

     1  package gojupyterscaffold
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"sync"
     8  
     9  	zmq "github.com/pebbe/zmq4"
    10  )
    11  
    12  var (
    13  	errLoopEnd = errors.New("end of loop")
    14  	// emptyMetadata is set to Metadata in sendDisplayData if metadata is missing.
    15  	emptyMetadata = make(map[string]interface{})
    16  )
    17  
    18  type contextAndCancel struct {
    19  	ctx    context.Context
    20  	cancel func()
    21  }
    22  
    23  type iopubSocket struct {
    24  	socket    *zmq.Socket
    25  	mutex     *sync.Mutex
    26  	hmacKey   []byte
    27  	serverCtx context.Context
    28  	ongoing   map[*contextAndCancel]bool
    29  }
    30  
    31  func newIOPubSocket(serverCtx context.Context, zmqCtx *zmq.Context, cinfo *connectionInfo) (*iopubSocket, error) {
    32  	iopub, err := zmqCtx.NewSocket(zmq.PUB)
    33  	if err != nil {
    34  		return nil, fmt.Errorf("Failed to open iopub socket: %v", err)
    35  	}
    36  	if err := iopub.Bind(cinfo.getAddr(cinfo.IOPubPort)); err != nil {
    37  		return nil, fmt.Errorf("Failed to bind shell socket: %v", err)
    38  	}
    39  	return &iopubSocket{
    40  		socket:    iopub,
    41  		mutex:     &sync.Mutex{},
    42  		hmacKey:   []byte(cinfo.Key),
    43  		serverCtx: serverCtx,
    44  		ongoing:   make(map[*contextAndCancel]bool),
    45  	}, nil
    46  }
    47  
    48  func (s *iopubSocket) close() error {
    49  	return s.socket.Close()
    50  }
    51  
    52  func (s *iopubSocket) addOngoingContext() *contextAndCancel {
    53  	ctx, cancel := context.WithCancel(s.serverCtx)
    54  	ctxCancel := &contextAndCancel{ctx: ctx, cancel: cancel}
    55  	s.mutex.Lock()
    56  	defer s.mutex.Unlock()
    57  	s.ongoing[ctxCancel] = true
    58  	return ctxCancel
    59  }
    60  
    61  func (s *iopubSocket) removeOngoingContext(ctx *contextAndCancel) {
    62  	s.mutex.Lock()
    63  	defer s.mutex.Unlock()
    64  	delete(s.ongoing, ctx)
    65  }
    66  
    67  func (s *iopubSocket) cancelOngoings() {
    68  	s.mutex.Lock()
    69  	defer s.mutex.Unlock()
    70  	for e := range s.ongoing {
    71  		e.cancel()
    72  	}
    73  }
    74  
    75  func (s *iopubSocket) WithOngoingContext(f func(ctx context.Context) error, parent *message) (err error) {
    76  	// We may want to call addOngoingContext in the goroutine for zmq loop.
    77  	// TODO: Reconsider this deeply.
    78  	ctxCancel := s.addOngoingContext()
    79  	if err := s.publishStatus("busy", parent); err != nil {
    80  		return err
    81  	}
    82  	defer func() {
    83  		s.removeOngoingContext(ctxCancel)
    84  		if ierr := s.publishStatus("idle", parent); ierr != nil && err == nil {
    85  			err = ierr
    86  		}
    87  	}()
    88  	return f(ctxCancel.ctx)
    89  }
    90  
    91  func (s *iopubSocket) sendMessage(msg *message) error {
    92  	s.mutex.Lock()
    93  	defer s.mutex.Unlock()
    94  	return msg.Send(s.socket, s.hmacKey)
    95  }
    96  
    97  func (s *iopubSocket) publishStatus(status string, parent *message) error {
    98  	var msg message
    99  	// TODO: Change the format of Identity to kernel.<uuid>.MsgType.
   100  	// http://jupyter-client.readthedocs.io/en/latest/messaging.html#the-wire-protocol
   101  	msg.Identity = [][]byte{[]byte("status")}
   102  	msg.Header.MsgType = "status"
   103  	msg.Header.Version = "5.2"
   104  	msg.Header.Username = "username"
   105  	msg.Header.MsgID = genMsgID()
   106  	msg.ParentHeader = parent.Header
   107  	msg.Content = &struct {
   108  		ExecutionState string `json:"execution_state"`
   109  	}{
   110  		ExecutionState: status,
   111  	}
   112  	return s.sendMessage(&msg)
   113  }
   114  
   115  // http://jupyter-client.readthedocs.io/en/latest/messaging.html#streams-stdout-stderr-etc
   116  func (s *iopubSocket) sendStream(name, text string, parent *message) {
   117  	var msg message
   118  	msg.Identity = [][]byte{[]byte("stream")}
   119  	msg.Header.MsgType = "stream"
   120  	msg.Header.Version = "5.2"
   121  	msg.Header.Username = "username"
   122  	msg.Header.MsgID = genMsgID()
   123  
   124  	msg.ParentHeader = parent.Header
   125  	msg.Content = &struct {
   126  		Name string `json:"name"`
   127  		Text string `json:"text"`
   128  	}{
   129  		Name: name,
   130  		Text: text,
   131  	}
   132  	if err := s.sendMessage(&msg); err != nil {
   133  		logger.Errorf("Failed to send stream: %v", err)
   134  	}
   135  }
   136  
   137  func (s *iopubSocket) sendDisplayData(data *DisplayData, parent *message, update bool) {
   138  	var msg message
   139  	msgType := "display_data"
   140  	if update {
   141  		if data.Transient["display_id"] == nil {
   142  			logger.Warning("update_display_data with no display_id")
   143  		}
   144  		msgType = "update_display_data"
   145  	}
   146  	if data.Metadata == nil {
   147  		var copy DisplayData
   148  		copy = *data
   149  		copy.Metadata = emptyMetadata
   150  		data = &copy
   151  	}
   152  	msg.Identity = [][]byte{[]byte(msgType)}
   153  	msg.Header.MsgType = msgType
   154  	msg.Header.Version = "5.2"
   155  	msg.Header.Username = "username"
   156  	msg.Header.MsgID = genMsgID()
   157  	msg.ParentHeader = parent.Header
   158  	msg.Content = data
   159  	if err := s.sendMessage(&msg); err != nil {
   160  		logger.Errorf("Failed to send stream: %v", err)
   161  	}
   162  }
   163  
   164  type shellSocket struct {
   165  	name          string
   166  	hmacKey       []byte
   167  	socket        *zmq.Socket
   168  	resultPush    *zmq.Socket
   169  	resultPushMux sync.Mutex
   170  	resultPull    *zmq.Socket
   171  	iopub         *iopubSocket
   172  
   173  	handlers  RequestHandlers
   174  	ctx       context.Context
   175  	cancelCtx func()
   176  
   177  	execQueue *executeQueue
   178  }
   179  
   180  func newShellSocket(serverCtx context.Context, zmqCtx *zmq.Context, name string, cinfo *connectionInfo, iopub *iopubSocket, handlers RequestHandlers, cancelCtx func(), execQueue *executeQueue) (*shellSocket, error) {
   181  	var routerAddr string
   182  	if name == "shell" {
   183  		routerAddr = cinfo.getAddr(cinfo.ShellPort)
   184  	} else if name == "control" {
   185  		routerAddr = cinfo.getAddr(cinfo.ControlPort)
   186  	} else {
   187  		return nil, fmt.Errorf("Unknown shell socket name: %q", name)
   188  	}
   189  
   190  	sock, err := zmqCtx.NewSocket(zmq.ROUTER)
   191  	if err != nil {
   192  		return nil, fmt.Errorf("Failed to open %s socket: %v", name, err)
   193  	}
   194  	if err := sock.Bind(routerAddr); err != nil {
   195  		return nil, fmt.Errorf("Failed to bind %s socket: %v", name, err)
   196  	}
   197  	resultPush, err := zmqCtx.NewSocket(zmq.PUSH)
   198  	if err != nil {
   199  		return nil, err
   200  	}
   201  	inprocAddr := fmt.Sprintf("inproc://result-for-%s-socket", name)
   202  	if err := resultPush.Bind(inprocAddr); err != nil {
   203  		return nil, err
   204  	}
   205  	resultPull, err := zmqCtx.NewSocket(zmq.PULL)
   206  	if err != nil {
   207  		return nil, err
   208  	}
   209  	if err := resultPull.Connect(inprocAddr); err != nil {
   210  		return nil, err
   211  	}
   212  	return &shellSocket{
   213  		name:       name,
   214  		hmacKey:    []byte(cinfo.Key),
   215  		socket:     sock,
   216  		resultPush: resultPush,
   217  		resultPull: resultPull,
   218  		iopub:      iopub,
   219  		handlers:   handlers,
   220  		ctx:        serverCtx,
   221  		cancelCtx:  cancelCtx,
   222  		execQueue:  execQueue,
   223  	}, nil
   224  }
   225  
   226  func (s *shellSocket) close() (err error) {
   227  	if cerr := s.socket.Close(); cerr != nil {
   228  		err = cerr
   229  	}
   230  	if cerr := s.resultPush.Close(); cerr != nil {
   231  		err = cerr
   232  	}
   233  	if cerr := s.resultPull.Close(); cerr != nil {
   234  		err = cerr
   235  	}
   236  	return
   237  }
   238  
   239  // pushResult sends a message to shellSocket so that it will be sent to the client.
   240  // This method is goroutine-safe.
   241  func (s *shellSocket) pushResult(msg *message) error {
   242  	s.resultPushMux.Lock()
   243  	defer s.resultPushMux.Unlock()
   244  	return msg.Send(s.resultPush, s.hmacKey)
   245  }
   246  
   247  // notifyLoopEnd notifies the end of the loop to the goroutine in loop().
   248  func (s *shellSocket) notifyLoopEnd() error {
   249  	s.resultPushMux.Lock()
   250  	defer s.resultPushMux.Unlock()
   251  	// Notes:
   252  	// You need to send at least one message with SendMessage.
   253  	// You can not use a zero-length messsages to notify the end of loop because
   254  	// the zero-length messages are not sent to the receiver.
   255  	_, err := s.resultPush.SendMessage("END_OF_LOOP")
   256  	return err
   257  }
   258  
   259  func (s *shellSocket) loop() {
   260  	poller := zmq.NewPoller()
   261  	poller.Add(s.socket, zmq.POLLIN)
   262  	poller.Add(s.resultPull, zmq.POLLIN)
   263  loop:
   264  	for {
   265  		polled, err := poller.Poll(-1)
   266  		if isEINTR(err) {
   267  			// It seems like poller.Poll sometimes return EINTR when a signal is sent
   268  			// even if a signal handler for SIGINT is registered.
   269  			logger.Info("zmq.Poll was interrupted")
   270  			continue
   271  		}
   272  		if err != nil {
   273  			logger.Errorf("Poll on %s socket failed: %v", s.name, err)
   274  			continue
   275  		}
   276  		for _, p := range polled {
   277  			switch p.Socket {
   278  			case s.socket:
   279  				if err := s.handleMessages(); err != nil {
   280  					logger.Errorf("Failed to handle a message on %s socket: %v", s.name, err)
   281  				}
   282  			case s.resultPull:
   283  				err := s.handleResultPull()
   284  				if err == errLoopEnd {
   285  					logger.Infof("Exiting polling loop for %s", s.name)
   286  					break loop
   287  				}
   288  				if err != nil {
   289  					logger.Infof("Failed to handle a message on the result socket of %s: %v", s.name, err)
   290  				}
   291  			default:
   292  				panic(errors.New("zmq.Poll returned an unexpected socket"))
   293  			}
   294  		}
   295  	}
   296  }
   297  
   298  func (s *shellSocket) sendKernelInfo(req *message) error {
   299  	return s.iopub.WithOngoingContext(func(_ context.Context) error {
   300  		var info KernelInfo
   301  		info = s.handlers.HandleKernelInfo()
   302  		res := newMessageWithParent(req)
   303  
   304  		// https://github.com/jupyter/notebook/blob/master/notebook/services/kernels/handlers.py#L174
   305  		res.Header.MsgType = "kernel_info_reply"
   306  		res.Content = &info
   307  		return res.Send(s.socket, s.hmacKey)
   308  	}, req)
   309  }
   310  
   311  func (s *shellSocket) handleMessages() error {
   312  	msgs, err := s.socket.RecvMessageBytes(0)
   313  	if err != nil {
   314  		return fmt.Errorf("Failed to receive data from %s: %v", s.name, err)
   315  	}
   316  	var msg message
   317  	err = msg.Unmarshal(msgs, s.hmacKey)
   318  	if err != nil {
   319  		return fmt.Errorf("Failed to unmarshal messages from %s: %v", s.name, err)
   320  	}
   321  	logger.Warningf("MsgType in %s: %q", s.name, msg.Header.MsgType)
   322  	switch typ := msg.Header.MsgType; typ {
   323  	case "kernel_info_request":
   324  		if err := s.sendKernelInfo(&msg); err != nil {
   325  			logger.Errorf("Failed to handle kernel_info_request: %v", err)
   326  		}
   327  	case "shutdown_request":
   328  		logger.Info("received shutdown_request.")
   329  		s.cancelCtx()
   330  		// TODO: Send shutdown_reply
   331  	case "execute_request":
   332  		s.execQueue.push(&msg, s)
   333  	case "complete_request":
   334  		go func() {
   335  			reply := s.handlers.HandleComplete(msg.Content.(*CompleteRequest))
   336  			if reply == nil {
   337  				reply = &CompleteReply{
   338  					Status: "ok",
   339  				}
   340  			}
   341  			if reply.Status == "ok" && reply.Matches == nil {
   342  				// matches must not be null because `jupyter console` can not accept null for matches as of 2017/12.
   343  				// https://goo.gl/QRd5rG
   344  				reply.Matches = make([]string, 0)
   345  			}
   346  			res := newMessageWithParent(&msg)
   347  			res.Header.MsgType = "complete_reply"
   348  			res.Content = reply
   349  			s.pushResult(res)
   350  		}()
   351  	case "inspect_request":
   352  		go func() {
   353  			reply := s.handlers.HandleInspect(msg.Content.(*InspectRequest))
   354  			if reply == nil {
   355  				reply = &InspectReply{
   356  					Status: "ok",
   357  					Found:  false,
   358  				}
   359  			}
   360  			res := newMessageWithParent(&msg)
   361  			res.Header.MsgType = "inspect_reply"
   362  			res.Content = reply
   363  			s.pushResult(res)
   364  		}()
   365  	case "is_complete_request":
   366  		go func() {
   367  			reply := s.handlers.HandleIsComplete(msg.Content.(*IsCompleteRequest))
   368  			if reply == nil {
   369  				reply = &IsCompleteReply{Status: "unknown"}
   370  			}
   371  			res := newMessageWithParent(&msg)
   372  			res.Header.MsgType = "is_complete_reply"
   373  			res.Content = reply
   374  			s.pushResult(res)
   375  		}()
   376  	case "gofmt_request":
   377  		go func() {
   378  			res := newMessageWithParent(&msg)
   379  			res.Header.MsgType = "gofmt_reply"
   380  			reply, err := s.handlers.HandleGoFmt(msg.Content.(*GoFmtRequest))
   381  			if err != nil {
   382  				res.Content = &errorReply{
   383  					Status: "error",
   384  					Ename:  "error",
   385  					Evalue: err.Error(),
   386  				}
   387  			} else {
   388  				res.Content = reply
   389  			}
   390  			s.pushResult(res)
   391  		}()
   392  	default:
   393  		logger.Warningf("Unsupported MsgType in %s: %q", s.name, typ)
   394  	}
   395  	return nil
   396  }
   397  
   398  // Forward a message on result_pull to socket.
   399  func (s *shellSocket) handleResultPull() error {
   400  	msgs, err := s.resultPull.RecvMessageBytes(0)
   401  	if err != nil {
   402  		return err
   403  	}
   404  	if len(msgs) == 1 && string(msgs[0]) == "END_OF_LOOP" {
   405  		return errLoopEnd
   406  	}
   407  	// For some reasons, execute_reply is not handled correctly
   408  	// unless we unmarshal and marshal msgs rather than just forwarding them.
   409  	var msg message
   410  	if err := msg.Unmarshal(msgs, s.hmacKey); err != nil {
   411  		return err
   412  	}
   413  	return msg.Send(s.socket, s.hmacKey)
   414  }