github.com/DerekStrickland/consul@v1.4.5/connect/resolver_test.go (about) 1 package connect 2 3 import ( 4 "context" 5 "testing" 6 "time" 7 8 "github.com/hashicorp/consul/agent" 9 "github.com/hashicorp/consul/agent/connect" 10 "github.com/hashicorp/consul/api" 11 "github.com/stretchr/testify/require" 12 ) 13 14 func TestStaticResolver_Resolve(t *testing.T) { 15 type fields struct { 16 Addr string 17 CertURI connect.CertURI 18 } 19 tests := []struct { 20 name string 21 fields fields 22 }{ 23 { 24 name: "simples", 25 fields: fields{"1.2.3.4:80", connect.TestSpiffeIDService(t, "foo")}, 26 }, 27 } 28 for _, tt := range tests { 29 t.Run(tt.name, func(t *testing.T) { 30 sr := StaticResolver{ 31 Addr: tt.fields.Addr, 32 CertURI: tt.fields.CertURI, 33 } 34 addr, certURI, err := sr.Resolve(context.Background()) 35 require := require.New(t) 36 require.Nil(err) 37 require.Equal(sr.Addr, addr) 38 require.Equal(sr.CertURI, certURI) 39 }) 40 } 41 } 42 43 func TestConsulResolver_Resolve(t *testing.T) { 44 // Setup a local test agent to query 45 agent := agent.NewTestAgent(t, "test-consul", "") 46 defer agent.Shutdown() 47 48 cfg := api.DefaultConfig() 49 cfg.Address = agent.HTTPAddr() 50 client, err := api.NewClient(cfg) 51 require.Nil(t, err) 52 53 // Setup a service with a connect proxy instance 54 regSrv := &api.AgentServiceRegistration{ 55 Name: "web", 56 Port: 8080, 57 } 58 err = client.Agent().ServiceRegister(regSrv) 59 require.Nil(t, err) 60 61 regProxy := &api.AgentServiceRegistration{ 62 Kind: "connect-proxy", 63 Name: "web-proxy", 64 Port: 9090, 65 Proxy: &api.AgentServiceConnectProxyConfig{ 66 DestinationServiceName: "web", 67 }, 68 } 69 err = client.Agent().ServiceRegister(regProxy) 70 require.Nil(t, err) 71 72 // And another proxy so we can test handling with multiple endpoints returned 73 regProxy.Port = 9091 74 regProxy.ID = "web-proxy-2" 75 err = client.Agent().ServiceRegister(regProxy) 76 require.Nil(t, err) 77 78 // Add a native service 79 { 80 regSrv := &api.AgentServiceRegistration{ 81 Name: "db", 82 Port: 8080, 83 Connect: &api.AgentServiceConnect{ 84 Native: true, 85 }, 86 } 87 require.NoError(t, client.Agent().ServiceRegister(regSrv)) 88 } 89 90 // Add a prepared query 91 queryId, _, err := client.PreparedQuery().Create(&api.PreparedQueryDefinition{ 92 Name: "test-query", 93 Service: api.ServiceQuery{ 94 Service: "web", 95 Connect: true, 96 }, 97 }, nil) 98 require.NoError(t, err) 99 100 proxyAddrs := []string{ 101 agent.Config.AdvertiseAddrLAN.String() + ":9090", 102 agent.Config.AdvertiseAddrLAN.String() + ":9091", 103 } 104 105 type fields struct { 106 Namespace string 107 Name string 108 Type int 109 Datacenter string 110 } 111 tests := []struct { 112 name string 113 fields fields 114 timeout time.Duration 115 wantAddr string 116 wantCertURI connect.CertURI 117 wantErr bool 118 addrs []string 119 }{ 120 { 121 name: "basic service discovery", 122 fields: fields{ 123 Namespace: "default", 124 Name: "web", 125 Type: ConsulResolverTypeService, 126 }, 127 // Want empty host since we don't enforce trust domain outside of TLS and 128 // don't need to load the current one this way. 129 wantCertURI: connect.TestSpiffeIDServiceWithHost(t, "web", ""), 130 wantErr: false, 131 addrs: proxyAddrs, 132 }, 133 { 134 name: "basic service with native service", 135 fields: fields{ 136 Namespace: "default", 137 Name: "db", 138 Type: ConsulResolverTypeService, 139 }, 140 // Want empty host since we don't enforce trust domain outside of TLS and 141 // don't need to load the current one this way. 142 wantCertURI: connect.TestSpiffeIDServiceWithHost(t, "db", ""), 143 wantErr: false, 144 }, 145 { 146 name: "Bad Type errors", 147 fields: fields{ 148 Namespace: "default", 149 Name: "web", 150 Type: 123, 151 }, 152 wantErr: true, 153 }, 154 { 155 name: "Non-existent service errors", 156 fields: fields{ 157 Namespace: "default", 158 Name: "foo", 159 Type: ConsulResolverTypeService, 160 }, 161 wantErr: true, 162 }, 163 { 164 name: "timeout errors", 165 fields: fields{ 166 Namespace: "default", 167 Name: "web", 168 Type: ConsulResolverTypeService, 169 }, 170 timeout: 1 * time.Nanosecond, 171 wantErr: true, 172 }, 173 { 174 name: "prepared query by id", 175 fields: fields{ 176 Name: queryId, 177 Type: ConsulResolverTypePreparedQuery, 178 }, 179 // Want empty host since we don't enforce trust domain outside of TLS and 180 // don't need to load the current one this way. 181 wantCertURI: connect.TestSpiffeIDServiceWithHost(t, "web", ""), 182 wantErr: false, 183 addrs: proxyAddrs, 184 }, 185 { 186 name: "prepared query by name", 187 fields: fields{ 188 Name: "test-query", 189 Type: ConsulResolverTypePreparedQuery, 190 }, 191 // Want empty host since we don't enforce trust domain outside of TLS and 192 // don't need to load the current one this way. 193 wantCertURI: connect.TestSpiffeIDServiceWithHost(t, "web", ""), 194 wantErr: false, 195 addrs: proxyAddrs, 196 }, 197 } 198 for _, tt := range tests { 199 t.Run(tt.name, func(t *testing.T) { 200 require := require.New(t) 201 cr := &ConsulResolver{ 202 Client: client, 203 Namespace: tt.fields.Namespace, 204 Name: tt.fields.Name, 205 Type: tt.fields.Type, 206 Datacenter: tt.fields.Datacenter, 207 } 208 // WithCancel just to have a cancel func in scope to assign in the if 209 // clause. 210 ctx, cancel := context.WithCancel(context.Background()) 211 if tt.timeout > 0 { 212 ctx, cancel = context.WithTimeout(ctx, tt.timeout) 213 } 214 defer cancel() 215 gotAddr, gotCertURI, err := cr.Resolve(ctx) 216 if tt.wantErr { 217 require.NotNil(err) 218 return 219 } 220 221 require.Nil(err) 222 require.Equal(tt.wantCertURI, gotCertURI) 223 if len(tt.addrs) > 0 { 224 require.Contains(tt.addrs, gotAddr) 225 } 226 }) 227 } 228 } 229 230 func TestConsulResolverFromAddrFunc(t *testing.T) { 231 // Don't need an actual instance since we don't do the service discovery but 232 // we do want to assert the client is pass through correctly. 233 client, err := api.NewClient(api.DefaultConfig()) 234 require.NoError(t, err) 235 236 tests := []struct { 237 name string 238 addr string 239 want Resolver 240 wantErr string 241 }{ 242 { 243 name: "service", 244 addr: "foo.service.consul", 245 want: &ConsulResolver{ 246 Client: client, 247 Namespace: "default", 248 Name: "foo", 249 Type: ConsulResolverTypeService, 250 }, 251 }, 252 { 253 name: "query", 254 addr: "foo.query.consul", 255 want: &ConsulResolver{ 256 Client: client, 257 Namespace: "default", 258 Name: "foo", 259 Type: ConsulResolverTypePreparedQuery, 260 }, 261 }, 262 { 263 name: "service with dc", 264 addr: "foo.service.dc2.consul", 265 want: &ConsulResolver{ 266 Client: client, 267 Datacenter: "dc2", 268 Namespace: "default", 269 Name: "foo", 270 Type: ConsulResolverTypeService, 271 }, 272 }, 273 { 274 name: "query with dc", 275 addr: "foo.query.dc2.consul", 276 want: &ConsulResolver{ 277 Client: client, 278 Datacenter: "dc2", 279 Namespace: "default", 280 Name: "foo", 281 Type: ConsulResolverTypePreparedQuery, 282 }, 283 }, 284 { 285 name: "invalid host:port", 286 addr: "%%%", 287 wantErr: "invalid Consul DNS domain", 288 }, 289 { 290 name: "custom domain", 291 addr: "foo.service.my-consul.com", 292 wantErr: "invalid Consul DNS domain", 293 }, 294 { 295 name: "unsupported query type", 296 addr: "foo.connect.consul", 297 wantErr: "unsupported Consul DNS domain", 298 }, 299 { 300 name: "unsupported query type and datacenter", 301 addr: "foo.connect.dc1.consul", 302 wantErr: "unsupported Consul DNS domain", 303 }, 304 { 305 name: "unsupported query type and datacenter", 306 addr: "foo.connect.dc1.consul", 307 wantErr: "unsupported Consul DNS domain", 308 }, 309 { 310 name: "unsupported tag filter", 311 addr: "tag1.foo.service.consul", 312 wantErr: "unsupported Consul DNS domain", 313 }, 314 { 315 name: "unsupported tag filter with DC", 316 addr: "tag1.foo.service.dc1.consul", 317 wantErr: "unsupported Consul DNS domain", 318 }, 319 } 320 for _, tt := range tests { 321 t.Run(tt.name, func(t *testing.T) { 322 require := require.New(t) 323 324 fn := ConsulResolverFromAddrFunc(client) 325 got, gotErr := fn(tt.addr) 326 if tt.wantErr != "" { 327 require.Error(gotErr) 328 require.Contains(gotErr.Error(), tt.wantErr) 329 } else { 330 require.NoError(gotErr) 331 require.Equal(tt.want, got) 332 } 333 }) 334 } 335 }