github.com/keysonzzz/kmg@v0.0.0-20151121023212-05317bfd7d39/kmgNet/kmgHttp/context.go (about) 1 package kmgHttp 2 3 import ( 4 "bytes" 5 "encoding/json" 6 "errors" 7 "fmt" 8 "io" 9 "mime" 10 "mime/multipart" 11 "net/http" 12 "strconv" 13 "unicode/utf8" 14 15 "github.com/bronze1man/kmg/encoding/kmgBase64" 16 "github.com/bronze1man/kmg/encoding/kmgJson" 17 "github.com/bronze1man/kmg/kmgCrypto" 18 "github.com/bronze1man/kmg/kmgErr" 19 "net" 20 ) 21 22 //该对象上的方法不应该被并发调用. 23 type Context struct { 24 method string 25 requestUrl string 26 inMap map[string]string 27 DataMap map[string]string //上下文里面可以携带一些信息 28 requestFile map[string]*multipart.FileHeader 29 responseBuffer bytes.Buffer 30 redirectUrl string 31 responseCode int 32 req *http.Request 33 responseHeader map[string]string 34 sessionMap map[string]string 35 sessionHasSet bool 36 } 37 38 const ( 39 defaultMaxMemory = 32 << 20 // 32 MB 40 ) 41 42 var SessionCookieName = "kmgSession" 43 var SessionPsk = [32]byte{0xd8, 0x51, 0xea, 0x81, 0xb9, 0xe, 0xf, 0x2f, 0x8c, 0x85, 0x5f, 0xb6, 0x14, 0xb2} 44 45 func NewContextFromHttp(w http.ResponseWriter, req *http.Request) *Context { 46 context := &Context{ 47 method: req.Method, 48 inMap: map[string]string{}, 49 requestFile: map[string]*multipart.FileHeader{}, 50 requestUrl: req.URL.String(), 51 //Session: kmgSession.GetSession(w, req), 52 responseCode: 200, 53 req: req, 54 } 55 //绕开支付宝请求bug 56 if req.Header.Get("Content-Type") == "application/x-www-form-urlencoded; text/html; charset=utf-8" { 57 req.Header.Set("Content-Type", "application/x-www-form-urlencoded; charset=utf-8") 58 } 59 err := req.ParseForm() 60 if err != nil { 61 panic(err) 62 } 63 for key, value := range req.Form { 64 context.inMap[key] = value[0] //TODO 这里没有处理同一个 key 多个 value 的情况 65 } 66 originContentType := req.Header.Get("Content-Type") 67 if originContentType == "" { 68 return context 69 } 70 contentType, _, err := mime.ParseMediaType(originContentType) 71 if err != nil { 72 panic(fmt.Errorf("[NewContextFromHttp] %s %s", originContentType, err.Error())) 73 } 74 if contentType != "multipart/form-data" { 75 return context 76 } 77 err = req.ParseMultipartForm(defaultMaxMemory) 78 if err != nil { 79 panic(err) 80 } 81 for key, value := range req.MultipartForm.File { 82 context.requestFile[key] = value[0] 83 } 84 for key, value := range req.MultipartForm.Value { 85 context.inMap[key] = value[0] 86 } 87 return context 88 } 89 90 //返回一个新的测试上下文,这个上下文的所有参数都是空的 91 func NewTestContext() *Context { 92 //调用 ctx 上的函数是不会更新这里的 buf 的 93 buf := []byte("test") 94 req, err := http.NewRequest("GET", "/testContext", bytes.NewReader(buf)) 95 if err != nil { 96 panic(err) 97 } 98 return &Context{ 99 requestUrl: "/testContext", 100 inMap: map[string]string{}, 101 requestFile: map[string]*multipart.FileHeader{}, 102 responseCode: 200, 103 sessionMap: map[string]string{}, 104 method: "GET", 105 req: req, 106 } 107 } 108 109 //根据key返回输入参数,包括post和url的query的数据,如果没有,或者不是整数返回0 返回类型为int 110 func (c *Context) InNum(key string) int { 111 value, ok := c.inMap[key] 112 if !ok { 113 return 0 114 } 115 num, err := strconv.Atoi(value) 116 if err != nil { 117 return 0 118 } 119 return num 120 } 121 122 //根据key返回输入参数,包括post和url的query的数据,如果没有返回"" 类型为string 123 func (c *Context) InStr(key string) string { 124 return c.inMap[key] 125 } 126 127 func (c *Context) InStrDefault(key string, def string) string { 128 out := c.inMap[key] 129 if out == "" { 130 return def 131 } 132 return out 133 } 134 135 func (c *Context) MustPost() { 136 if !c.IsPost() { 137 panic(errors.New("Need post")) 138 } 139 } 140 141 func (c *Context) IsGet() bool { 142 return c.method == "GET" 143 } 144 func (c *Context) IsPost() bool { 145 return c.method == "POST" 146 } 147 148 func (c *Context) MustInNum(key string) int { 149 s := c.InNum(key) 150 if s == 0 { 151 panic(fmt.Errorf("Need %s parameter", key)) 152 } 153 return s 154 } 155 156 func (c *Context) InHas(key string) bool { 157 return c.inMap[key] != "" 158 } 159 160 func (c *Context) MustInStr(key string) string { 161 s := c.InStr(key) 162 if s == "" { 163 panic(fmt.Errorf("Need %s parameter", key)) 164 } 165 return s 166 } 167 168 func (c *Context) MustInJson(key string, obj interface{}) { 169 s := c.MustInStr(key) 170 err := json.Unmarshal([]byte(s), obj) 171 if err != nil { 172 panic(err) 173 } 174 return 175 } 176 177 func (c *Context) MustInFile(key string) *multipart.FileHeader { 178 file := c.requestFile[key] 179 if file == nil { 180 panic(fmt.Errorf("Need %s file", key)) 181 } 182 return file 183 } 184 185 func (c *Context) MustFirstInFile() *multipart.FileHeader { 186 for _, file := range c.requestFile { 187 return file 188 } 189 panic(fmt.Errorf("Need a upload file")) 190 } 191 192 func (c *Context) SetInStr(key string, value string) *Context { 193 c.inMap[key] = value 194 return c 195 } 196 197 func (c *Context) DeleteInMap(key string) { 198 delete(c.inMap, key) 199 } 200 201 func (c *Context) SetInMap(data map[string]string) *Context { 202 c.inMap = data 203 return c 204 } 205 206 func (c *Context) GetInMap() map[string]string { 207 return c.inMap 208 } 209 210 func (c *Context) SetPost() *Context { 211 c.method = "POST" 212 return c 213 } 214 215 func (c *Context) GetDataStr(key string) string { 216 if c.DataMap == nil { 217 return "" 218 } 219 return c.DataMap[key] 220 } 221 func (c *Context) SetDataStr(key string, value string) { 222 if c.DataMap == nil { 223 c.DataMap = map[string]string{} 224 } 225 c.DataMap[key] = value 226 } 227 228 func (c *Context) sessionInit() { 229 if c.sessionMap != nil { 230 return 231 } 232 cookie, err := c.req.Cookie(SessionCookieName) 233 if err != nil { 234 //kmgErr.LogErrorWithStack(err) 235 // 这个地方没有cookie是正常情况 236 c.sessionMap = map[string]string{} 237 //没有Cooke 238 return 239 } 240 output, err := kmgCrypto.CompressAndEncryptBase64Decode(&SessionPsk, cookie.Value) 241 if err != nil { 242 kmgErr.LogErrorWithStack(err) 243 c.sessionMap = map[string]string{} 244 return 245 } 246 err = json.Unmarshal(output, &c.sessionMap) 247 if err != nil { 248 kmgErr.LogErrorWithStack(err) 249 c.sessionMap = map[string]string{} 250 return 251 } 252 } 253 254 //向Session里面设置一个字符串 255 func (c *Context) SessionSetStr(key string, value string) *Context { 256 c.sessionInit() 257 c.sessionHasSet = true 258 c.sessionMap[key] = value 259 return c 260 } 261 262 //从Session里面获取一个字符串 263 func (c *Context) SessionGetStr(key string) string { 264 c.sessionInit() 265 return c.sessionMap[key] 266 } 267 268 func (c *Context) SessionSetJson(key string, value interface{}) *Context { 269 json, err := json.Marshal(value) 270 if err != nil { 271 panic(err) //不能Marshal一定是代码的问题 272 } 273 c.SessionSetStr(key, string(json)) 274 return c 275 } 276 277 func (c *Context) SessionGetJson(key string, obj interface{}) (err error) { 278 out := c.SessionGetStr(key) 279 if out == "" { 280 return errors.New("Session Empty") 281 } 282 err = json.Unmarshal([]byte(out), obj) 283 return err 284 } 285 286 //清除Session里面的内容. 287 //更换Session的Id. 288 func (c *Context) SessionClear() *Context { 289 c.sessionInit() 290 c.sessionHasSet = len(c.sessionMap) > 0 291 c.sessionMap = map[string]string{} 292 return c 293 } 294 295 //仅把Session传递过去的上下文,其他东西都恢复默认值 296 func (c *Context) NewTestContextWithSession() *Context { 297 nc := NewTestContext() 298 nc.sessionMap = c.sessionMap 299 return nc 300 } 301 302 func (c *Context) SetResponseCode(code int) { 303 c.responseCode = code 304 } 305 306 func (c *Context) Redirect(url string) { 307 c.redirectUrl = url 308 c.responseCode = 302 // 302 是无缓存跳转,请不要修改为其他code 309 } 310 311 func (c *Context) NotFound(msg string) { 312 c.responseBuffer.WriteString(msg) 313 c.responseCode = 404 314 } 315 316 func (c *Context) Error(err error) { 317 c.responseBuffer.WriteString(err.Error()) 318 c.responseCode = 500 319 } 320 321 func (c *Context) WriteString(s string) { 322 c.responseBuffer.WriteString(s) 323 } 324 325 func (c *Context) WriteByte(s []byte) { 326 c.responseBuffer.Write(s) 327 } 328 329 func (c *Context) WriteAttachmentFile(b []byte, fileName string) { 330 c.responseBuffer.Write(b) 331 c.SetResponseHeader("Content-Disposition", "attachment;filename="+fileName) 332 } 333 334 func (c *Context) WriteJson(obj interface{}) { 335 c.responseBuffer.Write(kmgJson.MustMarshal(obj)) 336 } 337 338 func (c *Context) WriteToResponseWriter(w http.ResponseWriter, req *http.Request) { 339 for key, value := range c.responseHeader { 340 w.Header().Set(key, value) 341 } 342 if c.sessionMap != nil && c.sessionHasSet { 343 http.SetCookie(w, &http.Cookie{ 344 Name: SessionCookieName, 345 Value: kmgCrypto.CompressAndEncryptBase64Encode(&SessionPsk, kmgJson.MustMarshal(c.sessionMap)), 346 }) 347 } 348 if c.redirectUrl != "" { 349 http.Redirect(w, req, c.redirectUrl, c.responseCode) 350 return 351 } 352 w.WriteHeader(c.responseCode) 353 if c.responseBuffer.Len() > 0 { 354 w.Write(c.responseBuffer.Bytes()) 355 } 356 } 357 358 func (c *Context) SetResponseHeader(key string, value string) { 359 if c.responseHeader == nil { 360 c.responseHeader = map[string]string{} 361 } 362 c.responseHeader[key] = value 363 } 364 365 func (c *Context) GetResponseHeader(key string) string { 366 return c.responseHeader[key] 367 } 368 369 func (c *Context) GetResponseWriter() io.Writer { 370 return &c.responseBuffer 371 } 372 373 func (c *Context) GetRequest() *http.Request { 374 return c.req //调用者可以拿去干一些高级的事情 375 } 376 377 func (c *Context) GetRequestUrl() string { 378 return c.requestUrl 379 } 380 381 func (c *Context) SetRequestUrl(url string) *Context { 382 c.requestUrl = url 383 return c 384 } 385 386 func (c *Context) GetRedirectUrl() string { 387 return c.redirectUrl 388 } 389 390 func (c *Context) GetResponseCode() int { 391 return c.responseCode 392 } 393 394 func (c *Context) GetResponseByteList() []byte { 395 return c.responseBuffer.Bytes() 396 } 397 398 func (c *Context) GetResponseString() string { 399 return c.responseBuffer.String() 400 } 401 402 func (c *Context) MustGetClientIp() net.IP { 403 if c.req == nil { 404 return nil 405 } 406 if c.req.RemoteAddr == "" { 407 return nil 408 } 409 host, _, err := net.SplitHostPort(c.req.RemoteAddr) 410 if err != nil { 411 panic(err) 412 } 413 return net.ParseIP(host) 414 } 415 416 func (c *Context) GetClientIpStringIgnoreError() string { 417 if c.req == nil { 418 return "" 419 } 420 if c.req.RemoteAddr == "" { 421 return "" 422 } 423 host, _, err := net.SplitHostPort(c.req.RemoteAddr) 424 if err != nil { 425 return "" 426 } 427 ip := net.ParseIP(host) 428 if ip == nil { 429 return "" 430 } 431 return ip.String() 432 } 433 434 type ContextLog struct { 435 Method string `json:",omitempty"` 436 ResponseCode int `json:",omitempty"` 437 Url string `json:",omitempty"` 438 RemoteAddr string `json:",omitempty"` 439 UA string `json:",omitempty"` 440 Refer string `json:",omitempty"` 441 RedirectUrl string `json:",omitempty"` 442 InMap map[string]string `json:",omitempty"` 443 ProcessTime string `json:",omitempty"` 444 RequestSize int `json:",omitempty"` 445 ResponseSize int `json:",omitempty"` 446 ResponseContent string `json:",omitempty"` 447 } 448 449 func (c *Context) Log() *ContextLog { 450 out := &ContextLog{ 451 Method: c.method, 452 ResponseCode: c.responseCode, 453 Url: c.requestUrl, 454 RedirectUrl: c.redirectUrl, 455 InMap: c.inMap, 456 ResponseSize: c.responseBuffer.Len(), 457 } 458 if c.req != nil { 459 out.RemoteAddr = c.req.RemoteAddr 460 out.UA = c.req.UserAgent() 461 out.Refer = c.req.Referer() 462 out.RequestSize = int(c.req.ContentLength) 463 } 464 //小于64个字节,并且都是utf8,就输出到log里面 465 if out.ResponseSize <= 64 { 466 out.ResponseContent = c.responseBuffer.String() 467 if !utf8.ValidString(out.ResponseContent) { 468 out.ResponseContent = "" 469 } 470 } 471 return out 472 } 473 474 /* 475 //这个返回类型可能有问题 476 func (c *Context)InArray(key string)[]string{ 477 return nil 478 } 479 */ 480 481 func init() { 482 kmgCrypto.RegisterPskChangeCallback(pskChange) 483 } 484 func pskChange() { 485 psk1 := kmgCrypto.GetPskFromDefaultPsk(6, "kmgHttp.SessionCookieName") 486 SessionCookieName = "kmgSession" + kmgBase64.Base64EncodeByteToString(psk1) 487 488 psk2 := kmgCrypto.GetPskFromDefaultPsk(32, "kmgHttp.SessionPsk") 489 copy(SessionPsk[:], psk2) 490 }