github.com/nsqio/nsq@v1.3.0/apps/nsq_to_nsq/nsq_to_nsq.go (about)

     1  // This is an NSQ client that reads the specified topic/channel
     2  // and re-publishes the messages to destination nsqd via TCP
     3  
     4  package main
     5  
     6  import (
     7  	"encoding/json"
     8  	"errors"
     9  	"flag"
    10  	"fmt"
    11  	"log"
    12  	"os"
    13  	"os/signal"
    14  	"strconv"
    15  	"sync/atomic"
    16  	"syscall"
    17  	"time"
    18  
    19  	"github.com/bitly/go-hostpool"
    20  	"github.com/bitly/timer_metrics"
    21  	"github.com/nsqio/go-nsq"
    22  	"github.com/nsqio/nsq/internal/app"
    23  	"github.com/nsqio/nsq/internal/protocol"
    24  	"github.com/nsqio/nsq/internal/version"
    25  )
    26  
    27  const (
    28  	ModeRoundRobin = iota
    29  	ModeHostPool
    30  )
    31  
    32  var (
    33  	showVersion = flag.Bool("version", false, "print version string")
    34  	channel     = flag.String("channel", "nsq_to_nsq", "nsq channel")
    35  	destTopic   = flag.String("destination-topic", "", "use this destination topic for all consumed topics (default is consumed topic name)")
    36  	maxInFlight = flag.Int("max-in-flight", 200, "max number of messages to allow in flight")
    37  
    38  	statusEvery = flag.Int("status-every", 250, "the # of requests between logging status (per destination), 0 disables")
    39  	mode        = flag.String("mode", "hostpool", "the upstream request mode options: round-robin, hostpool (default), epsilon-greedy")
    40  
    41  	nsqdTCPAddrs        = app.StringArray{}
    42  	lookupdHTTPAddrs    = app.StringArray{}
    43  	destNsqdTCPAddrs    = app.StringArray{}
    44  	whitelistJSONFields = app.StringArray{}
    45  	topics              = app.StringArray{}
    46  
    47  	requireJSONField = flag.String("require-json-field", "", "for JSON messages: only pass messages that contain this field")
    48  	requireJSONValue = flag.String("require-json-value", "", "for JSON messages: only pass messages in which the required field has this value")
    49  )
    50  
    51  func init() {
    52  	flag.Var(&nsqdTCPAddrs, "nsqd-tcp-address", "nsqd TCP address (may be given multiple times)")
    53  	flag.Var(&destNsqdTCPAddrs, "destination-nsqd-tcp-address", "destination nsqd TCP address (may be given multiple times)")
    54  	flag.Var(&lookupdHTTPAddrs, "lookupd-http-address", "lookupd HTTP address (may be given multiple times)")
    55  	flag.Var(&topics, "topic", "nsq topic (may be given multiple times)")
    56  	flag.Var(&whitelistJSONFields, "whitelist-json-field", "for JSON messages: pass this field (may be given multiple times)")
    57  }
    58  
    59  type PublishHandler struct {
    60  	// 64bit atomic vars need to be first for proper alignment on 32bit platforms
    61  	counter uint64
    62  
    63  	addresses app.StringArray
    64  	producers map[string]*nsq.Producer
    65  	mode      int
    66  	hostPool  hostpool.HostPool
    67  	respChan  chan *nsq.ProducerTransaction
    68  
    69  	requireJSONValueParsed   bool
    70  	requireJSONValueIsNumber bool
    71  	requireJSONNumber        float64
    72  
    73  	perAddressStatus map[string]*timer_metrics.TimerMetrics
    74  	timermetrics     *timer_metrics.TimerMetrics
    75  }
    76  
    77  type TopicHandler struct {
    78  	publishHandler   *PublishHandler
    79  	destinationTopic string
    80  }
    81  
    82  func (ph *PublishHandler) responder() {
    83  	var msg *nsq.Message
    84  	var startTime time.Time
    85  	var address string
    86  	var hostPoolResponse hostpool.HostPoolResponse
    87  
    88  	for t := range ph.respChan {
    89  		switch ph.mode {
    90  		case ModeRoundRobin:
    91  			msg = t.Args[0].(*nsq.Message)
    92  			startTime = t.Args[1].(time.Time)
    93  			hostPoolResponse = nil
    94  			address = t.Args[2].(string)
    95  		case ModeHostPool:
    96  			msg = t.Args[0].(*nsq.Message)
    97  			startTime = t.Args[1].(time.Time)
    98  			hostPoolResponse = t.Args[2].(hostpool.HostPoolResponse)
    99  			address = hostPoolResponse.Host()
   100  		}
   101  
   102  		success := t.Error == nil
   103  
   104  		if hostPoolResponse != nil {
   105  			if !success {
   106  				hostPoolResponse.Mark(errors.New("failed"))
   107  			} else {
   108  				hostPoolResponse.Mark(nil)
   109  			}
   110  		}
   111  
   112  		if success {
   113  			msg.Finish()
   114  		} else {
   115  			msg.Requeue(-1)
   116  		}
   117  
   118  		ph.perAddressStatus[address].Status(startTime)
   119  		ph.timermetrics.Status(startTime)
   120  	}
   121  }
   122  
   123  func (ph *PublishHandler) shouldPassMessage(js map[string]interface{}) (bool, bool) {
   124  	pass := true
   125  	backoff := false
   126  
   127  	if *requireJSONField == "" {
   128  		return pass, backoff
   129  	}
   130  
   131  	if *requireJSONValue != "" && !ph.requireJSONValueParsed {
   132  		// cache conversion in case needed while filtering json
   133  		var err error
   134  		ph.requireJSONNumber, err = strconv.ParseFloat(*requireJSONValue, 64)
   135  		ph.requireJSONValueIsNumber = (err == nil)
   136  		ph.requireJSONValueParsed = true
   137  	}
   138  
   139  	v, ok := js[*requireJSONField]
   140  	if !ok {
   141  		pass = false
   142  		if *requireJSONValue != "" {
   143  			log.Printf("ERROR: missing field to check required value")
   144  			backoff = true
   145  		}
   146  	} else if *requireJSONValue != "" {
   147  		// if command-line argument can't convert to float, then it can't match a number
   148  		// if it can, also integers (up to 2^53 or so) can be compared as float64
   149  		if s, ok := v.(string); ok {
   150  			if s != *requireJSONValue {
   151  				pass = false
   152  			}
   153  		} else if ph.requireJSONValueIsNumber {
   154  			f, ok := v.(float64)
   155  			if !ok || f != ph.requireJSONNumber {
   156  				pass = false
   157  			}
   158  		} else {
   159  			// json value wasn't a plain string, and argument wasn't a number
   160  			// give up on comparisons of other types
   161  			pass = false
   162  		}
   163  	}
   164  
   165  	return pass, backoff
   166  }
   167  
   168  func filterMessage(js map[string]interface{}, rawMsg []byte) ([]byte, error) {
   169  	if len(whitelistJSONFields) == 0 {
   170  		// no change
   171  		return rawMsg, nil
   172  	}
   173  
   174  	newMsg := make(map[string]interface{}, len(whitelistJSONFields))
   175  
   176  	for _, key := range whitelistJSONFields {
   177  		value, ok := js[key]
   178  		if ok {
   179  			// avoid printing int as float (go 1.0)
   180  			switch tvalue := value.(type) {
   181  			case float64:
   182  				ivalue := int64(tvalue)
   183  				if float64(ivalue) == tvalue {
   184  					newMsg[key] = ivalue
   185  				} else {
   186  					newMsg[key] = tvalue
   187  				}
   188  			default:
   189  				newMsg[key] = value
   190  			}
   191  		}
   192  	}
   193  
   194  	newRawMsg, err := json.Marshal(newMsg)
   195  	if err != nil {
   196  		return nil, fmt.Errorf("unable to marshal filtered message %v", newMsg)
   197  	}
   198  	return newRawMsg, nil
   199  }
   200  
   201  func (t *TopicHandler) HandleMessage(m *nsq.Message) error {
   202  	return t.publishHandler.HandleMessage(m, t.destinationTopic)
   203  }
   204  
   205  func (ph *PublishHandler) HandleMessage(m *nsq.Message, destinationTopic string) error {
   206  	var err error
   207  	msgBody := m.Body
   208  
   209  	if *requireJSONField != "" || len(whitelistJSONFields) > 0 {
   210  		var js map[string]interface{}
   211  		err = json.Unmarshal(msgBody, &js)
   212  		if err != nil {
   213  			log.Printf("ERROR: Unable to decode json: %s", msgBody)
   214  			return nil
   215  		}
   216  
   217  		if pass, backoff := ph.shouldPassMessage(js); !pass {
   218  			if backoff {
   219  				return errors.New("backoff")
   220  			}
   221  			return nil
   222  		}
   223  
   224  		msgBody, err = filterMessage(js, msgBody)
   225  
   226  		if err != nil {
   227  			log.Printf("ERROR: filterMessage() failed: %s", err)
   228  			return err
   229  		}
   230  	}
   231  
   232  	startTime := time.Now()
   233  
   234  	switch ph.mode {
   235  	case ModeRoundRobin:
   236  		counter := atomic.AddUint64(&ph.counter, 1)
   237  		idx := counter % uint64(len(ph.addresses))
   238  		addr := ph.addresses[idx]
   239  		p := ph.producers[addr]
   240  		err = p.PublishAsync(destinationTopic, msgBody, ph.respChan, m, startTime, addr)
   241  	case ModeHostPool:
   242  		hostPoolResponse := ph.hostPool.Get()
   243  		p := ph.producers[hostPoolResponse.Host()]
   244  		err = p.PublishAsync(destinationTopic, msgBody, ph.respChan, m, startTime, hostPoolResponse)
   245  		if err != nil {
   246  			hostPoolResponse.Mark(err)
   247  		}
   248  	}
   249  
   250  	if err != nil {
   251  		return err
   252  	}
   253  	m.DisableAutoResponse()
   254  	return nil
   255  }
   256  
   257  func main() {
   258  	var selectedMode int
   259  
   260  	cCfg := nsq.NewConfig()
   261  	pCfg := nsq.NewConfig()
   262  
   263  	flag.Var(&nsq.ConfigFlag{cCfg}, "consumer-opt", "option to passthrough to nsq.Consumer (may be given multiple times, see http://godoc.org/github.com/nsqio/go-nsq#Config)")
   264  	flag.Var(&nsq.ConfigFlag{pCfg}, "producer-opt", "option to passthrough to nsq.Producer (may be given multiple times, see http://godoc.org/github.com/nsqio/go-nsq#Config)")
   265  
   266  	flag.Parse()
   267  
   268  	if *showVersion {
   269  		fmt.Printf("nsq_to_nsq v%s\n", version.Binary)
   270  		return
   271  	}
   272  
   273  	if len(topics) == 0 || *channel == "" {
   274  		log.Fatal("--topic and --channel are required")
   275  	}
   276  
   277  	for _, topic := range topics {
   278  		if !protocol.IsValidTopicName(topic) {
   279  			log.Fatal("--topic is invalid")
   280  		}
   281  	}
   282  
   283  	if *destTopic != "" && !protocol.IsValidTopicName(*destTopic) {
   284  		log.Fatal("--destination-topic is invalid")
   285  	}
   286  
   287  	if !protocol.IsValidChannelName(*channel) {
   288  		log.Fatal("--channel is invalid")
   289  	}
   290  
   291  	if len(nsqdTCPAddrs) == 0 && len(lookupdHTTPAddrs) == 0 {
   292  		log.Fatal("--nsqd-tcp-address or --lookupd-http-address required")
   293  	}
   294  	if len(nsqdTCPAddrs) > 0 && len(lookupdHTTPAddrs) > 0 {
   295  		log.Fatal("use --nsqd-tcp-address or --lookupd-http-address not both")
   296  	}
   297  
   298  	if len(destNsqdTCPAddrs) == 0 {
   299  		log.Fatal("--destination-nsqd-tcp-address required")
   300  	}
   301  
   302  	switch *mode {
   303  	case "round-robin":
   304  		selectedMode = ModeRoundRobin
   305  	case "hostpool", "epsilon-greedy":
   306  		selectedMode = ModeHostPool
   307  	}
   308  
   309  	termChan := make(chan os.Signal, 1)
   310  	signal.Notify(termChan, syscall.SIGINT, syscall.SIGTERM)
   311  
   312  	defaultUA := fmt.Sprintf("nsq_to_nsq/%s go-nsq/%s", version.Binary, nsq.VERSION)
   313  
   314  	cCfg.UserAgent = defaultUA
   315  	cCfg.MaxInFlight = *maxInFlight
   316  	pCfg.UserAgent = defaultUA
   317  
   318  	producers := make(map[string]*nsq.Producer)
   319  	for _, addr := range destNsqdTCPAddrs {
   320  		producer, err := nsq.NewProducer(addr, pCfg)
   321  		if err != nil {
   322  			log.Fatalf("failed creating producer %s", err)
   323  		}
   324  		producers[addr] = producer
   325  	}
   326  
   327  	perAddressStatus := make(map[string]*timer_metrics.TimerMetrics)
   328  	if len(destNsqdTCPAddrs) == 1 {
   329  		// disable since there is only one address
   330  		perAddressStatus[destNsqdTCPAddrs[0]] = timer_metrics.NewTimerMetrics(0, "")
   331  	} else {
   332  		for _, a := range destNsqdTCPAddrs {
   333  			perAddressStatus[a] = timer_metrics.NewTimerMetrics(*statusEvery,
   334  				fmt.Sprintf("[%s]:", a))
   335  		}
   336  	}
   337  
   338  	hostPool := hostpool.New(destNsqdTCPAddrs)
   339  	if *mode == "epsilon-greedy" {
   340  		hostPool = hostpool.NewEpsilonGreedy(destNsqdTCPAddrs, 0, &hostpool.LinearEpsilonValueCalculator{})
   341  	}
   342  
   343  	var consumerList []*nsq.Consumer
   344  
   345  	publisher := &PublishHandler{
   346  		addresses:        destNsqdTCPAddrs,
   347  		producers:        producers,
   348  		mode:             selectedMode,
   349  		hostPool:         hostPool,
   350  		respChan:         make(chan *nsq.ProducerTransaction, len(destNsqdTCPAddrs)),
   351  		perAddressStatus: perAddressStatus,
   352  		timermetrics:     timer_metrics.NewTimerMetrics(*statusEvery, "[aggregate]:"),
   353  	}
   354  
   355  	for _, topic := range topics {
   356  		consumer, err := nsq.NewConsumer(topic, *channel, cCfg)
   357  		consumerList = append(consumerList, consumer)
   358  		if err != nil {
   359  			log.Fatal(err)
   360  		}
   361  
   362  		publishTopic := topic
   363  		if *destTopic != "" {
   364  			publishTopic = *destTopic
   365  		}
   366  		topicHandler := &TopicHandler{
   367  			publishHandler:   publisher,
   368  			destinationTopic: publishTopic,
   369  		}
   370  		consumer.AddConcurrentHandlers(topicHandler, len(destNsqdTCPAddrs))
   371  	}
   372  	for i := 0; i < len(destNsqdTCPAddrs); i++ {
   373  		go publisher.responder()
   374  	}
   375  
   376  	for _, consumer := range consumerList {
   377  		err := consumer.ConnectToNSQDs(nsqdTCPAddrs)
   378  		if err != nil {
   379  			log.Fatal(err)
   380  		}
   381  	}
   382  
   383  	for _, consumer := range consumerList {
   384  		err := consumer.ConnectToNSQLookupds(lookupdHTTPAddrs)
   385  		if err != nil {
   386  			log.Fatal(err)
   387  		}
   388  	}
   389  
   390  	<-termChan // wait for signal
   391  
   392  	for _, consumer := range consumerList {
   393  		consumer.Stop()
   394  	}
   395  	for _, consumer := range consumerList {
   396  		<-consumer.StopChan
   397  	}
   398  }