golang.org/x/build@v0.0.0-20240506185731-218518f32b70/internal/rendezvous/rendezvous_test.go (about)

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package rendezvous
     6  
     7  import (
     8  	"context"
     9  	"net/http"
    10  	"net/http/httptest"
    11  	"sync"
    12  	"testing"
    13  	"time"
    14  
    15  	"golang.org/x/build/revdial/v2"
    16  )
    17  
    18  func TestNew(t *testing.T) {
    19  	ctx, cancel := context.WithCancel(context.Background())
    20  	defer cancel()
    21  	_ = New(ctx)
    22  }
    23  
    24  func TestPurgeExpiredRegistrations(t *testing.T) {
    25  	rdv := &Rendezvous{
    26  		m: make(map[string]*entry),
    27  	}
    28  	rdv.m["test"] = &entry{
    29  		deadline: time.Unix(0, 0),
    30  		ch:       make(chan *result, 1),
    31  	}
    32  	rdv.purgeExpiredRegistrations()
    33  	if len(rdv.m) != 0 {
    34  		t.Errorf("purgeExpiredRegistrations() did not purge expired entries: want 0 got %d", len(rdv.m))
    35  	}
    36  }
    37  
    38  func TestRegisterInstance(t *testing.T) {
    39  	ctx, cancel := context.WithCancel(context.Background())
    40  	defer cancel()
    41  	rdv := New(ctx)
    42  	rdv.RegisterInstance(ctx, "sample-1", time.Minute)
    43  	if len(rdv.m) != 1 {
    44  		t.Errorf("RegisterInstance: want 1, got %d", len(rdv.m))
    45  	}
    46  }
    47  
    48  func TestWaitForInstanceError(t *testing.T) {
    49  	testCases := []struct {
    50  		desc           string
    51  		headers        map[string]string
    52  		wantStatusCode int
    53  	}{
    54  		{desc: "missing host header", headers: map[string]string{HeaderID: "test-id", HeaderToken: "test-token"}, wantStatusCode: 400},
    55  		{desc: "missing id header", headers: map[string]string{HeaderToken: "test-token", HeaderHostname: "test-hostname"}, wantStatusCode: 400},
    56  		{desc: "missing auth token", headers: map[string]string{HeaderID: "test-id", HeaderHostname: "test-hostname"}, wantStatusCode: 400},
    57  		{desc: "missing registration", headers: map[string]string{HeaderID: "test-id", HeaderToken: "test-token", HeaderHostname: "test-hostname"}, wantStatusCode: 412},
    58  	}
    59  	for _, tc := range testCases {
    60  		t.Run(tc.desc, func(t *testing.T) {
    61  			rdv := &Rendezvous{
    62  				m: make(map[string]*entry),
    63  				validator: func(ctx context.Context, jwt string) bool {
    64  					return true
    65  				},
    66  			}
    67  			ts := httptest.NewTLSServer(http.HandlerFunc(rdv.HandleReverse))
    68  			defer ts.Close()
    69  			client := ts.Client()
    70  			req, err := http.NewRequest("GET", ts.URL, nil)
    71  			for k, v := range tc.headers {
    72  				req.Header.Set(k, v)
    73  			}
    74  			resp, err := client.Do(req)
    75  			if err != nil {
    76  				t.Errorf("client.Get(%s): %s", ts.URL, err)
    77  			}
    78  			if resp.StatusCode != tc.wantStatusCode {
    79  				t.Fatalf("resp.StatusCode: got %d, want %d", resp.StatusCode, tc.wantStatusCode)
    80  			}
    81  		})
    82  	}
    83  }
    84  
    85  func TestWaitForInstaceErrorNonTLS(t *testing.T) {
    86  	rdv := &Rendezvous{
    87  		m: make(map[string]*entry),
    88  		validator: func(ctx context.Context, jwt string) bool {
    89  			return true
    90  		},
    91  	}
    92  	ts := httptest.NewServer(http.HandlerFunc(rdv.HandleReverse))
    93  	defer ts.Close()
    94  	client := ts.Client()
    95  	req, err := http.NewRequest("GET", ts.URL, nil)
    96  	resp, err := client.Do(req)
    97  	if err != nil {
    98  		t.Errorf("client.Get(%s): %s", ts.URL, err)
    99  	}
   100  	if resp.StatusCode != 500 {
   101  		t.Fatalf("resp.StatusCode: got %d, want %d", resp.StatusCode, 500)
   102  	}
   103  }
   104  
   105  func TestWaitForInstaceRevdialError(t *testing.T) {
   106  	rdv := &Rendezvous{
   107  		m: make(map[string]*entry),
   108  		validator: func(ctx context.Context, jwt string) bool {
   109  			return true
   110  		},
   111  	}
   112  	instanceID := "test-id-3"
   113  	ctx := context.Background()
   114  	rdv.RegisterInstance(ctx, instanceID, 15*time.Second)
   115  	mux := http.NewServeMux()
   116  	mux.HandleFunc("/reverse", rdv.HandleReverse)
   117  	mux.Handle("/revdial", revdial.ConnHandler())
   118  	ts := httptest.NewTLSServer(mux)
   119  	defer ts.Close()
   120  	client := ts.Client()
   121  	req, err := http.NewRequest("GET", ts.URL+"/reverse", nil)
   122  	req.Header.Set(HeaderID, instanceID)
   123  	req.Header.Set(HeaderToken, "test-token")
   124  	req.Header.Set(HeaderHostname, "test-hostname")
   125  
   126  	var wg sync.WaitGroup
   127  	wg.Add(1)
   128  	go func() {
   129  		defer wg.Done()
   130  
   131  		_, _ = client.Do(req)
   132  	}()
   133  	_, err = rdv.WaitForInstance(ctx, instanceID)
   134  	if err == nil {
   135  		// expect a missing status endpoint
   136  		t.Fatal("WaitForInstance(): got nil, want error")
   137  	}
   138  	wg.Wait()
   139  }
   140  
   141  func TestDeregisterInstance(t *testing.T) {
   142  	rdv := &Rendezvous{
   143  		m: make(map[string]*entry),
   144  	}
   145  	id := "test-xyz"
   146  	rdv.m[id] = &entry{}
   147  	rdv.DeregisterInstance(context.Background(), id)
   148  	if len(rdv.m) != 0 {
   149  		t.Errorf("/deregusterInstance() did not remove the entry: want 0 got %d", len(rdv.m))
   150  	}
   151  }