github.com/kaleido-io/firefly@v0.0.0-20210622132723-8b4b6aacb971/internal/apiserver/server_test.go (about)

     1  // Copyright © 2021 Kaleido, Inc.
     2  //
     3  // SPDX-License-Identifier: Apache-2.0
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //     http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  
    17  package apiserver
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"crypto/rand"
    23  	"crypto/rsa"
    24  	"crypto/tls"
    25  	"crypto/x509"
    26  	"crypto/x509/pkix"
    27  	"encoding/json"
    28  	"encoding/pem"
    29  	"fmt"
    30  	"io/ioutil"
    31  	"math/big"
    32  	"net"
    33  	"net/http"
    34  	"net/http/httptest"
    35  	"os"
    36  	"sync"
    37  	"testing"
    38  	"time"
    39  
    40  	"github.com/getkin/kin-openapi/openapi3"
    41  	"github.com/gorilla/mux"
    42  	"github.com/kaleido-io/firefly/internal/config"
    43  	"github.com/kaleido-io/firefly/internal/i18n"
    44  	"github.com/kaleido-io/firefly/internal/oapispec"
    45  	"github.com/kaleido-io/firefly/mocks/orchestratormocks"
    46  	"github.com/stretchr/testify/assert"
    47  )
    48  
    49  const configDir = "../../test/data/config"
    50  
    51  func TestStartStopServer(t *testing.T) {
    52  	config.Reset()
    53  	config.Set(config.HTTPPort, 0)
    54  	config.Set(config.UIPath, "test")
    55  	ctx, cancel := context.WithCancel(context.Background())
    56  	cancel() // server will immediately shut down
    57  	err := Serve(ctx, &orchestratormocks.Orchestrator{})
    58  	assert.NoError(t, err)
    59  }
    60  
    61  func TestInvalidListener(t *testing.T) {
    62  	config.Reset()
    63  	config.Set(config.HTTPAddress, "...")
    64  	_, err := createListener(context.Background())
    65  	assert.Error(t, err)
    66  }
    67  
    68  func TestServeFail(t *testing.T) {
    69  	l, _ := net.Listen("tcp", "127.0.0.1:0")
    70  	l.Close() // So server will fail
    71  	s := &http.Server{}
    72  	err := serveHTTP(context.Background(), l, s)
    73  	assert.Error(t, err)
    74  }
    75  
    76  func TestMissingCAFile(t *testing.T) {
    77  	config.Reset()
    78  	config.Set(config.HTTPTLSCAFile, "badness")
    79  	r := mux.NewRouter()
    80  	_, err := createServer(context.Background(), r)
    81  	assert.Regexp(t, "FF10105", err)
    82  }
    83  
    84  func TestBadCAFile(t *testing.T) {
    85  	config.Reset()
    86  	config.Set(config.HTTPTLSCAFile, configDir+"/firefly.core.yaml")
    87  	r := mux.NewRouter()
    88  	_, err := createServer(context.Background(), r)
    89  	assert.Regexp(t, "FF10106", err)
    90  }
    91  
    92  func TestTLSServerSelfSignedWithClientAuth(t *testing.T) {
    93  
    94  	// Create an X509 certificate pair
    95  	privatekey, _ := rsa.GenerateKey(rand.Reader, 2048)
    96  	publickey := &privatekey.PublicKey
    97  	var privateKeyBytes []byte = x509.MarshalPKCS1PrivateKey(privatekey)
    98  	privateKeyFile, _ := ioutil.TempFile("", "key.pem")
    99  	defer os.Remove(privateKeyFile.Name())
   100  	privateKeyBlock := &pem.Block{Type: "RSA PRIVATE KEY", Bytes: privateKeyBytes}
   101  	pem.Encode(privateKeyFile, privateKeyBlock)
   102  	serialNumber, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
   103  	x509Template := &x509.Certificate{
   104  		SerialNumber: serialNumber,
   105  		Subject: pkix.Name{
   106  			Organization: []string{"Unit Tests"},
   107  		},
   108  		NotBefore:             time.Now(),
   109  		NotAfter:              time.Now().Add(100 * time.Second),
   110  		KeyUsage:              x509.KeyUsageDigitalSignature,
   111  		BasicConstraintsValid: true,
   112  		IPAddresses:           []net.IP{net.IPv4(127, 0, 0, 1)},
   113  	}
   114  	derBytes, err := x509.CreateCertificate(rand.Reader, x509Template, x509Template, publickey, privatekey)
   115  	assert.NoError(t, err)
   116  	publicKeyFile, _ := ioutil.TempFile("", "cert.pem")
   117  	defer os.Remove(publicKeyFile.Name())
   118  	pem.Encode(publicKeyFile, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
   119  
   120  	// Start up a listener configured for TLS Mutual auth
   121  	config.Reset()
   122  	config.Set(config.HTTPTLSEnabled, true)
   123  	config.Set(config.HTTPTLSClientAuth, true)
   124  	config.Set(config.HTTPTLSKeyFile, privateKeyFile.Name())
   125  	config.Set(config.HTTPTLSCertFile, publicKeyFile.Name())
   126  	config.Set(config.HTTPTLSCAFile, publicKeyFile.Name())
   127  	config.Set(config.HTTPPort, 0)
   128  	ctx, cancelCtx := context.WithCancel(context.Background())
   129  	l, err := createListener(ctx)
   130  	assert.NoError(t, err)
   131  	r := mux.NewRouter()
   132  	r.HandleFunc("/test", func(res http.ResponseWriter, req *http.Request) {
   133  		res.WriteHeader(200)
   134  		json.NewEncoder(res).Encode(map[string]interface{}{"hello": "world"})
   135  	})
   136  	s, err := createServer(ctx, r)
   137  	assert.NoError(t, err)
   138  
   139  	wg := sync.WaitGroup{}
   140  	wg.Add(1)
   141  	go func() {
   142  		err := serveHTTP(ctx, l, s)
   143  		assert.NoError(t, err)
   144  		wg.Done()
   145  	}()
   146  
   147  	// Attempt a request, with a client certificate
   148  	rootCAs := x509.NewCertPool()
   149  	caPEM, _ := ioutil.ReadFile(publicKeyFile.Name())
   150  	ok := rootCAs.AppendCertsFromPEM(caPEM)
   151  	assert.True(t, ok)
   152  	c := http.Client{
   153  		Transport: &http.Transport{
   154  			TLSClientConfig: &tls.Config{
   155  				GetClientCertificate: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
   156  					clientKeyPair, err := tls.LoadX509KeyPair(publicKeyFile.Name(), privateKeyFile.Name())
   157  					return &clientKeyPair, err
   158  				},
   159  				RootCAs: rootCAs,
   160  			},
   161  		},
   162  	}
   163  	httpsAddr := fmt.Sprintf("https://%s/test", l.Addr().String())
   164  	res, err := c.Get(httpsAddr)
   165  	assert.NoError(t, err)
   166  	if res != nil {
   167  		assert.Equal(t, 200, res.StatusCode)
   168  		var resBody map[string]interface{}
   169  		json.NewDecoder(res.Body).Decode(&resBody)
   170  		assert.Equal(t, "world", resBody["hello"])
   171  	}
   172  
   173  	// Close down the server and wait for it to complete
   174  	cancelCtx()
   175  	wg.Wait()
   176  }
   177  
   178  func TestJSONHTTPServePOST201(t *testing.T) {
   179  	mo := &orchestratormocks.Orchestrator{}
   180  	handler := routeHandler(mo, &oapispec.Route{
   181  		Name:            "testRoute",
   182  		Path:            "/test",
   183  		Method:          "POST",
   184  		JSONInputValue:  func() interface{} { return make(map[string]interface{}) },
   185  		JSONOutputValue: func() interface{} { return make(map[string]interface{}) },
   186  		JSONOutputCode:  201,
   187  		JSONHandler: func(r oapispec.APIRequest) (output interface{}, err error) {
   188  			assert.Equal(t, "value1", r.Input.(map[string]interface{})["input1"])
   189  			return map[string]interface{}{"output1": "value2"}, nil
   190  		},
   191  	})
   192  	s := httptest.NewServer(http.HandlerFunc(handler))
   193  	defer s.Close()
   194  
   195  	b, _ := json.Marshal(map[string]interface{}{"input1": "value1"})
   196  	res, err := http.Post(fmt.Sprintf("http://%s/test", s.Listener.Addr()), "application/json", bytes.NewReader(b))
   197  	assert.NoError(t, err)
   198  	assert.Equal(t, 201, res.StatusCode)
   199  	var resJSON map[string]interface{}
   200  	json.NewDecoder(res.Body).Decode(&resJSON)
   201  	assert.Equal(t, "value2", resJSON["output1"])
   202  }
   203  
   204  func TestJSONHTTPResponseEncodeFail(t *testing.T) {
   205  	mo := &orchestratormocks.Orchestrator{}
   206  	handler := routeHandler(mo, &oapispec.Route{
   207  		Name:            "testRoute",
   208  		Path:            "/test",
   209  		Method:          "GET",
   210  		JSONInputValue:  nil,
   211  		JSONOutputValue: func() interface{} { return make(map[string]interface{}) },
   212  		JSONOutputCode:  200,
   213  		JSONHandler: func(r oapispec.APIRequest) (output interface{}, err error) {
   214  			v := map[string]interface{}{"unserializable": map[bool]interface{}{true: "not in JSON"}}
   215  			return v, nil
   216  		},
   217  	})
   218  	s := httptest.NewServer(http.HandlerFunc(handler))
   219  	defer s.Close()
   220  
   221  	b, _ := json.Marshal(map[string]interface{}{"input1": "value1"})
   222  	res, err := http.Post(fmt.Sprintf("http://%s/test", s.Listener.Addr()), "application/json", bytes.NewReader(b))
   223  	assert.NoError(t, err)
   224  	var resJSON map[string]interface{}
   225  	json.NewDecoder(res.Body).Decode(&resJSON)
   226  	assert.Regexp(t, "FF10107", resJSON["error"])
   227  }
   228  
   229  func TestJSONHTTPNilResponseNon204(t *testing.T) {
   230  	mo := &orchestratormocks.Orchestrator{}
   231  	handler := routeHandler(mo, &oapispec.Route{
   232  		Name:            "testRoute",
   233  		Path:            "/test",
   234  		Method:          "GET",
   235  		JSONInputValue:  nil,
   236  		JSONOutputValue: func() interface{} { return make(map[string]interface{}) },
   237  		JSONOutputCode:  200,
   238  		JSONHandler: func(r oapispec.APIRequest) (output interface{}, err error) {
   239  			return nil, nil
   240  		},
   241  	})
   242  	s := httptest.NewServer(http.HandlerFunc(handler))
   243  	defer s.Close()
   244  
   245  	b, _ := json.Marshal(map[string]interface{}{"input1": "value1"})
   246  	res, err := http.Post(fmt.Sprintf("http://%s/test", s.Listener.Addr()), "application/json", bytes.NewReader(b))
   247  	assert.NoError(t, err)
   248  	assert.Equal(t, 404, res.StatusCode)
   249  	var resJSON map[string]interface{}
   250  	json.NewDecoder(res.Body).Decode(&resJSON)
   251  	assert.Regexp(t, "FF10143", resJSON["error"])
   252  }
   253  
   254  func TestJSONHTTPDefault500Error(t *testing.T) {
   255  	mo := &orchestratormocks.Orchestrator{}
   256  	handler := routeHandler(mo, &oapispec.Route{
   257  		Name:            "testRoute",
   258  		Path:            "/test",
   259  		Method:          "GET",
   260  		JSONInputValue:  nil,
   261  		JSONOutputValue: func() interface{} { return make(map[string]interface{}) },
   262  		JSONOutputCode:  200,
   263  		JSONHandler: func(r oapispec.APIRequest) (output interface{}, err error) {
   264  			return nil, fmt.Errorf("pop")
   265  		},
   266  	})
   267  	s := httptest.NewServer(http.HandlerFunc(handler))
   268  	defer s.Close()
   269  
   270  	b, _ := json.Marshal(map[string]interface{}{"input1": "value1"})
   271  	res, err := http.Post(fmt.Sprintf("http://%s/test", s.Listener.Addr()), "application/json", bytes.NewReader(b))
   272  	assert.NoError(t, err)
   273  	assert.Equal(t, 500, res.StatusCode)
   274  	var resJSON map[string]interface{}
   275  	json.NewDecoder(res.Body).Decode(&resJSON)
   276  	assert.Regexp(t, "pop", resJSON["error"])
   277  }
   278  
   279  func TestStatusCodeHintMapping(t *testing.T) {
   280  	mo := &orchestratormocks.Orchestrator{}
   281  	handler := routeHandler(mo, &oapispec.Route{
   282  		Name:            "testRoute",
   283  		Path:            "/test",
   284  		Method:          "GET",
   285  		JSONInputValue:  nil,
   286  		JSONOutputValue: func() interface{} { return make(map[string]interface{}) },
   287  		JSONOutputCode:  200,
   288  		JSONHandler: func(r oapispec.APIRequest) (output interface{}, err error) {
   289  			return nil, i18n.NewError(r.Ctx, i18n.MsgResponseMarshalError)
   290  		},
   291  	})
   292  	s := httptest.NewServer(http.HandlerFunc(handler))
   293  	defer s.Close()
   294  
   295  	b, _ := json.Marshal(map[string]interface{}{"input1": "value1"})
   296  	res, err := http.Post(fmt.Sprintf("http://%s/test", s.Listener.Addr()), "application/json", bytes.NewReader(b))
   297  	assert.NoError(t, err)
   298  	assert.Equal(t, 400, res.StatusCode)
   299  	var resJSON map[string]interface{}
   300  	json.NewDecoder(res.Body).Decode(&resJSON)
   301  	assert.Regexp(t, "FF10107", resJSON["error"])
   302  }
   303  
   304  func TestStatusInvalidContentType(t *testing.T) {
   305  	mo := &orchestratormocks.Orchestrator{}
   306  	handler := routeHandler(mo, &oapispec.Route{
   307  		Name:            "testRoute",
   308  		Path:            "/test",
   309  		Method:          "POST",
   310  		JSONInputValue:  nil,
   311  		JSONOutputValue: func() interface{} { return make(map[string]interface{}) },
   312  		JSONOutputCode:  204,
   313  		JSONHandler: func(r oapispec.APIRequest) (output interface{}, err error) {
   314  			return nil, nil
   315  		},
   316  	})
   317  	s := httptest.NewServer(http.HandlerFunc(handler))
   318  	defer s.Close()
   319  
   320  	res, err := http.Post(fmt.Sprintf("http://%s/test", s.Listener.Addr()), "application/text", bytes.NewReader([]byte{}))
   321  	assert.NoError(t, err)
   322  	assert.Equal(t, 415, res.StatusCode)
   323  	var resJSON map[string]interface{}
   324  	json.NewDecoder(res.Body).Decode(&resJSON)
   325  	assert.Regexp(t, "FF10130", resJSON["error"])
   326  }
   327  
   328  func TestNotFound(t *testing.T) {
   329  	handler := apiWrapper(notFoundHandler)
   330  	s := httptest.NewServer(http.HandlerFunc(handler))
   331  	defer s.Close()
   332  
   333  	res, err := http.Get(fmt.Sprintf("http://%s/test", s.Listener.Addr()))
   334  	assert.NoError(t, err)
   335  	assert.Equal(t, 404, res.StatusCode)
   336  	var resJSON map[string]interface{}
   337  	json.NewDecoder(res.Body).Decode(&resJSON)
   338  	assert.Regexp(t, "FF10109", resJSON["error"])
   339  }
   340  
   341  func TestSwaggerUI(t *testing.T) {
   342  	handler := apiWrapper(swaggerUIHandler)
   343  	s := httptest.NewServer(http.HandlerFunc(handler))
   344  	defer s.Close()
   345  
   346  	res, err := http.Get(fmt.Sprintf("http://%s/api", s.Listener.Addr()))
   347  	assert.NoError(t, err)
   348  	assert.Equal(t, 200, res.StatusCode)
   349  	b, _ := ioutil.ReadAll(res.Body)
   350  	assert.Regexp(t, "html", string(b))
   351  }
   352  
   353  func TestSwaggerYAML(t *testing.T) {
   354  	handler := apiWrapper(swaggerHandler)
   355  	s := httptest.NewServer(http.HandlerFunc(handler))
   356  	defer s.Close()
   357  
   358  	res, err := http.Get(fmt.Sprintf("http://%s/api/swagger.yaml", s.Listener.Addr()))
   359  	assert.NoError(t, err)
   360  	assert.Equal(t, 200, res.StatusCode)
   361  	b, _ := ioutil.ReadAll(res.Body)
   362  	doc, err := openapi3.NewLoader().LoadFromData(b)
   363  	assert.NoError(t, err)
   364  	err = doc.Validate(context.Background())
   365  	assert.NoError(t, err)
   366  }
   367  
   368  func TestSwaggerJSON(t *testing.T) {
   369  	mo := &orchestratormocks.Orchestrator{}
   370  	r := createMuxRouter(mo)
   371  	s := httptest.NewServer(r)
   372  	defer s.Close()
   373  
   374  	res, err := http.Get(fmt.Sprintf("http://%s/api/swagger.json", s.Listener.Addr()))
   375  	assert.NoError(t, err)
   376  	assert.Equal(t, 200, res.StatusCode)
   377  	b, _ := ioutil.ReadAll(res.Body)
   378  	err = json.Unmarshal(b, &openapi3.T{})
   379  	assert.NoError(t, err)
   380  }