github.com/google/cloudprober@v0.11.3/servers/udp/udp_test.go (about)

     1  // Copyright 2017 The Cloudprober Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package udp
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"fmt"
    21  	"math/rand"
    22  	"net"
    23  	"sync"
    24  	"testing"
    25  	"time"
    26  
    27  	"github.com/golang/protobuf/proto"
    28  	"github.com/google/cloudprober/logger"
    29  	configpb "github.com/google/cloudprober/servers/udp/proto"
    30  )
    31  
    32  // Return true if the underlying error indicates a udp.Client timeout.
    33  // In our case, we're using the ReadTimeout- time until response is read.
    34  func isClientTimeout(err error) bool {
    35  	e, ok := err.(*net.OpError)
    36  	return ok && e != nil && e.Timeout()
    37  }
    38  
    39  func sendAndTestResponse(t *testing.T, c *configpb.ServerConf, conn net.Conn) {
    40  	size := rand.Intn(1024)
    41  	data := make([]byte, size)
    42  	rand.Read(data)
    43  
    44  	var err error
    45  	m, err := conn.Write(data)
    46  	if err != nil {
    47  		t.Fatal(err)
    48  	}
    49  	if m < len(data) {
    50  		t.Errorf("Wrote only %d of %d bytes", m, len(data))
    51  	}
    52  
    53  	timeout := time.Duration(100) * time.Millisecond
    54  	conn.SetReadDeadline(time.Now().Add(timeout))
    55  
    56  	switch c.GetType() {
    57  	case configpb.ServerConf_ECHO:
    58  		rcvd := make([]byte, size)
    59  		n, err := conn.Read(rcvd)
    60  		if err != nil {
    61  			t.Fatal(err)
    62  		}
    63  
    64  		if m != n {
    65  			t.Errorf("Sent %d bytes, got %d bytes", m, n)
    66  		}
    67  		if !bytes.Equal(data, rcvd) {
    68  			t.Errorf("Data mismatch: Sent '%v', Got '%v'", data, rcvd)
    69  		}
    70  	case configpb.ServerConf_DISCARD:
    71  		rcvd := make([]byte, size)
    72  		n, err := conn.Read(rcvd)
    73  		if err != nil {
    74  			if isClientTimeout(err) {
    75  				// Success, timed out with no response
    76  				return
    77  			}
    78  			t.Fatal(err)
    79  		}
    80  		if n > 0 {
    81  			t.Errorf("Received data (%v)! (Should be discarded)", rcvd)
    82  		}
    83  	}
    84  }
    85  
    86  func TestEchoServer(t *testing.T) {
    87  	testConfig := &configpb.ServerConf{
    88  		Port: proto.Int32(int32(0)),
    89  		Type: configpb.ServerConf_ECHO.Enum(),
    90  	}
    91  	testServer(t, testConfig)
    92  }
    93  
    94  func TestDiscardServer(t *testing.T) {
    95  	testConfig := &configpb.ServerConf{
    96  		Port: proto.Int32(int32(0)),
    97  		Type: configpb.ServerConf_DISCARD.Enum(),
    98  	}
    99  	testServer(t, testConfig)
   100  }
   101  
   102  func testServer(t *testing.T, testConfig *configpb.ServerConf) {
   103  	l := &logger.Logger{}
   104  	server, err := New(context.Background(), testConfig, l)
   105  	if err != nil {
   106  		t.Fatalf("Error creating a new server: %v", err)
   107  	}
   108  	serverAddr := fmt.Sprintf("localhost:%d", server.conn.LocalAddr().(*net.UDPAddr).Port)
   109  	go server.Start(context.Background(), nil)
   110  	// try 100 Samples
   111  	for i := 0; i < 100; i++ {
   112  		t.Logf("Creating connection %d to %s", i, serverAddr)
   113  		conn, err := net.Dial("udp", serverAddr)
   114  		if err != nil {
   115  			t.Fatal(err)
   116  		}
   117  		sendAndTestResponse(t, testConfig, conn)
   118  		conn.Close()
   119  	}
   120  	// try 10 samples on the same connection
   121  	t.Logf("Creating many-packet connection to %s", serverAddr)
   122  	conn, err := net.Dial("udp", serverAddr)
   123  	if err != nil {
   124  		t.Fatal(err)
   125  	}
   126  	defer conn.Close()
   127  	for i := 0; i < 10; i++ {
   128  		sendAndTestResponse(t, testConfig, conn)
   129  	}
   130  }
   131  
   132  func TestServerStop(t *testing.T) {
   133  	t.Run("ECHO mode", func(t *testing.T) {
   134  		testServerStopWithConfig(t, &configpb.ServerConf{
   135  			Port: proto.Int32(int32(0)),
   136  			Type: configpb.ServerConf_ECHO.Enum(),
   137  		})
   138  	})
   139  	t.Run("Discard mode", func(t *testing.T) {
   140  		testServerStopWithConfig(t, &configpb.ServerConf{
   141  			Port: proto.Int32(int32(0)),
   142  			Type: configpb.ServerConf_DISCARD.Enum(),
   143  		})
   144  	})
   145  }
   146  
   147  func testServerStopWithConfig(t *testing.T, testConfig *configpb.ServerConf) {
   148  	t.Helper()
   149  
   150  	server, err := New(context.Background(), testConfig, &logger.Logger{})
   151  	if err != nil {
   152  		t.Fatalf("Error creating a new server: %v", err)
   153  	}
   154  	serverAddr := fmt.Sprintf("localhost:%d", server.conn.LocalAddr().(*net.UDPAddr).Port)
   155  
   156  	var wg sync.WaitGroup
   157  	ctx, cancelF := context.WithCancel(context.Background())
   158  
   159  	wg.Add(1)
   160  	go func() {
   161  		server.Start(ctx, nil)
   162  		wg.Done()
   163  	}()
   164  
   165  	go func() {
   166  		time.Sleep(1 * time.Second)
   167  		cancelF()
   168  	}()
   169  
   170  	conn, err := net.Dial("udp", serverAddr)
   171  	if err != nil {
   172  		t.Errorf("Error connecting to test UDP server (%s): %v", serverAddr, err)
   173  	}
   174  	conn.SetWriteDeadline(time.Now().Add(2 * time.Second))
   175  	for i := 0; true; i++ {
   176  		_, err := conn.Write(make([]byte, 10))
   177  		if err == nil {
   178  			continue
   179  		}
   180  		t.Logf("Stopped writing packet due to error: %v, sent %d packets", err, i+1)
   181  		break
   182  	}
   183  
   184  	wg.Wait()
   185  }