github.com/keybase/client/go@v0.0.0-20240309051027-028f7c731f8b/kbhttp/srv.go (about) 1 // Copyright 2018 Keybase, Inc. All rights reserved. Use of 2 // this source code is governed by the included BSD license. 3 4 package kbhttp 5 6 import ( 7 "crypto/rand" 8 "errors" 9 "fmt" 10 "math/big" 11 "net" 12 "net/http" 13 "sync" 14 15 "github.com/keybase/client/go/logger" 16 ) 17 18 // ListenerSource represents where an HTTP server should listen. 19 type ListenerSource interface { 20 GetListener() (net.Listener, string, error) 21 } 22 23 // AutoPortListenerSource means listen on a port that's picked automatically by 24 // the kernel. 25 type AutoPortListenerSource struct{} 26 27 // GetListener implements ListenerSource. 28 func (r AutoPortListenerSource) GetListener() (net.Listener, string, error) { 29 localhost := "127.0.0.1" 30 listener, err := net.Listen("tcp", fmt.Sprintf("%s:0", localhost)) 31 if err != nil { 32 return nil, "", err 33 } 34 port := listener.Addr().(*net.TCPAddr).Port 35 address := fmt.Sprintf("%s:%d", localhost, port) 36 return listener, address, nil 37 } 38 39 // NewAutoPortListenerSource creates a new AutoPortListenerSource. 40 func NewAutoPortListenerSource() *AutoPortListenerSource { 41 return &AutoPortListenerSource{} 42 } 43 44 var ErrPinnedPortInUse = errors.New("unable to bind to pinned port") 45 46 // PortRangeListenerSource means listen on the given range. 47 type PortRangeListenerSource struct { 48 sync.Mutex 49 pinnedPort int 50 low, high int 51 } 52 53 // NewPortRangeListenerSource creates a new PortListenerSource 54 // listening on low to high (inclusive). 55 func NewPortRangeListenerSource(low, high int) *PortRangeListenerSource { 56 return &PortRangeListenerSource{ 57 low: low, 58 high: high, 59 } 60 } 61 62 // NewFixedPortListenerSource creates a new PortListenerSource 63 // listening on the given port. 64 func NewFixedPortListenerSource(port int) *PortRangeListenerSource { 65 return NewPortRangeListenerSource(port, port) 66 } 67 68 // GetListener implements ListenerSource. 69 func (p *PortRangeListenerSource) GetListener() (listener net.Listener, address string, err error) { 70 p.Lock() 71 defer p.Unlock() 72 localhost := "127.0.0.1" 73 if p.pinnedPort > 0 { 74 address = fmt.Sprintf("%s:%d", localhost, p.pinnedPort) 75 if listener, err = net.Listen("tcp", address); err != nil { 76 return listener, address, ErrPinnedPortInUse 77 } 78 return listener, address, nil 79 } 80 for port := p.low; port <= p.high; port++ { 81 address = fmt.Sprintf("%s:%d", localhost, port) 82 listener, err = net.Listen("tcp", address) 83 if err == nil { 84 p.pinnedPort = port 85 return listener, address, nil 86 } 87 } 88 return listener, address, errors.New("failed to bind to port in range") 89 } 90 91 // RandomPortRangeListenerSource listens on a port randomly chosen within a 92 // given range. 93 type RandomPortRangeListenerSource struct { 94 sync.Mutex 95 pinnedPort int 96 low, high int 97 } 98 99 // NewRandomPortRangeListenerSource creates a new RadomPortListenerSource 100 // listening on low to high (exclusive). 101 func NewRandomPortRangeListenerSource(low, high int) *RandomPortRangeListenerSource { 102 return &RandomPortRangeListenerSource{ 103 low: low, 104 high: high, 105 } 106 } 107 108 const maxRandomTries = 10 109 110 // GetListener implements ListenerSource. 111 func (p *RandomPortRangeListenerSource) GetListener() (listener net.Listener, address string, err error) { 112 p.Lock() 113 defer p.Unlock() 114 localhost := "127.0.0.1" 115 for i := 0; i < maxRandomTries; i++ { 116 if p.pinnedPort > 0 { 117 address = fmt.Sprintf("%s:%d", localhost, p.pinnedPort) 118 if listener, err = net.Listen("tcp", address); err != nil { 119 return listener, address, ErrPinnedPortInUse 120 } 121 return listener, address, nil 122 } 123 124 n, err := rand.Int(rand.Reader, big.NewInt(int64(p.high-p.low))) 125 if err != nil { 126 return nil, "", err 127 } 128 port := p.low + int(n.Int64()) 129 address = fmt.Sprintf("%s:%d", localhost, port) 130 listener, err = net.Listen("tcp", address) 131 if err == nil { 132 p.pinnedPort = port 133 return listener, address, nil 134 } 135 } 136 return listener, address, errors.New("failed to bind to port in range") 137 } 138 139 var errAlreadyRunning = errors.New("http server already running") 140 141 // Srv starts a simple HTTP server with a parameter for a module to provide a listener source 142 type Srv struct { 143 sync.Mutex 144 *http.ServeMux 145 log logger.Logger 146 147 listenerSource ListenerSource 148 server *http.Server 149 doneCh chan struct{} 150 } 151 152 // NewSrv creates a new HTTP server with the given listener 153 // source. 154 func NewSrv(log logger.Logger, listenerSource ListenerSource) *Srv { 155 return &Srv{ 156 log: log, 157 listenerSource: listenerSource, 158 } 159 } 160 161 // Start starts listening on the server's listener source. 162 func (h *Srv) Start() (err error) { 163 h.Lock() 164 defer h.Unlock() 165 if h.server != nil { 166 h.log.Debug("kbhttp.Srv: already running, not starting again") 167 // Just bail out of this if we are already running 168 return errAlreadyRunning 169 } 170 h.ServeMux = http.NewServeMux() 171 listener, address, err := h.listenerSource.GetListener() 172 if err != nil { 173 h.log.Debug("kbhttp.Srv: failed to get a listener: %s", err) 174 return err 175 } 176 h.server = &http.Server{ 177 Addr: address, 178 Handler: h.ServeMux, 179 } 180 h.doneCh = make(chan struct{}) 181 go func(server *http.Server, doneCh chan struct{}) { 182 h.log.Debug("kbhttp.Srv: server starting on: %s", address) 183 if err := server.Serve(listener); err != nil { 184 h.log.Debug("kbhttp.Srv: server died: %s", err) 185 } 186 close(doneCh) 187 }(h.server, h.doneCh) 188 return nil 189 } 190 191 // Active returns true if the server is active. 192 func (h *Srv) Active() bool { 193 h.Lock() 194 defer h.Unlock() 195 return h.server != nil 196 } 197 198 // Addr returns the server's address, if it's running. 199 func (h *Srv) Addr() (string, error) { 200 h.Lock() 201 defer h.Unlock() 202 if h.server != nil { 203 return h.server.Addr, nil 204 } 205 return "", errors.New("server not running") 206 } 207 208 // Stop stops listening on the server's listener source. 209 func (h *Srv) Stop() <-chan struct{} { 210 h.Lock() 211 defer h.Unlock() 212 if h.server != nil { 213 h.server.Close() 214 h.server = nil 215 return h.doneCh 216 } 217 doneCh := make(chan struct{}) 218 close(doneCh) 219 return doneCh 220 }