github.com/hooklift/terraform@v0.11.0-beta1.0.20171117000744-6786c1361ffe/svchost/disco/disco_test.go (about)

     1  package disco
     2  
     3  import (
     4  	"crypto/tls"
     5  	"net/http"
     6  	"net/http/httptest"
     7  	"net/url"
     8  	"os"
     9  	"strconv"
    10  	"testing"
    11  
    12  	"github.com/hashicorp/terraform/svchost"
    13  	"github.com/hashicorp/terraform/svchost/auth"
    14  )
    15  
    16  func TestMain(m *testing.M) {
    17  	// During all tests we override the HTTP transport we use for discovery
    18  	// so it'll tolerate the locally-generated TLS certificates we use
    19  	// for test URLs.
    20  	httpTransport = &http.Transport{
    21  		TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
    22  	}
    23  
    24  	os.Exit(m.Run())
    25  }
    26  
    27  func TestDiscover(t *testing.T) {
    28  	t.Run("happy path", func(t *testing.T) {
    29  		portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
    30  			resp := []byte(`
    31  {
    32  "thingy.v1": "http://example.com/foo",
    33  "wotsit.v2": "http://example.net/bar"
    34  }
    35  `)
    36  			w.Header().Add("Content-Type", "application/json")
    37  			w.Header().Add("Content-Length", strconv.Itoa(len(resp)))
    38  			w.Write(resp)
    39  		})
    40  		defer close()
    41  
    42  		givenHost := "localhost" + portStr
    43  		host, err := svchost.ForComparison(givenHost)
    44  		if err != nil {
    45  			t.Fatalf("test server hostname is invalid: %s", err)
    46  		}
    47  
    48  		d := NewDisco()
    49  		discovered := d.Discover(host)
    50  		gotURL := discovered.ServiceURL("thingy.v1")
    51  		if gotURL == nil {
    52  			t.Fatalf("found no URL for thingy.v1")
    53  		}
    54  		if got, want := gotURL.String(), "http://example.com/foo"; got != want {
    55  			t.Fatalf("wrong result %q; want %q", got, want)
    56  		}
    57  	})
    58  	t.Run("chunked encoding", func(t *testing.T) {
    59  		portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
    60  			resp := []byte(`
    61  {
    62  "thingy.v1": "http://example.com/foo",
    63  "wotsit.v2": "http://example.net/bar"
    64  }
    65  `)
    66  			w.Header().Add("Content-Type", "application/json")
    67  			// We're going to force chunked encoding here -- and thus prevent
    68  			// the server from predicting the length -- so we can make sure
    69  			// our client is tolerant of servers using this encoding.
    70  			w.Write(resp[:5])
    71  			w.(http.Flusher).Flush()
    72  			w.Write(resp[5:])
    73  			w.(http.Flusher).Flush()
    74  		})
    75  		defer close()
    76  
    77  		givenHost := "localhost" + portStr
    78  		host, err := svchost.ForComparison(givenHost)
    79  		if err != nil {
    80  			t.Fatalf("test server hostname is invalid: %s", err)
    81  		}
    82  
    83  		d := NewDisco()
    84  		discovered := d.Discover(host)
    85  		gotURL := discovered.ServiceURL("wotsit.v2")
    86  		if gotURL == nil {
    87  			t.Fatalf("found no URL for wotsit.v2")
    88  		}
    89  		if got, want := gotURL.String(), "http://example.net/bar"; got != want {
    90  			t.Fatalf("wrong result %q; want %q", got, want)
    91  		}
    92  	})
    93  	t.Run("with credentials", func(t *testing.T) {
    94  		var authHeaderText string
    95  		portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
    96  			resp := []byte(`{}`)
    97  			authHeaderText = r.Header.Get("Authorization")
    98  			w.Header().Add("Content-Type", "application/json")
    99  			w.Header().Add("Content-Length", strconv.Itoa(len(resp)))
   100  			w.Write(resp)
   101  		})
   102  		defer close()
   103  
   104  		givenHost := "localhost" + portStr
   105  		host, err := svchost.ForComparison(givenHost)
   106  		if err != nil {
   107  			t.Fatalf("test server hostname is invalid: %s", err)
   108  		}
   109  
   110  		d := NewDisco()
   111  		d.SetCredentialsSource(auth.StaticCredentialsSource(map[svchost.Hostname]map[string]interface{}{
   112  			host: map[string]interface{}{
   113  				"token": "abc123",
   114  			},
   115  		}))
   116  		d.Discover(host)
   117  		if got, want := authHeaderText, "Bearer abc123"; got != want {
   118  			t.Fatalf("wrong Authorization header\ngot:  %s\nwant: %s", got, want)
   119  		}
   120  	})
   121  	t.Run("forced services override", func(t *testing.T) {
   122  		forced := map[string]interface{}{
   123  			"thingy.v1": "http://example.net/foo",
   124  			"wotsit.v2": "/foo",
   125  		}
   126  
   127  		d := NewDisco()
   128  		d.ForceHostServices(svchost.Hostname("example.com"), forced)
   129  
   130  		givenHost := "example.com"
   131  		host, err := svchost.ForComparison(givenHost)
   132  		if err != nil {
   133  			t.Fatalf("test server hostname is invalid: %s", err)
   134  		}
   135  
   136  		discovered := d.Discover(host)
   137  		{
   138  			gotURL := discovered.ServiceURL("thingy.v1")
   139  			if gotURL == nil {
   140  				t.Fatalf("found no URL for thingy.v1")
   141  			}
   142  			if got, want := gotURL.String(), "http://example.net/foo"; got != want {
   143  				t.Fatalf("wrong result %q; want %q", got, want)
   144  			}
   145  		}
   146  		{
   147  			gotURL := discovered.ServiceURL("wotsit.v2")
   148  			if gotURL == nil {
   149  				t.Fatalf("found no URL for wotsit.v2")
   150  			}
   151  			if got, want := gotURL.String(), "https://example.com/foo"; got != want {
   152  				t.Fatalf("wrong result %q; want %q", got, want)
   153  			}
   154  		}
   155  	})
   156  	t.Run("not JSON", func(t *testing.T) {
   157  		portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
   158  			resp := []byte(`{"thingy.v1": "http://example.com/foo"}`)
   159  			w.Header().Add("Content-Type", "application/octet-stream")
   160  			w.Write(resp)
   161  		})
   162  		defer close()
   163  
   164  		givenHost := "localhost" + portStr
   165  		host, err := svchost.ForComparison(givenHost)
   166  		if err != nil {
   167  			t.Fatalf("test server hostname is invalid: %s", err)
   168  		}
   169  
   170  		d := NewDisco()
   171  		discovered := d.Discover(host)
   172  
   173  		// result should be empty, which we can verify only by reaching into
   174  		// its internals.
   175  		if discovered.services != nil {
   176  			t.Errorf("response not empty; should be")
   177  		}
   178  	})
   179  	t.Run("malformed JSON", func(t *testing.T) {
   180  		portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
   181  			resp := []byte(`{"thingy.v1": "htt`) // truncated, for example...
   182  			w.Header().Add("Content-Type", "application/json")
   183  			w.Write(resp)
   184  		})
   185  		defer close()
   186  
   187  		givenHost := "localhost" + portStr
   188  		host, err := svchost.ForComparison(givenHost)
   189  		if err != nil {
   190  			t.Fatalf("test server hostname is invalid: %s", err)
   191  		}
   192  
   193  		d := NewDisco()
   194  		discovered := d.Discover(host)
   195  
   196  		// result should be empty, which we can verify only by reaching into
   197  		// its internals.
   198  		if discovered.services != nil {
   199  			t.Errorf("response not empty; should be")
   200  		}
   201  	})
   202  	t.Run("JSON with redundant charset", func(t *testing.T) {
   203  		// The JSON RFC defines no parameters for the application/json
   204  		// MIME type, but some servers have a weird tendency to just add
   205  		// "charset" to everything, so we'll make sure we ignore it successfully.
   206  		// (JSON uses content sniffing for encoding detection, not media type params.)
   207  		portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
   208  			resp := []byte(`{"thingy.v1": "http://example.com/foo"}`)
   209  			w.Header().Add("Content-Type", "application/json; charset=latin-1")
   210  			w.Write(resp)
   211  		})
   212  		defer close()
   213  
   214  		givenHost := "localhost" + portStr
   215  		host, err := svchost.ForComparison(givenHost)
   216  		if err != nil {
   217  			t.Fatalf("test server hostname is invalid: %s", err)
   218  		}
   219  
   220  		d := NewDisco()
   221  		discovered := d.Discover(host)
   222  
   223  		if discovered.services == nil {
   224  			t.Errorf("response is empty; shouldn't be")
   225  		}
   226  	})
   227  	t.Run("no discovery doc", func(t *testing.T) {
   228  		portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
   229  			w.WriteHeader(404)
   230  		})
   231  		defer close()
   232  
   233  		givenHost := "localhost" + portStr
   234  		host, err := svchost.ForComparison(givenHost)
   235  		if err != nil {
   236  			t.Fatalf("test server hostname is invalid: %s", err)
   237  		}
   238  
   239  		d := NewDisco()
   240  		discovered := d.Discover(host)
   241  
   242  		// result should be empty, which we can verify only by reaching into
   243  		// its internals.
   244  		if discovered.services != nil {
   245  			t.Errorf("response not empty; should be")
   246  		}
   247  	})
   248  	t.Run("redirect", func(t *testing.T) {
   249  		// For this test, we have two servers and one redirects to the other
   250  		portStr1, close1 := testServer(func(w http.ResponseWriter, r *http.Request) {
   251  			// This server is the one that returns a real response.
   252  			resp := []byte(`{"thingy.v1": "http://example.com/foo"}`)
   253  			w.Header().Add("Content-Type", "application/json")
   254  			w.Header().Add("Content-Length", strconv.Itoa(len(resp)))
   255  			w.Write(resp)
   256  		})
   257  		portStr2, close2 := testServer(func(w http.ResponseWriter, r *http.Request) {
   258  			// This server is the one that redirects.
   259  			http.Redirect(w, r, "https://127.0.0.1"+portStr1+"/.well-known/terraform.json", 302)
   260  		})
   261  		defer close1()
   262  		defer close2()
   263  
   264  		givenHost := "localhost" + portStr2
   265  		host, err := svchost.ForComparison(givenHost)
   266  		if err != nil {
   267  			t.Fatalf("test server hostname is invalid: %s", err)
   268  		}
   269  
   270  		d := NewDisco()
   271  		discovered := d.Discover(host)
   272  
   273  		gotURL := discovered.ServiceURL("thingy.v1")
   274  		if gotURL == nil {
   275  			t.Fatalf("found no URL for thingy.v1")
   276  		}
   277  		if got, want := gotURL.String(), "http://example.com/foo"; got != want {
   278  			t.Fatalf("wrong result %q; want %q", got, want)
   279  		}
   280  
   281  		// The base URL for the host object should be the URL we redirected to,
   282  		// rather than the we redirected _from_.
   283  		gotBaseURL := discovered.discoURL.String()
   284  		wantBaseURL := "https://127.0.0.1" + portStr1 + "/.well-known/terraform.json"
   285  		if gotBaseURL != wantBaseURL {
   286  			t.Errorf("incorrect base url %s; want %s", gotBaseURL, wantBaseURL)
   287  		}
   288  
   289  	})
   290  }
   291  
   292  func testServer(h func(w http.ResponseWriter, r *http.Request)) (portStr string, close func()) {
   293  	server := httptest.NewTLSServer(http.HandlerFunc(
   294  		func(w http.ResponseWriter, r *http.Request) {
   295  			// Test server always returns 404 if the URL isn't what we expect
   296  			if r.URL.Path != "/.well-known/terraform.json" {
   297  				w.WriteHeader(404)
   298  				w.Write([]byte("not found"))
   299  				return
   300  			}
   301  
   302  			// If the URL is correct then the given hander decides the response
   303  			h(w, r)
   304  		},
   305  	))
   306  
   307  	serverURL, _ := url.Parse(server.URL)
   308  
   309  	portStr = serverURL.Port()
   310  	if portStr != "" {
   311  		portStr = ":" + portStr
   312  	}
   313  
   314  	close = func() {
   315  		server.Close()
   316  	}
   317  
   318  	return
   319  }