github.com/kumasuke120/mockuma@v1.1.9/internal/server/executor.go (about)

     1  package server
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	"log"
    10  	"math/rand"
    11  	"net/http"
    12  	"net/url"
    13  	"path"
    14  	"strconv"
    15  	"strings"
    16  	"time"
    17  
    18  	"github.com/kumasuke120/mockuma/internal/mckmaps"
    19  	"github.com/kumasuke120/mockuma/internal/myhttp"
    20  )
    21  
    22  type policyExecutor struct {
    23  	h      *mockHandler
    24  	r      *http.Request
    25  	w      *http.ResponseWriter
    26  	policy *mckmaps.Policy
    27  
    28  	returnHead   bool
    29  	fromForwards bool
    30  	statusCode   int
    31  }
    32  
    33  type forwardError struct {
    34  	err error
    35  }
    36  
    37  func (e *forwardError) Error() string {
    38  	return "fail to forward: " + e.err.Error()
    39  }
    40  
    41  func (e *policyExecutor) execute() error {
    42  	cmdType := e.policy.CmdType
    43  	switch cmdType {
    44  	case mckmaps.CmdTypeReturns:
    45  		fallthrough
    46  	case mckmaps.CmdTypeRedirects:
    47  		return e.executeReturns()
    48  	case mckmaps.CmdTypeForwards:
    49  		return e.executeForwards()
    50  	}
    51  
    52  	log.Printf("[executor] %-9s: unsupported command type\n", cmdType)
    53  	return errors.New("unsupported command type: " + string(cmdType))
    54  }
    55  
    56  func (e *policyExecutor) executeReturns() error {
    57  	returns := e.policy.Returns
    58  
    59  	if returns.Latency != nil {
    60  		waitBeforeReturns(returns.Latency)
    61  	}
    62  
    63  	err := e.writeResponseForReturns(returns)
    64  	if err != nil {
    65  		return err
    66  	}
    67  
    68  	e.statusCode = int(returns.StatusCode) // records statusCode for forwards
    69  	if !e.fromForwards {                   // forwards prints its log by itself
    70  		log.Printf("[executor] %-9s: (%d) %s %s\n", e.policy.CmdType,
    71  			e.statusCode, e.r.Method, e.r.URL)
    72  	}
    73  	return nil
    74  }
    75  
    76  func (e *policyExecutor) writeResponseForReturns(returns *mckmaps.Returns) error {
    77  	e.writeHeaders(returns.Headers)
    78  
    79  	if e.returnHead {
    80  		(*e.w).Header().Set(myhttp.HeaderContentLength, strconv.Itoa(len(returns.Body)))
    81  	}
    82  
    83  	// writes the statusCode, which must be written after headers
    84  	(*e.w).WriteHeader(int(returns.StatusCode))
    85  
    86  	if !e.returnHead {
    87  		err := e.writeBody(returns.Body)
    88  		if err != nil {
    89  			return err
    90  		}
    91  	}
    92  
    93  	return nil
    94  }
    95  
    96  func (e *policyExecutor) writeHeaders(headers []*mckmaps.NameValuesPair) {
    97  	outHeader := (*e.w).Header()
    98  
    99  	// new headers overrides old ones
   100  	for _, pair := range headers {
   101  		if _, ok := outHeader[pair.Name]; ok {
   102  			outHeader.Del(pair.Name)
   103  		}
   104  
   105  		for _, value := range pair.Values {
   106  			outHeader.Add(pair.Name, value)
   107  		}
   108  	}
   109  }
   110  
   111  func (e *policyExecutor) writeBody(body []byte) error {
   112  	var err error
   113  	if body != nil && len(body) != 0 {
   114  		_, err = (*e.w).Write(body)
   115  	}
   116  	return err
   117  }
   118  
   119  func (e *policyExecutor) executeForwards() error {
   120  	forwards := e.policy.Forwards
   121  
   122  	if forwards.Latency != nil {
   123  		waitBeforeReturns(forwards.Latency)
   124  	}
   125  
   126  	fPath := forwards.Path
   127  	if strings.HasPrefix(fPath, "http://") || strings.HasPrefix(fPath, "https://") {
   128  		return e.forwardsRemote(fPath)
   129  	} else {
   130  		return e.forwardsLocal(fPath)
   131  	}
   132  }
   133  
   134  func (e *policyExecutor) forwardsRemote(fPath string) error {
   135  	newRequest, err := e.newForwardRemoteRequest(fPath)
   136  	if err != nil {
   137  		return err
   138  	}
   139  
   140  	httpClient := http.Client{}
   141  	resp, err := httpClient.Do(newRequest)
   142  	if err != nil {
   143  		return &forwardError{err: err}
   144  	}
   145  	defer func() {
   146  		if err := resp.Body.Close(); err != nil {
   147  			log.Println("[executor] error    : error encountered when forwarding: " + err.Error())
   148  		}
   149  	}()
   150  
   151  	err = e.writeResponseForForwardsRemote(resp)
   152  	if err != nil {
   153  		return err
   154  	}
   155  
   156  	e.statusCode = resp.StatusCode // records statusCode for forwards
   157  	log.Printf("[executor] %-9s: (%d) %s %s => %s\n", e.policy.CmdType,
   158  		resp.StatusCode, e.r.Method, e.r.URL, newRequest.URL)
   159  	return nil
   160  }
   161  
   162  func (e *policyExecutor) newForwardRemoteRequest(fPath string) (*http.Request, error) {
   163  	reqURL := e.r.URL
   164  
   165  	_url, err := url.Parse(fPath)
   166  	if err != nil {
   167  		return nil, &forwardError{err: err}
   168  	}
   169  	_url.RawQuery = reqURL.RawQuery
   170  
   171  	newRequest, err := e.newForwardRequest(_url.String())
   172  	if err != nil {
   173  		return nil, err
   174  	}
   175  	newRequest.Header.Set(myhttp.HeaderXForwardedFor, e.r.RemoteAddr)
   176  
   177  	return newRequest, err
   178  }
   179  
   180  func (e *policyExecutor) writeResponseForForwardsRemote(resp *http.Response) error {
   181  	for key, values := range resp.Header {
   182  		(*e.w).Header()[key] = values
   183  	}
   184  	(*e.w).Header().Set(myhttp.HeaderXForwardedServer, HeaderValueServer)
   185  	(*e.w).WriteHeader(resp.StatusCode) // statusCode must be written after headers
   186  	_, err := io.Copy(*e.w, resp.Body)
   187  	if err != nil {
   188  		return err
   189  	}
   190  	return nil
   191  }
   192  
   193  func (e *policyExecutor) forwardsLocal(fPath string) error {
   194  	newRequest, err := e.newForwardLocalRequest(fPath)
   195  	if err != nil {
   196  		return err
   197  	}
   198  
   199  	fe := e.h.matchNewExecutor(newRequest, *e.w)
   200  	fe.fromForwards = true
   201  	err = fe.execute() // executor writes response for forwards
   202  
   203  	if err == nil {
   204  		e.statusCode = fe.statusCode // records statusCode for forwards
   205  		log.Printf("[executor] %-9s: (%d) %s %s => %s\n", e.policy.CmdType,
   206  			fe.statusCode, e.r.Method, e.r.URL, newRequest.URL)
   207  	}
   208  	return err
   209  }
   210  
   211  func (e *policyExecutor) newForwardLocalRequest(fPath string) (*http.Request, error) {
   212  	reqURL := e.r.URL
   213  
   214  	if !strings.HasPrefix(fPath, "/") {
   215  		uri := reqURL.Path
   216  		fPath = path.Join(uri, "../"+fPath)
   217  	}
   218  
   219  	rawQuery := reqURL.RawQuery
   220  	var requestURI string
   221  	if len(rawQuery) == 0 {
   222  		requestURI = fPath
   223  	} else {
   224  		requestURI = fmt.Sprintf("%s?%s", fPath, rawQuery)
   225  	}
   226  	newRequest, err := e.newForwardRequest(requestURI)
   227  	if err != nil {
   228  		return nil, err
   229  	}
   230  	return newRequest, err
   231  }
   232  
   233  func (e *policyExecutor) newForwardRequest(url string) (*http.Request, error) {
   234  	method := e.r.Method
   235  	body, err := ioutil.ReadAll(e.r.Body)
   236  	if err != nil {
   237  		return nil, &forwardError{err: err}
   238  	}
   239  	newRequest, err := http.NewRequest(method, url, bytes.NewReader(body))
   240  	if err != nil {
   241  		return nil, &forwardError{err: err}
   242  	}
   243  	newRequest.Header = e.r.Header
   244  	return newRequest, nil
   245  }
   246  
   247  func waitBeforeReturns(latency *mckmaps.Interval) {
   248  	diff := latency.Max - latency.Min
   249  	if diff > 0 {
   250  		d := rand.Int63n(diff) + latency.Min
   251  		time.Sleep(time.Duration(d * int64(time.Millisecond)))
   252  	} else if latency.Min > 0 {
   253  		time.Sleep(time.Duration(latency.Min * int64(time.Millisecond)))
   254  	}
   255  }