github.com/erda-project/erda-infra@v1.0.9/providers/legacy/httpendpoints/endpoints.go (about)

     1  // Copyright (c) 2021 Terminus, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package httpendpoints
    16  
    17  import (
    18  	"context"
    19  	"encoding/base64"
    20  	"encoding/json"
    21  	"io"
    22  	"io/ioutil"
    23  	"net/http"
    24  	"net/http/httputil"
    25  	"strings"
    26  	"time"
    27  
    28  	"github.com/erda-project/erda-infra/providers/legacy/httpendpoints/i18n"
    29  	"github.com/erda-project/erda-infra/providers/legacy/httpendpoints/ierror"
    30  	"github.com/gofrs/uuid"
    31  	"github.com/gorilla/mux"
    32  )
    33  
    34  const (
    35  	// ContentTypeJSON Content Type
    36  	ContentTypeJSON = "application/json"
    37  	// ResponseWriter context value key
    38  	ResponseWriter = ContextKey("responseWriter")
    39  	// Base64EncodedRequestBody .
    40  	Base64EncodedRequestBody = "base64-encoded-request-body"
    41  	// TraceID .
    42  	TraceID = ContextKey("dice-trace-id")
    43  )
    44  
    45  // ContextKey ...
    46  type ContextKey string
    47  
    48  // RegisterEndpoints match URL path to corresponding handler
    49  func (p *provider) RegisterEndpoints(endpoints []Endpoint) {
    50  	for _, ep := range endpoints {
    51  		if ep.WriterHandler != nil {
    52  			p.router.Path(ep.Path).Methods(ep.Method).HandlerFunc(p.internalWriterHandler(ep.WriterHandler))
    53  		} else if ep.ReverseHandler != nil {
    54  			p.router.Path(ep.Path).Methods(ep.Method).Handler(p.internalReverseHandler(ep.ReverseHandler))
    55  		} else {
    56  			p.router.Path(ep.Path).Methods(ep.Method).HandlerFunc(p.internal(ep.Handler))
    57  		}
    58  		p.L.Infof("Added endpoint: %s %s", ep.Method, ep.Path)
    59  	}
    60  }
    61  
    62  func (p *provider) internal(handler func(context.Context, *http.Request, map[string]string) (Responser, error)) http.HandlerFunc {
    63  	pctx := context.Background()
    64  	pctx = injectTraceID(pctx)
    65  
    66  	return func(w http.ResponseWriter, r *http.Request) {
    67  		start := time.Now()
    68  		p.L.Debugf("start %s %s", r.Method, r.URL.String())
    69  
    70  		ctx, cancel := context.WithCancel(pctx)
    71  		defer func() {
    72  			cancel()
    73  			p.L.Debugf("finished handle request %s %s (took %v)", r.Method, r.URL.String(), time.Since(start))
    74  		}()
    75  		ctx = context.WithValue(ctx, ResponseWriter, w)
    76  
    77  		handleRequest(r)
    78  
    79  		langs := i18n.Language(r)
    80  		locale := i18n.WrapLocaleResource(p.t, langs)
    81  		response, err := handler(ctx, r, mux.Vars(r))
    82  		if err == nil {
    83  			response = response.GetLocaledResp(locale)
    84  		}
    85  		if err != nil {
    86  			apiError, isAPIError := err.(ierror.IAPIError)
    87  			if isAPIError {
    88  				response = HTTPResponse{
    89  					Status: apiError.HTTPCode(),
    90  					Content: Resp{
    91  						Success: false,
    92  						Err: ErrorResponse{
    93  							Code: apiError.Code(),
    94  							Msg:  apiError.Render(locale),
    95  						},
    96  					},
    97  				}
    98  			} else {
    99  				p.L.Errorf("failed to handle request: %s (%v)", r.URL.String(), err)
   100  
   101  				statusCode := http.StatusInternalServerError
   102  				if response != nil {
   103  					statusCode = response.GetStatus()
   104  				}
   105  				w.WriteHeader(statusCode)
   106  				io.WriteString(w, err.Error())
   107  				return
   108  			}
   109  		}
   110  
   111  		w.Header().Set("Content-Type", ContentTypeJSON)
   112  		w.WriteHeader(response.GetStatus())
   113  
   114  		encoder := json.NewEncoder(w)
   115  		vals := r.URL.Query()
   116  		pretty, ok := vals["pretty"]
   117  		if ok && strings.Compare(pretty[0], "true") == 0 {
   118  			encoder.SetIndent("", "    ")
   119  		}
   120  
   121  		if err := encoder.Encode(response.GetContent()); err != nil {
   122  			p.L.Errorf("failed to send response: %s (%v)", r.URL.String(), err)
   123  			return
   124  		}
   125  	}
   126  }
   127  
   128  func (p *provider) internalWriterHandler(handler func(context.Context, http.ResponseWriter, *http.Request, map[string]string) error) http.HandlerFunc {
   129  	pctx := context.Background()
   130  	pctx = injectTraceID(pctx)
   131  
   132  	return func(w http.ResponseWriter, r *http.Request) {
   133  		start := time.Now()
   134  		p.L.Debugf("start %s %s", r.Method, r.URL.String())
   135  
   136  		ctx, cancel := context.WithCancel(pctx)
   137  		defer func() {
   138  			cancel()
   139  			p.L.Debugf("finished handle request %s %s (took %v)", r.Method, r.URL.String(), time.Since(start))
   140  		}()
   141  
   142  		handleRequest(r)
   143  
   144  		err := handler(ctx, w, r, mux.Vars(r))
   145  		if err != nil {
   146  			p.L.Errorf("failed to handle request: %s (%v)", r.URL.String(), err)
   147  
   148  			statusCode := http.StatusInternalServerError
   149  			w.WriteHeader(statusCode)
   150  			io.WriteString(w, err.Error())
   151  		}
   152  	}
   153  }
   154  
   155  // internalReverseHandler .
   156  func (p *provider) internalReverseHandler(handler func(context.Context, *http.Request, map[string]string) error) http.Handler {
   157  	pctx := context.Background()
   158  	pctx = injectTraceID(pctx)
   159  
   160  	return &httputil.ReverseProxy{
   161  		Director: func(r *http.Request) {
   162  			start := time.Now()
   163  			p.L.Debugf("start %s %s", r.Method, r.URL.String())
   164  
   165  			ctx, cancel := context.WithCancel(pctx)
   166  			defer func() {
   167  				cancel()
   168  				p.L.Debugf("finished handle request %s %s (took %v)", r.Method, r.URL.String(), time.Since(start))
   169  			}()
   170  
   171  			handleRequest(r)
   172  
   173  			err := handler(ctx, r, mux.Vars(r))
   174  			if err != nil {
   175  				p.L.Errorf("failed to handle request: %s (%v)", r.URL.String(), err)
   176  				return
   177  			}
   178  		},
   179  		FlushInterval: -1,
   180  	}
   181  }
   182  
   183  func handleRequest(r *http.Request) {
   184  	// base64 decode request body if declared in header
   185  	if strings.EqualFold(r.Header.Get(Base64EncodedRequestBody), "true") {
   186  		r.Body = ioutil.NopCloser(base64.NewDecoder(base64.StdEncoding, r.Body))
   187  	}
   188  }
   189  
   190  func injectTraceID(ctx context.Context) context.Context {
   191  	id, _ := uuid.NewV4()
   192  	return context.WithValue(ctx, TraceID, id.String())
   193  }