github.com/AntonOrnatskyi/goproxy@v0.0.0-20190205095733-4526a9fa18b4/utils/lb/backend.go (about)

     1  package lb
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"log"
     7  	"net"
     8  	"runtime/debug"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/AntonOrnatskyi/goproxy/utils/dnsx"
    13  )
    14  
    15  // BackendConfig it's the configuration loaded
    16  type BackendConfig struct {
    17  	Address string
    18  
    19  	ActiveAfter   int
    20  	InactiveAfter int
    21  	Weight        int
    22  
    23  	Timeout   time.Duration
    24  	RetryTime time.Duration
    25  
    26  	IsMuxCheck  bool
    27  	ConnFactory func(address string, timeout time.Duration) (net.Conn, error)
    28  }
    29  type BackendsConfig []*BackendConfig
    30  
    31  // BackendControl keep the control data
    32  type BackendControl struct {
    33  	Failed bool // The last request failed
    34  	Active bool
    35  
    36  	InactiveTries int
    37  	ActiveTries   int
    38  
    39  	Connections int
    40  
    41  	ConnectUsedMillisecond int
    42  
    43  	isStop bool
    44  }
    45  
    46  // Backend structure
    47  type Backend struct {
    48  	BackendConfig
    49  	BackendControl
    50  	sync.RWMutex
    51  	log *log.Logger
    52  	dr  *dnsx.DomainResolver
    53  }
    54  
    55  type Backends []*Backend
    56  
    57  func NewBackend(backendConfig BackendConfig, dr *dnsx.DomainResolver, log *log.Logger) (*Backend, error) {
    58  
    59  	if backendConfig.Address == "" {
    60  		return nil, errors.New("Address rquired")
    61  	}
    62  	if backendConfig.ActiveAfter == 0 {
    63  		backendConfig.ActiveAfter = 2
    64  	}
    65  	if backendConfig.InactiveAfter == 0 {
    66  		backendConfig.InactiveAfter = 3
    67  	}
    68  	if backendConfig.Weight == 0 {
    69  		backendConfig.Weight = 1
    70  	}
    71  	if backendConfig.Timeout == 0 {
    72  		backendConfig.Timeout = time.Millisecond * 1500
    73  	}
    74  	if backendConfig.RetryTime == 0 {
    75  		backendConfig.RetryTime = time.Millisecond * 2000
    76  	}
    77  	return &Backend{
    78  		dr:            dr,
    79  		log:           log,
    80  		BackendConfig: backendConfig,
    81  		BackendControl: BackendControl{
    82  			Failed:                 true,
    83  			Active:                 false,
    84  			InactiveTries:          0,
    85  			ActiveTries:            0,
    86  			Connections:            0,
    87  			ConnectUsedMillisecond: 0,
    88  			isStop:                 false,
    89  		},
    90  	}, nil
    91  }
    92  func (b *Backend) StopHeartCheck() {
    93  	b.isStop = true
    94  }
    95  
    96  func (b *Backend) IncreasConns() {
    97  	b.RWMutex.Lock()
    98  	b.Connections++
    99  	b.RWMutex.Unlock()
   100  }
   101  
   102  func (b *Backend) DecreaseConns() {
   103  	b.RWMutex.Lock()
   104  	b.Connections--
   105  	b.RWMutex.Unlock()
   106  }
   107  
   108  func (b *Backend) StartHeartCheck() {
   109  	if b.IsMuxCheck {
   110  		b.startMuxHeartCheck()
   111  	} else {
   112  		b.startTCPHeartCheck()
   113  	}
   114  }
   115  func (b *Backend) startMuxHeartCheck() {
   116  	go func() {
   117  		defer func() {
   118  			if e := recover(); e != nil {
   119  				fmt.Printf("crashed, err: %s\nstack:\n%s",e, string(debug.Stack()))
   120  			}
   121  		}()
   122  		for {
   123  			if b.isStop {
   124  				return
   125  			}
   126  			var c net.Conn
   127  			var err error
   128  			start := time.Now().UnixNano() / int64(time.Microsecond)
   129  			c, err = b.getConn()
   130  			b.ConnectUsedMillisecond = int(time.Now().UnixNano()/int64(time.Microsecond) - start)
   131  			if err != nil {
   132  				b.Active = false
   133  				time.Sleep(time.Second * 2)
   134  				continue
   135  			} else {
   136  				b.Active = true
   137  			}
   138  			for {
   139  				buf := make([]byte, 1)
   140  				c.Read(buf)
   141  				buf = nil
   142  				break
   143  			}
   144  			b.Active = false
   145  		}
   146  	}()
   147  }
   148  
   149  // Monitoring the backend
   150  func (b *Backend) startTCPHeartCheck() {
   151  	go func() {
   152  		defer func() {
   153  			if e := recover(); e != nil {
   154  				fmt.Printf("crashed, err: %s\nstack:\n%s",e, string(debug.Stack()))
   155  			}
   156  		}()
   157  		for {
   158  			if b.isStop {
   159  				return
   160  			}
   161  			var c net.Conn
   162  			var err error
   163  			start := time.Now().UnixNano() / int64(time.Microsecond)
   164  			c, err = b.getConn()
   165  			b.ConnectUsedMillisecond = int(time.Now().UnixNano()/int64(time.Microsecond) - start)
   166  			if err == nil {
   167  				c.Close()
   168  			}
   169  			if err != nil {
   170  				b.RWMutex.Lock()
   171  				// Max tries before consider inactive
   172  				if b.InactiveTries >= b.InactiveAfter {
   173  					//b.log.Printf("Backend inactive [%s]", b.Address)
   174  					b.Active = false
   175  					b.ActiveTries = 0
   176  				} else {
   177  					// Ok that guy it's out of the game
   178  					b.Failed = true
   179  					b.InactiveTries++
   180  					//b.log.Printf("Error to check address [%s] tries [%d]", b.Address, b.InactiveTries)
   181  				}
   182  				b.RWMutex.Unlock()
   183  			} else {
   184  
   185  				// Ok, let's keep working boys
   186  				b.RWMutex.Lock()
   187  				if b.ActiveTries >= b.ActiveAfter {
   188  					if b.Failed {
   189  						//log.Printf("Backend active [%s]", b.Address)
   190  					}
   191  					b.Failed = false
   192  					b.Active = true
   193  					b.InactiveTries = 0
   194  				} else {
   195  					b.ActiveTries++
   196  				}
   197  				b.RWMutex.Unlock()
   198  			}
   199  			time.Sleep(b.RetryTime)
   200  		}
   201  	}()
   202  }
   203  func (b *Backend) getConn() (conn net.Conn, err error) {
   204  	address := b.Address
   205  	if b.dr != nil && b.dr.DnsAddress() != "" {
   206  		address, err = b.dr.Resolve(b.Address)
   207  		if err != nil {
   208  			b.log.Printf("dns error %s , ERR:%s", b.Address, err)
   209  		}
   210  	}
   211  	if b.ConnFactory != nil {
   212  		return b.ConnFactory(address, b.Timeout)
   213  	}
   214  	return net.DialTimeout("tcp", address, b.Timeout)
   215  }