github.com/Tyktechnologies/tyk@v2.9.5+incompatible/gateway/coprocess_bundle_test.go (about)

     1  package gateway
     2  
     3  import (
     4  	"crypto/md5"
     5  	"fmt"
     6  	"io"
     7  	"os"
     8  	"path/filepath"
     9  	"runtime"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/TykTechnologies/tyk/config"
    14  	"github.com/TykTechnologies/tyk/test"
    15  )
    16  
    17  var (
    18  	testBundlesPath = filepath.Join(testMiddlewarePath, "bundles")
    19  )
    20  
    21  var pkgPath string
    22  
    23  func init() {
    24  	_, filename, _, _ := runtime.Caller(0)
    25  	pkgPath = filepath.Dir(filename) + "./.."
    26  }
    27  
    28  var grpcBundleWithAuthCheck = map[string]string{
    29  	"manifest.json": `
    30  		{
    31  		    "file_list": [],
    32  		    "custom_middleware": {
    33  		        "driver": "grpc",
    34  		        "auth_check": {
    35  		            "name": "MyAuthHook"
    36  		        }
    37  		    }
    38  		}
    39  	`,
    40  }
    41  
    42  func TestBundleLoader(t *testing.T) {
    43  	bundleID := RegisterBundle("grpc_with_auth_check", grpcBundleWithAuthCheck)
    44  
    45  	t.Run("Nonexistent bundle", func(t *testing.T) {
    46  		specs := BuildAndLoadAPI(func(spec *APISpec) {
    47  			spec.CustomMiddlewareBundle = "nonexistent.zip"
    48  		})
    49  		err := loadBundle(specs[0])
    50  		if err == nil {
    51  			t.Fatal("Fetching a nonexistent bundle, expected an error")
    52  		}
    53  	})
    54  
    55  	t.Run("Existing bundle with auth check", func(t *testing.T) {
    56  		specs := BuildAndLoadAPI(func(spec *APISpec) {
    57  			spec.CustomMiddlewareBundle = bundleID
    58  		})
    59  		spec := specs[0]
    60  		err := loadBundle(spec)
    61  		if err != nil {
    62  			t.Fatalf("Bundle not found: %s\n", bundleID)
    63  		}
    64  
    65  		bundleNameHash := md5.New()
    66  		io.WriteString(bundleNameHash, spec.CustomMiddlewareBundle)
    67  		bundleDir := fmt.Sprintf("%s_%x", spec.APIID, bundleNameHash.Sum(nil))
    68  		savedBundlePath := filepath.Join(testBundlesPath, bundleDir)
    69  		if _, err = os.Stat(savedBundlePath); os.IsNotExist(err) {
    70  			t.Fatalf("Bundle wasn't saved to disk: %s", err.Error())
    71  		}
    72  
    73  		// Check bundle contents:
    74  		if spec.CustomMiddleware.AuthCheck.Name != "MyAuthHook" {
    75  			t.Fatalf("Auth check function doesn't match: got %s, expected %s\n", spec.CustomMiddleware.AuthCheck.Name, "MyAuthHook")
    76  		}
    77  		if string(spec.CustomMiddleware.Driver) != "grpc" {
    78  			t.Fatalf("Driver doesn't match: got %s, expected %s\n", spec.CustomMiddleware.Driver, "grpc")
    79  		}
    80  	})
    81  }
    82  
    83  func TestBundleFetcher(t *testing.T) {
    84  	bundleID := "testbundle"
    85  	defer ResetTestConfig()
    86  
    87  	t.Run("Simple bundle base URL", func(t *testing.T) {
    88  		globalConf := config.Global()
    89  		globalConf.BundleBaseURL = "mock://somepath"
    90  		globalConf.BundleInsecureSkipVerify = false
    91  		config.SetGlobal(globalConf)
    92  		specs := BuildAndLoadAPI(func(spec *APISpec) {
    93  			spec.CustomMiddlewareBundle = bundleID
    94  		})
    95  		spec := specs[0]
    96  		bundle, err := fetchBundle(spec)
    97  		if err != nil {
    98  			t.Fatalf("Couldn't fetch bundle: %s", err.Error())
    99  		}
   100  
   101  		if string(bundle.Data) != "bundle" {
   102  			t.Errorf("Wrong bundle data: %s", bundle.Data)
   103  		}
   104  		if bundle.Name != bundleID {
   105  			t.Errorf("Wrong bundle name: %s", bundle.Name)
   106  		}
   107  	})
   108  
   109  	t.Run("Bundle base URL with querystring", func(t *testing.T) {
   110  		globalConf := config.Global()
   111  		globalConf.BundleBaseURL = "mock://somepath?api_key=supersecret"
   112  		globalConf.BundleInsecureSkipVerify = true
   113  		config.SetGlobal(globalConf)
   114  		specs := BuildAndLoadAPI(func(spec *APISpec) {
   115  			spec.CustomMiddlewareBundle = bundleID
   116  		})
   117  		spec := specs[0]
   118  		bundle, err := fetchBundle(spec)
   119  		if err != nil {
   120  			t.Fatalf("Couldn't fetch bundle: %s", err.Error())
   121  		}
   122  
   123  		if string(bundle.Data) != "bundle-insecure" {
   124  			t.Errorf("Wrong bundle data: %s", bundle.Data)
   125  		}
   126  		if bundle.Name != bundleID {
   127  			t.Errorf("Wrong bundle name: %s", bundle.Name)
   128  		}
   129  	})
   130  }
   131  
   132  var overrideResponsePython = map[string]string{
   133  	"manifest.json": `
   134  		{
   135  		    "file_list": [
   136  		        "middleware.py"
   137  		    ],
   138  		    "custom_middleware": {
   139  		        "driver": "python",
   140  		        "pre": [{
   141  		            "name": "MyRequestHook"
   142  		        }]
   143  		    }
   144  		}
   145  	`,
   146  	"middleware.py": `
   147  from tyk.decorators import *
   148  from gateway import TykGateway as tyk
   149  
   150  @Hook
   151  def MyRequestHook(request, session, spec):
   152  	request.object.return_overrides.headers['X-Foo'] = 'Bar'
   153  	request.object.return_overrides.response_code = int(request.object.params["status"])
   154  
   155  	if request.object.params["response_body"] == "true":
   156  		request.object.return_overrides.response_body = "foobar"
   157  	else:
   158  		request.object.return_overrides.response_error = "{\"foo\": \"bar\"}"
   159  
   160  	if request.object.params["override"]:
   161  		request.object.return_overrides.override_error = True
   162  
   163  	return request, session
   164  `,
   165  }
   166  
   167  var overrideResponseJSVM = map[string]string{
   168  	"manifest.json": `
   169  {
   170      "file_list": [],
   171      "custom_middleware": {
   172          "driver": "otto",
   173          "pre": [{
   174              "name": "pre",
   175              "path": "pre.js"
   176          }]
   177      }
   178  }
   179  `,
   180  	"pre.js": `
   181  var pre = new TykJS.TykMiddleware.NewMiddleware({});
   182  
   183  pre.NewProcessRequest(function(request, session) {
   184  	if (request.Params["response_body"]) {
   185  		request.ReturnOverrides.ResponseBody = 'foobar'
   186  	} else {
   187  		request.ReturnOverrides.ResponseError = '{"foo": "bar"}'
   188  	}
   189  
   190  	request.ReturnOverrides.ResponseCode = parseInt(request.Params["status"])
   191  	request.ReturnOverrides.ResponseHeaders = {"X-Foo": "Bar"}
   192  
   193  	if (request.Params["override"]) {
   194  		request.ReturnOverrides.OverrideError = true
   195  	}
   196  	return pre.ReturnData(request, {});
   197  });
   198  `,
   199  }
   200  
   201  func TestResponseOverride(t *testing.T) {
   202  	ts := StartTest(TestConfig{
   203  		CoprocessConfig: config.CoProcessConfig{
   204  			EnableCoProcess:  true,
   205  			PythonPathPrefix: pkgPath,
   206  		}})
   207  	defer ts.Close()
   208  
   209  	customHeader := map[string]string{"X-Foo": "Bar"}
   210  	customError := `{"foo": "bar"}`
   211  	customBody := `foobar`
   212  
   213  	testOverride := func(t *testing.T, bundle string) {
   214  		BuildAndLoadAPI(func(spec *APISpec) {
   215  			spec.Proxy.ListenPath = "/test/"
   216  			spec.UseKeylessAccess = true
   217  			spec.CustomMiddlewareBundle = bundle
   218  		})
   219  
   220  		time.Sleep(1 * time.Second)
   221  
   222  		ts.Run(t, []test.TestCase{
   223  			{Path: "/test/?status=200", Code: 200, BodyMatch: customError, HeadersMatch: customHeader},
   224  			{Path: "/test/?status=200&response_body=true", Code: 200, BodyMatch: customBody, HeadersMatch: customHeader},
   225  			{Path: "/test/?status=400", Code: 400, BodyMatch: `"error": "`, HeadersMatch: customHeader},
   226  			{Path: "/test/?status=400&response_body=true", Code: 400, BodyMatch: `"error": "foobar"`, HeadersMatch: customHeader},
   227  			{Path: "/test/?status=401", Code: 401, BodyMatch: `"error": "`, HeadersMatch: customHeader},
   228  			{Path: "/test/?status=400&override=true", Code: 400, BodyMatch: customError, HeadersMatch: customHeader},
   229  			{Path: "/test/?status=400&override=true&response_body=true", Code: 400, BodyMatch: customBody, HeadersMatch: customHeader},
   230  			{Path: "/test/?status=401&override=true", Code: 401, BodyMatch: customError, HeadersMatch: customHeader},
   231  		}...)
   232  	}
   233  	t.Run("Python", func(t *testing.T) {
   234  		testOverride(t, RegisterBundle("python_override", overrideResponsePython))
   235  	})
   236  
   237  	t.Run("JSVM", func(t *testing.T) {
   238  		testOverride(t, RegisterBundle("jsvm_override", overrideResponseJSVM))
   239  	})
   240  }