github.com/artisanhe/tools@v1.0.1-0.20210607022958-19a8fef2eb04/courier/transport_http/transport.go (about)

     1  package transport_http
     2  
     3  import (
     4  	"bufio"
     5  	"context"
     6  	"net"
     7  	"net/http"
     8  	"os"
     9  
    10  	"github.com/google/uuid"
    11  	"github.com/gorilla/websocket"
    12  	"github.com/julienschmidt/httprouter"
    13  	"github.com/sirupsen/logrus"
    14  
    15  	"github.com/artisanhe/tools/courier"
    16  	"github.com/artisanhe/tools/courier/httpx"
    17  	"github.com/artisanhe/tools/duration"
    18  	logContext "github.com/artisanhe/tools/log/context"
    19  )
    20  
    21  func CreateHttpHandler(s *ServeHTTP, ops ...courier.IOperator) httprouter.Handle {
    22  	operatorMetas := courier.ToOperatorMetaList(ops...)
    23  
    24  	return func(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
    25  		var err error
    26  		ctx := r.Context()
    27  		ctx = ContextWithServiceName(ctx, s.Name)
    28  		ctx = ContextWithOperators(ctx, ops...)
    29  		ctx = ContextWithRequest(ctx, r)
    30  
    31  		codeWriter := &ResponseWriter{
    32  			ResponseWriter: w,
    33  			Code:           0,
    34  			written:        -1,
    35  		}
    36  
    37  		w.Header().Set("X-Reversion", ProjectRef)
    38  
    39  		reqID := r.Header.Get(httpx.HeaderRequestID)
    40  
    41  		if reqID == "" {
    42  			reqID = uuid.New().String()
    43  		}
    44  
    45  		logContext.SetLogID(reqID)
    46  		defer logContext.Close()
    47  
    48  		d := duration.NewDuration()
    49  
    50  		defer func() {
    51  
    52  			if r.URL.Path == "/healthz" {
    53  				return
    54  			}
    55  			fields := logrus.Fields{
    56  				"tag":        "access",
    57  				"log_id":     reqID,
    58  				"remote_ip":  GetClientIP(r),
    59  				"method":     r.Method,
    60  				"pathname":   r.URL.Path,
    61  				"user_agent": r.Header.Get(httpx.HeaderUserAgent),
    62  			}
    63  
    64  			fields["status"] = codeWriter.Code
    65  			fields["request_time"] = d.Get()
    66  
    67  			logger := logrus.WithFields(fields)
    68  
    69  			if err != nil {
    70  				if codeWriter.Code >= http.StatusInternalServerError {
    71  					logger.Errorf(err.Error())
    72  				} else {
    73  					logger.Warnf(err.Error())
    74  				}
    75  			} else {
    76  				logger.Infof("")
    77  			}
    78  		}()
    79  
    80  		opDecode := createHttpRequestDecoder(r, &params)
    81  
    82  		for _, opMeta := range operatorMetas {
    83  			op, decodeErr := courier.NewOperatorBy(opMeta.Type, opMeta.Operator, opDecode)
    84  			if decodeErr != nil {
    85  				err = encodeHttpError(ctx, codeWriter, r, decodeErr)
    86  				return
    87  			}
    88  
    89  			response, endpointErr := op.Output(ctx)
    90  			if canCookie, ok := op.(httpx.ICookie); ok {
    91  				cookie := canCookie.Cookies()
    92  				if cookie != nil {
    93  					http.SetCookie(w, cookie)
    94  				}
    95  			}
    96  
    97  			if endpointErr != nil {
    98  				err = encodeHttpError(ctx, codeWriter, r, endpointErr)
    99  				return
   100  			}
   101  
   102  			if !opMeta.IsLast {
   103  				// set result in context with key of operator name
   104  				ctx = context.WithValue(ctx, opMeta.ContextKey, response)
   105  				continue
   106  			}
   107  
   108  			if ws, ok := response.(IWebSocket); ok {
   109  				conn, errForUpgrade := (&websocket.Upgrader{}).Upgrade(codeWriter, r, nil)
   110  				if errForUpgrade != nil {
   111  					err = errForUpgrade
   112  					return
   113  				}
   114  				ws.SubscribeOn(conn)
   115  				return
   116  			}
   117  
   118  			encodeErr := encodeHttpResponse(ctx, codeWriter, r, response)
   119  			if encodeErr != nil {
   120  				err = encodeHttpError(ctx, codeWriter, r, encodeErr)
   121  				return
   122  			}
   123  		}
   124  	}
   125  }
   126  
   127  var ProjectRef = os.Getenv("PROJECT_REF")
   128  
   129  var (
   130  	ContextKeyServerName = uuid.New().String()
   131  	ContextKeyRequest    = uuid.New().String()
   132  	ContextKeyOperators  = uuid.New().String()
   133  )
   134  
   135  func ContextWithServiceName(ctx context.Context, serverName string) context.Context {
   136  	return context.WithValue(ctx, ContextKeyServerName, serverName)
   137  }
   138  
   139  func ContextWithOperators(ctx context.Context, ops ...courier.IOperator) context.Context {
   140  	return context.WithValue(ctx, ContextKeyOperators, ops)
   141  }
   142  
   143  func GetOperators(ctx context.Context) []courier.IOperator {
   144  	return ctx.Value(ContextKeyOperators).([]courier.IOperator)
   145  }
   146  
   147  func ContextWithRequest(ctx context.Context, req *http.Request) context.Context {
   148  	return context.WithValue(ctx, ContextKeyRequest, req)
   149  }
   150  
   151  func GetRequest(ctx context.Context) *http.Request {
   152  	return ctx.Value(ContextKeyRequest).(*http.Request)
   153  }
   154  
   155  type ResponseWriter struct {
   156  	http.ResponseWriter
   157  	Code    int
   158  	written int64
   159  }
   160  
   161  func (w *ResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
   162  	return w.ResponseWriter.(http.Hijacker).Hijack()
   163  }
   164  
   165  func (w *ResponseWriter) WriteHeader(code int) {
   166  	w.Code = code
   167  	w.ResponseWriter.WriteHeader(code)
   168  }
   169  
   170  func (w *ResponseWriter) Write(p []byte) (int, error) {
   171  	n, err := w.ResponseWriter.Write(p)
   172  	w.written += int64(n)
   173  	return n, err
   174  }