github.com/cnotch/ipchub@v1.1.0/service/rtsp/session.go (about) 1 // Copyright (c) 2019,CAOHONGJU All rights reserved. 2 // Use of this source code is governed by a MIT-style 3 // license that can be found in the LICENSE file. 4 5 package rtsp 6 7 import ( 8 "bytes" 9 "errors" 10 "fmt" 11 "io" 12 "net" 13 "net/url" 14 "runtime/debug" 15 "strings" 16 "sync" 17 "time" 18 19 "github.com/cnotch/ipchub/config" 20 "github.com/cnotch/ipchub/media" 21 "github.com/cnotch/ipchub/network/socket/buffered" 22 "github.com/cnotch/ipchub/network/websocket" 23 "github.com/cnotch/ipchub/provider/auth" 24 "github.com/cnotch/ipchub/provider/security" 25 "github.com/cnotch/ipchub/stats" 26 "github.com/cnotch/ipchub/utils" 27 "github.com/cnotch/xlog" 28 "github.com/pixelbender/go-sdp/sdp" 29 ) 30 31 const ( 32 realm = config.Name 33 ) 34 35 const ( 36 statusInit = iota 37 statusReady 38 statusPlaying 39 statusRecording 40 ) 41 42 var buffers = sync.Pool{ 43 New: func() interface{} { 44 return bytes.NewBuffer(make([]byte, 0, 1024*2)) 45 }, 46 } 47 48 // Session RTSP 会话 49 type Session struct { 50 // 创建时设置 51 svr *Server 52 logger *xlog.Logger 53 closed bool 54 lsession string // 本地会话标识 55 timeout time.Duration 56 conn *buffered.Conn 57 lockW sync.Mutex 58 59 wsconn websocket.Conn 60 61 authMode auth.Mode 62 nonce string 63 user *auth.User 64 65 // DESCRIBE,或 ANNOUNCE 后设置 66 url *url.URL 67 path string 68 rawSdp string 69 sdp *sdp.Session 70 aControl string 71 vControl string 72 aCodec string 73 vCodec string 74 mode SessionMode 75 76 // Setup 后设置 77 transport RTPTransport 78 79 // 启动流媒体传输后设置 80 status int // session状态 81 stream mediaStream // 媒体流 82 consumer media.Consumer // 消费者 83 } 84 85 func newSession(svr *Server, conn net.Conn) *Session { 86 87 session := &Session{ 88 svr: svr, 89 lsession: security.NewID().Base64(), 90 timeout: config.NetTimeout(), 91 conn: buffered.NewConn(conn, 92 buffered.FlushRate(config.NetFlushRate()), 93 buffered.BufferSize(config.NetBufferSize())), 94 mode: UnknownSession, 95 transport: RTPTransport{ 96 Mode: PlaySession, // 默认为播放 97 Type: RTPUnknownTrans, 98 }, 99 authMode: config.RtspAuthMode(), 100 nonce: security.NewID().MD5(), 101 status: statusInit, 102 stream: defaultStream, 103 consumer: defaultConsumer, 104 } 105 106 if wsc, ok := conn.(websocket.Conn); ok { // 如果是WebSocket,有http进行验证 107 session.authMode = auth.NoneAuth 108 session.wsconn = wsc 109 session.path = wsc.Path() 110 session.user = auth.Get(wsc.Username()) 111 } 112 113 // ipaddr, _ := address.Parse(conn.RemoteAddr().String(), 80) 114 // // 如果是本机IP,不验证;以便ffmpeg本机rtsp->rtmp 115 // if network.IsLocalhostIP(ipaddr.IP) { 116 // session.authMode = auth.NoneAuth 117 // } 118 119 for i := rtpChannelMin; i < rtpChannelCount; i++ { 120 session.transport.Channels[i] = -1 121 session.transport.ClientPorts[i] = -1 122 } 123 session.logger = svr.logger.With(xlog.Fields( 124 xlog.F("session", session.lsession))) 125 126 return session 127 } 128 129 // Addr Session地址 130 func (s *Session) Addr() string { 131 return s.conn.RemoteAddr().String() 132 } 133 134 // Consume 消费媒体包 135 func (s *Session) Consume(p Pack) { 136 s.consumer.Consume(p) 137 } 138 139 // Close 关闭会话 140 func (s *Session) Close() error { 141 if s.closed { 142 return nil 143 } 144 145 s.closed = true 146 s.conn.Close() 147 return nil 148 } 149 150 func (s *Session) process() { 151 defer func() { 152 if r := recover(); r != nil { 153 s.logger.Errorf("session panic; %v \n %s", r, debug.Stack()) 154 } 155 156 stats.RtspConns.Release() 157 s.Close() 158 s.consumer.Close() 159 s.stream.Close() 160 161 // 重置到初始状态 162 s.conn = nil 163 s.status = statusInit 164 s.stream = defaultStream 165 s.consumer = defaultConsumer 166 s.logger.Infof("close rtsp session") 167 }() 168 169 s.logger.Infof("open rtsp session") 170 stats.RtspConns.Add() // 增加一个 RTSP 连接计数 171 reader := s.conn.Reader() 172 173 for !s.closed { 174 deadLine := time.Time{} 175 if s.timeout > 0 { 176 deadLine = time.Now().Add(s.timeout) 177 } 178 if err := s.conn.SetReadDeadline(deadLine); err != nil { 179 s.logger.Error(err.Error()) 180 break 181 } 182 183 err := receive(s.logger, reader, s.transport.Channels[:], s) 184 if err != nil { 185 if err == io.EOF { // 如果客户端断开提醒 186 s.logger.Warn("The client actively disconnects") 187 } else if !s.closed { // 如果主动关闭,不提示 188 s.logger.Error(err.Error()) 189 } 190 break 191 } 192 } 193 } 194 195 // receiveHandler.onPack 196 func (s *Session) onPack(pack *RTPPack) (err error) { 197 return s.stream.WritePacket(pack) 198 } 199 200 // receiveHandler.onResponse 201 func (s *Session) onResponse(resp *Response) (err error) { 202 // 忽略,服务器不会主动发起请求 203 return 204 } 205 206 // receiveHandler.onRequest 207 func (s *Session) onRequest(req *Request) (err error) { 208 resp := s.newResponse(StatusOK, req) 209 // 预处理 210 continueProcess, err := s.onPreprocess(resp, req) 211 if !continueProcess { 212 return err 213 } 214 215 switch req.Method { 216 case MethodDescribe: 217 s.onDescribe(resp, req) 218 case MethodAnnounce: 219 s.onAnnounce(resp, req) 220 case MethodSetup: 221 s.onSetup(resp, req) 222 case MethodRecord: 223 s.onRecord(resp, req) 224 case MethodPlay: 225 return s.onPlay(resp, req) // play 发送流媒体不在当前 routine,需要先回复 226 default: 227 // 状态不支持的方法 228 resp.StatusCode = StatusMethodNotValidInThisState 229 } 230 231 // 发送响应 232 err = s.response(resp) 233 return err 234 } 235 236 func (s *Session) onDescribe(resp *Response, req *Request) { 237 238 // TODO: 检查 accept 中的类型是否包含 sdp 239 s.url = req.URL 240 if s.wsconn == nil { // websocket访问的路径有ws://路径表示 241 s.path = utils.CanonicalPath(req.URL.Path) 242 } 243 244 stream := media.GetOrCreate(s.path) 245 if stream == nil { 246 resp.StatusCode = StatusNotFound 247 return 248 } 249 250 if !s.checkPermission(auth.PullRight) { 251 resp.StatusCode = StatusForbidden 252 return 253 } 254 255 // 从流中取 sdp 256 sdpRaw := stream.Sdp() 257 if len(sdpRaw) == 0 { 258 resp.StatusCode = StatusNotFound 259 return 260 } 261 err := s.parseSdp(sdpRaw) 262 if err != nil { // TODO:需要更好的处理方式 263 resp.StatusCode = StatusNotFound 264 return 265 } 266 267 resp.Header.Set(FieldContentType, "application/sdp") 268 resp.Body = s.rawSdp 269 s.mode = PlaySession // 标记为播放会话 270 } 271 272 func (s *Session) onAnnounce(resp *Response, req *Request) { 273 274 // 检查 Content-Type: application/sdp 275 if req.Header.Get(FieldContentType) != "application/sdp" { 276 resp.StatusCode = StatusBadRequest // TODO:更合适的代码 277 return 278 } 279 280 s.url = req.URL 281 s.path = utils.CanonicalPath(req.URL.Path) 282 283 if !s.checkPermission(auth.PushRight) { 284 resp.StatusCode = StatusForbidden 285 return 286 } 287 288 // 从流中取 sdp 289 err := s.parseSdp(req.Body) 290 if err != nil { 291 resp.StatusCode = StatusBadRequest 292 return 293 } 294 295 s.mode = RecordSession // 标记为录像会话 296 } 297 298 func (s *Session) onSetup(resp *Response, req *Request) { 299 // a=control:streamid=1 300 // a=control:rtsp://192.168.1.165/trackID=1 301 // a=control:?ctype=video 302 setupURL := &url.URL{} 303 *setupURL = *req.URL 304 if setupURL.Port() == "" { 305 setupURL.Host = fmt.Sprintf("%s:554", setupURL.Host) 306 } 307 setupPath := setupURL.String() 308 309 //setupPath = setupPath[strings.LastIndex(setupPath, "/")+1:] 310 vPath, err := getControlPath(s.vControl) 311 if err != nil { 312 resp.StatusCode = StatusInternalServerError 313 resp.Status = "Invalid VControl" 314 return 315 } 316 317 aPath, err := getControlPath(s.aControl) 318 if err != nil { 319 resp.StatusCode = StatusInternalServerError 320 resp.Status = "Invalid AControl" 321 return 322 } 323 324 ts := req.Header.Get(FieldTransport) 325 resp.Header.Set(FieldTransport, ts) // 先回写transport 326 327 // 检查控制路径 328 chindex := -1 329 if setupPath == aPath || (aPath != "" && strings.LastIndex(setupPath, aPath) == len(setupPath)-len(aPath)) { 330 chindex = int(ChannelAudio) 331 } else if setupPath == vPath || (vPath != "" && strings.LastIndex(setupPath, vPath) == len(setupPath)-len(vPath)) { 332 chindex = int(ChannelVideo) 333 } else { // 找不到被 Setup 的资源 334 resp.StatusCode = StatusInternalServerError 335 resp.Status = fmt.Sprintf("SETUP Unkown control:%s", setupPath) 336 return 337 } 338 339 err = s.transport.ParseTransport(chindex, ts) 340 if err != nil { 341 resp.StatusCode = StatusInvalidParameter 342 resp.Status = err.Error() 343 return 344 } 345 346 // 检查和以前的命令是否一致 347 if s.mode == UnknownSession { 348 s.mode = s.transport.Mode 349 } 350 351 if s.mode != s.transport.Mode { 352 resp.StatusCode = StatusInvalidParameter 353 if s.mode == PlaySession { 354 resp.Status = "Current state can't setup as record" 355 } else { 356 resp.Status = "Current state can't setup as play" 357 } 358 return 359 } 360 361 // record 只支持 TCP 单播 362 if s.mode == RecordSession { 363 // 检查用户权限 364 if !s.checkPermission(auth.PushRight) { 365 resp.StatusCode = StatusForbidden 366 return 367 } 368 369 if s.transport.Type != RTPTCPUnicast { 370 resp.StatusCode = StatusUnsupportedTransport 371 resp.Status = "when mode = record,only support tcp unicast" 372 } else { 373 if s.status < statusReady { // 初始状态切换到Ready 374 s.status = statusReady 375 } 376 } 377 return 378 } 379 380 // 检查用户权限,播放 381 if !s.checkPermission(auth.PullRight) { 382 resp.StatusCode = StatusForbidden 383 return 384 } 385 386 if s.transport.Type == RTPMulticast { // 需要修改回复的transport 387 st := media.GetOrCreate(s.path) 388 if st == nil { // 没有找到源 389 resp.StatusCode = StatusNotFound 390 return 391 } 392 ma := st.Multicastable() 393 if ma == nil { // 不支持组播 394 resp.StatusCode = StatusUnsupportedTransport 395 return 396 } 397 398 ts = fmt.Sprintf("%s;destination=%s;port=%d-%d;source=%s;ttl=%d", 399 ts, ma.MulticastIP(), 400 ma.Port(chindex), ma.Port(chindex+1), 401 ma.SourceIP(), ma.TTL()) 402 resp.Header.Set(FieldTransport, ts) 403 } 404 405 if s.status < statusReady { // 初始状态切换到Ready 406 s.status = statusReady 407 } 408 } 409 410 func (s *Session) onRecord(resp *Response, req *Request) { 411 if s.status == statusRecording { 412 return 413 } 414 415 // 传输模式、会话模式判断 416 if s.mode != RecordSession || s.transport.Type != RTPTCPUnicast { 417 resp.StatusCode = StatusMethodNotValidInThisState 418 return 419 } 420 421 if !s.checkPermission(auth.PushRight) { 422 resp.StatusCode = StatusForbidden 423 return 424 } 425 426 s.asTCPPusher() 427 s.status = statusRecording 428 } 429 430 func (s *Session) onPlay(resp *Response, req *Request) (err error) { 431 if s.status == statusPlaying { 432 return 433 } 434 435 // 传输模式、会话模式判断 436 if s.mode != PlaySession || s.transport.Type == RTPUnknownTrans { 437 resp.StatusCode = StatusMethodNotValidInThisState 438 return s.response(resp) 439 } 440 441 stream := media.GetOrCreate(s.path) 442 if stream == nil { 443 resp.StatusCode = StatusNotFound 444 return s.response(resp) 445 } 446 447 if !s.checkPermission(auth.PullRight) { 448 resp.StatusCode = StatusForbidden 449 return s.response(resp) 450 } 451 452 resp.Header.Set(FieldRange, req.Header.Get(FieldRange)) 453 switch s.transport.Type { 454 case RTPTCPUnicast: 455 err = s.asTCPConsumer(stream, resp) 456 case RTPUDPUnicast: 457 err = s.asUDPConsumer(stream, resp) 458 default: 459 err = s.asMulticastConsumer(stream, resp) 460 } 461 462 if err == nil { 463 s.status = statusPlaying 464 } 465 return 466 } 467 468 func (s *Session) checkPermission(right auth.AccessRight) bool { 469 if s.authMode == auth.NoneAuth { 470 return true 471 } 472 473 if s.user == nil { 474 return false 475 } 476 477 return s.user.ValidatePermission(s.path, right) 478 } 479 480 func (s *Session) checkAuth(r *Request) (user *auth.User, err error) { 481 switch s.authMode { 482 case auth.BasicAuth: 483 username, password, has := r.BasicAuth() 484 if !has { 485 return nil, errors.New("require legal Authorization field") 486 } 487 user := auth.Get(username) 488 if user == nil { 489 return nil, errors.New("user not exist") 490 } 491 err = user.ValidatePassword(password) 492 if err != nil { 493 return nil, err 494 } 495 return user, nil 496 497 case auth.DigestAuth: 498 username, response, has := r.DigestAuth() 499 if !has { 500 return nil, errors.New("require legal Authorization field") 501 } 502 user := auth.Get(username) 503 if user == nil { 504 return nil, errors.New("user not exist") 505 } 506 resp2 := formatDigestAuthResponse(realm, s.nonce, r.Method, r.URL.String(), username, user.Password) 507 if resp2 == response { 508 return user, nil 509 } 510 resp2 = formatDigestAuthResponse(realm, s.nonce, r.Method, r.URL.String(), username, user.PasswordMD5()) 511 if resp2 == response { 512 return user, nil 513 } 514 s.nonce = security.NewID().MD5() 515 return nil, errors.New("require legal Authorization field") 516 default: // 无需验证 517 return nil, nil 518 } 519 } 520 521 func (s *Session) onPreprocess(resp *Response, req *Request) (continueProcess bool, err error) { 522 // Options 方法无需验证,直接回复 523 if req.Method == MethodOptions { 524 resp.Header.Set(FieldPublic, "DESCRIBE, SETUP, TEARDOWN, PLAY, OPTIONS, ANNOUNCE, RECORD") 525 err = s.response(resp) 526 return false, err 527 } 528 529 // 关闭请求 530 if req.Method == MethodTeardown { 531 // 发送响应 532 err = s.response(resp) 533 s.Close() 534 return false, err 535 } 536 537 // 检查状态下的方法 538 switch s.status { 539 case statusReady: 540 continueProcess = req.Method == MethodSetup || 541 req.Method == MethodPlay || req.Method == MethodRecord 542 case statusPlaying: 543 continueProcess = req.Method == MethodPlay 544 case statusRecording: 545 continueProcess = req.Method == MethodRecord 546 default: 547 continueProcess = !(req.Method == MethodPlay || req.Method == MethodRecord) 548 } 549 if !continueProcess { 550 resp.StatusCode = StatusMethodNotValidInThisState 551 err = s.response(resp) 552 return false, err 553 } 554 555 // 检查认证 556 user, err2 := s.checkAuth(req) 557 if err2 != nil { 558 resp.StatusCode = StatusUnauthorized 559 if err2 != nil { 560 resp.Status = err2.Error() 561 } 562 err = s.response(resp) 563 return false, err 564 } 565 566 s.user = user 567 return true, nil 568 } 569 570 func (s *Session) response(resp *Response) error { 571 s.lockW.Lock() 572 573 var err error 574 575 if s.wsconn != nil { // websocket 客户端 576 buf := buffers.Get().(*bytes.Buffer) 577 buf.Reset() 578 defer buffers.Put(buf) 579 580 err = resp.Write(buf) // 保证写入包的完整性,简化前端分包 581 _, err = s.wsconn.Write(buf.Bytes()) 582 } else { 583 err = resp.Write(s.conn) 584 if err == nil { 585 _, err = s.conn.Flush() 586 } 587 } 588 589 s.lockW.Unlock() 590 591 if err != nil { 592 s.logger.Errorf("send response error = %v", err) 593 return err 594 } 595 596 if s.logger.LevelEnabled(xlog.DebugLevel) { 597 s.logger.Debugf("===>>>\r\n%s", strings.TrimSpace(resp.String())) 598 } 599 600 return nil 601 } 602 603 func (s *Session) newResponse(code int, req *Request) *Response { 604 resp := &Response{ 605 StatusCode: code, 606 Header: make(Header), 607 Request: req, 608 } 609 610 resp.Header.Set(FieldCSeq, req.Header.Get(FieldCSeq)) 611 resp.Header.Set(FieldSession, s.lsession) 612 613 // 根据认证模式增加认证所需的字段 614 switch s.authMode { 615 case auth.BasicAuth: 616 resp.SetBasicAuth(realm) 617 case auth.DigestAuth: 618 resp.SetDigestAuth(realm, s.nonce) 619 } 620 return resp 621 } 622 623 func (s *Session) parseSdp(rawSdp string) (err error) { 624 // 从流中取 sdp 625 s.rawSdp = rawSdp 626 // 解析 627 s.sdp, err = sdp.ParseString(s.rawSdp) 628 if err != nil { 629 return 630 } 631 632 for _, media := range s.sdp.Media { 633 switch media.Type { 634 case "video": 635 s.vControl = media.Attributes.Get("control") 636 s.vCodec = media.Format[0].Name 637 case "audio": 638 s.aControl = media.Attributes.Get("control") 639 s.aCodec = media.Format[0].Name 640 } 641 } 642 return 643 } 644 645 func getControlPath(ctrl string) (path string, err error) { 646 if len(ctrl) >= len(rtspURLPrefix) && strings.EqualFold(ctrl[:len(rtspURLPrefix)], rtspURLPrefix) { 647 var ctrlURL *url.URL 648 ctrlURL, err = url.Parse(ctrl) 649 if err != nil { 650 return "", err 651 } 652 if ctrlURL.Port() == "" { 653 ctrlURL.Host = fmt.Sprintf("%s:554", ctrlURL.Hostname()) 654 } 655 return ctrlURL.String(), nil 656 } 657 return ctrl, nil 658 }