github.com/jtzjtz/kit@v1.0.2/conn/rabbitmq_pool/rabbitmq_pool.go (about)

     1  package rabbitmq_pool
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"github.com/streadway/amqp"
     8  	"log"
     9  	"math/rand"
    10  	"strings"
    11  	"sync"
    12  	"time"
    13  )
    14  
    15  var (
    16  	retryCount      = 5
    17  	waitConfirmTime = 10 * time.Second
    18  )
    19  
    20  var AmqpServer Service
    21  
    22  type Service struct {
    23  	AmqpUrl       string //amqp地址
    24  	ConnectionNum int    //连接数
    25  	ChannelNum    int    //每个连接的channel数量
    26  
    27  	connections    map[int]*amqp.Connection
    28  	channels       map[int]channel
    29  	idelChannels   []int
    30  	busyChannels   map[int]int
    31  	m              *sync.Mutex
    32  	wg             *sync.WaitGroup
    33  	ctx            context.Context
    34  	cancel         context.CancelFunc
    35  	connectIdChan  chan int
    36  	lockConnectIds map[int]bool
    37  }
    38  
    39  type channel struct {
    40  	ch            *amqp.Channel
    41  	notifyClose   chan *amqp.Error
    42  	notifyConfirm chan amqp.Confirmation
    43  }
    44  
    45  func InitAmqp() {
    46  	if AmqpServer.AmqpUrl == "" {
    47  		log.Fatal("rabbitmq's address can not be empty!")
    48  	}
    49  	if AmqpServer.ConnectionNum == 0 {
    50  		AmqpServer.ConnectionNum = 10
    51  	}
    52  	if AmqpServer.ChannelNum == 0 {
    53  		AmqpServer.ChannelNum = 10
    54  	}
    55  	AmqpServer.m = new(sync.Mutex)
    56  	AmqpServer.wg = new(sync.WaitGroup)
    57  	AmqpServer.ctx, AmqpServer.cancel = context.WithTimeout(context.Background(), waitConfirmTime)
    58  	AmqpServer.lockConnectIds = make(map[int]bool)
    59  	AmqpServer.connectIdChan = make(chan int)
    60  
    61  	AmqpServer.connectPool()
    62  	AmqpServer.channelPool()
    63  }
    64  
    65  func failOnError(err error, msg string) {
    66  	if err != nil {
    67  		log.Panicf("%s: %s", msg, err)
    68  	}
    69  }
    70  
    71  func (S *Service) connectPool() {
    72  	S.connections = make(map[int]*amqp.Connection)
    73  	for i := 0; i < S.ConnectionNum; i++ {
    74  		connection := S.connect()
    75  		S.connections[i] = connection
    76  	}
    77  }
    78  
    79  func (S *Service) channelPool() {
    80  	S.channels = make(map[int]channel)
    81  	//S.idelChannels = make(map[int]int)
    82  	for index, _ := range S.connections {
    83  		for j := 0; j < S.ChannelNum; j++ {
    84  			key := index*S.ChannelNum + j
    85  			S.channels[key] = S.createChannel(index)
    86  			//S.idelChannels[key] = key
    87  			S.idelChannels = append(S.idelChannels, key)
    88  		}
    89  	}
    90  }
    91  
    92  func (S *Service) connect() *amqp.Connection {
    93  	conn, err := amqp.Dial(S.AmqpUrl)
    94  	failOnError(err, "Failed to connect to RabbitMQ")
    95  	//defer conn.Close()
    96  	return conn
    97  }
    98  
    99  func (S *Service) recreateChannel(connectId int, err error) (ch *amqp.Channel) {
   100  	if strings.Index(err.Error(), "channel/connection is not open") >= 0 || strings.Index(err.Error(), "CHANNEL_ERROR - expected 'channel.open'") >= 0 {
   101  		//S.connections[connectId].Close()
   102  		if S.connections[connectId].IsClosed() {
   103  			S.lockWriteConnect(connectId)
   104  		}
   105  		ch, err = S.connections[connectId].Channel()
   106  		failOnError(err, "Failed to open a channel")
   107  	} else {
   108  		failOnError(err, "Failed to open a channel")
   109  	}
   110  	return
   111  }
   112  
   113  func (S *Service) lockWriteConnect(connectId int) {
   114  
   115  	S.m.Lock()
   116  	if !S.lockConnectIds[connectId] {
   117  		S.lockConnectIds[connectId] = true
   118  		S.m.Unlock()
   119  
   120  		go func(connectId int) {
   121  			S.wg.Add(1)
   122  			defer S.wg.Done()
   123  
   124  			S.connections[connectId] = S.connect()
   125  			S.connectIdChan <- connectId
   126  
   127  		}(connectId)
   128  	} else {
   129  		S.m.Unlock()
   130  	}
   131  
   132  	for {
   133  		select {
   134  		case cid := <-S.connectIdChan:
   135  
   136  			delete(S.lockConnectIds, cid)
   137  
   138  			if len(S.lockConnectIds) == 0 {
   139  				S.wg.Wait()
   140  				return
   141  			} else {
   142  				continue
   143  			}
   144  		case <-time.After(waitConfirmTime):
   145  			S.lockConnectIds = make(map[int]bool)
   146  			S.wg.Wait()
   147  			return
   148  		}
   149  	}
   150  }
   151  
   152  func (S *Service) createChannel(connectId int) channel {
   153  	var notifyClose = make(chan *amqp.Error)
   154  	var notifyConfirm = make(chan amqp.Confirmation)
   155  
   156  	cha := channel{
   157  		notifyClose:   notifyClose,
   158  		notifyConfirm: notifyConfirm,
   159  	}
   160  	if S.connections[connectId].IsClosed() {
   161  		S.lockWriteConnect(connectId)
   162  	}
   163  	ch, err := S.connections[connectId].Channel()
   164  	if err != nil {
   165  		ch = S.recreateChannel(connectId, err)
   166  	}
   167  
   168  	ch.Confirm(false)
   169  	ch.NotifyClose(cha.notifyClose)
   170  	ch.NotifyPublish(cha.notifyConfirm)
   171  
   172  	cha.ch = ch
   173  	//go func() {
   174  	//	select {
   175  	//	case <-cha.notifyClose:
   176  	//		fmt.Println("close channel")
   177  	//	}
   178  	//}()
   179  	return cha
   180  }
   181  
   182  func (S *Service) getChannel() (*amqp.Channel, int) {
   183  	S.m.Lock()
   184  	defer S.m.Unlock()
   185  	idelLength := len(S.idelChannels)
   186  	if idelLength > 0 {
   187  
   188  		rand.Seed(time.Now().Unix())
   189  		index := rand.Intn(idelLength)
   190  		channelId := S.idelChannels[index]
   191  		S.idelChannels = append(S.idelChannels[:index], S.idelChannels[index+1:]...)
   192  		S.busyChannels = make(map[int]int)
   193  		S.busyChannels[channelId] = channelId
   194  
   195  		ch := S.channels[channelId].ch
   196  		//fmt.Println("channels count: ",len(S.channels))
   197  		//fmt.Println("idel channels count: ",len(S.idelChannels))
   198  		//fmt.Println("busy channels count: ",len(S.busyChannels))
   199  		//fmt.Println("channel id: ",channelId)
   200  		return ch, channelId
   201  	} else {
   202  		//return S.createChannel(0,S.connections[0]),-1
   203  		return nil, -1
   204  	}
   205  }
   206  
   207  func (S *Service) declareExchange(ch *amqp.Channel, exchangeName string, channelId int) *amqp.Channel {
   208  	err := ch.ExchangeDeclare(
   209  		exchangeName, // name
   210  		"direct",     // type
   211  		true,         // durable
   212  		false,        // auto-deleted
   213  		false,        // internal
   214  		false,        // no-wait
   215  		nil,          // arguments
   216  	)
   217  	if err != nil {
   218  		ch = S.reDeclareExchange(channelId, exchangeName, err)
   219  	}
   220  	return ch
   221  }
   222  
   223  func (S *Service) reDeclareExchange(channelId int, exchangeName string, err error) (ch *amqp.Channel) {
   224  	//fmt.Println("reDeclareExchange")
   225  
   226  	var connectionId int
   227  	if strings.Index(err.Error(), "channel/connection is not open") >= 0 {
   228  
   229  		//S.channels[channelId].Close()
   230  		if channelId == -1 {
   231  			rand.Seed(time.Now().Unix())
   232  			index := rand.Intn(S.ConnectionNum)
   233  			connectionId = index
   234  		} else {
   235  			connectionId = int(channelId / S.ChannelNum)
   236  		}
   237  		cha := S.createChannel(connectionId)
   238  
   239  		S.lockWriteChannel(channelId, cha)
   240  		//S.channels[channelId] = cha
   241  		err := cha.ch.ExchangeDeclare(
   242  			exchangeName, // name
   243  			"direct",     // type
   244  			true,         // durable
   245  			false,        // auto-deleted
   246  			false,        // internal
   247  			false,        // no-wait
   248  			nil,          // arguments
   249  		)
   250  		if err != nil {
   251  			failOnError(err, "Failed to declare an exchange")
   252  		}
   253  		return cha.ch
   254  	} else {
   255  
   256  		failOnError(err, "Failed to declare an exchange")
   257  		return nil
   258  	}
   259  }
   260  
   261  func (S *Service) lockWriteChannel(channelId int, cha channel) {
   262  	S.m.Lock()
   263  	defer S.m.Unlock()
   264  	S.channels[channelId] = cha
   265  }
   266  
   267  func (S *Service) dataForm(notice interface{}) string {
   268  	body, err := json.Marshal(notice)
   269  	if err != nil {
   270  		log.Panic(err)
   271  	}
   272  	return string(body)
   273  }
   274  
   275  func (S *Service) publish(channelId int, ch *amqp.Channel, exchangeName string, routeKey string, data string, replyTo string, headers map[string]interface{}) (err error) {
   276  
   277  	err = ch.Publish(
   278  		exchangeName, // exchange
   279  		routeKey,     //severityFrom(os.Args), // routing key
   280  		false,        // mandatory
   281  		false,        // immediate
   282  		amqp.Publishing{
   283  			ReplyTo:         replyTo,
   284  			ContentEncoding: "utf-8",
   285  			DeliveryMode:    amqp.Persistent,
   286  			ContentType:     "application/json",
   287  			Body:            []byte(data),
   288  			Headers:         headers,
   289  		})
   290  
   291  	if err != nil {
   292  		if strings.Index(err.Error(), "channel/connection is not open") >= 0 {
   293  			err = S.rePublish(channelId, exchangeName, err, routeKey, data, replyTo, headers)
   294  		}
   295  	}
   296  	return
   297  }
   298  
   299  func (S *Service) rePublish(channelId int, exchangeName string, errmsg error, routeKey string, data string, replyTo string, headers map[string]interface{}) (err error) {
   300  	//fmt.Println("rePublish")
   301  
   302  	ch := S.reDeclareExchange(channelId, exchangeName, errmsg)
   303  	err = ch.Publish(
   304  		exchangeName, // exchange
   305  		routeKey,     //severityFrom(os.Args), // routing key
   306  		false,        // mandatory
   307  		false,        // immediate
   308  		amqp.Publishing{
   309  			ReplyTo:         replyTo,
   310  			ContentEncoding: "utf-8",
   311  			DeliveryMode:    amqp.Persistent,
   312  			ContentType:     "application/json",
   313  			Body:            []byte(data),
   314  		})
   315  	return
   316  }
   317  
   318  func (S *Service) backChannelId(channelId int, ch *amqp.Channel) {
   319  	S.m.Lock()
   320  	defer S.m.Unlock()
   321  	S.idelChannels = append(S.idelChannels, channelId)
   322  	delete(S.busyChannels, channelId)
   323  	return
   324  }
   325  
   326  func (S *Service) PutIntoQueue(exchangeName string, routeKey string, notice string, replyTo string, headers map[string]interface{}) (message interface{}, puberr error) {
   327  	defer func() {
   328  		msg := recover()
   329  		if msg != nil {
   330  			//fmt.Println("msg: ",msg)
   331  			puberrMsg, _ := msg.(string)
   332  			//fmt.Println("ok: ",ok)
   333  			//fmt.Println("puberrMsg : ",puberrMsg)
   334  			puberr = errors.New(puberrMsg)
   335  			return
   336  		}
   337  	}()
   338  
   339  	ch, channelId := S.getChannel()
   340  	if ch == nil {
   341  		rand.Seed(time.Now().Unix())
   342  		index := rand.Intn(S.ConnectionNum)
   343  		cha := S.createChannel(index)
   344  		defer cha.ch.Close()
   345  		ch = cha.ch
   346  		//fmt.Println("ch: ",ch)
   347  	}
   348  	ch = S.declareExchange(ch, exchangeName, channelId)
   349  
   350  	data := notice //S.dataForm(notice)
   351  	var tryTime = 1
   352  
   353  	for {
   354  		puberr = S.publish(channelId, ch, exchangeName, routeKey, data, replyTo, headers)
   355  		if puberr != nil {
   356  			if tryTime <= retryCount {
   357  				//log.Printf("%s: %s", "Failed to publish a message, try again.", puberr)
   358  				tryTime++
   359  				continue
   360  			} else {
   361  				//log.Printf("%s: %s data: %s", "Failed to publish a message", puberr,data)
   362  				S.backChannelId(channelId, ch)
   363  				return notice, puberr
   364  			}
   365  		}
   366  
   367  		select {
   368  		case confirm := <-S.channels[channelId].notifyConfirm:
   369  			if confirm.Ack {
   370  				//log.Printf(" [%s] Sent %d message %s", routeKey, confirm.DeliveryTag, data)
   371  				S.backChannelId(channelId, ch)
   372  
   373  				return notice, nil
   374  			}
   375  		case <-time.After(waitConfirmTime):
   376  			//	log.Printf("message: %s data: %s", "Can not receive the confirm.", data)
   377  			S.backChannelId(channelId, ch)
   378  			confirmErr := errors.New("Can not receive the confirm . ")
   379  			return notice, confirmErr
   380  		}
   381  	}
   382  
   383  }
   384  func (S *Service) SetWaitConfirmTime(t time.Duration) {
   385  	waitConfirmTime = t
   386  }
   387  func (S *Service) SetRetryCount(c int) {
   388  	retryCount = c
   389  }