golang.org/x/build@v0.0.0-20240506185731-218518f32b70/internal/rendezvous/rendezvous_test.go (about) 1 // Copyright 2023 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package rendezvous 6 7 import ( 8 "context" 9 "net/http" 10 "net/http/httptest" 11 "sync" 12 "testing" 13 "time" 14 15 "golang.org/x/build/revdial/v2" 16 ) 17 18 func TestNew(t *testing.T) { 19 ctx, cancel := context.WithCancel(context.Background()) 20 defer cancel() 21 _ = New(ctx) 22 } 23 24 func TestPurgeExpiredRegistrations(t *testing.T) { 25 rdv := &Rendezvous{ 26 m: make(map[string]*entry), 27 } 28 rdv.m["test"] = &entry{ 29 deadline: time.Unix(0, 0), 30 ch: make(chan *result, 1), 31 } 32 rdv.purgeExpiredRegistrations() 33 if len(rdv.m) != 0 { 34 t.Errorf("purgeExpiredRegistrations() did not purge expired entries: want 0 got %d", len(rdv.m)) 35 } 36 } 37 38 func TestRegisterInstance(t *testing.T) { 39 ctx, cancel := context.WithCancel(context.Background()) 40 defer cancel() 41 rdv := New(ctx) 42 rdv.RegisterInstance(ctx, "sample-1", time.Minute) 43 if len(rdv.m) != 1 { 44 t.Errorf("RegisterInstance: want 1, got %d", len(rdv.m)) 45 } 46 } 47 48 func TestWaitForInstanceError(t *testing.T) { 49 testCases := []struct { 50 desc string 51 headers map[string]string 52 wantStatusCode int 53 }{ 54 {desc: "missing host header", headers: map[string]string{HeaderID: "test-id", HeaderToken: "test-token"}, wantStatusCode: 400}, 55 {desc: "missing id header", headers: map[string]string{HeaderToken: "test-token", HeaderHostname: "test-hostname"}, wantStatusCode: 400}, 56 {desc: "missing auth token", headers: map[string]string{HeaderID: "test-id", HeaderHostname: "test-hostname"}, wantStatusCode: 400}, 57 {desc: "missing registration", headers: map[string]string{HeaderID: "test-id", HeaderToken: "test-token", HeaderHostname: "test-hostname"}, wantStatusCode: 412}, 58 } 59 for _, tc := range testCases { 60 t.Run(tc.desc, func(t *testing.T) { 61 rdv := &Rendezvous{ 62 m: make(map[string]*entry), 63 validator: func(ctx context.Context, jwt string) bool { 64 return true 65 }, 66 } 67 ts := httptest.NewTLSServer(http.HandlerFunc(rdv.HandleReverse)) 68 defer ts.Close() 69 client := ts.Client() 70 req, err := http.NewRequest("GET", ts.URL, nil) 71 for k, v := range tc.headers { 72 req.Header.Set(k, v) 73 } 74 resp, err := client.Do(req) 75 if err != nil { 76 t.Errorf("client.Get(%s): %s", ts.URL, err) 77 } 78 if resp.StatusCode != tc.wantStatusCode { 79 t.Fatalf("resp.StatusCode: got %d, want %d", resp.StatusCode, tc.wantStatusCode) 80 } 81 }) 82 } 83 } 84 85 func TestWaitForInstaceErrorNonTLS(t *testing.T) { 86 rdv := &Rendezvous{ 87 m: make(map[string]*entry), 88 validator: func(ctx context.Context, jwt string) bool { 89 return true 90 }, 91 } 92 ts := httptest.NewServer(http.HandlerFunc(rdv.HandleReverse)) 93 defer ts.Close() 94 client := ts.Client() 95 req, err := http.NewRequest("GET", ts.URL, nil) 96 resp, err := client.Do(req) 97 if err != nil { 98 t.Errorf("client.Get(%s): %s", ts.URL, err) 99 } 100 if resp.StatusCode != 500 { 101 t.Fatalf("resp.StatusCode: got %d, want %d", resp.StatusCode, 500) 102 } 103 } 104 105 func TestWaitForInstaceRevdialError(t *testing.T) { 106 rdv := &Rendezvous{ 107 m: make(map[string]*entry), 108 validator: func(ctx context.Context, jwt string) bool { 109 return true 110 }, 111 } 112 instanceID := "test-id-3" 113 ctx := context.Background() 114 rdv.RegisterInstance(ctx, instanceID, 15*time.Second) 115 mux := http.NewServeMux() 116 mux.HandleFunc("/reverse", rdv.HandleReverse) 117 mux.Handle("/revdial", revdial.ConnHandler()) 118 ts := httptest.NewTLSServer(mux) 119 defer ts.Close() 120 client := ts.Client() 121 req, err := http.NewRequest("GET", ts.URL+"/reverse", nil) 122 req.Header.Set(HeaderID, instanceID) 123 req.Header.Set(HeaderToken, "test-token") 124 req.Header.Set(HeaderHostname, "test-hostname") 125 126 var wg sync.WaitGroup 127 wg.Add(1) 128 go func() { 129 defer wg.Done() 130 131 _, _ = client.Do(req) 132 }() 133 _, err = rdv.WaitForInstance(ctx, instanceID) 134 if err == nil { 135 // expect a missing status endpoint 136 t.Fatal("WaitForInstance(): got nil, want error") 137 } 138 wg.Wait() 139 } 140 141 func TestDeregisterInstance(t *testing.T) { 142 rdv := &Rendezvous{ 143 m: make(map[string]*entry), 144 } 145 id := "test-xyz" 146 rdv.m[id] = &entry{} 147 rdv.DeregisterInstance(context.Background(), id) 148 if len(rdv.m) != 0 { 149 t.Errorf("/deregusterInstance() did not remove the entry: want 0 got %d", len(rdv.m)) 150 } 151 }