github.com/kubevela/workflow@v0.6.0/pkg/providers/http/do_test.go (about)

     1  /*
     2  Copyright 2022 The KubeVela Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package http
    18  
    19  import (
    20  	"context"
    21  	"crypto/tls"
    22  	"crypto/x509"
    23  	"encoding/base64"
    24  	"encoding/json"
    25  	"fmt"
    26  	"io"
    27  	"net"
    28  	"net/http"
    29  	"net/http/httptest"
    30  	"testing"
    31  	"time"
    32  
    33  	"github.com/crossplane/crossplane-runtime/pkg/test"
    34  	"github.com/stretchr/testify/require"
    35  	v1 "k8s.io/api/core/v1"
    36  	"sigs.k8s.io/controller-runtime/pkg/client"
    37  
    38  	monitorContext "github.com/kubevela/pkg/monitor/context"
    39  
    40  	"github.com/kubevela/workflow/pkg/cue/model/value"
    41  	"github.com/kubevela/workflow/pkg/providers"
    42  	"github.com/kubevela/workflow/pkg/providers/http/ratelimiter"
    43  	"github.com/kubevela/workflow/pkg/providers/http/testdata"
    44  )
    45  
    46  func TestHttpDo(t *testing.T) {
    47  	shutdown := make(chan struct{})
    48  	runMockServer(shutdown)
    49  	defer func() {
    50  		close(shutdown)
    51  	}()
    52  	ctx := monitorContext.NewTraceContext(context.Background(), "")
    53  	baseTemplate := `
    54  		url: string
    55  		request?: close({
    56  			timeout?: string
    57  			body?:    string
    58  			header?:  [string]: string
    59  			trailer?: [string]: string
    60  			ratelimiter?: {
    61  				limit: int
    62  				period: string
    63  			}
    64  		})
    65  		response: close({
    66  			body: string
    67  			header?:  [string]: [...string]
    68  			trailer?: [string]: [...string]
    69  			statusCode: int
    70  		})
    71  `
    72  	testCases := map[string]struct {
    73  		request      string
    74  		expectedBody string
    75  		expectedErr  string
    76  		statusCode   int
    77  	}{
    78  		"hello": {
    79  			request: baseTemplate + `
    80  method: "GET"
    81  url: "http://127.0.0.1:1229/hello"
    82  request: {
    83  	timeout: "2s"
    84  }`,
    85  			expectedBody: `hello`,
    86  			statusCode:   200,
    87  		},
    88  
    89  		"echo": {
    90  			request: baseTemplate + `
    91  method: "POST"
    92  url: "http://127.0.0.1:1229/echo"
    93  request:{ 
    94     body: "I am vela" 
    95     header: "Content-Type": "text/plain; charset=utf-8"
    96  }`,
    97  			expectedBody: `I am vela`,
    98  			statusCode:   200,
    99  		},
   100  		"json": {
   101  			request: `
   102  import ("encoding/json")
   103  foo: {
   104  	name: "foo"
   105  	score: 100
   106  }
   107  
   108  method: "POST"
   109  url: "http://127.0.0.1:1229/echo"
   110  request:{ 
   111     body: json.Marshal(foo)
   112     header: "Content-Type": "application/json; charset=utf-8"
   113  }` + baseTemplate,
   114  			expectedBody: `{"name":"foo","score":100}`,
   115  			statusCode:   200,
   116  		},
   117  		"timeout": {
   118  			request: baseTemplate + `
   119  method: "GET"
   120  url: "http://127.0.0.1:1229/timeout"
   121  request: {
   122  	timeout: "1s"
   123  }`,
   124  			expectedErr: "context deadline exceeded",
   125  		},
   126  		"not-timeout": {
   127  			request: baseTemplate + `
   128  method: "GET"
   129  url: "http://127.0.0.1:1229/timeout"
   130  request: {
   131  	timeout: "3s"
   132  }`,
   133  			expectedBody: `hello`,
   134  			statusCode:   200,
   135  		},
   136  		"invalid-timeout": {
   137  			request: baseTemplate + `
   138  method: "GET"
   139  url: "http://127.0.0.1:1229/timeout"
   140  request: {
   141  	timeout: "test"
   142  }`,
   143  			expectedErr: "invalid duration",
   144  		},
   145  		"notfound": {
   146  			request: baseTemplate + `
   147  method: "GET"
   148  url: "http://127.0.0.1:1229/notfound"
   149  `,
   150  			statusCode: 404,
   151  		},
   152  	}
   153  
   154  	for tName, tCase := range testCases {
   155  		r := require.New(t)
   156  		v, err := value.NewValue(tCase.request, nil, "")
   157  		r.NoError(err, tName)
   158  		prd := &provider{}
   159  		err = prd.Do(ctx, nil, v, nil)
   160  		if tCase.expectedErr != "" {
   161  			r.Error(err)
   162  			r.Contains(err.Error(), tCase.expectedErr)
   163  			continue
   164  		}
   165  		r.NoError(err, tName)
   166  		body, err := v.LookupValue("response", "body")
   167  		r.NoError(err, tName)
   168  		ret, err := body.CueValue().String()
   169  		r.NoError(err)
   170  		r.Equal(ret, tCase.expectedBody, tName)
   171  		code, err := v.LookupValue("response", "statusCode")
   172  		r.NoError(err, tName)
   173  		sc, err := code.CueValue().Int64()
   174  		r.NoError(err)
   175  		r.Equal(tCase.statusCode, int(sc), tName)
   176  	}
   177  
   178  	// test ratelimiter
   179  	rateLimiter = ratelimiter.NewRateLimiter(1)
   180  	limiterTestCases := []struct {
   181  		request     string
   182  		expectedErr string
   183  	}{
   184  		{
   185  			request: baseTemplate + `
   186  method: "GET"
   187  url: "http://127.0.0.1:1229/hello"
   188  request: {
   189  	ratelimiter: {
   190  		limit: 1
   191  		period: "1m"
   192  	}
   193  }`},
   194  		{
   195  			request: baseTemplate + `
   196  method: "GET"
   197  url: "http://127.0.0.1:1229/hello?query=1"
   198  request: {
   199  	ratelimiter: {
   200  		limit: 1
   201  		period: "1m"
   202  	}
   203  }`,
   204  			expectedErr: "request exceeds the rate limiter",
   205  		},
   206  		{
   207  			request: baseTemplate + `
   208  method: "GET"
   209  url: "http://127.0.0.1:1229/echo"
   210  request: {
   211  	ratelimiter: {
   212  		limit: 1
   213  		period: "1m"
   214  	}
   215  }`,
   216  		},
   217  		{
   218  			request: baseTemplate + `
   219  method: "GET"
   220  url: "http://127.0.0.1:1229/hello?query=2"
   221  request: {
   222  	ratelimiter: {
   223  		limit: 1
   224  		period: "1m"
   225  	}
   226  }`,
   227  		},
   228  	}
   229  
   230  	for tName, tCase := range limiterTestCases {
   231  		r := require.New(t)
   232  		v, err := value.NewValue(tCase.request, nil, "")
   233  		r.NoError(err, tName)
   234  		prd := &provider{}
   235  		err = prd.Do(ctx, nil, v, nil)
   236  		if tCase.expectedErr != "" {
   237  			r.Error(err)
   238  			r.Contains(err.Error(), tCase.expectedErr)
   239  			continue
   240  		}
   241  		r.NoError(err, tName)
   242  	}
   243  }
   244  
   245  func TestInstall(t *testing.T) {
   246  	r := require.New(t)
   247  	p := providers.NewProviders()
   248  	Install(p, nil, "")
   249  	h, ok := p.GetHandler("http", "do")
   250  	r.Equal(ok, true)
   251  	r.Equal(h != nil, true)
   252  }
   253  
   254  func runMockServer(shutdown chan struct{}) {
   255  	http.HandleFunc("/timeout", func(w http.ResponseWriter, req *http.Request) {
   256  		time.Sleep(time.Second * 2)
   257  		_, _ = w.Write([]byte("hello"))
   258  	})
   259  	http.HandleFunc("/hello", func(w http.ResponseWriter, req *http.Request) {
   260  		_, _ = w.Write([]byte("hello"))
   261  	})
   262  	http.HandleFunc("/echo", func(w http.ResponseWriter, req *http.Request) {
   263  		bt, _ := io.ReadAll(req.Body)
   264  		_, _ = w.Write(bt)
   265  	})
   266  	http.HandleFunc("/notfound", func(w http.ResponseWriter, req *http.Request) {
   267  		w.WriteHeader(404)
   268  	})
   269  	srv := &http.Server{Addr: ":1229"}
   270  	go srv.ListenAndServe() //nolint
   271  	go func() {
   272  		<-shutdown
   273  		srv.Close()
   274  	}()
   275  
   276  	client := &http.Client{}
   277  	// wait server started.
   278  	for {
   279  		time.Sleep(time.Millisecond * 300)
   280  		req, _ := http.NewRequest("GET", "http://127.0.0.1:1229/hello", nil)
   281  		_, err := client.Do(req)
   282  		if err == nil {
   283  			break
   284  		}
   285  	}
   286  }
   287  
   288  func TestHTTPSDo(t *testing.T) {
   289  	ctx := monitorContext.NewTraceContext(context.Background(), "")
   290  	s := newMockHttpsServer()
   291  	defer s.Close()
   292  	cli := &test.MockClient{
   293  		MockGet: func(ctx context.Context, key client.ObjectKey, obj client.Object) error {
   294  			secret := obj.(*v1.Secret)
   295  			*secret = v1.Secret{
   296  				Data: map[string][]byte{
   297  					"ca.crt":     []byte(testdata.MockCerts.Ca),
   298  					"client.crt": []byte(testdata.MockCerts.ClientCrt),
   299  					"client.key": []byte(testdata.MockCerts.ClientKey),
   300  				},
   301  			}
   302  			return nil
   303  		},
   304  	}
   305  	r := require.New(t)
   306  	v, err := value.NewValue(`
   307  method: "GET"
   308  url: "https://127.0.0.1:8443/api/v1/token?val=test-token"
   309  `, nil, "")
   310  	r.NoError(err)
   311  	r.NoError(v.FillObject("certs", "tls_config", "secret"))
   312  	prd := &provider{cli, "default"}
   313  	err = prd.Do(ctx, nil, v, nil)
   314  	r.NoError(err)
   315  }
   316  
   317  func newMockHttpsServer() *httptest.Server {
   318  	ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   319  		if r.Method != "GET" {
   320  			fmt.Printf("Expected 'GET' request, got '%s'", r.Method)
   321  		}
   322  		if r.URL.EscapedPath() != "/api/v1/token" {
   323  			fmt.Printf("Expected request to '/person', got '%s'", r.URL.EscapedPath())
   324  		}
   325  		_ = r.ParseForm()
   326  		token := r.Form.Get("val")
   327  		tokenBytes, _ := json.Marshal(map[string]interface{}{"token": token})
   328  
   329  		w.WriteHeader(http.StatusOK)
   330  		_, _ = w.Write(tokenBytes)
   331  	}))
   332  	l, _ := net.Listen("tcp", "127.0.0.1:8443")
   333  	ts.Listener.Close()
   334  	ts.Listener = l
   335  
   336  	decode := func(in string) []byte {
   337  		out, _ := base64.StdEncoding.DecodeString(in)
   338  		return out
   339  	}
   340  
   341  	pool := x509.NewCertPool()
   342  	pool.AppendCertsFromPEM(decode(testdata.MockCerts.Ca))
   343  	cert, _ := tls.X509KeyPair(decode(testdata.MockCerts.ServerCrt), decode(testdata.MockCerts.ServerKey))
   344  	ts.TLS = &tls.Config{
   345  		ClientCAs:    pool,
   346  		ClientAuth:   tls.RequireAndVerifyClientCert,
   347  		Certificates: []tls.Certificate{cert},
   348  		NextProtos:   []string{"http/1.1"},
   349  	}
   350  	ts.StartTLS()
   351  	return ts
   352  }