github.com/medzin/terraform@v0.11.11/svchost/disco/disco_test.go (about)

     1  package disco
     2  
     3  import (
     4  	"crypto/tls"
     5  	"net/http"
     6  	"net/http/httptest"
     7  	"net/url"
     8  	"os"
     9  	"strconv"
    10  	"testing"
    11  
    12  	"github.com/hashicorp/terraform/svchost"
    13  	"github.com/hashicorp/terraform/svchost/auth"
    14  )
    15  
    16  func TestMain(m *testing.M) {
    17  	// During all tests we override the HTTP transport we use for discovery
    18  	// so it'll tolerate the locally-generated TLS certificates we use
    19  	// for test URLs.
    20  	httpTransport = &http.Transport{
    21  		TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
    22  	}
    23  
    24  	os.Exit(m.Run())
    25  }
    26  
    27  func TestDiscover(t *testing.T) {
    28  	t.Run("happy path", func(t *testing.T) {
    29  		portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
    30  			resp := []byte(`
    31  {
    32  "thingy.v1": "http://example.com/foo",
    33  "wotsit.v2": "http://example.net/bar"
    34  }
    35  `)
    36  			w.Header().Add("Content-Type", "application/json")
    37  			w.Header().Add("Content-Length", strconv.Itoa(len(resp)))
    38  			w.Write(resp)
    39  		})
    40  		defer close()
    41  
    42  		givenHost := "localhost" + portStr
    43  		host, err := svchost.ForComparison(givenHost)
    44  		if err != nil {
    45  			t.Fatalf("test server hostname is invalid: %s", err)
    46  		}
    47  
    48  		d := New()
    49  		discovered, err := d.Discover(host)
    50  		if err != nil {
    51  			t.Fatalf("unexpected discovery error: %s", err)
    52  		}
    53  
    54  		gotURL, err := discovered.ServiceURL("thingy.v1")
    55  		if err != nil {
    56  			t.Fatalf("unexpected service URL error: %s", err)
    57  		}
    58  		if gotURL == nil {
    59  			t.Fatalf("found no URL for thingy.v1")
    60  		}
    61  		if got, want := gotURL.String(), "http://example.com/foo"; got != want {
    62  			t.Fatalf("wrong result %q; want %q", got, want)
    63  		}
    64  	})
    65  	t.Run("chunked encoding", func(t *testing.T) {
    66  		portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
    67  			resp := []byte(`
    68  {
    69  "thingy.v1": "http://example.com/foo",
    70  "wotsit.v2": "http://example.net/bar"
    71  }
    72  `)
    73  			w.Header().Add("Content-Type", "application/json")
    74  			// We're going to force chunked encoding here -- and thus prevent
    75  			// the server from predicting the length -- so we can make sure
    76  			// our client is tolerant of servers using this encoding.
    77  			w.Write(resp[:5])
    78  			w.(http.Flusher).Flush()
    79  			w.Write(resp[5:])
    80  			w.(http.Flusher).Flush()
    81  		})
    82  		defer close()
    83  
    84  		givenHost := "localhost" + portStr
    85  		host, err := svchost.ForComparison(givenHost)
    86  		if err != nil {
    87  			t.Fatalf("test server hostname is invalid: %s", err)
    88  		}
    89  
    90  		d := New()
    91  		discovered, err := d.Discover(host)
    92  		if err != nil {
    93  			t.Fatalf("unexpected discovery error: %s", err)
    94  		}
    95  
    96  		gotURL, err := discovered.ServiceURL("wotsit.v2")
    97  		if err != nil {
    98  			t.Fatalf("unexpected service URL error: %s", err)
    99  		}
   100  		if gotURL == nil {
   101  			t.Fatalf("found no URL for wotsit.v2")
   102  		}
   103  		if got, want := gotURL.String(), "http://example.net/bar"; got != want {
   104  			t.Fatalf("wrong result %q; want %q", got, want)
   105  		}
   106  	})
   107  	t.Run("with credentials", func(t *testing.T) {
   108  		var authHeaderText string
   109  		portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
   110  			resp := []byte(`{}`)
   111  			authHeaderText = r.Header.Get("Authorization")
   112  			w.Header().Add("Content-Type", "application/json")
   113  			w.Header().Add("Content-Length", strconv.Itoa(len(resp)))
   114  			w.Write(resp)
   115  		})
   116  		defer close()
   117  
   118  		givenHost := "localhost" + portStr
   119  		host, err := svchost.ForComparison(givenHost)
   120  		if err != nil {
   121  			t.Fatalf("test server hostname is invalid: %s", err)
   122  		}
   123  
   124  		d := New()
   125  		d.SetCredentialsSource(auth.StaticCredentialsSource(map[svchost.Hostname]map[string]interface{}{
   126  			host: map[string]interface{}{
   127  				"token": "abc123",
   128  			},
   129  		}))
   130  		d.Discover(host)
   131  		if got, want := authHeaderText, "Bearer abc123"; got != want {
   132  			t.Fatalf("wrong Authorization header\ngot:  %s\nwant: %s", got, want)
   133  		}
   134  	})
   135  	t.Run("forced services override", func(t *testing.T) {
   136  		forced := map[string]interface{}{
   137  			"thingy.v1": "http://example.net/foo",
   138  			"wotsit.v2": "/foo",
   139  		}
   140  
   141  		d := New()
   142  		d.ForceHostServices(svchost.Hostname("example.com"), forced)
   143  
   144  		givenHost := "example.com"
   145  		host, err := svchost.ForComparison(givenHost)
   146  		if err != nil {
   147  			t.Fatalf("test server hostname is invalid: %s", err)
   148  		}
   149  
   150  		discovered, err := d.Discover(host)
   151  		if err != nil {
   152  			t.Fatalf("unexpected discovery error: %s", err)
   153  		}
   154  		{
   155  			gotURL, err := discovered.ServiceURL("thingy.v1")
   156  			if err != nil {
   157  				t.Fatalf("unexpected service URL error: %s", err)
   158  			}
   159  			if gotURL == nil {
   160  				t.Fatalf("found no URL for thingy.v1")
   161  			}
   162  			if got, want := gotURL.String(), "http://example.net/foo"; got != want {
   163  				t.Fatalf("wrong result %q; want %q", got, want)
   164  			}
   165  		}
   166  		{
   167  			gotURL, err := discovered.ServiceURL("wotsit.v2")
   168  			if err != nil {
   169  				t.Fatalf("unexpected service URL error: %s", err)
   170  			}
   171  			if gotURL == nil {
   172  				t.Fatalf("found no URL for wotsit.v2")
   173  			}
   174  			if got, want := gotURL.String(), "https://example.com/foo"; got != want {
   175  				t.Fatalf("wrong result %q; want %q", got, want)
   176  			}
   177  		}
   178  	})
   179  	t.Run("not JSON", func(t *testing.T) {
   180  		portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
   181  			resp := []byte(`{"thingy.v1": "http://example.com/foo"}`)
   182  			w.Header().Add("Content-Type", "application/octet-stream")
   183  			w.Write(resp)
   184  		})
   185  		defer close()
   186  
   187  		givenHost := "localhost" + portStr
   188  		host, err := svchost.ForComparison(givenHost)
   189  		if err != nil {
   190  			t.Fatalf("test server hostname is invalid: %s", err)
   191  		}
   192  
   193  		d := New()
   194  		discovered, err := d.Discover(host)
   195  		if err == nil {
   196  			t.Fatalf("expected a discovery error")
   197  		}
   198  
   199  		// Returned discovered should be nil.
   200  		if discovered != nil {
   201  			t.Errorf("discovered not nil; should be")
   202  		}
   203  	})
   204  	t.Run("malformed JSON", func(t *testing.T) {
   205  		portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
   206  			resp := []byte(`{"thingy.v1": "htt`) // truncated, for example...
   207  			w.Header().Add("Content-Type", "application/json")
   208  			w.Write(resp)
   209  		})
   210  		defer close()
   211  
   212  		givenHost := "localhost" + portStr
   213  		host, err := svchost.ForComparison(givenHost)
   214  		if err != nil {
   215  			t.Fatalf("test server hostname is invalid: %s", err)
   216  		}
   217  
   218  		d := New()
   219  		discovered, err := d.Discover(host)
   220  		if err == nil {
   221  			t.Fatalf("expected a discovery error")
   222  		}
   223  
   224  		// Returned discovered should be nil.
   225  		if discovered != nil {
   226  			t.Errorf("discovered not nil; should be")
   227  		}
   228  	})
   229  	t.Run("JSON with redundant charset", func(t *testing.T) {
   230  		// The JSON RFC defines no parameters for the application/json
   231  		// MIME type, but some servers have a weird tendency to just add
   232  		// "charset" to everything, so we'll make sure we ignore it successfully.
   233  		// (JSON uses content sniffing for encoding detection, not media type params.)
   234  		portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
   235  			resp := []byte(`{"thingy.v1": "http://example.com/foo"}`)
   236  			w.Header().Add("Content-Type", "application/json; charset=latin-1")
   237  			w.Write(resp)
   238  		})
   239  		defer close()
   240  
   241  		givenHost := "localhost" + portStr
   242  		host, err := svchost.ForComparison(givenHost)
   243  		if err != nil {
   244  			t.Fatalf("test server hostname is invalid: %s", err)
   245  		}
   246  
   247  		d := New()
   248  		discovered, err := d.Discover(host)
   249  		if err != nil {
   250  			t.Fatalf("unexpected discovery error: %s", err)
   251  		}
   252  
   253  		if discovered.services == nil {
   254  			t.Errorf("response is empty; shouldn't be")
   255  		}
   256  	})
   257  	t.Run("no discovery doc", func(t *testing.T) {
   258  		portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
   259  			w.WriteHeader(404)
   260  		})
   261  		defer close()
   262  
   263  		givenHost := "localhost" + portStr
   264  		host, err := svchost.ForComparison(givenHost)
   265  		if err != nil {
   266  			t.Fatalf("test server hostname is invalid: %s", err)
   267  		}
   268  
   269  		d := New()
   270  		discovered, err := d.Discover(host)
   271  		if err != nil {
   272  			t.Fatalf("unexpected discovery error: %s", err)
   273  		}
   274  
   275  		// Returned discovered.services should be nil (empty).
   276  		if discovered.services != nil {
   277  			t.Errorf("discovered.services not nil (empty); should be")
   278  		}
   279  	})
   280  	t.Run("redirect", func(t *testing.T) {
   281  		// For this test, we have two servers and one redirects to the other
   282  		portStr1, close1 := testServer(func(w http.ResponseWriter, r *http.Request) {
   283  			// This server is the one that returns a real response.
   284  			resp := []byte(`{"thingy.v1": "http://example.com/foo"}`)
   285  			w.Header().Add("Content-Type", "application/json")
   286  			w.Header().Add("Content-Length", strconv.Itoa(len(resp)))
   287  			w.Write(resp)
   288  		})
   289  		portStr2, close2 := testServer(func(w http.ResponseWriter, r *http.Request) {
   290  			// This server is the one that redirects.
   291  			http.Redirect(w, r, "https://127.0.0.1"+portStr1+"/.well-known/terraform.json", 302)
   292  		})
   293  		defer close1()
   294  		defer close2()
   295  
   296  		givenHost := "localhost" + portStr2
   297  		host, err := svchost.ForComparison(givenHost)
   298  		if err != nil {
   299  			t.Fatalf("test server hostname is invalid: %s", err)
   300  		}
   301  
   302  		d := New()
   303  		discovered, err := d.Discover(host)
   304  		if err != nil {
   305  			t.Fatalf("unexpected discovery error: %s", err)
   306  		}
   307  
   308  		gotURL, err := discovered.ServiceURL("thingy.v1")
   309  		if err != nil {
   310  			t.Fatalf("unexpected service URL error: %s", err)
   311  		}
   312  		if gotURL == nil {
   313  			t.Fatalf("found no URL for thingy.v1")
   314  		}
   315  		if got, want := gotURL.String(), "http://example.com/foo"; got != want {
   316  			t.Fatalf("wrong result %q; want %q", got, want)
   317  		}
   318  
   319  		// The base URL for the host object should be the URL we redirected to,
   320  		// rather than the we redirected _from_.
   321  		gotBaseURL := discovered.discoURL.String()
   322  		wantBaseURL := "https://127.0.0.1" + portStr1 + "/.well-known/terraform.json"
   323  		if gotBaseURL != wantBaseURL {
   324  			t.Errorf("incorrect base url %s; want %s", gotBaseURL, wantBaseURL)
   325  		}
   326  
   327  	})
   328  }
   329  
   330  func testServer(h func(w http.ResponseWriter, r *http.Request)) (portStr string, close func()) {
   331  	server := httptest.NewTLSServer(http.HandlerFunc(
   332  		func(w http.ResponseWriter, r *http.Request) {
   333  			// Test server always returns 404 if the URL isn't what we expect
   334  			if r.URL.Path != "/.well-known/terraform.json" {
   335  				w.WriteHeader(404)
   336  				w.Write([]byte("not found"))
   337  				return
   338  			}
   339  
   340  			// If the URL is correct then the given hander decides the response
   341  			h(w, r)
   342  		},
   343  	))
   344  
   345  	serverURL, _ := url.Parse(server.URL)
   346  
   347  	portStr = serverURL.Port()
   348  	if portStr != "" {
   349  		portStr = ":" + portStr
   350  	}
   351  
   352  	close = func() {
   353  		server.Close()
   354  	}
   355  
   356  	return portStr, close
   357  }