github.com/kubevela/workflow@v0.6.0/pkg/providers/http/do.go (about)

     1  /*
     2  Copyright 2022 The KubeVela Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package http
    18  
    19  import (
    20  	"context"
    21  	"crypto/tls"
    22  	"crypto/x509"
    23  	"encoding/base64"
    24  	"fmt"
    25  	"io"
    26  	"net/http"
    27  	"strings"
    28  	"time"
    29  
    30  	"cuelang.org/go/cue"
    31  	"github.com/pkg/errors"
    32  	v1 "k8s.io/api/core/v1"
    33  	"sigs.k8s.io/controller-runtime/pkg/client"
    34  
    35  	monitorContext "github.com/kubevela/pkg/monitor/context"
    36  
    37  	wfContext "github.com/kubevela/workflow/pkg/context"
    38  	"github.com/kubevela/workflow/pkg/cue/model/value"
    39  	"github.com/kubevela/workflow/pkg/providers/http/ratelimiter"
    40  	"github.com/kubevela/workflow/pkg/types"
    41  )
    42  
    43  const (
    44  	// ProviderName is provider name for install.
    45  	ProviderName = "http"
    46  )
    47  
    48  var (
    49  	rateLimiter *ratelimiter.RateLimiter
    50  )
    51  
    52  func init() {
    53  	rateLimiter = ratelimiter.NewRateLimiter(128)
    54  }
    55  
    56  type provider struct {
    57  	cli client.Client
    58  	ns  string
    59  }
    60  
    61  // Do process http request.
    62  func (h *provider) Do(ctx monitorContext.Context, wfCtx wfContext.Context, v *value.Value, act types.Action) error {
    63  	resp, err := h.runHTTP(ctx, v)
    64  	if err != nil {
    65  		return err
    66  	}
    67  	return v.FillObject(resp, "response")
    68  }
    69  
    70  func (h *provider) runHTTP(ctx monitorContext.Context, v *value.Value) (interface{}, error) {
    71  	var (
    72  		err             error
    73  		method, u       string
    74  		header, trailer http.Header
    75  		r               io.Reader
    76  	)
    77  	defaultClient := &http.Client{
    78  		Transport: http.DefaultTransport,
    79  		Timeout:   time.Second * 3,
    80  	}
    81  	if timeout, err := v.GetString("request", "timeout"); err == nil && timeout != "" {
    82  		duration, err := time.ParseDuration(timeout)
    83  		if err != nil {
    84  			return nil, err
    85  		}
    86  		defaultClient.Timeout = duration
    87  	}
    88  	if method, err = v.GetString("method"); err != nil {
    89  		return nil, err
    90  	}
    91  	if u, err = v.GetString("url"); err != nil {
    92  		return nil, err
    93  	}
    94  	if rl, err := v.LookupValue("request", "ratelimiter"); err == nil {
    95  		limit, err := rl.GetInt64("limit")
    96  		if err != nil {
    97  			return nil, err
    98  		}
    99  		period, err := rl.GetString("period")
   100  		if err != nil {
   101  			return nil, err
   102  		}
   103  		duration, err := time.ParseDuration(period)
   104  		if err != nil {
   105  			return nil, err
   106  		}
   107  		if !rateLimiter.Allow(fmt.Sprintf("%s-%s", method, strings.Split(u, "?")[0]), int(limit), duration) {
   108  			return nil, errors.New("request exceeds the rate limiter")
   109  		}
   110  	}
   111  	if body, err := v.LookupValue("request", "body"); err == nil {
   112  		r, err = body.CueValue().Reader()
   113  		if err != nil {
   114  			return nil, err
   115  		}
   116  	}
   117  	if header, err = parseHeaders(v.CueValue(), "header"); err != nil {
   118  		return nil, err
   119  	}
   120  	if trailer, err = parseHeaders(v.CueValue(), "trailer"); err != nil {
   121  		return nil, err
   122  	}
   123  	if header == nil {
   124  		header = map[string][]string{}
   125  		header.Set("Content-Type", "application/json")
   126  	}
   127  
   128  	req, err := http.NewRequestWithContext(context.Background(), method, u, r)
   129  	if err != nil {
   130  		return nil, err
   131  	}
   132  	req.Header = header
   133  	req.Trailer = trailer
   134  
   135  	if tr, err := h.getTransport(ctx, v); err == nil && tr != nil {
   136  		defaultClient.Transport = tr
   137  	}
   138  
   139  	resp, err := defaultClient.Do(req)
   140  	if err != nil {
   141  		return nil, err
   142  	}
   143  	//nolint:errcheck
   144  	defer resp.Body.Close()
   145  	b, err := io.ReadAll(resp.Body)
   146  	// parse response body and headers
   147  	return map[string]interface{}{
   148  		"body":       string(b),
   149  		"header":     resp.Header,
   150  		"trailer":    resp.Trailer,
   151  		"statusCode": resp.StatusCode,
   152  	}, err
   153  }
   154  
   155  func (h *provider) getTransport(ctx monitorContext.Context, v *value.Value) (http.RoundTripper, error) {
   156  	tlsConfig, err := v.LookupValue("tls_config")
   157  	if err != nil {
   158  		return nil, nil
   159  	}
   160  	tr := &http.Transport{
   161  		TLSClientConfig: &tls.Config{
   162  			NextProtos: []string{"http/1.1"},
   163  		},
   164  	}
   165  
   166  	secretName, err := tlsConfig.GetString("secret")
   167  	if err != nil {
   168  		return nil, err
   169  	}
   170  	objectKey := client.ObjectKey{
   171  		Namespace: h.ns,
   172  		Name:      secretName,
   173  	}
   174  	index := strings.Index(secretName, "/")
   175  	if index > 0 {
   176  		objectKey.Namespace = secretName[:index-1]
   177  		objectKey.Name = secretName[index:]
   178  	}
   179  	secret := new(v1.Secret)
   180  	if err := h.cli.Get(ctx, objectKey, secret); err != nil {
   181  		return nil, err
   182  	}
   183  	if ca, ok := secret.Data["ca.crt"]; ok {
   184  		caData, err := base64.StdEncoding.DecodeString(string(ca))
   185  		if err != nil {
   186  			return nil, err
   187  		}
   188  		pool := x509.NewCertPool()
   189  		pool.AppendCertsFromPEM(caData)
   190  		tr.TLSClientConfig.RootCAs = pool
   191  	}
   192  	var certData, keyData []byte
   193  	if clientCert, ok := secret.Data["client.crt"]; ok {
   194  		certData, err = base64.StdEncoding.DecodeString(string(clientCert))
   195  		if err != nil {
   196  			return nil, err
   197  		}
   198  	}
   199  	if clientKey, ok := secret.Data["client.key"]; ok {
   200  		keyData, err = base64.StdEncoding.DecodeString(string(clientKey))
   201  		if err != nil {
   202  			return nil, err
   203  		}
   204  	}
   205  	cliCrt, err := tls.X509KeyPair(certData, keyData)
   206  	if err != nil {
   207  		return nil, errors.WithMessage(err, "parse client keypair")
   208  	}
   209  	tr.TLSClientConfig.Certificates = []tls.Certificate{cliCrt}
   210  	return tr, nil
   211  }
   212  
   213  func parseHeaders(obj cue.Value, label string) (http.Header, error) {
   214  	m := obj.LookupPath(value.FieldPath("request", label))
   215  	if !m.Exists() {
   216  		return nil, nil
   217  	}
   218  	iter, err := m.Fields()
   219  	if err != nil {
   220  		return nil, err
   221  	}
   222  	h := http.Header{}
   223  	for iter.Next() {
   224  		str, err := iter.Value().String()
   225  		if err != nil {
   226  			return nil, err
   227  		}
   228  		h.Add(iter.Label(), str)
   229  	}
   230  	return h, nil
   231  }
   232  
   233  // Install register handlers to provider discover.
   234  func Install(p types.Providers, cli client.Client, ns string) {
   235  	prd := &provider{
   236  		cli: cli,
   237  		ns:  ns,
   238  	}
   239  	p.Register(ProviderName, map[string]types.Handler{
   240  		"do": prd.Do,
   241  	})
   242  }