github.com/keysonzzz/kmg@v0.0.0-20151121023212-05317bfd7d39/kmgNet/kmgHttp/context.go (about)

     1  package kmgHttp
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"mime"
    10  	"mime/multipart"
    11  	"net/http"
    12  	"strconv"
    13  	"unicode/utf8"
    14  
    15  	"github.com/bronze1man/kmg/encoding/kmgBase64"
    16  	"github.com/bronze1man/kmg/encoding/kmgJson"
    17  	"github.com/bronze1man/kmg/kmgCrypto"
    18  	"github.com/bronze1man/kmg/kmgErr"
    19  	"net"
    20  )
    21  
    22  //该对象上的方法不应该被并发调用.
    23  type Context struct {
    24  	method         string
    25  	requestUrl     string
    26  	inMap          map[string]string
    27  	DataMap        map[string]string //上下文里面可以携带一些信息
    28  	requestFile    map[string]*multipart.FileHeader
    29  	responseBuffer bytes.Buffer
    30  	redirectUrl    string
    31  	responseCode   int
    32  	req            *http.Request
    33  	responseHeader map[string]string
    34  	sessionMap     map[string]string
    35  	sessionHasSet  bool
    36  }
    37  
    38  const (
    39  	defaultMaxMemory = 32 << 20 // 32 MB
    40  )
    41  
    42  var SessionCookieName = "kmgSession"
    43  var SessionPsk = [32]byte{0xd8, 0x51, 0xea, 0x81, 0xb9, 0xe, 0xf, 0x2f, 0x8c, 0x85, 0x5f, 0xb6, 0x14, 0xb2}
    44  
    45  func NewContextFromHttp(w http.ResponseWriter, req *http.Request) *Context {
    46  	context := &Context{
    47  		method:      req.Method,
    48  		inMap:       map[string]string{},
    49  		requestFile: map[string]*multipart.FileHeader{},
    50  		requestUrl:  req.URL.String(),
    51  		//Session:      kmgSession.GetSession(w, req),
    52  		responseCode: 200,
    53  		req:          req,
    54  	}
    55  	//绕开支付宝请求bug
    56  	if req.Header.Get("Content-Type") == "application/x-www-form-urlencoded; text/html; charset=utf-8" {
    57  		req.Header.Set("Content-Type", "application/x-www-form-urlencoded; charset=utf-8")
    58  	}
    59  	err := req.ParseForm()
    60  	if err != nil {
    61  		panic(err)
    62  	}
    63  	for key, value := range req.Form {
    64  		context.inMap[key] = value[0] //TODO 这里没有处理同一个 key 多个 value 的情况
    65  	}
    66  	originContentType := req.Header.Get("Content-Type")
    67  	if originContentType == "" {
    68  		return context
    69  	}
    70  	contentType, _, err := mime.ParseMediaType(originContentType)
    71  	if err != nil {
    72  		panic(fmt.Errorf("[NewContextFromHttp] %s %s", originContentType, err.Error()))
    73  	}
    74  	if contentType != "multipart/form-data" {
    75  		return context
    76  	}
    77  	err = req.ParseMultipartForm(defaultMaxMemory)
    78  	if err != nil {
    79  		panic(err)
    80  	}
    81  	for key, value := range req.MultipartForm.File {
    82  		context.requestFile[key] = value[0]
    83  	}
    84  	for key, value := range req.MultipartForm.Value {
    85  		context.inMap[key] = value[0]
    86  	}
    87  	return context
    88  }
    89  
    90  //返回一个新的测试上下文,这个上下文的所有参数都是空的
    91  func NewTestContext() *Context {
    92  	//调用 ctx 上的函数是不会更新这里的 buf 的
    93  	buf := []byte("test")
    94  	req, err := http.NewRequest("GET", "/testContext", bytes.NewReader(buf))
    95  	if err != nil {
    96  		panic(err)
    97  	}
    98  	return &Context{
    99  		requestUrl:   "/testContext",
   100  		inMap:        map[string]string{},
   101  		requestFile:  map[string]*multipart.FileHeader{},
   102  		responseCode: 200,
   103  		sessionMap:   map[string]string{},
   104  		method:       "GET",
   105  		req:          req,
   106  	}
   107  }
   108  
   109  //根据key返回输入参数,包括post和url的query的数据,如果没有,或者不是整数返回0 返回类型为int
   110  func (c *Context) InNum(key string) int {
   111  	value, ok := c.inMap[key]
   112  	if !ok {
   113  		return 0
   114  	}
   115  	num, err := strconv.Atoi(value)
   116  	if err != nil {
   117  		return 0
   118  	}
   119  	return num
   120  }
   121  
   122  //根据key返回输入参数,包括post和url的query的数据,如果没有返回"" 类型为string
   123  func (c *Context) InStr(key string) string {
   124  	return c.inMap[key]
   125  }
   126  
   127  func (c *Context) InStrDefault(key string, def string) string {
   128  	out := c.inMap[key]
   129  	if out == "" {
   130  		return def
   131  	}
   132  	return out
   133  }
   134  
   135  func (c *Context) MustPost() {
   136  	if !c.IsPost() {
   137  		panic(errors.New("Need post"))
   138  	}
   139  }
   140  
   141  func (c *Context) IsGet() bool {
   142  	return c.method == "GET"
   143  }
   144  func (c *Context) IsPost() bool {
   145  	return c.method == "POST"
   146  }
   147  
   148  func (c *Context) MustInNum(key string) int {
   149  	s := c.InNum(key)
   150  	if s == 0 {
   151  		panic(fmt.Errorf("Need %s parameter", key))
   152  	}
   153  	return s
   154  }
   155  
   156  func (c *Context) InHas(key string) bool {
   157  	return c.inMap[key] != ""
   158  }
   159  
   160  func (c *Context) MustInStr(key string) string {
   161  	s := c.InStr(key)
   162  	if s == "" {
   163  		panic(fmt.Errorf("Need %s parameter", key))
   164  	}
   165  	return s
   166  }
   167  
   168  func (c *Context) MustInJson(key string, obj interface{}) {
   169  	s := c.MustInStr(key)
   170  	err := json.Unmarshal([]byte(s), obj)
   171  	if err != nil {
   172  		panic(err)
   173  	}
   174  	return
   175  }
   176  
   177  func (c *Context) MustInFile(key string) *multipart.FileHeader {
   178  	file := c.requestFile[key]
   179  	if file == nil {
   180  		panic(fmt.Errorf("Need %s file", key))
   181  	}
   182  	return file
   183  }
   184  
   185  func (c *Context) MustFirstInFile() *multipart.FileHeader {
   186  	for _, file := range c.requestFile {
   187  		return file
   188  	}
   189  	panic(fmt.Errorf("Need a upload file"))
   190  }
   191  
   192  func (c *Context) SetInStr(key string, value string) *Context {
   193  	c.inMap[key] = value
   194  	return c
   195  }
   196  
   197  func (c *Context) DeleteInMap(key string) {
   198  	delete(c.inMap, key)
   199  }
   200  
   201  func (c *Context) SetInMap(data map[string]string) *Context {
   202  	c.inMap = data
   203  	return c
   204  }
   205  
   206  func (c *Context) GetInMap() map[string]string {
   207  	return c.inMap
   208  }
   209  
   210  func (c *Context) SetPost() *Context {
   211  	c.method = "POST"
   212  	return c
   213  }
   214  
   215  func (c *Context) GetDataStr(key string) string {
   216  	if c.DataMap == nil {
   217  		return ""
   218  	}
   219  	return c.DataMap[key]
   220  }
   221  func (c *Context) SetDataStr(key string, value string) {
   222  	if c.DataMap == nil {
   223  		c.DataMap = map[string]string{}
   224  	}
   225  	c.DataMap[key] = value
   226  }
   227  
   228  func (c *Context) sessionInit() {
   229  	if c.sessionMap != nil {
   230  		return
   231  	}
   232  	cookie, err := c.req.Cookie(SessionCookieName)
   233  	if err != nil {
   234  		//kmgErr.LogErrorWithStack(err)
   235  		// 这个地方没有cookie是正常情况
   236  		c.sessionMap = map[string]string{}
   237  		//没有Cooke
   238  		return
   239  	}
   240  	output, err := kmgCrypto.CompressAndEncryptBase64Decode(&SessionPsk, cookie.Value)
   241  	if err != nil {
   242  		kmgErr.LogErrorWithStack(err)
   243  		c.sessionMap = map[string]string{}
   244  		return
   245  	}
   246  	err = json.Unmarshal(output, &c.sessionMap)
   247  	if err != nil {
   248  		kmgErr.LogErrorWithStack(err)
   249  		c.sessionMap = map[string]string{}
   250  		return
   251  	}
   252  }
   253  
   254  //向Session里面设置一个字符串
   255  func (c *Context) SessionSetStr(key string, value string) *Context {
   256  	c.sessionInit()
   257  	c.sessionHasSet = true
   258  	c.sessionMap[key] = value
   259  	return c
   260  }
   261  
   262  //从Session里面获取一个字符串
   263  func (c *Context) SessionGetStr(key string) string {
   264  	c.sessionInit()
   265  	return c.sessionMap[key]
   266  }
   267  
   268  func (c *Context) SessionSetJson(key string, value interface{}) *Context {
   269  	json, err := json.Marshal(value)
   270  	if err != nil {
   271  		panic(err) //不能Marshal一定是代码的问题
   272  	}
   273  	c.SessionSetStr(key, string(json))
   274  	return c
   275  }
   276  
   277  func (c *Context) SessionGetJson(key string, obj interface{}) (err error) {
   278  	out := c.SessionGetStr(key)
   279  	if out == "" {
   280  		return errors.New("Session Empty")
   281  	}
   282  	err = json.Unmarshal([]byte(out), obj)
   283  	return err
   284  }
   285  
   286  //清除Session里面的内容.
   287  //更换Session的Id.
   288  func (c *Context) SessionClear() *Context {
   289  	c.sessionInit()
   290  	c.sessionHasSet = len(c.sessionMap) > 0
   291  	c.sessionMap = map[string]string{}
   292  	return c
   293  }
   294  
   295  //仅把Session传递过去的上下文,其他东西都恢复默认值
   296  func (c *Context) NewTestContextWithSession() *Context {
   297  	nc := NewTestContext()
   298  	nc.sessionMap = c.sessionMap
   299  	return nc
   300  }
   301  
   302  func (c *Context) SetResponseCode(code int) {
   303  	c.responseCode = code
   304  }
   305  
   306  func (c *Context) Redirect(url string) {
   307  	c.redirectUrl = url
   308  	c.responseCode = 302 // 302 是无缓存跳转,请不要修改为其他code
   309  }
   310  
   311  func (c *Context) NotFound(msg string) {
   312  	c.responseBuffer.WriteString(msg)
   313  	c.responseCode = 404
   314  }
   315  
   316  func (c *Context) Error(err error) {
   317  	c.responseBuffer.WriteString(err.Error())
   318  	c.responseCode = 500
   319  }
   320  
   321  func (c *Context) WriteString(s string) {
   322  	c.responseBuffer.WriteString(s)
   323  }
   324  
   325  func (c *Context) WriteByte(s []byte) {
   326  	c.responseBuffer.Write(s)
   327  }
   328  
   329  func (c *Context) WriteAttachmentFile(b []byte, fileName string) {
   330  	c.responseBuffer.Write(b)
   331  	c.SetResponseHeader("Content-Disposition", "attachment;filename="+fileName)
   332  }
   333  
   334  func (c *Context) WriteJson(obj interface{}) {
   335  	c.responseBuffer.Write(kmgJson.MustMarshal(obj))
   336  }
   337  
   338  func (c *Context) WriteToResponseWriter(w http.ResponseWriter, req *http.Request) {
   339  	for key, value := range c.responseHeader {
   340  		w.Header().Set(key, value)
   341  	}
   342  	if c.sessionMap != nil && c.sessionHasSet {
   343  		http.SetCookie(w, &http.Cookie{
   344  			Name:  SessionCookieName,
   345  			Value: kmgCrypto.CompressAndEncryptBase64Encode(&SessionPsk, kmgJson.MustMarshal(c.sessionMap)),
   346  		})
   347  	}
   348  	if c.redirectUrl != "" {
   349  		http.Redirect(w, req, c.redirectUrl, c.responseCode)
   350  		return
   351  	}
   352  	w.WriteHeader(c.responseCode)
   353  	if c.responseBuffer.Len() > 0 {
   354  		w.Write(c.responseBuffer.Bytes())
   355  	}
   356  }
   357  
   358  func (c *Context) SetResponseHeader(key string, value string) {
   359  	if c.responseHeader == nil {
   360  		c.responseHeader = map[string]string{}
   361  	}
   362  	c.responseHeader[key] = value
   363  }
   364  
   365  func (c *Context) GetResponseHeader(key string) string {
   366  	return c.responseHeader[key]
   367  }
   368  
   369  func (c *Context) GetResponseWriter() io.Writer {
   370  	return &c.responseBuffer
   371  }
   372  
   373  func (c *Context) GetRequest() *http.Request {
   374  	return c.req //调用者可以拿去干一些高级的事情
   375  }
   376  
   377  func (c *Context) GetRequestUrl() string {
   378  	return c.requestUrl
   379  }
   380  
   381  func (c *Context) SetRequestUrl(url string) *Context {
   382  	c.requestUrl = url
   383  	return c
   384  }
   385  
   386  func (c *Context) GetRedirectUrl() string {
   387  	return c.redirectUrl
   388  }
   389  
   390  func (c *Context) GetResponseCode() int {
   391  	return c.responseCode
   392  }
   393  
   394  func (c *Context) GetResponseByteList() []byte {
   395  	return c.responseBuffer.Bytes()
   396  }
   397  
   398  func (c *Context) GetResponseString() string {
   399  	return c.responseBuffer.String()
   400  }
   401  
   402  func (c *Context) MustGetClientIp() net.IP {
   403  	if c.req == nil {
   404  		return nil
   405  	}
   406  	if c.req.RemoteAddr == "" {
   407  		return nil
   408  	}
   409  	host, _, err := net.SplitHostPort(c.req.RemoteAddr)
   410  	if err != nil {
   411  		panic(err)
   412  	}
   413  	return net.ParseIP(host)
   414  }
   415  
   416  func (c *Context) GetClientIpStringIgnoreError() string {
   417  	if c.req == nil {
   418  		return ""
   419  	}
   420  	if c.req.RemoteAddr == "" {
   421  		return ""
   422  	}
   423  	host, _, err := net.SplitHostPort(c.req.RemoteAddr)
   424  	if err != nil {
   425  		return ""
   426  	}
   427  	ip := net.ParseIP(host)
   428  	if ip == nil {
   429  		return ""
   430  	}
   431  	return ip.String()
   432  }
   433  
   434  type ContextLog struct {
   435  	Method          string            `json:",omitempty"`
   436  	ResponseCode    int               `json:",omitempty"`
   437  	Url             string            `json:",omitempty"`
   438  	RemoteAddr      string            `json:",omitempty"`
   439  	UA              string            `json:",omitempty"`
   440  	Refer           string            `json:",omitempty"`
   441  	RedirectUrl     string            `json:",omitempty"`
   442  	InMap           map[string]string `json:",omitempty"`
   443  	ProcessTime     string            `json:",omitempty"`
   444  	RequestSize     int               `json:",omitempty"`
   445  	ResponseSize    int               `json:",omitempty"`
   446  	ResponseContent string            `json:",omitempty"`
   447  }
   448  
   449  func (c *Context) Log() *ContextLog {
   450  	out := &ContextLog{
   451  		Method:       c.method,
   452  		ResponseCode: c.responseCode,
   453  		Url:          c.requestUrl,
   454  		RedirectUrl:  c.redirectUrl,
   455  		InMap:        c.inMap,
   456  		ResponseSize: c.responseBuffer.Len(),
   457  	}
   458  	if c.req != nil {
   459  		out.RemoteAddr = c.req.RemoteAddr
   460  		out.UA = c.req.UserAgent()
   461  		out.Refer = c.req.Referer()
   462  		out.RequestSize = int(c.req.ContentLength)
   463  	}
   464  	//小于64个字节,并且都是utf8,就输出到log里面
   465  	if out.ResponseSize <= 64 {
   466  		out.ResponseContent = c.responseBuffer.String()
   467  		if !utf8.ValidString(out.ResponseContent) {
   468  			out.ResponseContent = ""
   469  		}
   470  	}
   471  	return out
   472  }
   473  
   474  /*
   475  //这个返回类型可能有问题
   476  func (c *Context)InArray(key string)[]string{
   477      return nil
   478  }
   479  */
   480  
   481  func init() {
   482  	kmgCrypto.RegisterPskChangeCallback(pskChange)
   483  }
   484  func pskChange() {
   485  	psk1 := kmgCrypto.GetPskFromDefaultPsk(6, "kmgHttp.SessionCookieName")
   486  	SessionCookieName = "kmgSession" + kmgBase64.Base64EncodeByteToString(psk1)
   487  
   488  	psk2 := kmgCrypto.GetPskFromDefaultPsk(32, "kmgHttp.SessionPsk")
   489  	copy(SessionPsk[:], psk2)
   490  }