github.com/Jeffail/benthos/v3@v3.65.0/lib/input/socket_server.go (about)

     1  package input
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"net"
     8  	"strings"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/Jeffail/benthos/v3/internal/codec"
    13  	"github.com/Jeffail/benthos/v3/internal/docs"
    14  	"github.com/Jeffail/benthos/v3/lib/log"
    15  	"github.com/Jeffail/benthos/v3/lib/message"
    16  	"github.com/Jeffail/benthos/v3/lib/metrics"
    17  	"github.com/Jeffail/benthos/v3/lib/types"
    18  )
    19  
    20  //------------------------------------------------------------------------------
    21  
    22  func init() {
    23  	Constructors[TypeSocketServer] = TypeSpec{
    24  		constructor: fromSimpleConstructor(NewSocketServer),
    25  		Summary:     `Creates a server that receives a stream of messages over a tcp, udp or unix socket.`,
    26  		Description: `
    27  The field ` + "`max_buffer`" + ` specifies the maximum amount of memory to allocate _per connection_ for buffering lines of data. If a line of data from a connection exceeds this value then the connection will be closed.`,
    28  		FieldSpecs: docs.FieldSpecs{
    29  			docs.FieldCommon("network", "A network type to accept (unix|tcp|udp).").HasOptions(
    30  				"unix", "tcp", "udp",
    31  			),
    32  			docs.FieldCommon("address", "The address to listen from.", "/tmp/benthos.sock", "0.0.0.0:6000"),
    33  			codec.ReaderDocs.AtVersion("3.42.0"),
    34  			docs.FieldAdvanced("max_buffer", "The maximum message buffer size. Must exceed the largest message to be consumed."),
    35  			docs.FieldDeprecated("multipart"),
    36  			docs.FieldDeprecated("delimiter"),
    37  		},
    38  		Categories: []Category{
    39  			CategoryNetwork,
    40  		},
    41  	}
    42  }
    43  
    44  //------------------------------------------------------------------------------
    45  
    46  // SocketServerConfig contains configuration for the SocketServer input type.
    47  type SocketServerConfig struct {
    48  	Network   string `json:"network" yaml:"network"`
    49  	Address   string `json:"address" yaml:"address"`
    50  	Codec     string `json:"codec" yaml:"codec"`
    51  	MaxBuffer int    `json:"max_buffer" yaml:"max_buffer"`
    52  	Multipart bool   `json:"multipart" yaml:"multipart"`
    53  	Delim     string `json:"delimiter" yaml:"delimiter"`
    54  }
    55  
    56  // NewSocketServerConfig creates a new SocketServerConfig with default values.
    57  func NewSocketServerConfig() SocketServerConfig {
    58  	return SocketServerConfig{
    59  		Network:   "unix",
    60  		Address:   "/tmp/benthos.sock",
    61  		Codec:     "lines",
    62  		MaxBuffer: 1000000,
    63  
    64  		// TODO: V4 Remove these fields
    65  		Multipart: false,
    66  		Delim:     "",
    67  	}
    68  }
    69  
    70  //------------------------------------------------------------------------------
    71  
    72  type wrapPacketConn struct {
    73  	net.PacketConn
    74  }
    75  
    76  func (w *wrapPacketConn) Read(p []byte) (n int, err error) {
    77  	n, _, err = w.ReadFrom(p)
    78  	return
    79  }
    80  
    81  // SocketServer is an input type that binds to an address and consumes streams of
    82  // messages over Socket.
    83  type SocketServer struct {
    84  	conf  SocketServerConfig
    85  	stats metrics.Type
    86  	log   log.Modular
    87  
    88  	codecCtor codec.ReaderConstructor
    89  	listener  net.Listener
    90  	conn      net.PacketConn
    91  
    92  	retriesMut   sync.RWMutex
    93  	transactions chan types.Transaction
    94  
    95  	ctx        context.Context
    96  	closeFn    func()
    97  	closedChan chan struct{}
    98  
    99  	mLatency metrics.StatTimer
   100  }
   101  
   102  // NewSocketServer creates a new SocketServer input type.
   103  func NewSocketServer(conf Config, mgr types.Manager, log log.Modular, stats metrics.Type) (Type, error) {
   104  	var ln net.Listener
   105  	var cn net.PacketConn
   106  	var err error
   107  
   108  	sconf := conf.SocketServer
   109  	if len(sconf.Delim) > 0 {
   110  		sconf.Codec = "delim:" + sconf.Delim
   111  	}
   112  	if sconf.Multipart && !strings.HasSuffix(sconf.Codec, "/multipart") {
   113  		sconf.Codec += "/multipart"
   114  	}
   115  
   116  	codecConf := codec.NewReaderConfig()
   117  	codecConf.MaxScanTokenSize = sconf.MaxBuffer
   118  	ctor, err := codec.GetReader(sconf.Codec, codecConf)
   119  	if err != nil {
   120  		return nil, err
   121  	}
   122  
   123  	switch sconf.Network {
   124  	case "tcp", "unix":
   125  		ln, err = net.Listen(sconf.Network, sconf.Address)
   126  	case "udp":
   127  		cn, err = net.ListenPacket(sconf.Network, sconf.Address)
   128  	default:
   129  		return nil, fmt.Errorf("socket network '%v' is not supported by this input", sconf.Network)
   130  	}
   131  	if err != nil {
   132  		return nil, err
   133  	}
   134  
   135  	t := SocketServer{
   136  		conf:  conf.SocketServer,
   137  		stats: stats,
   138  		log:   log,
   139  
   140  		codecCtor: ctor,
   141  		listener:  ln,
   142  		conn:      cn,
   143  
   144  		transactions: make(chan types.Transaction),
   145  		closedChan:   make(chan struct{}),
   146  
   147  		mLatency: stats.GetTimer("latency"),
   148  	}
   149  	t.ctx, t.closeFn = context.WithCancel(context.Background())
   150  
   151  	if ln == nil {
   152  		go t.udpLoop()
   153  	} else {
   154  		go t.loop()
   155  	}
   156  	return &t, nil
   157  }
   158  
   159  //------------------------------------------------------------------------------
   160  
   161  // Addr returns the underlying Socket listeners address.
   162  func (t *SocketServer) Addr() net.Addr {
   163  	if t.listener != nil {
   164  		return t.listener.Addr()
   165  	}
   166  	return t.conn.LocalAddr()
   167  }
   168  
   169  func (t *SocketServer) sendMsg(msg types.Message) bool {
   170  	tStarted := time.Now()
   171  
   172  	// Block whilst retries are happening
   173  	t.retriesMut.Lock()
   174  	// nolint:staticcheck, gocritic // Ignore SA2001 empty critical section, Ignore badLock
   175  	t.retriesMut.Unlock()
   176  
   177  	resChan := make(chan types.Response)
   178  	select {
   179  	case t.transactions <- types.NewTransaction(msg, resChan):
   180  	case <-t.ctx.Done():
   181  		return false
   182  	}
   183  
   184  	go func() {
   185  		hasLocked := false
   186  		defer func() {
   187  			if hasLocked {
   188  				t.retriesMut.RUnlock()
   189  			}
   190  		}()
   191  		for {
   192  			select {
   193  			case res, open := <-resChan:
   194  				if !open {
   195  					return
   196  				}
   197  				var sendErr error
   198  				if res != nil {
   199  					sendErr = res.Error()
   200  				}
   201  				if sendErr == nil || sendErr == types.ErrTypeClosed {
   202  					if sendErr == nil {
   203  						t.mLatency.Timing(time.Since(tStarted).Nanoseconds())
   204  					}
   205  					return
   206  				}
   207  				if !hasLocked {
   208  					hasLocked = true
   209  					t.retriesMut.RLock()
   210  				}
   211  				t.log.Errorf("failed to send message: %v\n", sendErr)
   212  
   213  				// Wait before attempting again
   214  				select {
   215  				case <-time.After(time.Second):
   216  				case <-t.ctx.Done():
   217  					return
   218  				}
   219  
   220  				// And then resend the transaction
   221  				select {
   222  				case t.transactions <- types.NewTransaction(msg, resChan):
   223  				case <-t.ctx.Done():
   224  					return
   225  				}
   226  			case <-t.ctx.Done():
   227  				return
   228  			}
   229  		}
   230  	}()
   231  	return true
   232  }
   233  
   234  func (t *SocketServer) loop() {
   235  	var (
   236  		mCount     = t.stats.GetCounter("count")
   237  		mRcvd      = t.stats.GetCounter("batch.received")
   238  		mPartsRcvd = t.stats.GetCounter("received")
   239  	)
   240  
   241  	var wg sync.WaitGroup
   242  
   243  	defer func() {
   244  		wg.Wait()
   245  
   246  		t.retriesMut.Lock()
   247  		// nolint:staticcheck, gocritic // Ignore SA2001 empty critical section, Ignore badLock
   248  		t.retriesMut.Unlock()
   249  
   250  		t.listener.Close()
   251  
   252  		close(t.transactions)
   253  		close(t.closedChan)
   254  	}()
   255  
   256  	t.log.Infof("Receiving %v socket messages from address: %v\n", t.conf.Network, t.listener.Addr())
   257  
   258  	go func() {
   259  		<-t.ctx.Done()
   260  		t.listener.Close()
   261  	}()
   262  
   263  acceptLoop:
   264  	for {
   265  		conn, err := t.listener.Accept()
   266  		if err != nil {
   267  			if !strings.Contains(err.Error(), "use of closed network connection") {
   268  				t.log.Errorf("Failed to accept Socket connection: %v\n", err)
   269  			}
   270  			select {
   271  			case <-time.After(time.Second):
   272  				continue acceptLoop
   273  			case <-t.ctx.Done():
   274  				return
   275  			}
   276  		}
   277  		connCtx, connDone := context.WithCancel(t.ctx)
   278  		go func() {
   279  			<-connCtx.Done()
   280  			conn.Close()
   281  		}()
   282  		wg.Add(1)
   283  		go func(c net.Conn) {
   284  			defer func() {
   285  				connDone()
   286  				wg.Done()
   287  				c.Close()
   288  			}()
   289  			codec, err := t.codecCtor("", c, func(ctx context.Context, err error) error {
   290  				return nil
   291  			})
   292  			if err != nil {
   293  				t.log.Errorf("Failed to create codec for new connection: %v\n", err)
   294  				return
   295  			}
   296  
   297  			for {
   298  				parts, ackFn, err := codec.Next(t.ctx)
   299  				if err != nil {
   300  					if err != io.EOF && err != types.ErrTimeout {
   301  						t.log.Errorf("Connection dropped due to: %v\n", err)
   302  					}
   303  					return
   304  				}
   305  				mCount.Incr(1)
   306  				mRcvd.Incr(1)
   307  				mPartsRcvd.Incr(int64(len(parts)))
   308  
   309  				// We simply bounce rejected messages in a loop downstream so
   310  				// there's no benefit to aggregating acks.
   311  				_ = ackFn(t.ctx, nil)
   312  
   313  				msg := message.New(nil)
   314  				msg.Append(parts...)
   315  				if !t.sendMsg(msg) {
   316  					return
   317  				}
   318  			}
   319  		}(conn)
   320  	}
   321  }
   322  
   323  func (t *SocketServer) udpLoop() {
   324  	var (
   325  		mCount     = t.stats.GetCounter("count")
   326  		mRcvd      = t.stats.GetCounter("batch.received")
   327  		mPartsRcvd = t.stats.GetCounter("received")
   328  	)
   329  
   330  	defer func() {
   331  		t.retriesMut.Lock()
   332  		// nolint:staticcheck, gocritic // Ignore SA2001 empty critical section, Ignore badLock
   333  		t.retriesMut.Unlock()
   334  
   335  		close(t.transactions)
   336  		close(t.closedChan)
   337  	}()
   338  
   339  	codec, err := t.codecCtor("", &wrapPacketConn{PacketConn: t.conn}, func(ctx context.Context, err error) error {
   340  		return nil
   341  	})
   342  	if err != nil {
   343  		t.log.Errorf("Connection error due to: %v\n", err)
   344  		return
   345  	}
   346  
   347  	go func() {
   348  		<-t.ctx.Done()
   349  		codec.Close(context.Background())
   350  		t.conn.Close()
   351  	}()
   352  
   353  	t.log.Infof("Receiving udp socket messages from address: %v\n", t.conn.LocalAddr())
   354  
   355  	for {
   356  		parts, ackFn, err := codec.Next(t.ctx)
   357  		if err != nil {
   358  			if err != io.EOF && err != types.ErrTimeout {
   359  				t.log.Errorf("Connection dropped due to: %v\n", err)
   360  			}
   361  			return
   362  		}
   363  		mCount.Incr(1)
   364  		mRcvd.Incr(1)
   365  		mPartsRcvd.Incr(int64(len(parts)))
   366  
   367  		// We simply bounce rejected messages in a loop downstream so
   368  		// there's no benefit to aggregating acks.
   369  		_ = ackFn(t.ctx, nil)
   370  
   371  		msg := message.New(nil)
   372  		msg.Append(parts...)
   373  		if !t.sendMsg(msg) {
   374  			return
   375  		}
   376  	}
   377  }
   378  
   379  // TransactionChan returns a transactions channel for consuming messages from
   380  // this input.
   381  func (t *SocketServer) TransactionChan() <-chan types.Transaction {
   382  	return t.transactions
   383  }
   384  
   385  // Connected returns a boolean indicating whether this input is currently
   386  // connected to its target.
   387  func (t *SocketServer) Connected() bool {
   388  	return true
   389  }
   390  
   391  // CloseAsync shuts down the SocketServer input and stops processing requests.
   392  func (t *SocketServer) CloseAsync() {
   393  	t.closeFn()
   394  }
   395  
   396  // WaitForClose blocks until the SocketServer input has closed down.
   397  func (t *SocketServer) WaitForClose(timeout time.Duration) error {
   398  	select {
   399  	case <-t.closedChan:
   400  	case <-time.After(timeout):
   401  		return types.ErrTimeout
   402  	}
   403  	return nil
   404  }
   405  
   406  //------------------------------------------------------------------------------