github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/controller/internal/enforcer/applicationproxy/protomux/protomux.go (about)

     1  package protomux
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net"
     7  	"sync"
     8  	"time"
     9  
    10  	"go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/applicationproxy/common"
    11  	"go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/applicationproxy/markedconn"
    12  	"go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/applicationproxy/serviceregistry"
    13  	"go.uber.org/zap"
    14  )
    15  
    16  // ProtoListener is
    17  type ProtoListener struct {
    18  	net.Listener
    19  	connection chan net.Conn
    20  	mark       int
    21  }
    22  
    23  // NewProtoListener creates a listener for a particular protocol.
    24  func NewProtoListener(mark int) *ProtoListener {
    25  	return &ProtoListener{
    26  		connection: make(chan net.Conn),
    27  		mark:       mark,
    28  	}
    29  }
    30  
    31  // Accept accepts new connections over the channel.
    32  func (p *ProtoListener) Accept() (net.Conn, error) {
    33  	c, ok := <-p.connection
    34  	if !ok {
    35  		return nil, fmt.Errorf("mux: listener closed")
    36  	}
    37  	return c, nil
    38  }
    39  
    40  // MultiplexedListener is the root listener that will split
    41  // connections to different protocols.
    42  type MultiplexedListener struct {
    43  	root     net.Listener
    44  	done     chan struct{}
    45  	shutdown chan struct{}
    46  	wg       sync.WaitGroup
    47  	protomap map[common.ListenerType]*ProtoListener
    48  	puID     string
    49  
    50  	defaultListener *ProtoListener
    51  	localIPs        map[string]struct{}
    52  	mark            int
    53  	sync.RWMutex
    54  }
    55  
    56  // NewMultiplexedListener returns a new multiplexed listener. Caller
    57  // must register protocols outside of the new object creation.
    58  func NewMultiplexedListener(l net.Listener, mark int, puID string) *MultiplexedListener {
    59  
    60  	return &MultiplexedListener{
    61  		root:     l,
    62  		done:     make(chan struct{}),
    63  		shutdown: make(chan struct{}),
    64  		wg:       sync.WaitGroup{},
    65  		protomap: map[common.ListenerType]*ProtoListener{},
    66  		localIPs: markedconn.GetInterfaces(),
    67  		mark:     mark,
    68  		puID:     puID,
    69  	}
    70  }
    71  
    72  // RegisterListener registers a new listener. It returns the listener that the various
    73  // protocol servers should use. If defaultListener is set, this will become
    74  // the default listener if no match is found. Obviously, there cannot be more
    75  // than one default.
    76  func (m *MultiplexedListener) RegisterListener(ltype common.ListenerType) (*ProtoListener, error) {
    77  	m.Lock()
    78  	defer m.Unlock()
    79  
    80  	if _, ok := m.protomap[ltype]; ok {
    81  		return nil, fmt.Errorf("Cannot register same listener type multiple times")
    82  	}
    83  
    84  	p := &ProtoListener{
    85  		Listener:   m.root,
    86  		connection: make(chan net.Conn),
    87  		mark:       m.mark,
    88  	}
    89  	m.protomap[ltype] = p
    90  
    91  	return p, nil
    92  }
    93  
    94  // UnregisterListener unregisters a listener. It returns an error if there are services
    95  // associated with this listener.
    96  func (m *MultiplexedListener) UnregisterListener(ltype common.ListenerType) error {
    97  	m.Lock()
    98  	defer m.Unlock()
    99  
   100  	delete(m.protomap, ltype)
   101  
   102  	return nil
   103  }
   104  
   105  // RegisterDefaultListener registers a default listener.
   106  func (m *MultiplexedListener) RegisterDefaultListener(p *ProtoListener) error {
   107  	m.Lock()
   108  	defer m.Unlock()
   109  
   110  	if m.defaultListener != nil {
   111  		return fmt.Errorf("Default listener already registered")
   112  	}
   113  
   114  	m.defaultListener = p
   115  	return nil
   116  }
   117  
   118  // UnregisterDefaultListener unregisters the default listener.
   119  func (m *MultiplexedListener) UnregisterDefaultListener() error {
   120  	m.Lock()
   121  	defer m.Unlock()
   122  
   123  	if m.defaultListener == nil {
   124  		return fmt.Errorf("No default listener registered")
   125  	}
   126  
   127  	m.defaultListener = nil
   128  
   129  	return nil
   130  }
   131  
   132  // Close terminates the server without the context.
   133  func (m *MultiplexedListener) Close() {
   134  	close(m.shutdown)
   135  }
   136  
   137  // Serve will demux the connections
   138  func (m *MultiplexedListener) Serve(ctx context.Context) error {
   139  
   140  	defer func() {
   141  		close(m.done)
   142  		m.wg.Wait()
   143  
   144  		m.RLock()
   145  		defer m.RUnlock()
   146  
   147  		for _, l := range m.protomap {
   148  			close(l.connection)
   149  			// Drain the connections enqueued for the listener.
   150  			for c := range l.connection {
   151  				c.Close() // nolint
   152  			}
   153  		}
   154  	}()
   155  
   156  	go func() {
   157  		for {
   158  			select {
   159  			case <-time.After(5 * time.Second):
   160  				m.Lock()
   161  				m.localIPs = markedconn.GetInterfaces()
   162  				m.Unlock()
   163  			case <-ctx.Done():
   164  				return
   165  			}
   166  		}
   167  	}()
   168  
   169  	for {
   170  		select {
   171  		case <-ctx.Done():
   172  			return nil
   173  		case <-m.shutdown:
   174  			return nil
   175  		default:
   176  
   177  			c, err := m.root.Accept()
   178  			if err != nil {
   179  				// check if the error is due to shutdown in progress
   180  				select {
   181  				case <-ctx.Done():
   182  					return nil
   183  				case <-m.shutdown:
   184  					return nil
   185  				default:
   186  				}
   187  				// if it is an actual error (which can happen in Windows we can't get origin ip/port from our driver),
   188  				// then log an error and continue accepting connections.
   189  				zap.L().Error("error from Accept", zap.Error(err))
   190  				break
   191  			}
   192  			m.wg.Add(1)
   193  			go m.serve(c)
   194  		}
   195  	}
   196  }
   197  
   198  func (m *MultiplexedListener) serve(conn net.Conn) {
   199  	defer m.wg.Done()
   200  
   201  	c, ok := conn.(*markedconn.ProxiedConnection)
   202  	if !ok {
   203  		zap.L().Error("Wrong connection type")
   204  		return
   205  	}
   206  
   207  	ip, port := c.GetOriginalDestination()
   208  	remoteAddr := c.RemoteAddr()
   209  	if remoteAddr == nil {
   210  		zap.L().Error("Connection remote address cannot be found. Abort")
   211  		return
   212  	}
   213  
   214  	local := false
   215  	m.Lock()
   216  	localIPs := m.localIPs
   217  	m.Unlock()
   218  	if _, ok = localIPs[networkOfAddress(remoteAddr.String())]; ok {
   219  		local = true
   220  	}
   221  
   222  	var listenerType common.ListenerType
   223  	if local {
   224  		_, serviceData, err := serviceregistry.Instance().RetrieveDependentServiceDataByIDAndNetwork(m.puID, ip, port, "")
   225  		if err != nil {
   226  			zap.L().Error("Cannot discover target service",
   227  				zap.String("ContextID", m.puID),
   228  				zap.String("ip", ip.String()),
   229  				zap.Int("port", port),
   230  				zap.String("Remote IP", remoteAddr.String()),
   231  				zap.Error(err),
   232  			)
   233  			return
   234  		}
   235  		listenerType = serviceData.ServiceType
   236  	} else {
   237  		pctx, err := serviceregistry.Instance().RetrieveExposedServiceContext(ip, port, "")
   238  		if err != nil {
   239  			zap.L().Error("Cannot discover target service",
   240  				zap.String("ip", ip.String()),
   241  				zap.Int("port", port),
   242  				zap.String("Remote IP", remoteAddr.String()),
   243  			)
   244  			return
   245  		}
   246  
   247  		listenerType = pctx.Type
   248  	}
   249  
   250  	m.RLock()
   251  	target, ok := m.protomap[listenerType]
   252  	m.RUnlock()
   253  	if !ok {
   254  		c.Close() // nolint
   255  		return
   256  	}
   257  
   258  	select {
   259  	case target.connection <- c:
   260  	case <-m.done:
   261  		c.Close() // nolint
   262  	}
   263  }
   264  
   265  func networkOfAddress(addr string) string {
   266  	ip, _, err := net.SplitHostPort(addr)
   267  	if err != nil {
   268  		return addr
   269  	}
   270  
   271  	return ip
   272  }