github.com/dbernstein1/tyk@v2.9.0-beta9-dl-apic+incompatible/gateway/batch_requests.go (about)

     1  package gateway
     2  
     3  import (
     4  	"crypto/tls"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io/ioutil"
     8  	"net/http"
     9  	"strconv"
    10  	"strings"
    11  
    12  	"github.com/TykTechnologies/tyk/config"
    13  )
    14  
    15  // RequestDefinition defines a batch request
    16  type RequestDefinition struct {
    17  	Method      string            `json:"method"`
    18  	Headers     map[string]string `json:"headers"`
    19  	Body        string            `json:"body"`
    20  	RelativeURL string            `json:"relative_url"`
    21  }
    22  
    23  // BatchRequestStructure defines a batch request order
    24  type BatchRequestStructure struct {
    25  	Requests                  []RequestDefinition `json:"requests"`
    26  	SuppressParallelExecution bool                `json:"suppress_parallel_execution"`
    27  }
    28  
    29  // BatchReplyUnit encodes a request suitable for replying to a batch request
    30  type BatchReplyUnit struct {
    31  	RelativeURL string      `json:"relative_url"`
    32  	Code        int         `json:"code"`
    33  	Headers     http.Header `json:"headers"`
    34  	Body        string      `json:"body"`
    35  }
    36  
    37  // BatchRequestHandler handles batch requests on /tyk/batch for any API Definition that has the feature enabled
    38  type BatchRequestHandler struct {
    39  	API *APISpec
    40  }
    41  
    42  // doRequest will make the same request but return a BatchReplyUnit
    43  func (b *BatchRequestHandler) doRequest(req *http.Request, relURL string) BatchReplyUnit {
    44  	tr := &http.Transport{TLSClientConfig: &tls.Config{}}
    45  
    46  	if cert := getUpstreamCertificate(req.Host, b.API); cert != nil {
    47  		tr.TLSClientConfig.Certificates = []tls.Certificate{*cert}
    48  	}
    49  
    50  	tr.TLSClientConfig.InsecureSkipVerify = config.Global().ProxySSLInsecureSkipVerify
    51  
    52  	tr.DialTLS = dialTLSPinnedCheck(b.API, tr.TLSClientConfig)
    53  
    54  	tr.Proxy = proxyFromAPI(b.API)
    55  
    56  	client := &http.Client{Transport: tr}
    57  
    58  	resp, err := client.Do(req)
    59  	if err != nil {
    60  		log.Error("Webhook request failed: ", err)
    61  		return BatchReplyUnit{}
    62  	}
    63  
    64  	defer resp.Body.Close()
    65  	content, err := ioutil.ReadAll(resp.Body)
    66  	if err != nil {
    67  		log.Warning("Body read failure! ", err)
    68  		return BatchReplyUnit{}
    69  	}
    70  
    71  	return BatchReplyUnit{
    72  		RelativeURL: relURL,
    73  		Code:        resp.StatusCode,
    74  		Headers:     resp.Header,
    75  		Body:        string(content),
    76  	}
    77  }
    78  
    79  func (b *BatchRequestHandler) DecodeBatchRequest(r *http.Request) (BatchRequestStructure, error) {
    80  	var batchRequest BatchRequestStructure
    81  	err := json.NewDecoder(r.Body).Decode(&batchRequest)
    82  	return batchRequest, err
    83  }
    84  
    85  func (b *BatchRequestHandler) ConstructRequests(batchRequest BatchRequestStructure, unsafe bool) ([]*http.Request, error) {
    86  	requestSet := []*http.Request{}
    87  
    88  	for i, requestDef := range batchRequest.Requests {
    89  		// We re-build the URL to ensure that the requested URL is actually for the API in question
    90  		// URLs need to be built absolute so they go through the rate limiting and request limiting machinery
    91  		var absURL string
    92  		if !unsafe {
    93  			absUrlHeader := "http://localhost:" + strconv.Itoa(config.Global().ListenPort)
    94  			absURL = strings.Join([]string{absUrlHeader, strings.Trim(b.API.Proxy.ListenPath, "/"), requestDef.RelativeURL}, "/")
    95  		} else {
    96  			absURL = requestDef.RelativeURL
    97  		}
    98  
    99  		request, err := http.NewRequest(requestDef.Method, absURL, strings.NewReader(requestDef.Body))
   100  		if err != nil {
   101  			log.Error("Failure generating batch request for request spec index: ", i)
   102  			return nil, err
   103  		}
   104  
   105  		// Add headers
   106  		for k, v := range requestDef.Headers {
   107  			request.Header.Set(k, v)
   108  		}
   109  
   110  		requestSet = append(requestSet, request)
   111  	}
   112  
   113  	return requestSet, nil
   114  }
   115  
   116  func (b *BatchRequestHandler) MakeRequests(batchRequest BatchRequestStructure, requestSet []*http.Request) []BatchReplyUnit {
   117  	replySet := []BatchReplyUnit{}
   118  
   119  	if len(batchRequest.Requests) != len(requestSet) {
   120  		log.Error("Something went wrong creating requests, they are of mismatched lengths!", len(batchRequest.Requests), len(requestSet))
   121  	}
   122  
   123  	if !batchRequest.SuppressParallelExecution {
   124  		replies := make(chan BatchReplyUnit)
   125  		for i, req := range requestSet {
   126  			go func(i int, req *http.Request) {
   127  				reply := b.doRequest(req, batchRequest.Requests[i].RelativeURL)
   128  				replies <- reply
   129  			}(i, req)
   130  		}
   131  
   132  		for range batchRequest.Requests {
   133  			replySet = append(replySet, <-replies)
   134  		}
   135  	} else {
   136  		for i, req := range requestSet {
   137  			reply := b.doRequest(req, batchRequest.Requests[i].RelativeURL)
   138  			replySet = append(replySet, reply)
   139  		}
   140  	}
   141  
   142  	return replySet
   143  }
   144  
   145  // HandleBatchRequest is the actual http handler for a batch request on an API definition
   146  func (b *BatchRequestHandler) HandleBatchRequest(w http.ResponseWriter, r *http.Request) {
   147  	if r.Method != "POST" {
   148  		return
   149  	}
   150  
   151  	// Decode request
   152  	batchRequest, err := b.DecodeBatchRequest(r)
   153  	if err != nil {
   154  		log.Error("Could not decode batch request, decoding failed: ", err)
   155  		doJSONWrite(w, http.StatusBadRequest, apiError("Batch request malformed"))
   156  		return
   157  	}
   158  
   159  	// Construct the requests
   160  	requestSet, err := b.ConstructRequests(batchRequest, false)
   161  	if err != nil {
   162  		doJSONWrite(w, http.StatusBadRequest, apiError(fmt.Sprintf("Batch request creation failed , request structure malformed")))
   163  		return
   164  	}
   165  
   166  	// Run requests and collate responses
   167  	replySet := b.MakeRequests(batchRequest, requestSet)
   168  
   169  	// Respond
   170  	doJSONWrite(w, http.StatusOK, replySet)
   171  }
   172  
   173  // HandleBatchRequest is the actual http handler for a batch request on an API definition
   174  func (b *BatchRequestHandler) ManualBatchRequest(requestObject []byte) ([]byte, error) {
   175  	// Decode request
   176  	var batchRequest BatchRequestStructure
   177  	if err := json.Unmarshal(requestObject, &batchRequest); err != nil {
   178  		return nil, fmt.Errorf("Could not decode batch request, decoding failed: %v", err)
   179  	}
   180  
   181  	// Construct the unsafe requests
   182  	requestSet, err := b.ConstructRequests(batchRequest, true)
   183  	if err != nil {
   184  		return nil, fmt.Errorf("Batch request creation failed , request structure malformed: %v", err)
   185  	}
   186  
   187  	// Run requests and collate responses
   188  	replySet := b.MakeRequests(batchRequest, requestSet)
   189  
   190  	// Encode responses
   191  	replyMessage, err := json.Marshal(&replySet)
   192  	if err != nil {
   193  		return nil, fmt.Errorf("Couldn't encode response to string: %v", err)
   194  	}
   195  
   196  	return replyMessage, nil
   197  }