github.com/projecteru2/core@v0.0.0-20240321043226-06bcc1c23f58/client/utils/servicepusher.go (about) 1 package utils 2 3 import ( 4 "context" 5 "os" 6 "strings" 7 "sync" 8 "time" 9 10 "github.com/alphadose/haxmap" 11 "github.com/go-ping/ping" 12 "github.com/projecteru2/core/log" 13 "github.com/projecteru2/core/types" 14 "golang.org/x/exp/slices" 15 ) 16 17 // EndpointPusher pushes endpoints to registered channels if the ep is L3 reachable 18 type EndpointPusher struct { 19 sync.Mutex 20 chans []chan []string 21 pendingEndpoints *haxmap.Map[string, context.CancelFunc] 22 availableEndpoints *haxmap.Map[string, struct{}] 23 } 24 25 // NewEndpointPusher . 26 func NewEndpointPusher() *EndpointPusher { 27 return &EndpointPusher{ 28 pendingEndpoints: haxmap.New[string, context.CancelFunc](), 29 availableEndpoints: haxmap.New[string, struct{}](), 30 } 31 } 32 33 // Register registers a channel that will receive the endpoints later 34 func (p *EndpointPusher) Register(ch chan []string) { 35 p.chans = append(p.chans, ch) 36 } 37 38 // Push pushes endpoint candicates 39 func (p *EndpointPusher) Push(ctx context.Context, endpoints []string) { 40 p.delOutdated(ctx, endpoints) 41 p.addCheck(ctx, endpoints) 42 } 43 44 func (p *EndpointPusher) delOutdated(ctx context.Context, endpoints []string) { 45 p.Lock() 46 defer p.Unlock() 47 logger := log.WithFunc("utils.EndpointPusher.delOutdated") 48 p.pendingEndpoints.ForEach(func(endpoint string, cancel context.CancelFunc) bool { 49 if !slices.Contains(endpoints, endpoint) { 50 cancel() 51 p.pendingEndpoints.Del(endpoint) 52 logger.Debugf(ctx, "pending endpoint deleted: %s", endpoint) 53 } 54 return true 55 }) 56 57 p.availableEndpoints.ForEach(func(endpoint string, _ struct{}) bool { 58 if !slices.Contains(endpoints, endpoint) { 59 p.availableEndpoints.Del(endpoint) 60 logger.Debugf(ctx, "available endpoint deleted: %s", endpoint) 61 } 62 return true 63 }) 64 } 65 66 func (p *EndpointPusher) addCheck(ctx context.Context, endpoints []string) { 67 for _, endpoint := range endpoints { 68 if _, ok := p.pendingEndpoints.Get(endpoint); ok { 69 continue 70 } 71 if _, ok := p.availableEndpoints.Get(endpoint); ok { 72 continue 73 } 74 75 ctx, cancel := context.WithCancel(ctx) 76 p.pendingEndpoints.Set(endpoint, cancel) 77 go p.pollReachability(ctx, endpoint) 78 log.WithFunc("utils.EndpointPusher.addCheck").Debugf(ctx, "pending endpoint added: %s", endpoint) 79 } 80 } 81 82 func (p *EndpointPusher) pollReachability(ctx context.Context, endpoint string) { 83 logger := log.WithFunc("utils.EndpointPusher.pollReachability") 84 parts := strings.Split(endpoint, ":") 85 if len(parts) != 2 { 86 logger.Errorf(ctx, types.ErrInvaildCoreEndpointType, "wrong format of endpoint: %s", endpoint) 87 return 88 } 89 90 ticker := time.NewTicker(time.Second) // TODO config from outside? 91 defer ticker.Stop() 92 for { 93 select { 94 case <-ctx.Done(): 95 logger.Debugf(ctx, "reachability goroutine ends: %s", endpoint) 96 return 97 case <-ticker.C: 98 p.Lock() 99 defer p.Unlock() 100 if err := p.checkReachability(ctx, parts[0]); err != nil { 101 continue 102 } 103 p.pendingEndpoints.Del(endpoint) 104 p.availableEndpoints.Set(endpoint, struct{}{}) 105 p.pushEndpoints() 106 logger.Debugf(ctx, "available endpoint added: %s", endpoint) 107 return 108 } 109 } 110 } 111 112 func (p *EndpointPusher) checkReachability(ctx context.Context, host string) (err error) { 113 pinger, err := ping.NewPinger(host) 114 if err != nil { 115 log.WithFunc("utils.EndpointPusher.checkReachability").Error(ctx, err, "failed to create pinger") 116 return 117 } 118 pinger.SetPrivileged(os.Getuid() == 0) 119 defer pinger.Stop() 120 121 pinger.Count = 1 122 pinger.Timeout = time.Second 123 if err = pinger.Run(); err != nil { 124 return 125 } 126 if pinger.Statistics().PacketsRecv != 1 { 127 return types.ErrICMPLost 128 } 129 return 130 } 131 132 func (p *EndpointPusher) pushEndpoints() { 133 endpoints := []string{} 134 p.availableEndpoints.ForEach(func(endpoint string, _ struct{}) bool { 135 endpoints = append(endpoints, endpoint) 136 return true 137 }) 138 for _, ch := range p.chans { 139 ch <- endpoints 140 } 141 }