github.com/masterhung0112/hk_server/v5@v5.0.0-20220302090640-ec71aef15e1c/shared/mlog/tcp_test.go (about)

     1  // Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
     2  // See LICENSE.txt for license information.
     3  
     4  package mlog
     5  
     6  import (
     7  	"bytes"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"net"
    12  	"sync"
    13  	"testing"
    14  	"time"
    15  
    16  	"github.com/stretchr/testify/require"
    17  	"github.com/wiggin77/merror"
    18  )
    19  
    20  const (
    21  	testPort = 18066
    22  )
    23  
    24  func TestNewTcpTarget(t *testing.T) {
    25  	target := LogTarget{
    26  		Type:         "tcp",
    27  		Format:       "json",
    28  		Levels:       []LogLevel{LvlInfo},
    29  		Options:      []byte(`{"IP": "localhost", "Port": 18066}`),
    30  		MaxQueueSize: 1000,
    31  	}
    32  	targets := map[string]*LogTarget{"tcp_test": &target}
    33  
    34  	t.Run("logging", func(t *testing.T) {
    35  		buf := &buffer{}
    36  		server, err := newSocketServer(testPort, buf)
    37  		require.NoError(t, err)
    38  
    39  		data := []string{"I drink your milkshake!", "We don't need no badges!", "You can't fight in here! This is the war room!"}
    40  
    41  		logger := newLogr()
    42  		err = logrAddTargets(logger, targets)
    43  		require.NoError(t, err)
    44  
    45  		for _, s := range data {
    46  			logger.Info(s)
    47  		}
    48  		err = logger.Logr().Flush()
    49  		require.NoError(t, err)
    50  		err = logger.Logr().Shutdown()
    51  		require.NoError(t, err)
    52  
    53  		err = server.waitForAnyConnection()
    54  		require.NoError(t, err)
    55  
    56  		err = server.stopServer(true)
    57  		require.NoError(t, err)
    58  
    59  		sdata := buf.String()
    60  		for _, s := range data {
    61  			require.Contains(t, sdata, s)
    62  		}
    63  	})
    64  }
    65  
    66  // socketServer is a simple socket server used for testing TCP log targets.
    67  // Note: There is more synchronization here than normally needed to avoid flaky tests.
    68  //       For example, it's possible for a unit test to create a socketServer, attempt
    69  //       writing to it, and stop the socket server before "go ss.listen()" gets scheduled.
    70  type socketServer struct {
    71  	listener net.Listener
    72  	anyConn  chan struct{}
    73  	buf      *buffer
    74  	conns    map[string]*socketServerConn
    75  	mux      sync.Mutex
    76  }
    77  
    78  type socketServerConn struct {
    79  	raddy string
    80  	conn  net.Conn
    81  	done  chan struct{}
    82  }
    83  
    84  func newSocketServer(port int, buf *buffer) (*socketServer, error) {
    85  	ss := &socketServer{
    86  		buf:     buf,
    87  		conns:   make(map[string]*socketServerConn),
    88  		anyConn: make(chan struct{}),
    89  	}
    90  
    91  	addy := fmt.Sprintf(":%d", port)
    92  	l, err := net.Listen("tcp4", addy)
    93  	if err != nil {
    94  		return nil, err
    95  	}
    96  	ss.listener = l
    97  
    98  	go ss.listen()
    99  	return ss, nil
   100  }
   101  
   102  func (ss *socketServer) listen() {
   103  	for {
   104  		conn, err := ss.listener.Accept()
   105  		if err != nil {
   106  			return
   107  		}
   108  		sconn := &socketServerConn{raddy: conn.RemoteAddr().String(), conn: conn, done: make(chan struct{})}
   109  		ss.registerConnection(sconn)
   110  		go ss.handleConnection(sconn)
   111  	}
   112  }
   113  
   114  func (ss *socketServer) waitForAnyConnection() error {
   115  	var err error
   116  	select {
   117  	case <-ss.anyConn:
   118  	case <-time.After(5 * time.Second):
   119  		err = errors.New("wait for any connection timed out")
   120  	}
   121  	return err
   122  }
   123  
   124  func (ss *socketServer) handleConnection(sconn *socketServerConn) {
   125  	close(ss.anyConn)
   126  	defer ss.unregisterConnection(sconn)
   127  	buf := make([]byte, 1024)
   128  
   129  	for {
   130  		n, err := sconn.conn.Read(buf)
   131  		if n > 0 {
   132  			ss.buf.Write(buf[:n])
   133  		}
   134  		if err == io.EOF {
   135  			ss.signalDone(sconn)
   136  			return
   137  		}
   138  	}
   139  }
   140  
   141  func (ss *socketServer) registerConnection(sconn *socketServerConn) {
   142  	ss.mux.Lock()
   143  	defer ss.mux.Unlock()
   144  	ss.conns[sconn.raddy] = sconn
   145  }
   146  
   147  func (ss *socketServer) unregisterConnection(sconn *socketServerConn) {
   148  	ss.mux.Lock()
   149  	defer ss.mux.Unlock()
   150  	delete(ss.conns, sconn.raddy)
   151  }
   152  
   153  func (ss *socketServer) signalDone(sconn *socketServerConn) {
   154  	ss.mux.Lock()
   155  	defer ss.mux.Unlock()
   156  	close(sconn.done)
   157  }
   158  
   159  func (ss *socketServer) stopServer(wait bool) error {
   160  	errs := merror.New()
   161  	ss.listener.Close()
   162  
   163  	ss.mux.Lock()
   164  	// defensive copy; no more connections can be accepted so copy will stay current.
   165  	conns := make(map[string]*socketServerConn, len(ss.conns))
   166  	for k, v := range ss.conns {
   167  		conns[k] = v
   168  	}
   169  	ss.mux.Unlock()
   170  
   171  	for _, sconn := range conns {
   172  		if wait {
   173  			select {
   174  			case <-sconn.done:
   175  			case <-time.After(time.Second * 5):
   176  				errs.Append(errors.New("timed out"))
   177  			}
   178  		}
   179  	}
   180  	return errs.ErrorOrNil()
   181  }
   182  
   183  type buffer struct {
   184  	buf bytes.Buffer
   185  	mux sync.Mutex
   186  }
   187  
   188  func (b *buffer) Write(p []byte) (n int, err error) {
   189  	b.mux.Lock()
   190  	defer b.mux.Unlock()
   191  	return b.buf.Write(p)
   192  }
   193  
   194  func (b *buffer) String() string {
   195  	b.mux.Lock()
   196  	defer b.mux.Unlock()
   197  	return b.buf.String()
   198  }