github.com/Jeffail/benthos/v3@v3.65.0/lib/output/http_server.go (about)

     1  package output
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"mime/multipart"
    10  	"net/http"
    11  	"net/textproto"
    12  	"sync/atomic"
    13  	"time"
    14  
    15  	"github.com/Jeffail/benthos/v3/internal/batch"
    16  	"github.com/Jeffail/benthos/v3/internal/docs"
    17  	httpdocs "github.com/Jeffail/benthos/v3/internal/http/docs"
    18  	"github.com/Jeffail/benthos/v3/lib/log"
    19  	"github.com/Jeffail/benthos/v3/lib/message"
    20  	"github.com/Jeffail/benthos/v3/lib/metrics"
    21  	"github.com/Jeffail/benthos/v3/lib/response"
    22  	"github.com/Jeffail/benthos/v3/lib/types"
    23  	"github.com/gorilla/websocket"
    24  )
    25  
    26  //------------------------------------------------------------------------------
    27  
    28  func init() {
    29  	corsSpec := httpdocs.ServerCORSFieldSpec()
    30  	corsSpec.Description += " Only valid with a custom `address`."
    31  
    32  	Constructors[TypeHTTPServer] = TypeSpec{
    33  		constructor: fromSimpleConstructor(NewHTTPServer),
    34  		Summary: `
    35  Sets up an HTTP server that will send messages over HTTP(S) GET requests. HTTP 2.0 is supported when using TLS, which is enabled when key and cert files are specified.`,
    36  		Description: `
    37  Sets up an HTTP server that will send messages over HTTP(S) GET requests. If the ` + "`address`" + ` config field is left blank the [service-wide HTTP server](/docs/components/http/about) will be used.
    38  
    39  Three endpoints will be registered at the paths specified by the fields ` + "`path`, `stream_path` and `ws_path`" + `. Which allow you to consume a single message batch, a continuous stream of line delimited messages, or a websocket of messages for each request respectively.
    40  
    41  When messages are batched the ` + "`path`" + ` endpoint encodes the batch according to [RFC1341](https://www.w3.org/Protocols/rfc1341/7_2_Multipart.html). This behaviour can be overridden by [archiving your batches](/docs/configuration/batching#post-batch-processing).`,
    42  		FieldSpecs: docs.FieldSpecs{
    43  			docs.FieldCommon("address", "An optional address to listen from. If left empty the service wide HTTP server is used."),
    44  			docs.FieldCommon("path", "The path from which discrete messages can be consumed."),
    45  			docs.FieldCommon("stream_path", "The path from which a continuous stream of messages can be consumed."),
    46  			docs.FieldCommon("ws_path", "The path from which websocket connections can be established."),
    47  			docs.FieldCommon("allowed_verbs", "An array of verbs that are allowed for the `path` and `stream_path` HTTP endpoint.").Array(),
    48  			docs.FieldAdvanced("timeout", "The maximum time to wait before a blocking, inactive connection is dropped (only applies to the `path` endpoint)."),
    49  			docs.FieldAdvanced("cert_file", "An optional certificate file to use for TLS connections. Only applicable when an `address` is specified."),
    50  			docs.FieldAdvanced("key_file", "An optional certificate key file to use for TLS connections. Only applicable when an `address` is specified."),
    51  			corsSpec,
    52  		},
    53  		Categories: []Category{
    54  			CategoryNetwork,
    55  		},
    56  	}
    57  }
    58  
    59  //------------------------------------------------------------------------------
    60  
    61  // HTTPServerConfig contains configuration fields for the HTTPServer output
    62  // type.
    63  type HTTPServerConfig struct {
    64  	Address      string              `json:"address" yaml:"address"`
    65  	Path         string              `json:"path" yaml:"path"`
    66  	StreamPath   string              `json:"stream_path" yaml:"stream_path"`
    67  	WSPath       string              `json:"ws_path" yaml:"ws_path"`
    68  	AllowedVerbs []string            `json:"allowed_verbs" yaml:"allowed_verbs"`
    69  	Timeout      string              `json:"timeout" yaml:"timeout"`
    70  	CertFile     string              `json:"cert_file" yaml:"cert_file"`
    71  	KeyFile      string              `json:"key_file" yaml:"key_file"`
    72  	CORS         httpdocs.ServerCORS `json:"cors" yaml:"cors"`
    73  }
    74  
    75  // NewHTTPServerConfig creates a new HTTPServerConfig with default values.
    76  func NewHTTPServerConfig() HTTPServerConfig {
    77  	return HTTPServerConfig{
    78  		Address:    "",
    79  		Path:       "/get",
    80  		StreamPath: "/get/stream",
    81  		WSPath:     "/get/ws",
    82  		AllowedVerbs: []string{
    83  			"GET",
    84  		},
    85  		Timeout:  "5s",
    86  		CertFile: "",
    87  		KeyFile:  "",
    88  		CORS:     httpdocs.NewServerCORS(),
    89  	}
    90  }
    91  
    92  //------------------------------------------------------------------------------
    93  
    94  // HTTPServer is an output type that serves HTTPServer GET requests.
    95  type HTTPServer struct {
    96  	running int32
    97  
    98  	conf  Config
    99  	stats metrics.Type
   100  	log   log.Modular
   101  
   102  	mux     *http.ServeMux
   103  	server  *http.Server
   104  	timeout time.Duration
   105  
   106  	transactions <-chan types.Transaction
   107  
   108  	closeChan  chan struct{}
   109  	closedChan chan struct{}
   110  
   111  	allowedVerbs map[string]struct{}
   112  
   113  	mRunning       metrics.StatGauge
   114  	mCount         metrics.StatCounter
   115  	mPartsCount    metrics.StatCounter
   116  	mSendSucc      metrics.StatCounter
   117  	mPartsSendSucc metrics.StatCounter
   118  	mSent          metrics.StatCounter
   119  	mPartsSent     metrics.StatCounter
   120  
   121  	mGetReqRcvd  metrics.StatCounter
   122  	mGetCount    metrics.StatCounter
   123  	mGetSendSucc metrics.StatCounter
   124  
   125  	mWSReqRcvd  metrics.StatCounter
   126  	mWSCount    metrics.StatCounter
   127  	mWSSendSucc metrics.StatCounter
   128  	mWSSendErr  metrics.StatCounter
   129  
   130  	mStrmReqRcvd  metrics.StatCounter
   131  	mStrmErrCast  metrics.StatCounter
   132  	mStrmErrWrong metrics.StatCounter
   133  	mStrmClosed   metrics.StatCounter
   134  	mStrmCount    metrics.StatCounter
   135  	mStrmErrWrite metrics.StatCounter
   136  	mStrmSndSucc  metrics.StatCounter
   137  }
   138  
   139  // NewHTTPServer creates a new HTTPServer output type.
   140  func NewHTTPServer(conf Config, mgr types.Manager, log log.Modular, stats metrics.Type) (Type, error) {
   141  	var mux *http.ServeMux
   142  	var server *http.Server
   143  
   144  	var err error
   145  	if len(conf.HTTPServer.Address) > 0 {
   146  		mux = http.NewServeMux()
   147  		server = &http.Server{Addr: conf.HTTPServer.Address}
   148  		if server.Handler, err = conf.HTTPServer.CORS.WrapHandler(mux); err != nil {
   149  			return nil, fmt.Errorf("bad CORS configuration: %w", err)
   150  		}
   151  	}
   152  
   153  	verbs := map[string]struct{}{}
   154  	for _, v := range conf.HTTPServer.AllowedVerbs {
   155  		verbs[v] = struct{}{}
   156  	}
   157  	if len(verbs) == 0 {
   158  		return nil, errors.New("must provide at least one allowed verb")
   159  	}
   160  
   161  	h := HTTPServer{
   162  		running:    1,
   163  		conf:       conf,
   164  		stats:      stats,
   165  		log:        log,
   166  		mux:        mux,
   167  		server:     server,
   168  		closeChan:  make(chan struct{}),
   169  		closedChan: make(chan struct{}),
   170  
   171  		allowedVerbs: verbs,
   172  
   173  		mRunning:       stats.GetGauge("running"),
   174  		mCount:         stats.GetCounter("count"),
   175  		mPartsCount:    stats.GetCounter("parts.count"),
   176  		mSendSucc:      stats.GetCounter("send.success"),
   177  		mPartsSendSucc: stats.GetCounter("parts.send.success"),
   178  		mSent:          stats.GetCounter("batch.sent"),
   179  		mPartsSent:     stats.GetCounter("sent"),
   180  		mGetReqRcvd:    stats.GetCounter("get.request.received"),
   181  		mGetCount:      stats.GetCounter("get.count"),
   182  		mGetSendSucc:   stats.GetCounter("get.send.success"),
   183  		mWSCount:       stats.GetCounter("ws.count"),
   184  		mWSReqRcvd:     stats.GetCounter("stream.request.received"),
   185  		mWSSendSucc:    stats.GetCounter("ws.send.success"),
   186  		mWSSendErr:     stats.GetCounter("ws.send.error"),
   187  		mStrmReqRcvd:   stats.GetCounter("stream.request.received"),
   188  		mStrmErrCast:   stats.GetCounter("stream.error.cast_flusher"),
   189  		mStrmErrWrong:  stats.GetCounter("stream.error.wrong_method"),
   190  		mStrmClosed:    stats.GetCounter("stream.client_closed"),
   191  		mStrmCount:     stats.GetCounter("stream.count"),
   192  		mStrmErrWrite:  stats.GetCounter("stream.error.write"),
   193  		mStrmSndSucc:   stats.GetCounter("stream.send.success"),
   194  	}
   195  
   196  	if tout := conf.HTTPServer.Timeout; len(tout) > 0 {
   197  		if h.timeout, err = time.ParseDuration(tout); err != nil {
   198  			return nil, fmt.Errorf("failed to parse timeout string: %v", err)
   199  		}
   200  	}
   201  
   202  	if mux != nil {
   203  		if len(h.conf.HTTPServer.Path) > 0 {
   204  			h.mux.HandleFunc(h.conf.HTTPServer.Path, h.getHandler)
   205  		}
   206  		if len(h.conf.HTTPServer.StreamPath) > 0 {
   207  			h.mux.HandleFunc(h.conf.HTTPServer.StreamPath, h.streamHandler)
   208  		}
   209  		if len(h.conf.HTTPServer.WSPath) > 0 {
   210  			h.mux.HandleFunc(h.conf.HTTPServer.WSPath, h.wsHandler)
   211  		}
   212  	} else {
   213  		if len(h.conf.HTTPServer.Path) > 0 {
   214  			mgr.RegisterEndpoint(
   215  				h.conf.HTTPServer.Path, "Read a single message from Benthos.",
   216  				h.getHandler,
   217  			)
   218  		}
   219  		if len(h.conf.HTTPServer.StreamPath) > 0 {
   220  			mgr.RegisterEndpoint(
   221  				h.conf.HTTPServer.StreamPath,
   222  				"Read a continuous stream of messages from Benthos.",
   223  				h.streamHandler,
   224  			)
   225  		}
   226  		if len(h.conf.HTTPServer.WSPath) > 0 {
   227  			mgr.RegisterEndpoint(
   228  				h.conf.HTTPServer.WSPath,
   229  				"Read messages from Benthos via websockets.",
   230  				h.wsHandler,
   231  			)
   232  		}
   233  	}
   234  
   235  	return &h, nil
   236  }
   237  
   238  //------------------------------------------------------------------------------
   239  
   240  func (h *HTTPServer) getHandler(w http.ResponseWriter, r *http.Request) {
   241  	h.mGetReqRcvd.Incr(1)
   242  
   243  	if atomic.LoadInt32(&h.running) != 1 {
   244  		http.Error(w, "Server closed", http.StatusServiceUnavailable)
   245  		return
   246  	}
   247  
   248  	if _, exists := h.allowedVerbs[r.Method]; !exists {
   249  		http.Error(w, "Incorrect method", http.StatusMethodNotAllowed)
   250  		return
   251  	}
   252  
   253  	tStart := time.Now()
   254  
   255  	var ts types.Transaction
   256  	var open bool
   257  	var err error
   258  
   259  	select {
   260  	case ts, open = <-h.transactions:
   261  		if !open {
   262  			http.Error(w, "Server closed", http.StatusServiceUnavailable)
   263  			go h.CloseAsync()
   264  			return
   265  		}
   266  		h.mGetCount.Incr(1)
   267  		h.mCount.Incr(1)
   268  		h.mPartsCount.Incr(int64(ts.Payload.Len()))
   269  	case <-time.After(h.timeout - time.Since(tStart)):
   270  		http.Error(w, "Timed out waiting for message", http.StatusRequestTimeout)
   271  		return
   272  	}
   273  
   274  	if ts.Payload.Len() > 1 {
   275  		body := &bytes.Buffer{}
   276  		writer := multipart.NewWriter(body)
   277  
   278  		for i := 0; i < ts.Payload.Len() && err == nil; i++ {
   279  			var part io.Writer
   280  			if part, err = writer.CreatePart(textproto.MIMEHeader{
   281  				"Content-Type": []string{"application/octet-stream"},
   282  			}); err == nil {
   283  				_, err = io.Copy(part, bytes.NewReader(ts.Payload.Get(i).Get()))
   284  			}
   285  		}
   286  
   287  		writer.Close()
   288  		w.Header().Add("Content-Type", writer.FormDataContentType())
   289  		w.Write(body.Bytes())
   290  	} else {
   291  		w.Header().Add("Content-Type", "application/octet-stream")
   292  		w.Write(ts.Payload.Get(0).Get())
   293  	}
   294  
   295  	h.mSendSucc.Incr(1)
   296  	h.mPartsSendSucc.Incr(int64(ts.Payload.Len()))
   297  	h.mSent.Incr(1)
   298  	h.mPartsSent.Incr(int64(batch.MessageCollapsedCount(ts.Payload)))
   299  	h.mGetSendSucc.Incr(1)
   300  
   301  	select {
   302  	case ts.ResponseChan <- response.NewAck():
   303  	case <-h.closeChan:
   304  		return
   305  	}
   306  }
   307  
   308  func (h *HTTPServer) streamHandler(w http.ResponseWriter, r *http.Request) {
   309  	h.mStrmReqRcvd.Incr(1)
   310  
   311  	flusher, ok := w.(http.Flusher)
   312  	if !ok {
   313  		http.Error(w, "Server error", http.StatusInternalServerError)
   314  		h.mStrmErrCast.Incr(1)
   315  		h.log.Errorln("Failed to cast response writer to flusher")
   316  		return
   317  	}
   318  
   319  	if _, exists := h.allowedVerbs[r.Method]; !exists {
   320  		http.Error(w, "Incorrect method", http.StatusMethodNotAllowed)
   321  		h.mStrmErrWrong.Incr(1)
   322  		return
   323  	}
   324  
   325  	for atomic.LoadInt32(&h.running) == 1 {
   326  		var ts types.Transaction
   327  		var open bool
   328  
   329  		select {
   330  		case ts, open = <-h.transactions:
   331  			if !open {
   332  				go h.CloseAsync()
   333  				return
   334  			}
   335  		case <-r.Context().Done():
   336  			h.mStrmClosed.Incr(1)
   337  			return
   338  		}
   339  		h.mStrmCount.Incr(1)
   340  		h.mCount.Incr(1)
   341  
   342  		var data []byte
   343  		if ts.Payload.Len() == 1 {
   344  			data = ts.Payload.Get(0).Get()
   345  		} else {
   346  			data = append(bytes.Join(message.GetAllBytes(ts.Payload), []byte("\n")), byte('\n'))
   347  		}
   348  
   349  		_, err := w.Write(data)
   350  		select {
   351  		case ts.ResponseChan <- response.NewError(err):
   352  		case <-h.closeChan:
   353  			return
   354  		}
   355  
   356  		if err != nil {
   357  			h.mStrmErrWrite.Incr(1)
   358  			return
   359  		}
   360  
   361  		w.Write([]byte("\n"))
   362  		flusher.Flush()
   363  		h.mStrmSndSucc.Incr(1)
   364  		h.mSendSucc.Incr(1)
   365  		h.mPartsSendSucc.Incr(int64(ts.Payload.Len()))
   366  		h.mSent.Incr(1)
   367  		h.mPartsSent.Incr(int64(batch.MessageCollapsedCount(ts.Payload)))
   368  	}
   369  }
   370  
   371  func (h *HTTPServer) wsHandler(w http.ResponseWriter, r *http.Request) {
   372  	h.mWSReqRcvd.Incr(1)
   373  
   374  	var err error
   375  	defer func() {
   376  		if err != nil {
   377  			http.Error(w, "Bad request", http.StatusBadRequest)
   378  			h.log.Warnf("Websocket request failed: %v\n", err)
   379  			return
   380  		}
   381  	}()
   382  
   383  	upgrader := websocket.Upgrader{}
   384  
   385  	var ws *websocket.Conn
   386  	if ws, err = upgrader.Upgrade(w, r, nil); err != nil {
   387  		return
   388  	}
   389  	defer ws.Close()
   390  
   391  	for atomic.LoadInt32(&h.running) == 1 {
   392  		var ts types.Transaction
   393  		var open bool
   394  
   395  		select {
   396  		case ts, open = <-h.transactions:
   397  			if !open {
   398  				go h.CloseAsync()
   399  				return
   400  			}
   401  		case <-r.Context().Done():
   402  			h.mStrmClosed.Incr(1)
   403  			return
   404  		case <-h.closeChan:
   405  			return
   406  		}
   407  		h.mWSCount.Incr(1)
   408  		h.mCount.Incr(1)
   409  
   410  		var werr error
   411  		for _, msg := range message.GetAllBytes(ts.Payload) {
   412  			if werr = ws.WriteMessage(websocket.BinaryMessage, msg); werr != nil {
   413  				break
   414  			}
   415  			h.mWSSendSucc.Incr(1)
   416  			h.mSendSucc.Incr(1)
   417  			h.mPartsSendSucc.Incr(int64(ts.Payload.Len()))
   418  			h.mSent.Incr(1)
   419  			h.mPartsSent.Incr(int64(batch.MessageCollapsedCount(ts.Payload)))
   420  		}
   421  
   422  		if werr != nil {
   423  			h.mWSSendErr.Incr(1)
   424  		}
   425  		select {
   426  		case ts.ResponseChan <- response.NewError(werr):
   427  		case <-h.closeChan:
   428  			return
   429  		}
   430  	}
   431  }
   432  
   433  //------------------------------------------------------------------------------
   434  
   435  // Consume assigns a messages channel for the output to read.
   436  func (h *HTTPServer) Consume(ts <-chan types.Transaction) error {
   437  	if h.transactions != nil {
   438  		return types.ErrAlreadyStarted
   439  	}
   440  	h.transactions = ts
   441  
   442  	if h.server != nil {
   443  		go func() {
   444  			h.mRunning.Incr(1)
   445  
   446  			if len(h.conf.HTTPServer.KeyFile) > 0 || len(h.conf.HTTPServer.CertFile) > 0 {
   447  				h.log.Infof(
   448  					"Serving messages through HTTPS GET request at: https://%s\n",
   449  					h.conf.HTTPServer.Address+h.conf.HTTPServer.Path,
   450  				)
   451  				if err := h.server.ListenAndServeTLS(
   452  					h.conf.HTTPServer.CertFile, h.conf.HTTPServer.KeyFile,
   453  				); err != http.ErrServerClosed {
   454  					h.log.Errorf("Server error: %v\n", err)
   455  				}
   456  			} else {
   457  				h.log.Infof(
   458  					"Serving messages through HTTP GET request at: http://%s\n",
   459  					h.conf.HTTPServer.Address+h.conf.HTTPServer.Path,
   460  				)
   461  				if err := h.server.ListenAndServe(); err != http.ErrServerClosed {
   462  					h.log.Errorf("Server error: %v\n", err)
   463  				}
   464  			}
   465  
   466  			h.mRunning.Decr(1)
   467  
   468  			atomic.StoreInt32(&h.running, 0)
   469  			close(h.closeChan)
   470  			close(h.closedChan)
   471  		}()
   472  	}
   473  	return nil
   474  }
   475  
   476  // Connected returns a boolean indicating whether this output is currently
   477  // connected to its target.
   478  func (h *HTTPServer) Connected() bool {
   479  	// Always return true as this is fuzzy right now.
   480  	return true
   481  }
   482  
   483  // CloseAsync shuts down the HTTPServer output and stops processing requests.
   484  func (h *HTTPServer) CloseAsync() {
   485  	if atomic.CompareAndSwapInt32(&h.running, 1, 0) {
   486  		if h.server != nil {
   487  			h.server.Shutdown(context.Background())
   488  		} else {
   489  			close(h.closedChan)
   490  		}
   491  	}
   492  }
   493  
   494  // WaitForClose blocks until the HTTPServer output has closed down.
   495  func (h *HTTPServer) WaitForClose(timeout time.Duration) error {
   496  	select {
   497  	case <-h.closedChan:
   498  	case <-time.After(timeout):
   499  		return types.ErrTimeout
   500  	}
   501  	return nil
   502  }
   503  
   504  //------------------------------------------------------------------------------