github.com/kumasuke120/mockuma@v1.1.9/internal/server/executor.go (about) 1 package server 2 3 import ( 4 "bytes" 5 "errors" 6 "fmt" 7 "io" 8 "io/ioutil" 9 "log" 10 "math/rand" 11 "net/http" 12 "net/url" 13 "path" 14 "strconv" 15 "strings" 16 "time" 17 18 "github.com/kumasuke120/mockuma/internal/mckmaps" 19 "github.com/kumasuke120/mockuma/internal/myhttp" 20 ) 21 22 type policyExecutor struct { 23 h *mockHandler 24 r *http.Request 25 w *http.ResponseWriter 26 policy *mckmaps.Policy 27 28 returnHead bool 29 fromForwards bool 30 statusCode int 31 } 32 33 type forwardError struct { 34 err error 35 } 36 37 func (e *forwardError) Error() string { 38 return "fail to forward: " + e.err.Error() 39 } 40 41 func (e *policyExecutor) execute() error { 42 cmdType := e.policy.CmdType 43 switch cmdType { 44 case mckmaps.CmdTypeReturns: 45 fallthrough 46 case mckmaps.CmdTypeRedirects: 47 return e.executeReturns() 48 case mckmaps.CmdTypeForwards: 49 return e.executeForwards() 50 } 51 52 log.Printf("[executor] %-9s: unsupported command type\n", cmdType) 53 return errors.New("unsupported command type: " + string(cmdType)) 54 } 55 56 func (e *policyExecutor) executeReturns() error { 57 returns := e.policy.Returns 58 59 if returns.Latency != nil { 60 waitBeforeReturns(returns.Latency) 61 } 62 63 err := e.writeResponseForReturns(returns) 64 if err != nil { 65 return err 66 } 67 68 e.statusCode = int(returns.StatusCode) // records statusCode for forwards 69 if !e.fromForwards { // forwards prints its log by itself 70 log.Printf("[executor] %-9s: (%d) %s %s\n", e.policy.CmdType, 71 e.statusCode, e.r.Method, e.r.URL) 72 } 73 return nil 74 } 75 76 func (e *policyExecutor) writeResponseForReturns(returns *mckmaps.Returns) error { 77 e.writeHeaders(returns.Headers) 78 79 if e.returnHead { 80 (*e.w).Header().Set(myhttp.HeaderContentLength, strconv.Itoa(len(returns.Body))) 81 } 82 83 // writes the statusCode, which must be written after headers 84 (*e.w).WriteHeader(int(returns.StatusCode)) 85 86 if !e.returnHead { 87 err := e.writeBody(returns.Body) 88 if err != nil { 89 return err 90 } 91 } 92 93 return nil 94 } 95 96 func (e *policyExecutor) writeHeaders(headers []*mckmaps.NameValuesPair) { 97 outHeader := (*e.w).Header() 98 99 // new headers overrides old ones 100 for _, pair := range headers { 101 if _, ok := outHeader[pair.Name]; ok { 102 outHeader.Del(pair.Name) 103 } 104 105 for _, value := range pair.Values { 106 outHeader.Add(pair.Name, value) 107 } 108 } 109 } 110 111 func (e *policyExecutor) writeBody(body []byte) error { 112 var err error 113 if body != nil && len(body) != 0 { 114 _, err = (*e.w).Write(body) 115 } 116 return err 117 } 118 119 func (e *policyExecutor) executeForwards() error { 120 forwards := e.policy.Forwards 121 122 if forwards.Latency != nil { 123 waitBeforeReturns(forwards.Latency) 124 } 125 126 fPath := forwards.Path 127 if strings.HasPrefix(fPath, "http://") || strings.HasPrefix(fPath, "https://") { 128 return e.forwardsRemote(fPath) 129 } else { 130 return e.forwardsLocal(fPath) 131 } 132 } 133 134 func (e *policyExecutor) forwardsRemote(fPath string) error { 135 newRequest, err := e.newForwardRemoteRequest(fPath) 136 if err != nil { 137 return err 138 } 139 140 httpClient := http.Client{} 141 resp, err := httpClient.Do(newRequest) 142 if err != nil { 143 return &forwardError{err: err} 144 } 145 defer func() { 146 if err := resp.Body.Close(); err != nil { 147 log.Println("[executor] error : error encountered when forwarding: " + err.Error()) 148 } 149 }() 150 151 err = e.writeResponseForForwardsRemote(resp) 152 if err != nil { 153 return err 154 } 155 156 e.statusCode = resp.StatusCode // records statusCode for forwards 157 log.Printf("[executor] %-9s: (%d) %s %s => %s\n", e.policy.CmdType, 158 resp.StatusCode, e.r.Method, e.r.URL, newRequest.URL) 159 return nil 160 } 161 162 func (e *policyExecutor) newForwardRemoteRequest(fPath string) (*http.Request, error) { 163 reqURL := e.r.URL 164 165 _url, err := url.Parse(fPath) 166 if err != nil { 167 return nil, &forwardError{err: err} 168 } 169 _url.RawQuery = reqURL.RawQuery 170 171 newRequest, err := e.newForwardRequest(_url.String()) 172 if err != nil { 173 return nil, err 174 } 175 newRequest.Header.Set(myhttp.HeaderXForwardedFor, e.r.RemoteAddr) 176 177 return newRequest, err 178 } 179 180 func (e *policyExecutor) writeResponseForForwardsRemote(resp *http.Response) error { 181 for key, values := range resp.Header { 182 (*e.w).Header()[key] = values 183 } 184 (*e.w).Header().Set(myhttp.HeaderXForwardedServer, HeaderValueServer) 185 (*e.w).WriteHeader(resp.StatusCode) // statusCode must be written after headers 186 _, err := io.Copy(*e.w, resp.Body) 187 if err != nil { 188 return err 189 } 190 return nil 191 } 192 193 func (e *policyExecutor) forwardsLocal(fPath string) error { 194 newRequest, err := e.newForwardLocalRequest(fPath) 195 if err != nil { 196 return err 197 } 198 199 fe := e.h.matchNewExecutor(newRequest, *e.w) 200 fe.fromForwards = true 201 err = fe.execute() // executor writes response for forwards 202 203 if err == nil { 204 e.statusCode = fe.statusCode // records statusCode for forwards 205 log.Printf("[executor] %-9s: (%d) %s %s => %s\n", e.policy.CmdType, 206 fe.statusCode, e.r.Method, e.r.URL, newRequest.URL) 207 } 208 return err 209 } 210 211 func (e *policyExecutor) newForwardLocalRequest(fPath string) (*http.Request, error) { 212 reqURL := e.r.URL 213 214 if !strings.HasPrefix(fPath, "/") { 215 uri := reqURL.Path 216 fPath = path.Join(uri, "../"+fPath) 217 } 218 219 rawQuery := reqURL.RawQuery 220 var requestURI string 221 if len(rawQuery) == 0 { 222 requestURI = fPath 223 } else { 224 requestURI = fmt.Sprintf("%s?%s", fPath, rawQuery) 225 } 226 newRequest, err := e.newForwardRequest(requestURI) 227 if err != nil { 228 return nil, err 229 } 230 return newRequest, err 231 } 232 233 func (e *policyExecutor) newForwardRequest(url string) (*http.Request, error) { 234 method := e.r.Method 235 body, err := ioutil.ReadAll(e.r.Body) 236 if err != nil { 237 return nil, &forwardError{err: err} 238 } 239 newRequest, err := http.NewRequest(method, url, bytes.NewReader(body)) 240 if err != nil { 241 return nil, &forwardError{err: err} 242 } 243 newRequest.Header = e.r.Header 244 return newRequest, nil 245 } 246 247 func waitBeforeReturns(latency *mckmaps.Interval) { 248 diff := latency.Max - latency.Min 249 if diff > 0 { 250 d := rand.Int63n(diff) + latency.Min 251 time.Sleep(time.Duration(d * int64(time.Millisecond))) 252 } else if latency.Min > 0 { 253 time.Sleep(time.Duration(latency.Min * int64(time.Millisecond))) 254 } 255 }