github.com/ActiveState/cli@v0.0.0-20240508170324-6801f60cd051/internal/ipc/server.go (about)

     1  package ipc
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"net"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/ActiveState/cli/internal/errs"
    11  	"github.com/ActiveState/cli/internal/ipc/internal/flisten"
    12  	"github.com/ActiveState/cli/internal/ipc/sockpath"
    13  	"github.com/ActiveState/cli/internal/logging"
    14  )
    15  
    16  var (
    17  	msgWidth = 640
    18  	network  = "unix"
    19  )
    20  
    21  type SockPath = sockpath.SockPath
    22  
    23  // RequestHandler describes a function that receives a key which is used to
    24  // verify if the handler is useful for a given request. If it is useful, the
    25  // remainder of the function is used for some special behavior (usually, to
    26  // simply return some value). This enables dynamic construction of IPC Server
    27  // handlers/endpoints.
    28  type RequestHandler func(key string) (resp string, isMatched bool)
    29  
    30  type Server struct {
    31  	spath       *SockPath
    32  	reqHandlers []RequestHandler
    33  	ctx         context.Context
    34  	cancel      context.CancelFunc
    35  	errsc       chan error
    36  	donec       chan struct{}
    37  }
    38  
    39  // NewServer constructs a reference to a Server instance which can be populated
    40  // with called-defined handlers, and is preconfigured with ping and stop
    41  // handlers as a low-level flexibility.
    42  func NewServer(topCtx context.Context, spath *SockPath, reqHandlers ...RequestHandler) *Server {
    43  	ctx, cancel := context.WithCancel(topCtx)
    44  
    45  	ipc := Server{
    46  		spath:       spath,
    47  		reqHandlers: make([]RequestHandler, 0, len(reqHandlers)+2),
    48  		ctx:         ctx,
    49  		cancel:      cancel,
    50  		donec:       make(chan struct{}),
    51  	}
    52  
    53  	ipc.reqHandlers = append(ipc.reqHandlers, pingHandler())
    54  	ipc.reqHandlers = append(ipc.reqHandlers, reqHandlers...)
    55  	ipc.reqHandlers = append(ipc.reqHandlers, stopHandler(ipc.Shutdown))
    56  
    57  	return &ipc
    58  }
    59  
    60  func (ipc *Server) Start() error {
    61  	listener, err := flisten.New(ipc.ctx, ipc.spath, network)
    62  	if err != nil {
    63  		// if sock listener construction error is "in use", ensure
    64  		// current owner can be contacted
    65  		if !errors.Is(err, flisten.ErrInUse) {
    66  			return errs.Wrap(err, "Cannot construct file listener")
    67  		}
    68  
    69  		ctx, cancel := context.WithTimeout(ipc.ctx, time.Second*3)
    70  		defer cancel()
    71  
    72  		_, pingErr := NewClient(ipc.spath).PingServer(ctx)
    73  		if pingErr == nil {
    74  			return ErrInUse
    75  		}
    76  
    77  		// if client comm error is "refused", we can reasonably clobber
    78  		// existing sock file
    79  		if !errors.Is(pingErr, flisten.ErrConnRefused) {
    80  			return errs.Wrap(err, "Cannot connect to existing socket file")
    81  		}
    82  
    83  		listener, err = flisten.NewWithCleanup(ctx, ipc.spath, network)
    84  		if err != nil {
    85  			return errs.Wrap(err, "Cannot construct file listener after file cleanup")
    86  		}
    87  	}
    88  
    89  	ipc.errsc = make(chan error) // errsc setup here so wait fn can know if start call was ok
    90  
    91  	go func() {
    92  		var wg sync.WaitGroup
    93  		defer close(ipc.errsc)
    94  
    95  		wg.Add(1)
    96  		go func() {
    97  			defer wg.Done()
    98  
    99  			logging.Debug("waiting for done channel closure")
   100  			<-ipc.donec
   101  			logging.Debug("closing listener")
   102  			listener.Close()
   103  		}()
   104  
   105  		go func() {
   106  			// Continually accept connections and route them to the correct handler.
   107  			for {
   108  				// At this time, the context.Context that is
   109  				// passed into the flisten construction func
   110  				// does not halt the listener. Close() must be
   111  				// called to halt and "doneness" managed.
   112  				err := accept(&wg, listener, ipc.reqHandlers)
   113  				select {
   114  				case <-ipc.donec:
   115  					return
   116  				default:
   117  				}
   118  				if err != nil {
   119  					ipc.errsc <- errs.Wrap(err, "Unexpected accept error")
   120  					return
   121  				}
   122  			}
   123  		}()
   124  
   125  		wg.Wait()
   126  	}()
   127  
   128  	return nil
   129  }
   130  
   131  func (ipc *Server) Shutdown() {
   132  	select {
   133  	case <-ipc.donec:
   134  	default:
   135  		close(ipc.donec)
   136  		ipc.cancel()
   137  	}
   138  }
   139  
   140  func (ipc *Server) Wait() error {
   141  	if ipc.errsc == nil {
   142  		return nil
   143  	}
   144  
   145  	var retErr error
   146  	for err := range ipc.errsc {
   147  		if err != nil && retErr == nil {
   148  			retErr = err
   149  		}
   150  	}
   151  	return retErr
   152  }
   153  
   154  func accept(wg *sync.WaitGroup, l net.Listener, reqHandlers []RequestHandler) error {
   155  	conn, err := l.Accept()
   156  	if err != nil {
   157  		logging.Debug("accept error (closed network expected when closing): %v", err)
   158  		return err
   159  	}
   160  
   161  	wg.Add(1)
   162  	go func() {
   163  		defer wg.Done()
   164  		defer conn.Close()
   165  
   166  		if err := handleMatching(conn, reqHandlers); err != nil {
   167  			logging.Debug(err.Error())
   168  			logging.Error("Unexpected IPC request handling error: %v", err)
   169  			return
   170  		}
   171  	}()
   172  
   173  	return nil
   174  }
   175  
   176  func handleMatching(conn net.Conn, reqHandlers []RequestHandler) error {
   177  	buf := make([]byte, msgWidth)
   178  	n, err := conn.Read(buf)
   179  	if err != nil {
   180  		return errs.Wrap(err, "Failed to read from client connection")
   181  	}
   182  
   183  	key := string(buf[:n])
   184  	output := "not found"
   185  
   186  	for _, reqHandler := range reqHandlers {
   187  		if resp, ok := reqHandler(key); ok {
   188  			output = resp
   189  			break
   190  		}
   191  	}
   192  
   193  	if _, err := conn.Write([]byte(output)); err != nil {
   194  		return errs.Wrap(err, "Failed to write to client connection")
   195  	}
   196  
   197  	return nil
   198  }