github.com/0chain/gosdk@v1.17.11/zcncore/transaction_query_test.go (about)

     1  package zcncore
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"log"
     9  	"math/rand"
    10  	"net"
    11  	"net/http"
    12  	"os"
    13  	"testing"
    14  	"time"
    15  
    16  	"github.com/stretchr/testify/require"
    17  )
    18  
    19  const (
    20  	keyServerAddr = "serverAddr"
    21  	addrPrefix    = "http://localhost"
    22  )
    23  
    24  var tq *TransactionQuery
    25  var numSharders int
    26  var avgTimeToFindSharder float32
    27  var maxTimePerIteration float32
    28  var sharders []string
    29  
    30  type SharderHealthStatus struct {
    31  	Host         string `json:"host"`
    32  	HealthStatus string `json:"health"`
    33  }
    34  
    35  func TestMain(m *testing.M) {
    36  	numSharders = 10
    37  	sharders = make([]string, 0)
    38  	for i := 0; i < numSharders; i++ {
    39  		port := fmt.Sprintf(":600%d", i)
    40  		sharders = append(sharders, addrPrefix+port)
    41  	}
    42  	startMockSharderServers(sharders)
    43  	// wait for 2s for all servers to start
    44  	time.Sleep(2 * time.Second)
    45  	exitVal := m.Run()
    46  	os.Exit(exitVal)
    47  }
    48  
    49  func TestGetRandomSharder(t *testing.T) {
    50  	var err error
    51  	tq, err = NewTransactionQuery(sharders, []string{})
    52  	if err != nil {
    53  		t.Fatalf("Failed to create new transaction query: %v", err)
    54  	}
    55  
    56  	for _, tc := range []struct {
    57  		name           string
    58  		onlineSharders []string
    59  		expectedErr    error
    60  		setupContext   func(ctx context.Context) context.Context
    61  	}{
    62  		{
    63  			name:           "context deadline exceeded",
    64  			onlineSharders: []string{"http://localhost:6009"},
    65  			expectedErr:    context.DeadlineExceeded,
    66  			setupContext: func(ct context.Context) context.Context {
    67  				ctx, cancel := context.WithTimeout(ct, 100*time.Microsecond)
    68  				go func() {
    69  					<-ctx.Done()
    70  					cancel()
    71  				}()
    72  				return ctx
    73  			},
    74  		},
    75  		{
    76  			name:           "all sharders online",
    77  			onlineSharders: sharders,
    78  			expectedErr:    nil,
    79  		},
    80  		{
    81  			name:           "only one sharder online",
    82  			onlineSharders: []string{"http://localhost:6000"},
    83  			expectedErr:    nil,
    84  		},
    85  		{
    86  			name:           "few sharders online",
    87  			onlineSharders: []string{"http://localhost:6001", "http://localhost:6006", "http://localhost:6009"},
    88  			expectedErr:    nil,
    89  		},
    90  		{
    91  			name:           "all sharders offline",
    92  			onlineSharders: []string{},
    93  			expectedErr:    ErrNoOnlineSharders,
    94  		},
    95  	} {
    96  		t.Run(tc.name, func(t *testing.T) {
    97  			tq.Reset()
    98  
    99  			for _, s := range sharders {
   100  				if !contains(tc.onlineSharders, s) {
   101  					tq.Lock()
   102  					tq.offline[s] = true
   103  					tq.Unlock()
   104  				}
   105  			}
   106  			ctx := context.Background()
   107  			if tc.setupContext != nil {
   108  				ctx = tc.setupContext(ctx)
   109  			}
   110  			sharder, err := tq.getRandomSharderWithHealthcheck(ctx)
   111  			if tc.expectedErr == nil {
   112  				require.NoError(t, err)
   113  				require.Subset(t, tc.onlineSharders, []string{sharder})
   114  			} else {
   115  				require.EqualError(t, err, tc.expectedErr.Error())
   116  			}
   117  		})
   118  	}
   119  }
   120  
   121  // Maybe replace this with the standard go benchmark later on
   122  func TestGetRandomSharderAndBenchmark(t *testing.T) {
   123  	var err error
   124  	tq, err = NewTransactionQuery(sharders, []string{})
   125  	if err != nil {
   126  		t.Fatalf("Failed to create new transaction query: %v", err)
   127  	}
   128  
   129  	done := make(chan struct{})
   130  	go startAndStopShardersRandomly(done)
   131  	fetchRandomSharderAndBenchmark(t)
   132  	close(done)
   133  }
   134  
   135  func startMockSharderServers(sharders []string) {
   136  	for i := range sharders {
   137  		url := fmt.Sprintf(":600%d", i)
   138  		go func(url string) {
   139  			ctx, cancel := context.WithCancel(context.Background())
   140  			mx := http.NewServeMux()
   141  			mx.HandleFunc(SharderEndpointHealthCheck, getSharderHealth)
   142  			httpServer := &http.Server{
   143  				Addr:    url,
   144  				Handler: mx,
   145  				BaseContext: func(l net.Listener) context.Context {
   146  					ctx := context.WithValue(ctx, keyServerAddr, url) // nolint
   147  					return ctx
   148  				},
   149  			}
   150  			log.Printf("Starting sharder server at: %v", url)
   151  			err := httpServer.ListenAndServe()
   152  			if errors.Is(err, http.ErrServerClosed) {
   153  				log.Printf("server %v closed\n", httpServer.Addr)
   154  			} else if err != nil {
   155  				log.Printf("error listening for server one: %s\n", err)
   156  			}
   157  			cancel()
   158  		}(url)
   159  	}
   160  }
   161  
   162  func getSharderHealth(w http.ResponseWriter, req *http.Request) {
   163  	ctx := req.Context()
   164  	sharderHost := ctx.Value(keyServerAddr).(string)
   165  	tq.RLock()
   166  	_, ok := tq.offline[sharderHost]
   167  	tq.RUnlock()
   168  	if ok {
   169  		errorAny(w, 404, fmt.Sprintf("sharder %v is offline", sharderHost))
   170  	} else {
   171  		healthStatus := &SharderHealthStatus{
   172  			Host:         sharderHost,
   173  			HealthStatus: "healthy",
   174  		}
   175  		err := json.NewEncoder(w).Encode(healthStatus)
   176  		if err != nil {
   177  			errorAny(w, http.StatusInternalServerError, "failed to encode json")
   178  		}
   179  	}
   180  }
   181  
   182  func startAndStopShardersRandomly(done chan struct{}) {
   183  	for {
   184  		select {
   185  		case <-time.After(5 * time.Millisecond):
   186  			tq.Lock()
   187  			// mark a random sharder offline every 5ms
   188  			randGen := rand.New(rand.NewSource(time.Now().UnixNano()))
   189  			randomSharder := tq.sharders[randGen.Intn(numSharders)]
   190  			tq.offline[randomSharder] = true
   191  			tq.Unlock()
   192  
   193  		case <-time.After(3 * time.Millisecond):
   194  			tq.Lock()
   195  			// mark a random sharder online every 3ms
   196  			randGen := rand.New(rand.NewSource(time.Now().UnixNano()))
   197  			randomSharder := tq.sharders[randGen.Intn(numSharders)]
   198  			delete(tq.offline, randomSharder)
   199  			tq.Unlock()
   200  
   201  		case <-time.After(5 * time.Second):
   202  			//Randomly mark all sharders online every 5s
   203  			tq.Lock()
   204  			tq.Reset()
   205  			tq.Unlock()
   206  		case <-done:
   207  			return
   208  		}
   209  	}
   210  }
   211  
   212  func fetchRandomSharderAndBenchmark(t *testing.T) {
   213  	numIterations := 5
   214  	for i := 0; i < numIterations; i++ {
   215  		// Sleep for sometime to have some random sharders started and stopped
   216  		time.Sleep(20 * time.Millisecond)
   217  		ctx := context.Background()
   218  		start := time.Now()
   219  		_, err := tq.getRandomSharderWithHealthcheck(ctx)
   220  		if err != nil {
   221  			t.Fatalf("Failed to get a random sharder err: %v", err)
   222  		}
   223  		end := float32(time.Since(start) / time.Microsecond)
   224  		if end > maxTimePerIteration {
   225  			maxTimePerIteration = end
   226  		}
   227  		avgTimeToFindSharder += end
   228  
   229  	}
   230  	avgTimeToFindSharder = (avgTimeToFindSharder / float32(numIterations)) / 1000
   231  	maxTimePerIteration /= 1000
   232  	t.Logf("Average time to find a random sharder: %vms and max time for an iteration: %vms", avgTimeToFindSharder, maxTimePerIteration)
   233  }
   234  
   235  func errorAny(w http.ResponseWriter, status int, msg string) {
   236  	httpMsg := fmt.Sprintf("%d %s", status, http.StatusText(status))
   237  	if msg != "" {
   238  		httpMsg = fmt.Sprintf("%s - %s", httpMsg, msg)
   239  	}
   240  	http.Error(w, httpMsg, status)
   241  }
   242  
   243  func contains(list []string, e string) bool {
   244  	for _, l := range list {
   245  		if l == e {
   246  			return true
   247  		}
   248  	}
   249  	return false
   250  }