github.com/Jeffail/benthos/v3@v3.65.0/lib/input/http_server.go (about)

     1  package input
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"mime"
    10  	"mime/multipart"
    11  	"net/http"
    12  	"net/textproto"
    13  	"strconv"
    14  	"strings"
    15  	"sync"
    16  	"time"
    17  
    18  	"github.com/Jeffail/benthos/v3/internal/bloblang/field"
    19  	"github.com/Jeffail/benthos/v3/internal/docs"
    20  	httpdocs "github.com/Jeffail/benthos/v3/internal/http/docs"
    21  	"github.com/Jeffail/benthos/v3/internal/interop"
    22  	imetadata "github.com/Jeffail/benthos/v3/internal/metadata"
    23  	"github.com/Jeffail/benthos/v3/internal/shutdown"
    24  	"github.com/Jeffail/benthos/v3/internal/tracing"
    25  	"github.com/Jeffail/benthos/v3/lib/log"
    26  	"github.com/Jeffail/benthos/v3/lib/message"
    27  	"github.com/Jeffail/benthos/v3/lib/message/metadata"
    28  	"github.com/Jeffail/benthos/v3/lib/message/roundtrip"
    29  	"github.com/Jeffail/benthos/v3/lib/metrics"
    30  	"github.com/Jeffail/benthos/v3/lib/types"
    31  	httputil "github.com/Jeffail/benthos/v3/lib/util/http"
    32  	"github.com/Jeffail/benthos/v3/lib/util/throttle"
    33  	"github.com/gorilla/mux"
    34  	"github.com/gorilla/websocket"
    35  )
    36  
    37  //------------------------------------------------------------------------------
    38  
    39  func init() {
    40  	corsSpec := httpdocs.ServerCORSFieldSpec()
    41  	corsSpec.Description += " Only valid with a custom `address`."
    42  
    43  	Constructors[TypeHTTPServer] = TypeSpec{
    44  		constructor: fromSimpleConstructor(NewHTTPServer),
    45  		Summary: `
    46  Receive messages POSTed over HTTP(S). HTTP 2.0 is supported when using TLS, which is enabled when key and cert files are specified.`,
    47  		Description: `
    48  If the ` + "`address`" + ` config field is left blank the [service-wide HTTP server](/docs/components/http/about) will be used.
    49  
    50  The field ` + "`rate_limit`" + ` allows you to specify an optional ` + "[`rate_limit` resource](/docs/components/rate_limits/about)" + `, which will be applied to each HTTP request made and each websocket payload received.
    51  
    52  When the rate limit is breached HTTP requests will have a 429 response returned with a Retry-After header. Websocket payloads will be dropped and an optional response payload will be sent as per ` + "`ws_rate_limit_message`" + `.
    53  
    54  ### Responses
    55  
    56  It's possible to return a response for each message received using [synchronous responses](/docs/guides/sync_responses). When doing so you can customise headers with the ` + "`sync_response` field `headers`" + `, which can also use [function interpolation](/docs/configuration/interpolation#bloblang-queries) in the value based on the response message contents.
    57  
    58  ### Endpoints
    59  
    60  The following fields specify endpoints that are registered for sending messages, and support path parameters of the form ` + "`/{foo}`" + `, which are added to ingested messages as metadata:
    61  
    62  #### ` + "`path` (defaults to `/post`)" + `
    63  
    64  This endpoint expects POST requests where the entire request body is consumed as a single message.
    65  
    66  If the request contains a multipart ` + "`content-type`" + ` header as per [rfc1341](https://www.w3.org/Protocols/rfc1341/7_2_Multipart.html) then the multiple parts are consumed as a batch of messages, where each body part is a message of the batch.
    67  
    68  #### ` + "`ws_path` (defaults to `/post/ws`)" + `
    69  
    70  Creates a websocket connection, where payloads received on the socket are passed through the pipeline as a batch of one message.
    71  
    72  You may specify an optional ` + "`ws_welcome_message`" + `, which is a static payload to be sent to all clients once a websocket connection is first established.
    73  
    74  It's also possible to specify a ` + "`ws_rate_limit_message`" + `, which is a static payload to be sent to clients that have triggered the servers rate limit.
    75  
    76  ### Metadata
    77  
    78  This input adds the following metadata fields to each message:
    79  
    80  ` + "``` text" + `
    81  - http_server_user_agent
    82  - http_server_request_path
    83  - http_server_verb
    84  - All headers (only first values are taken)
    85  - All query parameters
    86  - All path parameters
    87  - All cookies
    88  ` + "```" + `
    89  
    90  You can access these metadata fields using [function interpolation](/docs/configuration/interpolation#metadata).`,
    91  		FieldSpecs: docs.FieldSpecs{
    92  			docs.FieldCommon("address", "An alternative address to host from. If left empty the service wide address is used."),
    93  			docs.FieldCommon("path", "The endpoint path to listen for POST requests."),
    94  			docs.FieldCommon("ws_path", "The endpoint path to create websocket connections from."),
    95  			docs.FieldAdvanced("ws_welcome_message", "An optional message to deliver to fresh websocket connections."),
    96  			docs.FieldAdvanced("ws_rate_limit_message", "An optional message to delivery to websocket connections that are rate limited."),
    97  			docs.FieldCommon("allowed_verbs", "An array of verbs that are allowed for the `path` endpoint.").AtVersion("3.33.0").Array(),
    98  			docs.FieldCommon("timeout", "Timeout for requests. If a consumed messages takes longer than this to be delivered the connection is closed, but the message may still be delivered."),
    99  			docs.FieldCommon("rate_limit", "An optional [rate limit](/docs/components/rate_limits/about) to throttle requests by."),
   100  			docs.FieldAdvanced("cert_file", "Enable TLS by specifying a certificate and key file. Only valid with a custom `address`."),
   101  			docs.FieldAdvanced("key_file", "Enable TLS by specifying a certificate and key file. Only valid with a custom `address`."),
   102  			corsSpec,
   103  			docs.FieldAdvanced("sync_response", "Customise messages returned via [synchronous responses](/docs/guides/sync_responses).").WithChildren(
   104  				docs.FieldCommon(
   105  					"status",
   106  					"Specify the status code to return with synchronous responses. This is a string value, which allows you to customize it based on resulting payloads and their metadata.",
   107  					"200", `${! json("status") }`, `${! meta("status") }`,
   108  				).IsInterpolated(),
   109  				docs.FieldString("headers", "Specify headers to return with synchronous responses.").IsInterpolated().Map().HasDefault(map[string]string{
   110  					"Content-Type": "application/octet-stream",
   111  				}),
   112  				docs.FieldCommon("metadata_headers", "Specify criteria for which metadata values are added to the response as headers.").WithChildren(imetadata.IncludeFilterDocs()...),
   113  			),
   114  		},
   115  		Categories: []Category{
   116  			CategoryNetwork,
   117  		},
   118  	}
   119  }
   120  
   121  //------------------------------------------------------------------------------
   122  
   123  // HTTPServerResponseConfig provides config fields for customising the response
   124  // given from successful requests.
   125  type HTTPServerResponseConfig struct {
   126  	Status          string                        `json:"status" yaml:"status"`
   127  	Headers         map[string]string             `json:"headers" yaml:"headers"`
   128  	ExtractMetadata imetadata.IncludeFilterConfig `json:"metadata_headers" yaml:"metadata_headers"`
   129  }
   130  
   131  // NewHTTPServerResponseConfig creates a new HTTPServerConfig with default values.
   132  func NewHTTPServerResponseConfig() HTTPServerResponseConfig {
   133  	return HTTPServerResponseConfig{
   134  		Status: "200",
   135  		Headers: map[string]string{
   136  			"Content-Type": "application/octet-stream",
   137  		},
   138  		ExtractMetadata: imetadata.NewIncludeFilterConfig(),
   139  	}
   140  }
   141  
   142  // HTTPServerConfig contains configuration for the HTTPServer input type.
   143  type HTTPServerConfig struct {
   144  	Address            string                   `json:"address" yaml:"address"`
   145  	Path               string                   `json:"path" yaml:"path"`
   146  	WSPath             string                   `json:"ws_path" yaml:"ws_path"`
   147  	WSWelcomeMessage   string                   `json:"ws_welcome_message" yaml:"ws_welcome_message"`
   148  	WSRateLimitMessage string                   `json:"ws_rate_limit_message" yaml:"ws_rate_limit_message"`
   149  	AllowedVerbs       []string                 `json:"allowed_verbs" yaml:"allowed_verbs"`
   150  	Timeout            string                   `json:"timeout" yaml:"timeout"`
   151  	RateLimit          string                   `json:"rate_limit" yaml:"rate_limit"`
   152  	CertFile           string                   `json:"cert_file" yaml:"cert_file"`
   153  	KeyFile            string                   `json:"key_file" yaml:"key_file"`
   154  	CORS               httpdocs.ServerCORS      `json:"cors" yaml:"cors"`
   155  	Response           HTTPServerResponseConfig `json:"sync_response" yaml:"sync_response"`
   156  }
   157  
   158  // NewHTTPServerConfig creates a new HTTPServerConfig with default values.
   159  func NewHTTPServerConfig() HTTPServerConfig {
   160  	return HTTPServerConfig{
   161  		Address:            "",
   162  		Path:               "/post",
   163  		WSPath:             "/post/ws",
   164  		WSWelcomeMessage:   "",
   165  		WSRateLimitMessage: "",
   166  		AllowedVerbs: []string{
   167  			"POST",
   168  		},
   169  		Timeout:   "5s",
   170  		RateLimit: "",
   171  		CertFile:  "",
   172  		KeyFile:   "",
   173  		CORS:      httpdocs.NewServerCORS(),
   174  		Response:  NewHTTPServerResponseConfig(),
   175  	}
   176  }
   177  
   178  //------------------------------------------------------------------------------
   179  
   180  // HTTPServer is an input type that registers a range of HTTP endpoints where
   181  // requests can send messages through Benthos. The endpoints are registered on
   182  // the general Benthos HTTP server by default. It is also possible to specify a
   183  // custom address to bind a new server to which the endpoints will be registered
   184  // on instead.
   185  type HTTPServer struct {
   186  	conf  HTTPServerConfig
   187  	stats metrics.Type
   188  	log   log.Modular
   189  	mgr   types.Manager
   190  
   191  	mux     *http.ServeMux
   192  	server  *http.Server
   193  	timeout time.Duration
   194  
   195  	responseStatus  *field.Expression
   196  	responseHeaders map[string]*field.Expression
   197  	metaFilter      *imetadata.IncludeFilter
   198  
   199  	handlerWG    sync.WaitGroup
   200  	transactions chan types.Transaction
   201  
   202  	shutSig *shutdown.Signaller
   203  
   204  	allowedVerbs map[string]struct{}
   205  
   206  	// TODO: V4 Reduce this way down
   207  	mCount         metrics.StatCounter
   208  	mLatency       metrics.StatTimer
   209  	mRateLimited   metrics.StatCounter
   210  	mWSRateLimited metrics.StatCounter
   211  	mRcvd          metrics.StatCounter
   212  	mPartsRcvd     metrics.StatCounter
   213  	mWSCount       metrics.StatCounter
   214  	mTimeout       metrics.StatCounter
   215  	mErr           metrics.StatCounter
   216  	mWSErr         metrics.StatCounter
   217  	mSucc          metrics.StatCounter
   218  	mWSSucc        metrics.StatCounter
   219  	mAsyncErr      metrics.StatCounter
   220  	mAsyncSucc     metrics.StatCounter
   221  }
   222  
   223  // NewHTTPServer creates a new HTTPServer input type.
   224  func NewHTTPServer(conf Config, mgr types.Manager, log log.Modular, stats metrics.Type) (Type, error) {
   225  	var mux *http.ServeMux
   226  	var server *http.Server
   227  
   228  	var err error
   229  	if len(conf.HTTPServer.Address) > 0 {
   230  		mux = http.NewServeMux()
   231  		server = &http.Server{Addr: conf.HTTPServer.Address}
   232  		if server.Handler, err = conf.HTTPServer.CORS.WrapHandler(mux); err != nil {
   233  			return nil, fmt.Errorf("bad CORS configuration: %w", err)
   234  		}
   235  	}
   236  
   237  	var timeout time.Duration
   238  	if len(conf.HTTPServer.Timeout) > 0 {
   239  		if timeout, err = time.ParseDuration(conf.HTTPServer.Timeout); err != nil {
   240  			return nil, fmt.Errorf("failed to parse timeout string: %v", err)
   241  		}
   242  	}
   243  
   244  	verbs := map[string]struct{}{}
   245  	for _, v := range conf.HTTPServer.AllowedVerbs {
   246  		verbs[v] = struct{}{}
   247  	}
   248  	if len(verbs) == 0 {
   249  		return nil, errors.New("must provide at least one allowed verb")
   250  	}
   251  
   252  	h := HTTPServer{
   253  		shutSig:         shutdown.NewSignaller(),
   254  		conf:            conf.HTTPServer,
   255  		stats:           stats,
   256  		log:             log,
   257  		mgr:             mgr,
   258  		mux:             mux,
   259  		server:          server,
   260  		timeout:         timeout,
   261  		responseHeaders: map[string]*field.Expression{},
   262  		transactions:    make(chan types.Transaction),
   263  
   264  		allowedVerbs: verbs,
   265  
   266  		mCount:         stats.GetCounter("count"),
   267  		mLatency:       stats.GetTimer("latency"),
   268  		mRateLimited:   stats.GetCounter("rate_limited"),
   269  		mWSRateLimited: stats.GetCounter("ws.rate_limited"),
   270  		mRcvd:          stats.GetCounter("batch.received"),
   271  		mPartsRcvd:     stats.GetCounter("received"),
   272  		mWSCount:       stats.GetCounter("ws.count"),
   273  		mTimeout:       stats.GetCounter("send.timeout"),
   274  		mErr:           stats.GetCounter("send.error"),
   275  		mWSErr:         stats.GetCounter("ws.send.error"),
   276  		mSucc:          stats.GetCounter("send.success"),
   277  		mWSSucc:        stats.GetCounter("ws.send.success"),
   278  		mAsyncErr:      stats.GetCounter("send.async_error"),
   279  		mAsyncSucc:     stats.GetCounter("send.async_success"),
   280  	}
   281  
   282  	if h.responseStatus, err = interop.NewBloblangField(mgr, h.conf.Response.Status); err != nil {
   283  		return nil, fmt.Errorf("failed to parse response status expression: %v", err)
   284  	}
   285  	for k, v := range h.conf.Response.Headers {
   286  		if h.responseHeaders[strings.ToLower(k)], err = interop.NewBloblangField(mgr, v); err != nil {
   287  			return nil, fmt.Errorf("failed to parse response header '%v' expression: %v", k, err)
   288  		}
   289  	}
   290  
   291  	if h.metaFilter, err = h.conf.Response.ExtractMetadata.CreateFilter(); err != nil {
   292  		return nil, fmt.Errorf("failed to construct metadata filter: %w", err)
   293  	}
   294  
   295  	postHdlr := httputil.GzipHandler(h.postHandler)
   296  	wsHdlr := httputil.GzipHandler(h.wsHandler)
   297  	if mux != nil {
   298  		if len(h.conf.Path) > 0 {
   299  			mux.HandleFunc(h.conf.Path, postHdlr)
   300  		}
   301  		if len(h.conf.WSPath) > 0 {
   302  			mux.HandleFunc(h.conf.WSPath, wsHdlr)
   303  		}
   304  	} else {
   305  		if len(h.conf.Path) > 0 {
   306  			mgr.RegisterEndpoint(
   307  				h.conf.Path, "Post a message into Benthos.", postHdlr,
   308  			)
   309  		}
   310  		if len(h.conf.WSPath) > 0 {
   311  			mgr.RegisterEndpoint(
   312  				h.conf.WSPath, "Post messages via websocket into Benthos.", wsHdlr,
   313  			)
   314  		}
   315  	}
   316  
   317  	if h.conf.RateLimit != "" {
   318  		if err := interop.ProbeRateLimit(context.Background(), h.mgr, h.conf.RateLimit); err != nil {
   319  			return nil, err
   320  		}
   321  	}
   322  
   323  	go h.loop()
   324  	return &h, nil
   325  }
   326  
   327  //------------------------------------------------------------------------------
   328  
   329  func (h *HTTPServer) extractMessageFromRequest(r *http.Request) (types.Message, error) {
   330  	msg := message.New(nil)
   331  
   332  	contentType := r.Header.Get("Content-Type")
   333  	if contentType == "" {
   334  		contentType = "application/octet-stream"
   335  	}
   336  
   337  	mediaType, params, err := mime.ParseMediaType(contentType)
   338  	if err != nil {
   339  		return nil, err
   340  	}
   341  
   342  	if strings.HasPrefix(mediaType, "multipart/") {
   343  		mr := multipart.NewReader(r.Body, params["boundary"])
   344  		for {
   345  			var p *multipart.Part
   346  			if p, err = mr.NextPart(); err != nil {
   347  				if err == io.EOF {
   348  					break
   349  				}
   350  				return nil, err
   351  			}
   352  			var msgBytes []byte
   353  			if msgBytes, err = io.ReadAll(p); err != nil {
   354  				return nil, err
   355  			}
   356  			msg.Append(message.NewPart(msgBytes))
   357  		}
   358  	} else {
   359  		var msgBytes []byte
   360  		if msgBytes, err = io.ReadAll(r.Body); err != nil {
   361  			return nil, err
   362  		}
   363  		msg.Append(message.NewPart(msgBytes))
   364  	}
   365  
   366  	meta := metadata.New(nil)
   367  	meta.Set("http_server_user_agent", r.UserAgent())
   368  	meta.Set("http_server_request_path", r.URL.Path)
   369  	meta.Set("http_server_verb", r.Method)
   370  	for k, v := range r.Header {
   371  		if len(v) > 0 {
   372  			meta.Set(k, v[0])
   373  		}
   374  	}
   375  	for k, v := range r.URL.Query() {
   376  		if len(v) > 0 {
   377  			meta.Set(k, v[0])
   378  		}
   379  	}
   380  	for k, v := range mux.Vars(r) {
   381  		meta.Set(k, v)
   382  	}
   383  	for _, c := range r.Cookies() {
   384  		meta.Set(c.Name, c.Value)
   385  	}
   386  	message.SetAllMetadata(msg, meta)
   387  
   388  	textMapGeneric := map[string]interface{}{}
   389  	for k, vals := range r.Header {
   390  		for _, v := range vals {
   391  			textMapGeneric[k] = v
   392  		}
   393  	}
   394  
   395  	_ = tracing.InitSpansFromParentTextMap("input_http_server_post", textMapGeneric, msg)
   396  	return msg, nil
   397  }
   398  
   399  func (h *HTTPServer) postHandler(w http.ResponseWriter, r *http.Request) {
   400  	h.handlerWG.Add(1)
   401  	defer h.handlerWG.Done()
   402  	defer r.Body.Close()
   403  
   404  	if _, exists := h.allowedVerbs[r.Method]; !exists {
   405  		http.Error(w, "Incorrect method", http.StatusMethodNotAllowed)
   406  		return
   407  	}
   408  
   409  	if h.conf.RateLimit != "" {
   410  		var tUntil time.Duration
   411  		var err error
   412  		if rerr := interop.AccessRateLimit(r.Context(), h.mgr, h.conf.RateLimit, func(rl types.RateLimit) {
   413  			tUntil, err = rl.Access()
   414  		}); rerr != nil {
   415  			http.Error(w, "Server error", http.StatusBadGateway)
   416  			h.log.Warnf("Failed to access rate limit: %v\n", rerr)
   417  			return
   418  		}
   419  		if err != nil {
   420  			http.Error(w, "Server error", http.StatusBadGateway)
   421  			h.log.Warnf("Failed to access rate limit: %v\n", err)
   422  			return
   423  		} else if tUntil > 0 {
   424  			w.Header().Add("Retry-After", strconv.Itoa(int(tUntil.Seconds())))
   425  			http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
   426  			h.mRateLimited.Incr(1)
   427  			return
   428  		}
   429  	}
   430  
   431  	msg, err := h.extractMessageFromRequest(r)
   432  	if err != nil {
   433  		http.Error(w, "Bad request", http.StatusBadRequest)
   434  		h.log.Warnf("Request read failed: %v\n", err)
   435  		return
   436  	}
   437  	defer tracing.FinishSpans(msg)
   438  
   439  	store := roundtrip.NewResultStore()
   440  	roundtrip.AddResultStore(msg, store)
   441  
   442  	h.mCount.Incr(1)
   443  	h.mPartsRcvd.Incr(int64(msg.Len()))
   444  	h.mRcvd.Incr(1)
   445  	h.log.Tracef("Consumed %v messages from POST to '%v'.\n", msg.Len(), h.conf.Path)
   446  
   447  	resChan := make(chan types.Response, 1)
   448  	select {
   449  	case h.transactions <- types.NewTransaction(msg, resChan):
   450  	case <-time.After(h.timeout):
   451  		h.mTimeout.Incr(1)
   452  		http.Error(w, "Request timed out", http.StatusRequestTimeout)
   453  		return
   454  	case <-r.Context().Done():
   455  		h.mTimeout.Incr(1)
   456  		http.Error(w, "Request timed out", http.StatusRequestTimeout)
   457  		return
   458  	case <-h.shutSig.CloseAtLeisureChan():
   459  		http.Error(w, "Server closing", http.StatusServiceUnavailable)
   460  		return
   461  	}
   462  
   463  	select {
   464  	case res, open := <-resChan:
   465  		if !open {
   466  			http.Error(w, "Server closing", http.StatusServiceUnavailable)
   467  			return
   468  		} else if res.Error() != nil {
   469  			h.mErr.Incr(1)
   470  			http.Error(w, res.Error().Error(), http.StatusBadGateway)
   471  			return
   472  		}
   473  		tTaken := time.Since(msg.CreatedAt()).Nanoseconds()
   474  		h.mLatency.Timing(tTaken)
   475  		h.mSucc.Incr(1)
   476  	case <-time.After(h.timeout):
   477  		h.mTimeout.Incr(1)
   478  		http.Error(w, "Request timed out", http.StatusRequestTimeout)
   479  		return
   480  	case <-r.Context().Done():
   481  		h.mTimeout.Incr(1)
   482  		http.Error(w, "Request timed out", http.StatusRequestTimeout)
   483  		return
   484  	case <-h.shutSig.CloseNowChan():
   485  		http.Error(w, "Server closing", http.StatusServiceUnavailable)
   486  		return
   487  	}
   488  
   489  	responseMsg := message.New(nil)
   490  	for _, resMsg := range store.Get() {
   491  		resMsg.Iter(func(i int, part types.Part) error {
   492  			responseMsg.Append(part)
   493  			return nil
   494  		})
   495  	}
   496  	if responseMsg.Len() > 0 {
   497  		for k, v := range h.responseHeaders {
   498  			w.Header().Set(k, v.String(0, responseMsg))
   499  		}
   500  
   501  		statusCode := 200
   502  		if statusCodeStr := h.responseStatus.String(0, responseMsg); statusCodeStr != "200" {
   503  			if statusCode, err = strconv.Atoi(statusCodeStr); err != nil {
   504  				h.log.Errorf("Failed to parse sync response status code expression: %v\n", err)
   505  				w.WriteHeader(http.StatusBadGateway)
   506  				return
   507  			}
   508  		}
   509  
   510  		if plen := responseMsg.Len(); plen == 1 {
   511  			part := responseMsg.Get(0)
   512  			part.Metadata().Iter(func(k, v string) error {
   513  				if h.metaFilter.Match(k) {
   514  					w.Header().Set(k, v)
   515  					return nil
   516  				}
   517  				return nil
   518  			})
   519  			payload := part.Get()
   520  			if w.Header().Get("Content-Type") == "" {
   521  				w.Header().Set("Content-Type", http.DetectContentType(payload))
   522  			}
   523  			w.WriteHeader(statusCode)
   524  			w.Write(payload)
   525  		} else if plen > 1 {
   526  			customContentType, customContentTypeExists := h.responseHeaders["content-type"]
   527  
   528  			var buf bytes.Buffer
   529  			writer := multipart.NewWriter(&buf)
   530  
   531  			var merr error
   532  			for i := 0; i < plen && merr == nil; i++ {
   533  				part := responseMsg.Get(i)
   534  				part.Metadata().Iter(func(k, v string) error {
   535  					if h.metaFilter.Match(k) {
   536  						w.Header().Set(k, v)
   537  						return nil
   538  					}
   539  					return nil
   540  				})
   541  				payload := part.Get()
   542  
   543  				mimeHeader := textproto.MIMEHeader{}
   544  				if customContentTypeExists {
   545  					mimeHeader.Set("Content-Type", customContentType.String(i, responseMsg))
   546  				} else {
   547  					mimeHeader.Set("Content-Type", http.DetectContentType(payload))
   548  				}
   549  
   550  				var partWriter io.Writer
   551  				if partWriter, merr = writer.CreatePart(mimeHeader); merr == nil {
   552  					_, merr = io.Copy(partWriter, bytes.NewReader(payload))
   553  				}
   554  			}
   555  
   556  			merr = writer.Close()
   557  			if merr == nil {
   558  				w.Header().Del("Content-Type")
   559  				w.Header().Add("Content-Type", writer.FormDataContentType())
   560  				w.WriteHeader(statusCode)
   561  				buf.WriteTo(w)
   562  			} else {
   563  				h.log.Errorf("Failed to return sync response: %v\n", merr)
   564  				w.WriteHeader(http.StatusBadGateway)
   565  			}
   566  		}
   567  	}
   568  }
   569  
   570  func (h *HTTPServer) wsHandler(w http.ResponseWriter, r *http.Request) {
   571  	h.handlerWG.Add(1)
   572  	defer h.handlerWG.Done()
   573  
   574  	var err error
   575  	defer func() {
   576  		if err != nil {
   577  			http.Error(w, "Bad request", http.StatusBadRequest)
   578  			h.log.Warnf("Websocket request failed: %v\n", err)
   579  		}
   580  	}()
   581  
   582  	upgrader := websocket.Upgrader{}
   583  
   584  	var ws *websocket.Conn
   585  	if ws, err = upgrader.Upgrade(w, r, nil); err != nil {
   586  		return
   587  	}
   588  	defer ws.Close()
   589  
   590  	resChan := make(chan types.Response, 1)
   591  	throt := throttle.New(throttle.OptCloseChan(h.shutSig.CloseAtLeisureChan()))
   592  
   593  	if welMsg := h.conf.WSWelcomeMessage; len(welMsg) > 0 {
   594  		if err = ws.WriteMessage(websocket.BinaryMessage, []byte(welMsg)); err != nil {
   595  			h.log.Errorf("Failed to send welcome message: %v\n", err)
   596  		}
   597  	}
   598  
   599  	var msgBytes []byte
   600  	for !h.shutSig.ShouldCloseAtLeisure() {
   601  		if msgBytes == nil {
   602  			if _, msgBytes, err = ws.ReadMessage(); err != nil {
   603  				return
   604  			}
   605  			h.mWSCount.Incr(1)
   606  			h.mCount.Incr(1)
   607  		}
   608  
   609  		if h.conf.RateLimit != "" {
   610  			var tUntil time.Duration
   611  			if rerr := interop.AccessRateLimit(r.Context(), h.mgr, h.conf.RateLimit, func(rl types.RateLimit) {
   612  				tUntil, err = rl.Access()
   613  			}); rerr != nil {
   614  				h.log.Warnf("Failed to access rate limit: %v\n", rerr)
   615  				err = rerr
   616  			}
   617  			if err != nil || tUntil > 0 {
   618  				if err != nil {
   619  					h.log.Warnf("Failed to access rate limit: %v\n", err)
   620  				}
   621  				if rlMsg := h.conf.WSRateLimitMessage; len(rlMsg) > 0 {
   622  					if err = ws.WriteMessage(websocket.BinaryMessage, []byte(rlMsg)); err != nil {
   623  						h.log.Errorf("Failed to send rate limit message: %v\n", err)
   624  					}
   625  				}
   626  				h.mWSRateLimited.Incr(1)
   627  				continue
   628  			}
   629  		}
   630  
   631  		msg := message.New([][]byte{msgBytes})
   632  
   633  		meta := msg.Get(0).Metadata()
   634  		meta.Set("http_server_user_agent", r.UserAgent())
   635  		for k, v := range r.Header {
   636  			if len(v) > 0 {
   637  				meta.Set(k, v[0])
   638  			}
   639  		}
   640  		for k, v := range r.URL.Query() {
   641  			if len(v) > 0 {
   642  				meta.Set(k, v[0])
   643  			}
   644  		}
   645  		for k, v := range mux.Vars(r) {
   646  			meta.Set(k, v)
   647  		}
   648  		for _, c := range r.Cookies() {
   649  			meta.Set(c.Name, c.Value)
   650  		}
   651  		tracing.InitSpans("input_http_server_websocket", msg)
   652  
   653  		store := roundtrip.NewResultStore()
   654  		roundtrip.AddResultStore(msg, store)
   655  
   656  		select {
   657  		case h.transactions <- types.NewTransaction(msg, resChan):
   658  		case <-h.shutSig.CloseAtLeisureChan():
   659  			return
   660  		}
   661  		select {
   662  		case res, open := <-resChan:
   663  			if !open {
   664  				return
   665  			}
   666  			if res.Error() != nil {
   667  				h.mWSErr.Incr(1)
   668  				h.mErr.Incr(1)
   669  				throt.Retry()
   670  			} else {
   671  				tTaken := time.Since(msg.CreatedAt()).Nanoseconds()
   672  				h.mLatency.Timing(tTaken)
   673  				h.mWSSucc.Incr(1)
   674  				h.mSucc.Incr(1)
   675  				msgBytes = nil
   676  				throt.Reset()
   677  			}
   678  		case <-h.shutSig.CloseNowChan():
   679  			return
   680  		}
   681  
   682  		for _, responseMsg := range store.Get() {
   683  			if err := responseMsg.Iter(func(i int, part types.Part) error {
   684  				return ws.WriteMessage(websocket.TextMessage, part.Get())
   685  			}); err != nil {
   686  				h.log.Errorf("Failed to send sync response over websocket: %v\n", err)
   687  			}
   688  		}
   689  
   690  		tracing.FinishSpans(msg)
   691  	}
   692  }
   693  
   694  //------------------------------------------------------------------------------
   695  
   696  func (h *HTTPServer) loop() {
   697  	mRunning := h.stats.GetGauge("running")
   698  
   699  	defer func() {
   700  		if h.server != nil {
   701  			if err := h.server.Shutdown(context.Background()); err != nil {
   702  				h.log.Errorf("Failed to gracefully terminate http_server: %v\n", err)
   703  			}
   704  		} else {
   705  			if len(h.conf.Path) > 0 {
   706  				h.mgr.RegisterEndpoint(h.conf.Path, "Does nothing.", http.NotFound)
   707  			}
   708  			if len(h.conf.WSPath) > 0 {
   709  				h.mgr.RegisterEndpoint(h.conf.WSPath, "Does nothing.", http.NotFound)
   710  			}
   711  		}
   712  
   713  		h.handlerWG.Wait()
   714  		mRunning.Decr(1)
   715  
   716  		close(h.transactions)
   717  		h.shutSig.ShutdownComplete()
   718  	}()
   719  	mRunning.Incr(1)
   720  
   721  	if h.server != nil {
   722  		go func() {
   723  			if len(h.conf.KeyFile) > 0 || len(h.conf.CertFile) > 0 {
   724  				h.log.Infof(
   725  					"Receiving HTTPS messages at: https://%s\n",
   726  					h.conf.Address+h.conf.Path,
   727  				)
   728  				if err := h.server.ListenAndServeTLS(
   729  					h.conf.CertFile, h.conf.KeyFile,
   730  				); err != http.ErrServerClosed {
   731  					h.log.Errorf("Server error: %v\n", err)
   732  				}
   733  			} else {
   734  				h.log.Infof(
   735  					"Receiving HTTP messages at: http://%s\n",
   736  					h.conf.Address+h.conf.Path,
   737  				)
   738  				if err := h.server.ListenAndServe(); err != http.ErrServerClosed {
   739  					h.log.Errorf("Server error: %v\n", err)
   740  				}
   741  			}
   742  		}()
   743  	}
   744  
   745  	<-h.shutSig.CloseAtLeisureChan()
   746  }
   747  
   748  // TransactionChan returns a transactions channel for consuming messages from
   749  // this input.
   750  func (h *HTTPServer) TransactionChan() <-chan types.Transaction {
   751  	return h.transactions
   752  }
   753  
   754  // Connected returns a boolean indicating whether this input is currently
   755  // connected to its target.
   756  func (h *HTTPServer) Connected() bool {
   757  	return true
   758  }
   759  
   760  // CloseAsync shuts down the HTTPServer input and stops processing requests.
   761  func (h *HTTPServer) CloseAsync() {
   762  	h.shutSig.CloseAtLeisure()
   763  }
   764  
   765  // WaitForClose blocks until the HTTPServer input has closed down.
   766  func (h *HTTPServer) WaitForClose(timeout time.Duration) error {
   767  	go func() {
   768  		<-time.After(timeout - time.Second)
   769  		h.shutSig.CloseNow()
   770  	}()
   771  	select {
   772  	case <-h.shutSig.HasClosedChan():
   773  	case <-time.After(timeout):
   774  		return types.ErrTimeout
   775  	}
   776  	return nil
   777  }
   778  
   779  //------------------------------------------------------------------------------