github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/controller/internal/enforcer/applicationproxy/protomux/protomux.go (about) 1 package protomux 2 3 import ( 4 "context" 5 "fmt" 6 "net" 7 "sync" 8 "time" 9 10 "go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/applicationproxy/common" 11 "go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/applicationproxy/markedconn" 12 "go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/applicationproxy/serviceregistry" 13 "go.uber.org/zap" 14 ) 15 16 // ProtoListener is 17 type ProtoListener struct { 18 net.Listener 19 connection chan net.Conn 20 mark int 21 } 22 23 // NewProtoListener creates a listener for a particular protocol. 24 func NewProtoListener(mark int) *ProtoListener { 25 return &ProtoListener{ 26 connection: make(chan net.Conn), 27 mark: mark, 28 } 29 } 30 31 // Accept accepts new connections over the channel. 32 func (p *ProtoListener) Accept() (net.Conn, error) { 33 c, ok := <-p.connection 34 if !ok { 35 return nil, fmt.Errorf("mux: listener closed") 36 } 37 return c, nil 38 } 39 40 // MultiplexedListener is the root listener that will split 41 // connections to different protocols. 42 type MultiplexedListener struct { 43 root net.Listener 44 done chan struct{} 45 shutdown chan struct{} 46 wg sync.WaitGroup 47 protomap map[common.ListenerType]*ProtoListener 48 puID string 49 50 defaultListener *ProtoListener 51 localIPs map[string]struct{} 52 mark int 53 sync.RWMutex 54 } 55 56 // NewMultiplexedListener returns a new multiplexed listener. Caller 57 // must register protocols outside of the new object creation. 58 func NewMultiplexedListener(l net.Listener, mark int, puID string) *MultiplexedListener { 59 60 return &MultiplexedListener{ 61 root: l, 62 done: make(chan struct{}), 63 shutdown: make(chan struct{}), 64 wg: sync.WaitGroup{}, 65 protomap: map[common.ListenerType]*ProtoListener{}, 66 localIPs: markedconn.GetInterfaces(), 67 mark: mark, 68 puID: puID, 69 } 70 } 71 72 // RegisterListener registers a new listener. It returns the listener that the various 73 // protocol servers should use. If defaultListener is set, this will become 74 // the default listener if no match is found. Obviously, there cannot be more 75 // than one default. 76 func (m *MultiplexedListener) RegisterListener(ltype common.ListenerType) (*ProtoListener, error) { 77 m.Lock() 78 defer m.Unlock() 79 80 if _, ok := m.protomap[ltype]; ok { 81 return nil, fmt.Errorf("Cannot register same listener type multiple times") 82 } 83 84 p := &ProtoListener{ 85 Listener: m.root, 86 connection: make(chan net.Conn), 87 mark: m.mark, 88 } 89 m.protomap[ltype] = p 90 91 return p, nil 92 } 93 94 // UnregisterListener unregisters a listener. It returns an error if there are services 95 // associated with this listener. 96 func (m *MultiplexedListener) UnregisterListener(ltype common.ListenerType) error { 97 m.Lock() 98 defer m.Unlock() 99 100 delete(m.protomap, ltype) 101 102 return nil 103 } 104 105 // RegisterDefaultListener registers a default listener. 106 func (m *MultiplexedListener) RegisterDefaultListener(p *ProtoListener) error { 107 m.Lock() 108 defer m.Unlock() 109 110 if m.defaultListener != nil { 111 return fmt.Errorf("Default listener already registered") 112 } 113 114 m.defaultListener = p 115 return nil 116 } 117 118 // UnregisterDefaultListener unregisters the default listener. 119 func (m *MultiplexedListener) UnregisterDefaultListener() error { 120 m.Lock() 121 defer m.Unlock() 122 123 if m.defaultListener == nil { 124 return fmt.Errorf("No default listener registered") 125 } 126 127 m.defaultListener = nil 128 129 return nil 130 } 131 132 // Close terminates the server without the context. 133 func (m *MultiplexedListener) Close() { 134 close(m.shutdown) 135 } 136 137 // Serve will demux the connections 138 func (m *MultiplexedListener) Serve(ctx context.Context) error { 139 140 defer func() { 141 close(m.done) 142 m.wg.Wait() 143 144 m.RLock() 145 defer m.RUnlock() 146 147 for _, l := range m.protomap { 148 close(l.connection) 149 // Drain the connections enqueued for the listener. 150 for c := range l.connection { 151 c.Close() // nolint 152 } 153 } 154 }() 155 156 go func() { 157 for { 158 select { 159 case <-time.After(5 * time.Second): 160 m.Lock() 161 m.localIPs = markedconn.GetInterfaces() 162 m.Unlock() 163 case <-ctx.Done(): 164 return 165 } 166 } 167 }() 168 169 for { 170 select { 171 case <-ctx.Done(): 172 return nil 173 case <-m.shutdown: 174 return nil 175 default: 176 177 c, err := m.root.Accept() 178 if err != nil { 179 // check if the error is due to shutdown in progress 180 select { 181 case <-ctx.Done(): 182 return nil 183 case <-m.shutdown: 184 return nil 185 default: 186 } 187 // if it is an actual error (which can happen in Windows we can't get origin ip/port from our driver), 188 // then log an error and continue accepting connections. 189 zap.L().Error("error from Accept", zap.Error(err)) 190 break 191 } 192 m.wg.Add(1) 193 go m.serve(c) 194 } 195 } 196 } 197 198 func (m *MultiplexedListener) serve(conn net.Conn) { 199 defer m.wg.Done() 200 201 c, ok := conn.(*markedconn.ProxiedConnection) 202 if !ok { 203 zap.L().Error("Wrong connection type") 204 return 205 } 206 207 ip, port := c.GetOriginalDestination() 208 remoteAddr := c.RemoteAddr() 209 if remoteAddr == nil { 210 zap.L().Error("Connection remote address cannot be found. Abort") 211 return 212 } 213 214 local := false 215 m.Lock() 216 localIPs := m.localIPs 217 m.Unlock() 218 if _, ok = localIPs[networkOfAddress(remoteAddr.String())]; ok { 219 local = true 220 } 221 222 var listenerType common.ListenerType 223 if local { 224 _, serviceData, err := serviceregistry.Instance().RetrieveDependentServiceDataByIDAndNetwork(m.puID, ip, port, "") 225 if err != nil { 226 zap.L().Error("Cannot discover target service", 227 zap.String("ContextID", m.puID), 228 zap.String("ip", ip.String()), 229 zap.Int("port", port), 230 zap.String("Remote IP", remoteAddr.String()), 231 zap.Error(err), 232 ) 233 return 234 } 235 listenerType = serviceData.ServiceType 236 } else { 237 pctx, err := serviceregistry.Instance().RetrieveExposedServiceContext(ip, port, "") 238 if err != nil { 239 zap.L().Error("Cannot discover target service", 240 zap.String("ip", ip.String()), 241 zap.Int("port", port), 242 zap.String("Remote IP", remoteAddr.String()), 243 ) 244 return 245 } 246 247 listenerType = pctx.Type 248 } 249 250 m.RLock() 251 target, ok := m.protomap[listenerType] 252 m.RUnlock() 253 if !ok { 254 c.Close() // nolint 255 return 256 } 257 258 select { 259 case target.connection <- c: 260 case <-m.done: 261 c.Close() // nolint 262 } 263 } 264 265 func networkOfAddress(addr string) string { 266 ip, _, err := net.SplitHostPort(addr) 267 if err != nil { 268 return addr 269 } 270 271 return ip 272 }