github.com/aristanetworks/goarista@v0.0.0-20240514173732-cca2755bbd44/lanz/client.go (about)

     1  // Copyright (c) 2016 Arista Networks, Inc.
     2  // Use of this source code is governed by the Apache License 2.0
     3  // that can be found in the COPYING file.
     4  
     5  // Package lanz implements a LANZ client that will listen to notofications from LANZ streaming
     6  // server and will decode them and send them as a protobuf over a channel to a receiver.
     7  package lanz
     8  
     9  import (
    10  	"bufio"
    11  	"encoding/binary"
    12  	"io"
    13  	"net"
    14  	"sync"
    15  	"time"
    16  
    17  	pb "github.com/aristanetworks/goarista/lanz/proto"
    18  	"github.com/aristanetworks/goarista/logger"
    19  
    20  	"google.golang.org/protobuf/proto"
    21  )
    22  
    23  const (
    24  	defaultConnectTimeout = 10 * time.Second
    25  	defaultConnectBackoff = 30 * time.Second
    26  )
    27  
    28  // Client is the LANZ client interface.
    29  type Client interface {
    30  	// Run is the main loop of the client.
    31  	// It connects to the LANZ server and reads the notifications, decodes them
    32  	// and sends them to the channel.
    33  	// In case of disconnect, it will reconnect automatically.
    34  	Run(ch chan<- *pb.LanzRecord)
    35  	// Stops the client.
    36  	Stop()
    37  }
    38  
    39  // ConnectReadCloser extends the io.ReadCloser interface with a Connect method.
    40  type ConnectReadCloser interface {
    41  	io.ReadCloser
    42  	// Connect connects to the address, returning an error if it fails.
    43  	Connect() error
    44  }
    45  
    46  type client struct {
    47  	sync.Mutex
    48  	addr      string
    49  	stop      chan struct{}
    50  	connected bool
    51  	timeout   time.Duration
    52  	backoff   time.Duration
    53  	conn      ConnectReadCloser
    54  	log       logger.Logger
    55  }
    56  
    57  // New creates a new client with default TCP connection to the LANZ server.
    58  func New(opts ...Option) Client {
    59  	c := &client{
    60  		stop:    make(chan struct{}),
    61  		timeout: defaultConnectTimeout,
    62  		backoff: defaultConnectBackoff,
    63  		log:     logger.Std,
    64  	}
    65  
    66  	for _, opt := range opts {
    67  		opt(c)
    68  	}
    69  
    70  	if c.conn == nil {
    71  		if c.addr == "" {
    72  			panic("Neither address, nor connector specified")
    73  		}
    74  		c.conn = &netConnector{
    75  			addr:    c.addr,
    76  			timeout: c.timeout,
    77  			backoff: c.backoff,
    78  		}
    79  	}
    80  
    81  	return c
    82  }
    83  
    84  func (c *client) setConnected(connected bool) {
    85  	c.Lock()
    86  	defer c.Unlock()
    87  	if c.connected && !connected {
    88  		c.conn.Close()
    89  	}
    90  	c.connected = connected
    91  }
    92  
    93  func (c *client) Run(ch chan<- *pb.LanzRecord) {
    94  	go func() {
    95  		<-c.stop
    96  		c.setConnected(false)
    97  	}()
    98  
    99  	defer func() {
   100  		close(ch)
   101  		// This is to handle a race when the connection is
   102  		// established, but not marked as connected yet and is then
   103  		// preempted with c.stop closing.
   104  		c.setConnected(false)
   105  	}()
   106  
   107  	for {
   108  		select {
   109  		case <-c.stop:
   110  			return
   111  		default:
   112  			if err := c.conn.Connect(); err != nil {
   113  				select {
   114  				case <-c.stop:
   115  					return
   116  				default:
   117  					time.Sleep(c.backoff)
   118  					continue
   119  				}
   120  			}
   121  			c.setConnected(true)
   122  			if err := c.read(bufio.NewReader(c.conn), ch); err != nil {
   123  				select {
   124  				case <-c.stop:
   125  					return
   126  				default:
   127  					if err != io.EOF && err != io.ErrUnexpectedEOF {
   128  						c.log.Errorf("Error receiving LANZ events: %v", err)
   129  					}
   130  					c.setConnected(false)
   131  					time.Sleep(c.backoff)
   132  				}
   133  			}
   134  		}
   135  	}
   136  
   137  }
   138  
   139  func (c *client) read(r *bufio.Reader, ch chan<- *pb.LanzRecord) error {
   140  	for {
   141  		select {
   142  		case <-c.stop:
   143  			return nil
   144  		default:
   145  			len, err := binary.ReadUvarint(r)
   146  			if err != nil {
   147  				return err
   148  			}
   149  
   150  			buf := make([]byte, len)
   151  			if _, err = io.ReadFull(r, buf); err != nil {
   152  				return err
   153  			}
   154  
   155  			rec := &pb.LanzRecord{}
   156  			if err = proto.Unmarshal(buf, rec); err != nil {
   157  				return err
   158  			}
   159  
   160  			ch <- rec
   161  		}
   162  	}
   163  }
   164  
   165  func (c *client) Stop() {
   166  	close(c.stop)
   167  }
   168  
   169  type netConnector struct {
   170  	net.Conn
   171  	addr    string
   172  	timeout time.Duration
   173  	backoff time.Duration
   174  }
   175  
   176  func (c *netConnector) Connect() (err error) {
   177  	c.Conn, err = net.DialTimeout("tcp", c.addr, c.timeout)
   178  	if err != nil {
   179  	}
   180  	return
   181  }