github.com/nsqio/nsq@v1.3.0/apps/nsq_to_http/nsq_to_http.go (about)

     1  // This is an NSQ client that reads the specified topic/channel
     2  // and performs HTTP requests (GET/POST) to the specified endpoints
     3  
     4  package main
     5  
     6  import (
     7  	"bytes"
     8  	"flag"
     9  	"fmt"
    10  	"io"
    11  	"log"
    12  	"math/rand"
    13  	"net/http"
    14  	"net/url"
    15  	"os"
    16  	"os/signal"
    17  	"strings"
    18  	"sync/atomic"
    19  	"syscall"
    20  	"time"
    21  
    22  	"github.com/bitly/go-hostpool"
    23  	"github.com/bitly/timer_metrics"
    24  	"github.com/nsqio/go-nsq"
    25  	"github.com/nsqio/nsq/internal/app"
    26  	"github.com/nsqio/nsq/internal/http_api"
    27  	"github.com/nsqio/nsq/internal/version"
    28  )
    29  
    30  const (
    31  	ModeAll = iota
    32  	ModeRoundRobin
    33  	ModeHostPool
    34  )
    35  
    36  var (
    37  	showVersion = flag.Bool("version", false, "print version string")
    38  
    39  	topic       = flag.String("topic", "", "nsq topic")
    40  	channel     = flag.String("channel", "nsq_to_http", "nsq channel")
    41  	maxInFlight = flag.Int("max-in-flight", 200, "max number of messages to allow in flight")
    42  
    43  	numPublishers      = flag.Int("n", 100, "number of concurrent publishers")
    44  	mode               = flag.String("mode", "hostpool", "the upstream request mode options: round-robin, hostpool (default), epsilon-greedy")
    45  	sample             = flag.Float64("sample", 1.0, "% of messages to publish (float b/w 0 -> 1)")
    46  	httpConnectTimeout = flag.Duration("http-client-connect-timeout", 2*time.Second, "timeout for HTTP connect")
    47  	httpRequestTimeout = flag.Duration("http-client-request-timeout", 20*time.Second, "timeout for HTTP request")
    48  	statusEvery        = flag.Int("status-every", 250, "the # of requests between logging status (per handler), 0 disables")
    49  	contentType        = flag.String("content-type", "application/octet-stream", "the Content-Type used for POST requests")
    50  
    51  	getAddrs           = app.StringArray{}
    52  	postAddrs          = app.StringArray{}
    53  	customHeaders      = app.StringArray{}
    54  	nsqdTCPAddrs       = app.StringArray{}
    55  	lookupdHTTPAddrs   = app.StringArray{}
    56  	validCustomHeaders map[string]string
    57  )
    58  
    59  func init() {
    60  	flag.Var(&postAddrs, "post", "HTTP address to make a POST request to.  data will be in the body (may be given multiple times)")
    61  	flag.Var(&customHeaders, "header", "Custom header for HTTP requests (may be given multiple times)")
    62  	flag.Var(&getAddrs, "get", "HTTP address to make a GET request to. '%s' will be printf replaced with data (may be given multiple times)")
    63  	flag.Var(&nsqdTCPAddrs, "nsqd-tcp-address", "nsqd TCP address (may be given multiple times)")
    64  	flag.Var(&lookupdHTTPAddrs, "lookupd-http-address", "lookupd HTTP address (may be given multiple times)")
    65  }
    66  
    67  type Publisher interface {
    68  	Publish(string, []byte) error
    69  }
    70  
    71  type PublishHandler struct {
    72  	// 64bit atomic vars need to be first for proper alignment on 32bit platforms
    73  	counter uint64
    74  
    75  	Publisher
    76  	addresses app.StringArray
    77  	mode      int
    78  	hostPool  hostpool.HostPool
    79  
    80  	perAddressStatus map[string]*timer_metrics.TimerMetrics
    81  	timermetrics     *timer_metrics.TimerMetrics
    82  }
    83  
    84  func (ph *PublishHandler) HandleMessage(m *nsq.Message) error {
    85  	if *sample < 1.0 && rand.Float64() > *sample {
    86  		return nil
    87  	}
    88  
    89  	startTime := time.Now()
    90  	switch ph.mode {
    91  	case ModeAll:
    92  		for _, addr := range ph.addresses {
    93  			st := time.Now()
    94  			err := ph.Publish(addr, m.Body)
    95  			if err != nil {
    96  				return err
    97  			}
    98  			ph.perAddressStatus[addr].Status(st)
    99  		}
   100  	case ModeRoundRobin:
   101  		counter := atomic.AddUint64(&ph.counter, 1)
   102  		idx := counter % uint64(len(ph.addresses))
   103  		addr := ph.addresses[idx]
   104  		err := ph.Publish(addr, m.Body)
   105  		if err != nil {
   106  			return err
   107  		}
   108  		ph.perAddressStatus[addr].Status(startTime)
   109  	case ModeHostPool:
   110  		hostPoolResponse := ph.hostPool.Get()
   111  		addr := hostPoolResponse.Host()
   112  		err := ph.Publish(addr, m.Body)
   113  		hostPoolResponse.Mark(err)
   114  		if err != nil {
   115  			return err
   116  		}
   117  		ph.perAddressStatus[addr].Status(startTime)
   118  	}
   119  	ph.timermetrics.Status(startTime)
   120  
   121  	return nil
   122  }
   123  
   124  type PostPublisher struct{}
   125  
   126  func (p *PostPublisher) Publish(addr string, msg []byte) error {
   127  	buf := bytes.NewBuffer(msg)
   128  	resp, err := HTTPPost(addr, buf)
   129  	if err != nil {
   130  		return err
   131  	}
   132  	io.Copy(io.Discard, resp.Body)
   133  	resp.Body.Close()
   134  
   135  	if resp.StatusCode < 200 || resp.StatusCode >= 300 {
   136  		return fmt.Errorf("got status code %d", resp.StatusCode)
   137  	}
   138  	return nil
   139  }
   140  
   141  type GetPublisher struct{}
   142  
   143  func (p *GetPublisher) Publish(addr string, msg []byte) error {
   144  	endpoint := fmt.Sprintf(addr, url.QueryEscape(string(msg)))
   145  	resp, err := HTTPGet(endpoint)
   146  	if err != nil {
   147  		return err
   148  	}
   149  	io.Copy(io.Discard, resp.Body)
   150  	resp.Body.Close()
   151  
   152  	if resp.StatusCode != 200 {
   153  		return fmt.Errorf("got status code %d", resp.StatusCode)
   154  	}
   155  	return nil
   156  }
   157  
   158  func main() {
   159  	var publisher Publisher
   160  	var addresses app.StringArray
   161  	var selectedMode int
   162  
   163  	cfg := nsq.NewConfig()
   164  
   165  	flag.Var(&nsq.ConfigFlag{cfg}, "consumer-opt", "option to passthrough to nsq.Consumer (may be given multiple times, http://godoc.org/github.com/nsqio/go-nsq#Config)")
   166  	flag.Parse()
   167  
   168  	httpclient = &http.Client{Transport: http_api.NewDeadlineTransport(*httpConnectTimeout, *httpRequestTimeout), Timeout: *httpRequestTimeout}
   169  
   170  	if *showVersion {
   171  		fmt.Printf("nsq_to_http v%s\n", version.Binary)
   172  		return
   173  	}
   174  
   175  	if len(customHeaders) > 0 {
   176  		var err error
   177  		validCustomHeaders, err = parseCustomHeaders(customHeaders)
   178  		if err != nil {
   179  			log.Fatal("--header value format should be 'key=value'")
   180  		}
   181  	}
   182  
   183  	if *topic == "" || *channel == "" {
   184  		log.Fatal("--topic and --channel are required")
   185  	}
   186  
   187  	if *contentType != flag.Lookup("content-type").DefValue {
   188  		if len(postAddrs) == 0 {
   189  			log.Fatal("--content-type only used with --post")
   190  		}
   191  		if len(*contentType) == 0 {
   192  			log.Fatal("--content-type requires a value when used")
   193  		}
   194  	}
   195  
   196  	if len(nsqdTCPAddrs) == 0 && len(lookupdHTTPAddrs) == 0 {
   197  		log.Fatal("--nsqd-tcp-address or --lookupd-http-address required")
   198  	}
   199  	if len(nsqdTCPAddrs) > 0 && len(lookupdHTTPAddrs) > 0 {
   200  		log.Fatal("use --nsqd-tcp-address or --lookupd-http-address not both")
   201  	}
   202  
   203  	if len(getAddrs) == 0 && len(postAddrs) == 0 {
   204  		log.Fatal("--get or --post required")
   205  	}
   206  	if len(getAddrs) > 0 && len(postAddrs) > 0 {
   207  		log.Fatal("use --get or --post not both")
   208  	}
   209  	if len(getAddrs) > 0 {
   210  		for _, get := range getAddrs {
   211  			if strings.Count(get, "%s") != 1 {
   212  				log.Fatal("invalid GET address - must be a printf string")
   213  			}
   214  		}
   215  	}
   216  
   217  	switch *mode {
   218  	case "round-robin":
   219  		selectedMode = ModeRoundRobin
   220  	case "hostpool", "epsilon-greedy":
   221  		selectedMode = ModeHostPool
   222  	}
   223  
   224  	if *sample > 1.0 || *sample < 0.0 {
   225  		log.Fatal("ERROR: --sample must be between 0.0 and 1.0")
   226  	}
   227  
   228  	termChan := make(chan os.Signal, 1)
   229  	signal.Notify(termChan, syscall.SIGINT, syscall.SIGTERM)
   230  
   231  	if len(postAddrs) > 0 {
   232  		publisher = &PostPublisher{}
   233  		addresses = postAddrs
   234  	} else {
   235  		publisher = &GetPublisher{}
   236  		addresses = getAddrs
   237  	}
   238  
   239  	cfg.UserAgent = fmt.Sprintf("nsq_to_http/%s go-nsq/%s", version.Binary, nsq.VERSION)
   240  	cfg.MaxInFlight = *maxInFlight
   241  
   242  	consumer, err := nsq.NewConsumer(*topic, *channel, cfg)
   243  	if err != nil {
   244  		log.Fatal(err)
   245  	}
   246  
   247  	perAddressStatus := make(map[string]*timer_metrics.TimerMetrics)
   248  	if len(addresses) == 1 {
   249  		// disable since there is only one address
   250  		perAddressStatus[addresses[0]] = timer_metrics.NewTimerMetrics(0, "")
   251  	} else {
   252  		for _, a := range addresses {
   253  			perAddressStatus[a] = timer_metrics.NewTimerMetrics(*statusEvery,
   254  				fmt.Sprintf("[%s]:", a))
   255  		}
   256  	}
   257  
   258  	hostPool := hostpool.New(addresses)
   259  	if *mode == "epsilon-greedy" {
   260  		hostPool = hostpool.NewEpsilonGreedy(addresses, 0, &hostpool.LinearEpsilonValueCalculator{})
   261  	}
   262  
   263  	handler := &PublishHandler{
   264  		Publisher:        publisher,
   265  		addresses:        addresses,
   266  		mode:             selectedMode,
   267  		hostPool:         hostPool,
   268  		perAddressStatus: perAddressStatus,
   269  		timermetrics:     timer_metrics.NewTimerMetrics(*statusEvery, "[aggregate]:"),
   270  	}
   271  	consumer.AddConcurrentHandlers(handler, *numPublishers)
   272  
   273  	err = consumer.ConnectToNSQDs(nsqdTCPAddrs)
   274  	if err != nil {
   275  		log.Fatal(err)
   276  	}
   277  
   278  	err = consumer.ConnectToNSQLookupds(lookupdHTTPAddrs)
   279  	if err != nil {
   280  		log.Fatal(err)
   281  	}
   282  
   283  	for {
   284  		select {
   285  		case <-consumer.StopChan:
   286  			return
   287  		case <-termChan:
   288  			consumer.Stop()
   289  		}
   290  	}
   291  }
   292  
   293  func parseCustomHeaders(strs []string) (map[string]string, error) {
   294  	parsedHeaders := make(map[string]string)
   295  	for _, s := range strs {
   296  		sp := strings.SplitN(s, ":", 2)
   297  		if len(sp) != 2 {
   298  			return nil, fmt.Errorf("invalid header: %q", s)
   299  		}
   300  		key := strings.TrimSpace(sp[0])
   301  		val := strings.TrimSpace(sp[1])
   302  		if key == "" || val == "" {
   303  			return nil, fmt.Errorf("invalid header: %q", s)
   304  		}
   305  		parsedHeaders[key] = val
   306  
   307  	}
   308  	return parsedHeaders, nil
   309  }