github.com/keybase/client/go@v0.0.0-20240309051027-028f7c731f8b/kbhttp/srv.go (about)

     1  // Copyright 2018 Keybase, Inc. All rights reserved. Use of
     2  // this source code is governed by the included BSD license.
     3  
     4  package kbhttp
     5  
     6  import (
     7  	"crypto/rand"
     8  	"errors"
     9  	"fmt"
    10  	"math/big"
    11  	"net"
    12  	"net/http"
    13  	"sync"
    14  
    15  	"github.com/keybase/client/go/logger"
    16  )
    17  
    18  // ListenerSource represents where an HTTP server should listen.
    19  type ListenerSource interface {
    20  	GetListener() (net.Listener, string, error)
    21  }
    22  
    23  // AutoPortListenerSource means listen on a port that's picked automatically by
    24  // the kernel.
    25  type AutoPortListenerSource struct{}
    26  
    27  // GetListener implements ListenerSource.
    28  func (r AutoPortListenerSource) GetListener() (net.Listener, string, error) {
    29  	localhost := "127.0.0.1"
    30  	listener, err := net.Listen("tcp", fmt.Sprintf("%s:0", localhost))
    31  	if err != nil {
    32  		return nil, "", err
    33  	}
    34  	port := listener.Addr().(*net.TCPAddr).Port
    35  	address := fmt.Sprintf("%s:%d", localhost, port)
    36  	return listener, address, nil
    37  }
    38  
    39  // NewAutoPortListenerSource creates a new AutoPortListenerSource.
    40  func NewAutoPortListenerSource() *AutoPortListenerSource {
    41  	return &AutoPortListenerSource{}
    42  }
    43  
    44  var ErrPinnedPortInUse = errors.New("unable to bind to pinned port")
    45  
    46  // PortRangeListenerSource means listen on the given range.
    47  type PortRangeListenerSource struct {
    48  	sync.Mutex
    49  	pinnedPort int
    50  	low, high  int
    51  }
    52  
    53  // NewPortRangeListenerSource creates a new PortListenerSource
    54  // listening on low to high (inclusive).
    55  func NewPortRangeListenerSource(low, high int) *PortRangeListenerSource {
    56  	return &PortRangeListenerSource{
    57  		low:  low,
    58  		high: high,
    59  	}
    60  }
    61  
    62  // NewFixedPortListenerSource creates a new PortListenerSource
    63  // listening on the given port.
    64  func NewFixedPortListenerSource(port int) *PortRangeListenerSource {
    65  	return NewPortRangeListenerSource(port, port)
    66  }
    67  
    68  // GetListener implements ListenerSource.
    69  func (p *PortRangeListenerSource) GetListener() (listener net.Listener, address string, err error) {
    70  	p.Lock()
    71  	defer p.Unlock()
    72  	localhost := "127.0.0.1"
    73  	if p.pinnedPort > 0 {
    74  		address = fmt.Sprintf("%s:%d", localhost, p.pinnedPort)
    75  		if listener, err = net.Listen("tcp", address); err != nil {
    76  			return listener, address, ErrPinnedPortInUse
    77  		}
    78  		return listener, address, nil
    79  	}
    80  	for port := p.low; port <= p.high; port++ {
    81  		address = fmt.Sprintf("%s:%d", localhost, port)
    82  		listener, err = net.Listen("tcp", address)
    83  		if err == nil {
    84  			p.pinnedPort = port
    85  			return listener, address, nil
    86  		}
    87  	}
    88  	return listener, address, errors.New("failed to bind to port in range")
    89  }
    90  
    91  // RandomPortRangeListenerSource listens on a port randomly chosen within a
    92  // given range.
    93  type RandomPortRangeListenerSource struct {
    94  	sync.Mutex
    95  	pinnedPort int
    96  	low, high  int
    97  }
    98  
    99  // NewRandomPortRangeListenerSource creates a new RadomPortListenerSource
   100  // listening on low to high (exclusive).
   101  func NewRandomPortRangeListenerSource(low, high int) *RandomPortRangeListenerSource {
   102  	return &RandomPortRangeListenerSource{
   103  		low:  low,
   104  		high: high,
   105  	}
   106  }
   107  
   108  const maxRandomTries = 10
   109  
   110  // GetListener implements ListenerSource.
   111  func (p *RandomPortRangeListenerSource) GetListener() (listener net.Listener, address string, err error) {
   112  	p.Lock()
   113  	defer p.Unlock()
   114  	localhost := "127.0.0.1"
   115  	for i := 0; i < maxRandomTries; i++ {
   116  		if p.pinnedPort > 0 {
   117  			address = fmt.Sprintf("%s:%d", localhost, p.pinnedPort)
   118  			if listener, err = net.Listen("tcp", address); err != nil {
   119  				return listener, address, ErrPinnedPortInUse
   120  			}
   121  			return listener, address, nil
   122  		}
   123  
   124  		n, err := rand.Int(rand.Reader, big.NewInt(int64(p.high-p.low)))
   125  		if err != nil {
   126  			return nil, "", err
   127  		}
   128  		port := p.low + int(n.Int64())
   129  		address = fmt.Sprintf("%s:%d", localhost, port)
   130  		listener, err = net.Listen("tcp", address)
   131  		if err == nil {
   132  			p.pinnedPort = port
   133  			return listener, address, nil
   134  		}
   135  	}
   136  	return listener, address, errors.New("failed to bind to port in range")
   137  }
   138  
   139  var errAlreadyRunning = errors.New("http server already running")
   140  
   141  // Srv starts a simple HTTP server with a parameter for a module to provide a listener source
   142  type Srv struct {
   143  	sync.Mutex
   144  	*http.ServeMux
   145  	log logger.Logger
   146  
   147  	listenerSource ListenerSource
   148  	server         *http.Server
   149  	doneCh         chan struct{}
   150  }
   151  
   152  // NewSrv creates a new HTTP server with the given listener
   153  // source.
   154  func NewSrv(log logger.Logger, listenerSource ListenerSource) *Srv {
   155  	return &Srv{
   156  		log:            log,
   157  		listenerSource: listenerSource,
   158  	}
   159  }
   160  
   161  // Start starts listening on the server's listener source.
   162  func (h *Srv) Start() (err error) {
   163  	h.Lock()
   164  	defer h.Unlock()
   165  	if h.server != nil {
   166  		h.log.Debug("kbhttp.Srv: already running, not starting again")
   167  		// Just bail out of this if we are already running
   168  		return errAlreadyRunning
   169  	}
   170  	h.ServeMux = http.NewServeMux()
   171  	listener, address, err := h.listenerSource.GetListener()
   172  	if err != nil {
   173  		h.log.Debug("kbhttp.Srv: failed to get a listener: %s", err)
   174  		return err
   175  	}
   176  	h.server = &http.Server{
   177  		Addr:    address,
   178  		Handler: h.ServeMux,
   179  	}
   180  	h.doneCh = make(chan struct{})
   181  	go func(server *http.Server, doneCh chan struct{}) {
   182  		h.log.Debug("kbhttp.Srv: server starting on: %s", address)
   183  		if err := server.Serve(listener); err != nil {
   184  			h.log.Debug("kbhttp.Srv: server died: %s", err)
   185  		}
   186  		close(doneCh)
   187  	}(h.server, h.doneCh)
   188  	return nil
   189  }
   190  
   191  // Active returns true if the server is active.
   192  func (h *Srv) Active() bool {
   193  	h.Lock()
   194  	defer h.Unlock()
   195  	return h.server != nil
   196  }
   197  
   198  // Addr returns the server's address, if it's running.
   199  func (h *Srv) Addr() (string, error) {
   200  	h.Lock()
   201  	defer h.Unlock()
   202  	if h.server != nil {
   203  		return h.server.Addr, nil
   204  	}
   205  	return "", errors.New("server not running")
   206  }
   207  
   208  // Stop stops listening on the server's listener source.
   209  func (h *Srv) Stop() <-chan struct{} {
   210  	h.Lock()
   211  	defer h.Unlock()
   212  	if h.server != nil {
   213  		h.server.Close()
   214  		h.server = nil
   215  		return h.doneCh
   216  	}
   217  	doneCh := make(chan struct{})
   218  	close(doneCh)
   219  	return doneCh
   220  }