
     1  package utils
     3  import (
     4  	"context"
     5  	"os"
     6  	"strings"
     7  	"sync"
     8  	"time"
    10  	""
    11  	""
    12  	""
    13  	""
    14  	""
    15  )
    17  // EndpointPusher pushes endpoints to registered channels if the ep is L3 reachable
    18  type EndpointPusher struct {
    19  	sync.Mutex
    20  	chans              []chan []string
    21  	pendingEndpoints   *haxmap.Map[string, context.CancelFunc]
    22  	availableEndpoints *haxmap.Map[string, struct{}]
    23  }
    25  // NewEndpointPusher .
    26  func NewEndpointPusher() *EndpointPusher {
    27  	return &EndpointPusher{
    28  		pendingEndpoints:   haxmap.New[string, context.CancelFunc](),
    29  		availableEndpoints: haxmap.New[string, struct{}](),
    30  	}
    31  }
    33  // Register registers a channel that will receive the endpoints later
    34  func (p *EndpointPusher) Register(ch chan []string) {
    35  	p.chans = append(p.chans, ch)
    36  }
    38  // Push pushes endpoint candicates
    39  func (p *EndpointPusher) Push(ctx context.Context, endpoints []string) {
    40  	p.delOutdated(ctx, endpoints)
    41  	p.addCheck(ctx, endpoints)
    42  }
    44  func (p *EndpointPusher) delOutdated(ctx context.Context, endpoints []string) {
    45  	p.Lock()
    46  	defer p.Unlock()
    47  	logger := log.WithFunc("utils.EndpointPusher.delOutdated")
    48  	p.pendingEndpoints.ForEach(func(endpoint string, cancel context.CancelFunc) bool {
    49  		if !slices.Contains(endpoints, endpoint) {
    50  			cancel()
    51  			p.pendingEndpoints.Del(endpoint)
    52  			logger.Debugf(ctx, "pending endpoint deleted: %s", endpoint)
    53  		}
    54  		return true
    55  	})
    57  	p.availableEndpoints.ForEach(func(endpoint string, _ struct{}) bool {
    58  		if !slices.Contains(endpoints, endpoint) {
    59  			p.availableEndpoints.Del(endpoint)
    60  			logger.Debugf(ctx, "available endpoint deleted: %s", endpoint)
    61  		}
    62  		return true
    63  	})
    64  }
    66  func (p *EndpointPusher) addCheck(ctx context.Context, endpoints []string) {
    67  	for _, endpoint := range endpoints {
    68  		if _, ok := p.pendingEndpoints.Get(endpoint); ok {
    69  			continue
    70  		}
    71  		if _, ok := p.availableEndpoints.Get(endpoint); ok {
    72  			continue
    73  		}
    75  		ctx, cancel := context.WithCancel(ctx)
    76  		p.pendingEndpoints.Set(endpoint, cancel)
    77  		go p.pollReachability(ctx, endpoint)
    78  		log.WithFunc("utils.EndpointPusher.addCheck").Debugf(ctx, "pending endpoint added: %s", endpoint)
    79  	}
    80  }
    82  func (p *EndpointPusher) pollReachability(ctx context.Context, endpoint string) {
    83  	logger := log.WithFunc("utils.EndpointPusher.pollReachability")
    84  	parts := strings.Split(endpoint, ":")
    85  	if len(parts) != 2 {
    86  		logger.Errorf(ctx, types.ErrInvaildCoreEndpointType, "wrong format of endpoint: %s", endpoint)
    87  		return
    88  	}
    90  	ticker := time.NewTicker(time.Second) // TODO config from outside?
    91  	defer ticker.Stop()
    92  	for {
    93  		select {
    94  		case <-ctx.Done():
    95  			logger.Debugf(ctx, "reachability goroutine ends: %s", endpoint)
    96  			return
    97  		case <-ticker.C:
    98  			p.Lock()
    99  			defer p.Unlock()
   100  			if err := p.checkReachability(ctx, parts[0]); err != nil {
   101  				continue
   102  			}
   103  			p.pendingEndpoints.Del(endpoint)
   104  			p.availableEndpoints.Set(endpoint, struct{}{})
   105  			p.pushEndpoints()
   106  			logger.Debugf(ctx, "available endpoint added: %s", endpoint)
   107  			return
   108  		}
   109  	}
   110  }
   112  func (p *EndpointPusher) checkReachability(ctx context.Context, host string) (err error) {
   113  	pinger, err := ping.NewPinger(host)
   114  	if err != nil {
   115  		log.WithFunc("utils.EndpointPusher.checkReachability").Error(ctx, err, "failed to create pinger")
   116  		return
   117  	}
   118  	pinger.SetPrivileged(os.Getuid() == 0)
   119  	defer pinger.Stop()
   121  	pinger.Count = 1
   122  	pinger.Timeout = time.Second
   123  	if err = pinger.Run(); err != nil {
   124  		return
   125  	}
   126  	if pinger.Statistics().PacketsRecv != 1 {
   127  		return types.ErrICMPLost
   128  	}
   129  	return
   130  }
   132  func (p *EndpointPusher) pushEndpoints() {
   133  	endpoints := []string{}
   134  	p.availableEndpoints.ForEach(func(endpoint string, _ struct{}) bool {
   135  		endpoints = append(endpoints, endpoint)
   136  		return true
   137  	})
   138  	for _, ch := range p.chans {
   139  		ch <- endpoints
   140  	}
   141  }