github.com/elfadel/cilium@v1.6.12/pkg/proxy/server_test.go (about)

     1  // This code is copied from github.com/optiopay/kafka to provide the testing
     2  // framework
     3  
     4  // +build !privileged_tests
     5  
     6  package proxy
     7  
     8  import (
     9  	"bytes"
    10  	"fmt"
    11  	"net"
    12  	"strconv"
    13  	"time"
    14  
    15  	"github.com/cilium/cilium/pkg/lock"
    16  
    17  	"github.com/optiopay/kafka/proto"
    18  )
    19  
    20  const (
    21  	AnyRequest              = -1
    22  	ProduceRequest          = 0
    23  	FetchRequest            = 1
    24  	OffsetRequest           = 2
    25  	MetadataRequest         = 3
    26  	OffsetCommitRequest     = 8
    27  	OffsetFetchRequest      = 9
    28  	ConsumerMetadataRequest = 10
    29  )
    30  
    31  type Serializable interface {
    32  	Bytes(int16) ([]byte, error)
    33  }
    34  
    35  type RequestHandler func(request Serializable) (response Serializable)
    36  
    37  type Server struct {
    38  	Processed int
    39  
    40  	mu       lock.RWMutex
    41  	ln       net.Listener
    42  	clients  map[int64]net.Conn
    43  	handlers map[int16]RequestHandler
    44  }
    45  
    46  func NewServer() *Server {
    47  	srv := &Server{
    48  		clients:  make(map[int64]net.Conn),
    49  		handlers: make(map[int16]RequestHandler),
    50  	}
    51  	srv.handlers[AnyRequest] = srv.defaultRequestHandler
    52  	return srv
    53  }
    54  
    55  // Handle registers handler for given message kind. Handler registered with
    56  // AnyRequest kind will be used only if there is no precise handler for the
    57  // kind.
    58  func (srv *Server) Handle(reqKind int16, handler RequestHandler) {
    59  	srv.mu.Lock()
    60  	srv.handlers[reqKind] = handler
    61  	srv.mu.Unlock()
    62  }
    63  
    64  func (srv *Server) Address() string {
    65  	return srv.ln.Addr().String()
    66  }
    67  
    68  func (srv *Server) HostPort() (string, int) {
    69  	host, sport, err := net.SplitHostPort(srv.ln.Addr().String())
    70  	if err != nil {
    71  		panic(fmt.Sprintf("cannot split server address: %s", err))
    72  	}
    73  	port, err := strconv.Atoi(sport)
    74  	if err != nil {
    75  		panic(fmt.Sprintf("port '%s' is not a number: %s", sport, err))
    76  	}
    77  	if host == "" {
    78  		host = "localhost"
    79  	}
    80  	return host, port
    81  }
    82  
    83  func (srv *Server) Start() {
    84  	srv.mu.Lock()
    85  	defer srv.mu.Unlock()
    86  
    87  	if srv.ln != nil {
    88  		panic("server already started")
    89  	}
    90  	ln, err := net.Listen("tcp4", "127.0.0.1:")
    91  	if err != nil {
    92  		panic(fmt.Sprintf("cannot start server: %s", err))
    93  	}
    94  	srv.ln = ln
    95  
    96  	go func() {
    97  		for {
    98  			client, err := ln.Accept()
    99  			if err != nil {
   100  				return
   101  			}
   102  			go srv.handleClient(client)
   103  		}
   104  	}()
   105  }
   106  
   107  func (srv *Server) Close() {
   108  	srv.mu.Lock()
   109  	_ = srv.ln.Close()
   110  	for _, cli := range srv.clients {
   111  		_ = cli.Close()
   112  	}
   113  	srv.clients = make(map[int64]net.Conn)
   114  	srv.mu.Unlock()
   115  }
   116  
   117  func (srv *Server) handleClient(c net.Conn) {
   118  	clientID := time.Now().UnixNano()
   119  	srv.mu.Lock()
   120  	srv.clients[clientID] = c
   121  	srv.mu.Unlock()
   122  
   123  	defer func() {
   124  		srv.mu.Lock()
   125  		delete(srv.clients, clientID)
   126  		srv.mu.Unlock()
   127  	}()
   128  
   129  	for {
   130  		kind, b, err := proto.ReadReq(c)
   131  		if err != nil {
   132  			return
   133  		}
   134  		srv.mu.RLock()
   135  		fn, ok := srv.handlers[kind]
   136  		if !ok {
   137  			fn, ok = srv.handlers[AnyRequest]
   138  		}
   139  		srv.mu.RUnlock()
   140  
   141  		if !ok {
   142  			panic(fmt.Sprintf("no handler for %d", kind))
   143  		}
   144  
   145  		var request Serializable
   146  
   147  		switch kind {
   148  		case FetchRequest:
   149  			request, err = proto.ReadFetchReq(bytes.NewBuffer(b))
   150  		case ProduceRequest:
   151  			request, err = proto.ReadProduceReq(bytes.NewBuffer(b))
   152  		case OffsetRequest:
   153  			request, err = proto.ReadOffsetReq(bytes.NewBuffer(b))
   154  		case MetadataRequest:
   155  			request, err = proto.ReadMetadataReq(bytes.NewBuffer(b))
   156  		case ConsumerMetadataRequest:
   157  			request, err = proto.ReadConsumerMetadataReq(bytes.NewBuffer(b))
   158  		case OffsetCommitRequest:
   159  			request, err = proto.ReadOffsetCommitReq(bytes.NewBuffer(b))
   160  		case OffsetFetchRequest:
   161  			request, err = proto.ReadOffsetFetchReq(bytes.NewBuffer(b))
   162  		}
   163  
   164  		if err != nil {
   165  			panic(fmt.Sprintf("could not read message %d: %s", kind, err))
   166  		}
   167  
   168  		response := fn(request)
   169  		if response != nil {
   170  			b, err := response.Bytes(proto.KafkaV0)
   171  			if err != nil {
   172  				panic(fmt.Sprintf("cannot serialize %T: %s", response, err))
   173  			}
   174  			if _, err := c.Write(b); err != nil {
   175  				panic(fmt.Sprintf("cannot wirte to client: %s", err))
   176  			}
   177  		}
   178  	}
   179  }
   180  
   181  func (srv *Server) defaultRequestHandler(request Serializable) Serializable {
   182  	srv.mu.RLock()
   183  	defer srv.mu.RUnlock()
   184  
   185  	srv.Processed++
   186  
   187  	switch req := request.(type) {
   188  	case *proto.FetchReq:
   189  		resp := &proto.FetchResp{
   190  			CorrelationID: req.CorrelationID,
   191  			Topics:        make([]proto.FetchRespTopic, len(req.Topics)),
   192  		}
   193  		for ti, topic := range req.Topics {
   194  			resp.Topics[ti] = proto.FetchRespTopic{
   195  				Name:       topic.Name,
   196  				Partitions: make([]proto.FetchRespPartition, len(topic.Partitions)),
   197  			}
   198  			for pi, part := range topic.Partitions {
   199  				resp.Topics[ti].Partitions[pi] = proto.FetchRespPartition{
   200  					ID:        part.ID,
   201  					Err:       proto.ErrUnknownTopicOrPartition,
   202  					TipOffset: -1,
   203  					Messages:  []*proto.Message{},
   204  				}
   205  			}
   206  		}
   207  		return resp
   208  	case *proto.ProduceReq:
   209  		resp := &proto.ProduceResp{
   210  			CorrelationID: req.CorrelationID,
   211  		}
   212  		resp.Topics = make([]proto.ProduceRespTopic, len(req.Topics))
   213  		for ti, topic := range req.Topics {
   214  			resp.Topics[ti] = proto.ProduceRespTopic{
   215  				Name:       topic.Name,
   216  				Partitions: make([]proto.ProduceRespPartition, len(topic.Partitions)),
   217  			}
   218  			for pi, part := range topic.Partitions {
   219  				resp.Topics[ti].Partitions[pi] = proto.ProduceRespPartition{
   220  					ID:     part.ID,
   221  					Err:    proto.ErrUnknownTopicOrPartition,
   222  					Offset: -1,
   223  				}
   224  			}
   225  		}
   226  		return resp
   227  	case *proto.OffsetReq:
   228  		topics := make([]proto.OffsetRespTopic, len(req.Topics))
   229  		for ti := range req.Topics {
   230  			var topic = &topics[ti]
   231  			topic.Name = req.Topics[ti].Name
   232  			topic.Partitions = make([]proto.OffsetRespPartition, len(req.Topics[ti].Partitions))
   233  			for pi := range topic.Partitions {
   234  				var part = &topic.Partitions[pi]
   235  				part.ID = req.Topics[ti].Partitions[pi].ID
   236  				part.Err = proto.ErrUnknownTopicOrPartition
   237  			}
   238  		}
   239  
   240  		return &proto.OffsetResp{
   241  			CorrelationID: req.CorrelationID,
   242  			Topics:        topics,
   243  		}
   244  	case *proto.MetadataReq:
   245  		host, sport, err := net.SplitHostPort(srv.ln.Addr().String())
   246  		if err != nil {
   247  			panic(fmt.Sprintf("cannot split server address: %s", err))
   248  		}
   249  		port, err := strconv.Atoi(sport)
   250  		if err != nil {
   251  			panic(fmt.Sprintf("port '%s' is not a number: %s", sport, err))
   252  		}
   253  		if host == "" {
   254  			host = "localhost"
   255  		}
   256  		return &proto.MetadataResp{
   257  			CorrelationID: req.CorrelationID,
   258  			Brokers: []proto.MetadataRespBroker{
   259  				{NodeID: 1, Host: host, Port: int32(port)},
   260  			},
   261  			Topics: []proto.MetadataRespTopic{},
   262  		}
   263  	case *proto.ConsumerMetadataReq:
   264  		panic("not implemented")
   265  	case *proto.OffsetCommitReq:
   266  		panic("not implemented")
   267  	case *proto.OffsetFetchReq:
   268  		panic("not implemented")
   269  	default:
   270  		panic(fmt.Sprintf("unknown message type: %T", req))
   271  	}
   272  }