github.com/pawelgaczynski/gain@v0.4.0-alpha.0.20230821120126-41f1e60a18da/load_balancer.go (about)

     1  // Copyright (c) 2023 Paweł Gaczyński
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package gain
    16  
    17  import (
    18  	"hash/crc32"
    19  	"net"
    20  
    21  	"github.com/pawelgaczynski/gain/pkg/errors"
    22  )
    23  
    24  type LoadBalancing int
    25  
    26  const (
    27  	// RoundRobin forwards accepted connections to dedicated workers sequentially.
    28  	RoundRobin LoadBalancing = iota
    29  	// LeastConnections forwards the next accepted connection to the worker with the least number of active connections.
    30  	LeastConnections
    31  	// SourceAddrHash forwards the next accepted connection to the worker by hashing the remote peer address.
    32  	SourceIPHash
    33  )
    34  
    35  type loadBalancer interface {
    36  	register(consumer)
    37  	next(net.Addr) consumer
    38  	forEach(func(consumer) error) error
    39  }
    40  
    41  type genericLoadBalancer struct {
    42  	workers []consumer
    43  	size    int
    44  }
    45  
    46  func (b *genericLoadBalancer) register(worker consumer) {
    47  	worker.setIndex(b.size)
    48  	b.workers = append(b.workers, worker)
    49  	b.size++
    50  }
    51  
    52  type roundRobinLoadBalancer struct {
    53  	*genericLoadBalancer
    54  	nextWorkerIndex int
    55  }
    56  
    57  func (b *roundRobinLoadBalancer) next(_ net.Addr) consumer {
    58  	worker := b.workers[b.nextWorkerIndex]
    59  
    60  	if b.nextWorkerIndex++; b.nextWorkerIndex >= b.size {
    61  		b.nextWorkerIndex = 0
    62  	}
    63  
    64  	return worker
    65  }
    66  
    67  func (b *roundRobinLoadBalancer) forEach(callback func(consumer) error) error {
    68  	for _, c := range b.workers {
    69  		err := callback(c)
    70  		if err != nil {
    71  			return err
    72  		}
    73  	}
    74  
    75  	return nil
    76  }
    77  
    78  func newRoundRobinLoadBalancer() loadBalancer {
    79  	return &roundRobinLoadBalancer{
    80  		genericLoadBalancer: &genericLoadBalancer{},
    81  	}
    82  }
    83  
    84  type leastConnectionsLoadBalancer struct {
    85  	*genericLoadBalancer
    86  }
    87  
    88  func (b *leastConnectionsLoadBalancer) next(_ net.Addr) consumer {
    89  	worker := b.workers[0]
    90  	minN := worker.activeConnections()
    91  
    92  	for _, v := range b.workers[1:] {
    93  		if n := v.activeConnections(); n < minN {
    94  			minN = n
    95  			worker = v
    96  		}
    97  	}
    98  
    99  	return worker
   100  }
   101  
   102  func (b *leastConnectionsLoadBalancer) forEach(callback func(consumer) error) error {
   103  	for _, c := range b.workers {
   104  		err := callback(c)
   105  		if err != nil {
   106  			return err
   107  		}
   108  	}
   109  
   110  	return nil
   111  }
   112  
   113  func newLeastConnectionsLoadBalancer() loadBalancer {
   114  	return &leastConnectionsLoadBalancer{
   115  		genericLoadBalancer: &genericLoadBalancer{},
   116  	}
   117  }
   118  
   119  type sourceIPHashLoadBalancer struct {
   120  	*genericLoadBalancer
   121  }
   122  
   123  func (b *sourceIPHashLoadBalancer) hash(s string) int {
   124  	hash := int(crc32.ChecksumIEEE([]byte(s)))
   125  	if hash < 0 {
   126  		return -hash
   127  	}
   128  
   129  	return hash
   130  }
   131  
   132  func (b *sourceIPHashLoadBalancer) next(addr net.Addr) consumer {
   133  	return b.workers[b.hash(addr.String())%b.size]
   134  }
   135  
   136  func (b *sourceIPHashLoadBalancer) forEach(callback func(consumer) error) error {
   137  	for _, c := range b.workers {
   138  		err := callback(c)
   139  		if err != nil {
   140  			return err
   141  		}
   142  	}
   143  
   144  	return nil
   145  }
   146  
   147  func newSourceIPHashLoadBalancer() loadBalancer {
   148  	return &sourceIPHashLoadBalancer{
   149  		genericLoadBalancer: &genericLoadBalancer{},
   150  	}
   151  }
   152  
   153  func createLoadBalancer(loadBalancing LoadBalancing) (loadBalancer, error) {
   154  	switch loadBalancing {
   155  	case RoundRobin:
   156  		return newRoundRobinLoadBalancer(), nil
   157  	case LeastConnections:
   158  		return newLeastConnectionsLoadBalancer(), nil
   159  	case SourceIPHash:
   160  		return newSourceIPHashLoadBalancer(), nil
   161  	default:
   162  		return nil, errors.ErrNotSupported
   163  	}
   164  }