github.com/Jeffail/benthos/v3@v3.65.0/lib/input/socket_server.go (about) 1 package input 2 3 import ( 4 "context" 5 "fmt" 6 "io" 7 "net" 8 "strings" 9 "sync" 10 "time" 11 12 "github.com/Jeffail/benthos/v3/internal/codec" 13 "github.com/Jeffail/benthos/v3/internal/docs" 14 "github.com/Jeffail/benthos/v3/lib/log" 15 "github.com/Jeffail/benthos/v3/lib/message" 16 "github.com/Jeffail/benthos/v3/lib/metrics" 17 "github.com/Jeffail/benthos/v3/lib/types" 18 ) 19 20 //------------------------------------------------------------------------------ 21 22 func init() { 23 Constructors[TypeSocketServer] = TypeSpec{ 24 constructor: fromSimpleConstructor(NewSocketServer), 25 Summary: `Creates a server that receives a stream of messages over a tcp, udp or unix socket.`, 26 Description: ` 27 The field ` + "`max_buffer`" + ` specifies the maximum amount of memory to allocate _per connection_ for buffering lines of data. If a line of data from a connection exceeds this value then the connection will be closed.`, 28 FieldSpecs: docs.FieldSpecs{ 29 docs.FieldCommon("network", "A network type to accept (unix|tcp|udp).").HasOptions( 30 "unix", "tcp", "udp", 31 ), 32 docs.FieldCommon("address", "The address to listen from.", "/tmp/benthos.sock", "0.0.0.0:6000"), 33 codec.ReaderDocs.AtVersion("3.42.0"), 34 docs.FieldAdvanced("max_buffer", "The maximum message buffer size. Must exceed the largest message to be consumed."), 35 docs.FieldDeprecated("multipart"), 36 docs.FieldDeprecated("delimiter"), 37 }, 38 Categories: []Category{ 39 CategoryNetwork, 40 }, 41 } 42 } 43 44 //------------------------------------------------------------------------------ 45 46 // SocketServerConfig contains configuration for the SocketServer input type. 47 type SocketServerConfig struct { 48 Network string `json:"network" yaml:"network"` 49 Address string `json:"address" yaml:"address"` 50 Codec string `json:"codec" yaml:"codec"` 51 MaxBuffer int `json:"max_buffer" yaml:"max_buffer"` 52 Multipart bool `json:"multipart" yaml:"multipart"` 53 Delim string `json:"delimiter" yaml:"delimiter"` 54 } 55 56 // NewSocketServerConfig creates a new SocketServerConfig with default values. 57 func NewSocketServerConfig() SocketServerConfig { 58 return SocketServerConfig{ 59 Network: "unix", 60 Address: "/tmp/benthos.sock", 61 Codec: "lines", 62 MaxBuffer: 1000000, 63 64 // TODO: V4 Remove these fields 65 Multipart: false, 66 Delim: "", 67 } 68 } 69 70 //------------------------------------------------------------------------------ 71 72 type wrapPacketConn struct { 73 net.PacketConn 74 } 75 76 func (w *wrapPacketConn) Read(p []byte) (n int, err error) { 77 n, _, err = w.ReadFrom(p) 78 return 79 } 80 81 // SocketServer is an input type that binds to an address and consumes streams of 82 // messages over Socket. 83 type SocketServer struct { 84 conf SocketServerConfig 85 stats metrics.Type 86 log log.Modular 87 88 codecCtor codec.ReaderConstructor 89 listener net.Listener 90 conn net.PacketConn 91 92 retriesMut sync.RWMutex 93 transactions chan types.Transaction 94 95 ctx context.Context 96 closeFn func() 97 closedChan chan struct{} 98 99 mLatency metrics.StatTimer 100 } 101 102 // NewSocketServer creates a new SocketServer input type. 103 func NewSocketServer(conf Config, mgr types.Manager, log log.Modular, stats metrics.Type) (Type, error) { 104 var ln net.Listener 105 var cn net.PacketConn 106 var err error 107 108 sconf := conf.SocketServer 109 if len(sconf.Delim) > 0 { 110 sconf.Codec = "delim:" + sconf.Delim 111 } 112 if sconf.Multipart && !strings.HasSuffix(sconf.Codec, "/multipart") { 113 sconf.Codec += "/multipart" 114 } 115 116 codecConf := codec.NewReaderConfig() 117 codecConf.MaxScanTokenSize = sconf.MaxBuffer 118 ctor, err := codec.GetReader(sconf.Codec, codecConf) 119 if err != nil { 120 return nil, err 121 } 122 123 switch sconf.Network { 124 case "tcp", "unix": 125 ln, err = net.Listen(sconf.Network, sconf.Address) 126 case "udp": 127 cn, err = net.ListenPacket(sconf.Network, sconf.Address) 128 default: 129 return nil, fmt.Errorf("socket network '%v' is not supported by this input", sconf.Network) 130 } 131 if err != nil { 132 return nil, err 133 } 134 135 t := SocketServer{ 136 conf: conf.SocketServer, 137 stats: stats, 138 log: log, 139 140 codecCtor: ctor, 141 listener: ln, 142 conn: cn, 143 144 transactions: make(chan types.Transaction), 145 closedChan: make(chan struct{}), 146 147 mLatency: stats.GetTimer("latency"), 148 } 149 t.ctx, t.closeFn = context.WithCancel(context.Background()) 150 151 if ln == nil { 152 go t.udpLoop() 153 } else { 154 go t.loop() 155 } 156 return &t, nil 157 } 158 159 //------------------------------------------------------------------------------ 160 161 // Addr returns the underlying Socket listeners address. 162 func (t *SocketServer) Addr() net.Addr { 163 if t.listener != nil { 164 return t.listener.Addr() 165 } 166 return t.conn.LocalAddr() 167 } 168 169 func (t *SocketServer) sendMsg(msg types.Message) bool { 170 tStarted := time.Now() 171 172 // Block whilst retries are happening 173 t.retriesMut.Lock() 174 // nolint:staticcheck, gocritic // Ignore SA2001 empty critical section, Ignore badLock 175 t.retriesMut.Unlock() 176 177 resChan := make(chan types.Response) 178 select { 179 case t.transactions <- types.NewTransaction(msg, resChan): 180 case <-t.ctx.Done(): 181 return false 182 } 183 184 go func() { 185 hasLocked := false 186 defer func() { 187 if hasLocked { 188 t.retriesMut.RUnlock() 189 } 190 }() 191 for { 192 select { 193 case res, open := <-resChan: 194 if !open { 195 return 196 } 197 var sendErr error 198 if res != nil { 199 sendErr = res.Error() 200 } 201 if sendErr == nil || sendErr == types.ErrTypeClosed { 202 if sendErr == nil { 203 t.mLatency.Timing(time.Since(tStarted).Nanoseconds()) 204 } 205 return 206 } 207 if !hasLocked { 208 hasLocked = true 209 t.retriesMut.RLock() 210 } 211 t.log.Errorf("failed to send message: %v\n", sendErr) 212 213 // Wait before attempting again 214 select { 215 case <-time.After(time.Second): 216 case <-t.ctx.Done(): 217 return 218 } 219 220 // And then resend the transaction 221 select { 222 case t.transactions <- types.NewTransaction(msg, resChan): 223 case <-t.ctx.Done(): 224 return 225 } 226 case <-t.ctx.Done(): 227 return 228 } 229 } 230 }() 231 return true 232 } 233 234 func (t *SocketServer) loop() { 235 var ( 236 mCount = t.stats.GetCounter("count") 237 mRcvd = t.stats.GetCounter("batch.received") 238 mPartsRcvd = t.stats.GetCounter("received") 239 ) 240 241 var wg sync.WaitGroup 242 243 defer func() { 244 wg.Wait() 245 246 t.retriesMut.Lock() 247 // nolint:staticcheck, gocritic // Ignore SA2001 empty critical section, Ignore badLock 248 t.retriesMut.Unlock() 249 250 t.listener.Close() 251 252 close(t.transactions) 253 close(t.closedChan) 254 }() 255 256 t.log.Infof("Receiving %v socket messages from address: %v\n", t.conf.Network, t.listener.Addr()) 257 258 go func() { 259 <-t.ctx.Done() 260 t.listener.Close() 261 }() 262 263 acceptLoop: 264 for { 265 conn, err := t.listener.Accept() 266 if err != nil { 267 if !strings.Contains(err.Error(), "use of closed network connection") { 268 t.log.Errorf("Failed to accept Socket connection: %v\n", err) 269 } 270 select { 271 case <-time.After(time.Second): 272 continue acceptLoop 273 case <-t.ctx.Done(): 274 return 275 } 276 } 277 connCtx, connDone := context.WithCancel(t.ctx) 278 go func() { 279 <-connCtx.Done() 280 conn.Close() 281 }() 282 wg.Add(1) 283 go func(c net.Conn) { 284 defer func() { 285 connDone() 286 wg.Done() 287 c.Close() 288 }() 289 codec, err := t.codecCtor("", c, func(ctx context.Context, err error) error { 290 return nil 291 }) 292 if err != nil { 293 t.log.Errorf("Failed to create codec for new connection: %v\n", err) 294 return 295 } 296 297 for { 298 parts, ackFn, err := codec.Next(t.ctx) 299 if err != nil { 300 if err != io.EOF && err != types.ErrTimeout { 301 t.log.Errorf("Connection dropped due to: %v\n", err) 302 } 303 return 304 } 305 mCount.Incr(1) 306 mRcvd.Incr(1) 307 mPartsRcvd.Incr(int64(len(parts))) 308 309 // We simply bounce rejected messages in a loop downstream so 310 // there's no benefit to aggregating acks. 311 _ = ackFn(t.ctx, nil) 312 313 msg := message.New(nil) 314 msg.Append(parts...) 315 if !t.sendMsg(msg) { 316 return 317 } 318 } 319 }(conn) 320 } 321 } 322 323 func (t *SocketServer) udpLoop() { 324 var ( 325 mCount = t.stats.GetCounter("count") 326 mRcvd = t.stats.GetCounter("batch.received") 327 mPartsRcvd = t.stats.GetCounter("received") 328 ) 329 330 defer func() { 331 t.retriesMut.Lock() 332 // nolint:staticcheck, gocritic // Ignore SA2001 empty critical section, Ignore badLock 333 t.retriesMut.Unlock() 334 335 close(t.transactions) 336 close(t.closedChan) 337 }() 338 339 codec, err := t.codecCtor("", &wrapPacketConn{PacketConn: t.conn}, func(ctx context.Context, err error) error { 340 return nil 341 }) 342 if err != nil { 343 t.log.Errorf("Connection error due to: %v\n", err) 344 return 345 } 346 347 go func() { 348 <-t.ctx.Done() 349 codec.Close(context.Background()) 350 t.conn.Close() 351 }() 352 353 t.log.Infof("Receiving udp socket messages from address: %v\n", t.conn.LocalAddr()) 354 355 for { 356 parts, ackFn, err := codec.Next(t.ctx) 357 if err != nil { 358 if err != io.EOF && err != types.ErrTimeout { 359 t.log.Errorf("Connection dropped due to: %v\n", err) 360 } 361 return 362 } 363 mCount.Incr(1) 364 mRcvd.Incr(1) 365 mPartsRcvd.Incr(int64(len(parts))) 366 367 // We simply bounce rejected messages in a loop downstream so 368 // there's no benefit to aggregating acks. 369 _ = ackFn(t.ctx, nil) 370 371 msg := message.New(nil) 372 msg.Append(parts...) 373 if !t.sendMsg(msg) { 374 return 375 } 376 } 377 } 378 379 // TransactionChan returns a transactions channel for consuming messages from 380 // this input. 381 func (t *SocketServer) TransactionChan() <-chan types.Transaction { 382 return t.transactions 383 } 384 385 // Connected returns a boolean indicating whether this input is currently 386 // connected to its target. 387 func (t *SocketServer) Connected() bool { 388 return true 389 } 390 391 // CloseAsync shuts down the SocketServer input and stops processing requests. 392 func (t *SocketServer) CloseAsync() { 393 t.closeFn() 394 } 395 396 // WaitForClose blocks until the SocketServer input has closed down. 397 func (t *SocketServer) WaitForClose(timeout time.Duration) error { 398 select { 399 case <-t.closedChan: 400 case <-time.After(timeout): 401 return types.ErrTimeout 402 } 403 return nil 404 } 405 406 //------------------------------------------------------------------------------