github.com/cnotch/ipchub@v1.1.0/service/rtsp/pull_client.go (about) 1 // Copyright (c) 2019,CAOHONGJU All rights reserved. 2 // Use of this source code is governed by a MIT-style 3 // license that can be found in the LICENSE file. 4 5 package rtsp 6 7 import ( 8 "crypto/md5" 9 "encoding/hex" 10 "errors" 11 "fmt" 12 "io" 13 "net" 14 "net/url" 15 "runtime/debug" 16 "strconv" 17 "strings" 18 "sync" 19 "sync/atomic" 20 "time" 21 22 "github.com/cnotch/ipchub/config" 23 "github.com/cnotch/ipchub/media" 24 "github.com/cnotch/ipchub/network/socket/buffered" 25 "github.com/cnotch/ipchub/stats" 26 "github.com/cnotch/ipchub/utils" 27 "github.com/cnotch/xlog" 28 "github.com/pixelbender/go-sdp/sdp" 29 ) 30 31 const ( 32 defaultUserAgent = config.Name + "-rstp-client/1.0" 33 ) 34 35 // PullClient 负责拉流到服务器 36 type PullClient struct { 37 // 打开前设置 38 closed bool 39 url *url.URL 40 userName string 41 password string 42 md5password string 43 path string 44 rtpChannels [rtpChannelCount]int 45 logger *xlog.Logger 46 47 // 添加到流媒体中心后设置 48 stream *media.Stream 49 50 // 打开连接后设置 51 conn *buffered.Conn 52 lockW sync.Mutex 53 realm string 54 nonce string 55 rsession string 56 seq int64 57 58 rawSdp string 59 sdp *sdp.Session 60 aControl string 61 vControl string 62 aCodec string 63 vCodec string 64 } 65 66 // NewPullClient 创建拉流客户端 67 func NewPullClient(localPath, remoteURL string) (*PullClient, error) { 68 // 检查远端路径 69 url, err := url.Parse(remoteURL) 70 if err != nil { 71 return nil, err 72 } 73 if strings.ToLower(url.Scheme) != "rtsp" { 74 return nil, fmt.Errorf("RemoteURL '%s' is not RTSP url", remoteURL) 75 } 76 if strings.ToLower(url.Hostname()) == "" { 77 return nil, fmt.Errorf("RemoteURL '%s' is not RTSP url", remoteURL) 78 } 79 // 如果没有 port,补上默认端口 80 port := url.Port() 81 if len(port) == 0 { 82 url.Host = url.Hostname() + ":554" 83 } 84 85 // 提取用户名和密码 86 var userName, password string 87 if url.User != nil { 88 userName = url.User.Username() 89 password, _ = url.User.Password() 90 url.User = nil 91 } 92 93 // 检查发布路径 94 path := utils.CanonicalPath(localPath) 95 96 if path == "" { 97 path = utils.CanonicalPath(url.Path) 98 } else { 99 _, err := url.Parse("rtsp://localhost" + path) 100 if err != nil { 101 return nil, fmt.Errorf("Path '%s' 不合法", localPath) 102 } 103 } 104 105 client := &PullClient{ 106 closed: true, 107 url: url, 108 userName: userName, 109 password: password, 110 path: path, 111 } 112 113 for i := rtpChannelMin; i < rtpChannelCount; i++ { 114 client.rtpChannels[i] = int(i) 115 } 116 117 client.logger = xlog.L().With(xlog.Fields( 118 xlog.F("path", client.path), 119 xlog.F("rurl", client.url.String()), 120 xlog.F("type", "pull"))) 121 122 return client, nil 123 } 124 125 // Ping 测试网络和服务器 126 func (c *PullClient) Ping() error { 127 if !c.closed { 128 return nil 129 } 130 131 defer func() { 132 c.disconnect() 133 c.conn = nil 134 c.stream = nil 135 }() 136 137 err := c.connect() 138 if err != nil { 139 return err 140 } 141 142 // OPTIONS 尝试握手 143 err = c.requestHandshake() 144 if err != nil { 145 return err 146 } 147 148 // DESCRIBE 获取 sdp,看是否存在指定媒体 149 return c.requestSDP() 150 } 151 152 // Open 打开拉流客户端 153 // 依次发生请求:OPTIONS、DESCRIBE、SETUP、PLAY 154 // 全部成功,启动接收 RTP流 go routine 155 func (c *PullClient) Open() (err error) { 156 if !c.closed { 157 return nil 158 } 159 160 defer func() { 161 if err != nil { // 出现任何错误执行断链操作 162 c.disconnect() 163 c.conn = nil 164 c.stream = nil 165 } 166 }() 167 168 // 连接 169 err = c.connect() 170 if err != nil { 171 return err 172 } 173 174 // 请求握手 175 err = c.requestHandshake() 176 if err != nil { 177 return err 178 } 179 180 // 获取流信息 181 err = c.requestSDP() 182 if err != nil { 183 return err 184 } 185 186 // 设置通讯通道 187 err = c.requestSetup() 188 if err != nil { 189 return err 190 } 191 192 // 请求播放 193 err = c.requestPlay() 194 if err != nil { 195 return err 196 } 197 198 return err 199 } 200 201 // Close 关闭客户端 202 func (c *PullClient) Close() error { 203 c.disconnect() 204 return nil 205 } 206 207 func (c *PullClient) requestHandshake() (err error) { 208 // 使用 OPTIONS 尝试握手 209 r := c.newRequest(MethodOptions, c.url) 210 r.Header.Set(FieldRequire, "implicit-play") 211 _, err = c.requestWithResponse(r) 212 return err 213 } 214 215 func (c *PullClient) requestSDP() (err error) { 216 // DESCRIBE 获取 sdp 217 r := c.newRequest(MethodDescribe, c.url) 218 r.Header.Set(FieldAccept, "application/sdp") 219 resp, err := c.requestWithResponse(r) 220 if err != nil { 221 return err 222 } 223 224 // 解析 225 c.rawSdp = resp.Body 226 c.sdp, err = sdp.ParseString(c.rawSdp) 227 if err != nil { 228 return err 229 } 230 231 for _, media := range c.sdp.Media { 232 switch media.Type { 233 case "video": 234 c.vControl = media.Attributes.Get("control") 235 c.vCodec = media.Format[0].Name 236 237 case "audio": 238 c.aControl = media.Attributes.Get("control") 239 c.aCodec = media.Format[0].Name 240 } 241 } 242 return err 243 } 244 245 func (c *PullClient) requestSetup() (err error) { 246 var respVS, respAS *Response 247 // 视频通道设置 248 if len(c.vControl) > 0 { 249 var setupURL *url.URL 250 setupURL, err = c.getSetupURL(c.vControl) 251 252 r := c.newRequest(MethodSetup, setupURL) 253 r.Header.Set(FieldTransport, 254 fmt.Sprintf("RTP/AVP/TCP;unicast;interleaved=%d-%d", c.rtpChannels[ChannelVideo], c.rtpChannels[ChannelVideoControl])) 255 respVS, err = c.requestWithResponse(r) 256 if err != nil { 257 return err 258 } 259 } 260 261 // 音频通道设置 262 if len(c.aControl) > 0 { 263 var setupURL *url.URL 264 setupURL, err = c.getSetupURL(c.aControl) 265 266 r := c.newRequest(MethodSetup, setupURL) 267 r.Header.Set(FieldTransport, 268 fmt.Sprintf("RTP/AVP/TCP;unicast;interleaved=%d-%d", c.rtpChannels[ChannelAudio], c.rtpChannels[ChannelAudioControl])) 269 270 respAS, err = c.requestWithResponse(r) 271 if err != nil { 272 return err 273 } 274 } 275 _ = respVS 276 _ = respAS 277 return 278 } 279 280 func (c *PullClient) requestPlay() (err error) { 281 r := c.newRequest(MethodPlay, c.url) 282 283 resp, err := c.requestWithResponse(r) 284 if err != nil { 285 return err 286 } 287 _ = resp 288 mproxy := &multicastProxy{ 289 path: c.path, 290 bufferSize: config.NetBufferSize(), 291 multicastIP: utils.Multicast.NextIP(), // 设置组播IP 292 ttl: config.MulticastTTL(), 293 logger: c.logger, 294 } 295 296 for i := rtpChannelMin; i < rtpChannelCount; i++ { 297 mproxy.ports[i] = utils.Multicast.NextPort() 298 } 299 300 c.stream = media.NewStream(c.path, c.rawSdp, 301 media.Attr("addr", c.url.String()), 302 media.Multicast(mproxy)) 303 go c.playStream() 304 305 return nil 306 } 307 308 func (c *PullClient) playStream() { 309 defer func() { 310 if r := recover(); r != nil { 311 c.logger.Errorf("pull stream panic; %v \n %s", r, debug.Stack()) 312 } 313 314 stats.RtspConns.Release() // 减少RTSP连接计数 315 media.Unregist(c.stream) // 从媒体中心取消注册 316 c.disconnect() // 确保网络关闭 317 c.conn = nil // 通知GC,尽早释放资源 318 c.stream = nil 319 c.logger.Infof("close pull stream") 320 }() 321 322 c.logger.Infof("open pull stream") 323 media.Regist(c.stream) // 向媒体中心注册流 324 stats.RtspConns.Add() // 增加一个 RTSP 连接计数 325 326 lastHeartbeat := time.Now() 327 reader := c.conn.Reader() 328 heartbeatInterval := config.NetHeartbeatInterval() 329 timeout := config.NetTimeout() 330 331 for !c.closed { 332 deadLine := time.Time{} 333 if timeout > 0 { 334 deadLine = time.Now().Add(timeout) 335 } 336 if err := c.conn.SetReadDeadline(deadLine); err != nil { 337 c.logger.Error(err.Error()) 338 break 339 } 340 341 err := receive(c.logger, reader, c.rtpChannels[:], c) 342 if err != nil { 343 if err == io.EOF { // 如果对方断开 344 c.logger.Warn("The remote RTSP server is actively disconnected.") 345 } else if !c.closed { // 如果非主动关闭 346 c.logger.Error(err.Error()) 347 } 348 break 349 } 350 351 if heartbeatInterval > 0 && time.Now().Sub(lastHeartbeat) > heartbeatInterval { 352 lastHeartbeat = time.Now() 353 // 心跳包 354 r := c.newRequest(MethodOptions, c.url) 355 err := c.request(r) 356 if err != nil { 357 c.logger.Error(err.Error()) 358 break 359 } 360 } 361 } 362 reader = nil 363 } 364 365 func (c *PullClient) onPack(p *RTPPack) error { 366 return c.stream.WriteRtpPacket(p) 367 } 368 369 func (c *PullClient) onRequest(r *Request) (err error) { 370 // 只处理 Options 方法 371 switch r.Method { 372 case MethodOptions: 373 resp := &Response{ 374 StatusCode: 200, 375 Header: r.Header, 376 } 377 resp.Header.Del(FieldUserAgent) 378 resp.Header.Set(FieldPublic, MethodOptions) 379 err = c.response(resp) 380 if err != nil { 381 return err 382 } 383 default: 384 resp := &Response{ 385 StatusCode: StatusMethodNotAllowed, 386 Header: r.Header, 387 } 388 resp.Header.Del(FieldUserAgent) 389 err = c.response(resp) 390 if err != nil { 391 return err 392 } 393 } 394 return nil 395 } 396 397 func (c *PullClient) onResponse(resp *Response) (err error) { 398 // 忽略 399 return 400 } 401 402 func (c *PullClient) getSetupURL(ctrl string) (setupURL *url.URL, err error) { 403 if len(ctrl) >= len(rtspURLPrefix) && strings.EqualFold(ctrl[:len(rtspURLPrefix)], rtspURLPrefix) { 404 return url.Parse(ctrl) 405 } 406 407 setupURL = new(url.URL) 408 *setupURL = *c.url 409 if setupURL.Path[len(setupURL.Path)-1] == '/' { 410 setupURL.Path = setupURL.Path + ctrl 411 } else { 412 setupURL.Path = setupURL.Path + "/" + ctrl 413 } 414 415 return 416 } 417 418 func (c *PullClient) newRequest(method string, url *url.URL) *Request { 419 r := &Request{ 420 Method: method, 421 Header: make(Header), 422 } 423 424 r.URL = url 425 if url == nil { 426 r.URL = c.url 427 } 428 429 r.Header.Set(FieldUserAgent, defaultUserAgent) 430 r.Header.Set(FieldCSeq, strconv.FormatInt(atomic.AddInt64(&c.seq, 1), 10)) 431 if len(c.rsession) > 0 { 432 r.Header.Set(FieldSession, c.rsession) 433 } 434 435 // 和安全相关,已经收到安全作用域信息 436 if len(c.realm) > 0 { 437 pw := c.password 438 if len(c.md5password) > 0 { 439 pw = c.md5password 440 } 441 442 if len(c.nonce) > 0 { 443 // Digest 认证 444 r.SetDigestAuth(r.URL, c.realm, c.nonce, c.userName, pw) 445 } else { 446 // Basic 认证 447 r.SetBasicAuth(c.userName, pw) 448 } 449 } 450 451 return r 452 } 453 454 func (c *PullClient) receiveResponse() (resp *Response, err error) { 455 resp, err = ReadResponse(c.conn.Reader()) 456 if err != nil { 457 return nil, err 458 } 459 460 if c.logger.LevelEnabled(xlog.DebugLevel) { 461 c.logger.Debugf("<<<===\r\n%s", strings.TrimSpace(resp.String())) 462 } 463 464 return 465 } 466 467 func (c *PullClient) requestWithResponse(r *Request) (*Response, error) { 468 err := c.request(r) 469 if err != nil { 470 return nil, err 471 } 472 473 resp, err := c.receiveResponse() 474 if err != nil { 475 return nil, err 476 } 477 478 // 保存 session 479 c.rsession = resp.Header.Get(FieldSession) 480 481 // 如果需要安全信息,增加安全信息并再次请求 482 if resp.StatusCode == StatusUnauthorized { 483 484 if len(c.userName) == 0 { 485 return resp, errors.New("require username and password") 486 } 487 488 pw := c.password 489 auth := resp.Header.Get(FieldWWWAuthenticate) 490 if len(auth) > len(digestAuthPrefix) && strings.EqualFold(auth[:len(digestAuthPrefix)], digestAuthPrefix) { 491 ok := false 492 c.realm, c.nonce, ok = resp.DigestAuth() 493 if !ok { 494 return resp, fmt.Errorf("WWW-Authenticate, %s", auth) 495 } 496 497 r.SetDigestAuth(r.URL, c.realm, c.nonce, c.userName, pw) 498 } else if len(auth) > len(basicAuthPrefix) && strings.EqualFold(auth[:len(basicAuthPrefix)], basicAuthPrefix) { 499 ok := false 500 c.realm, ok = resp.BasicAuth() 501 if !ok { 502 return resp, fmt.Errorf("WWW-Authenticate, %s", auth) 503 } 504 r.SetBasicAuth(c.userName, pw) 505 } else { 506 return resp, fmt.Errorf("WWW-Authenticate, %s", auth) 507 } 508 509 // 修改请求序号 510 r.Header.Set(FieldCSeq, strconv.FormatInt(atomic.AddInt64(&c.seq, 1), 10)) 511 512 err := c.request(r) 513 if err != nil { 514 return nil, err 515 } 516 517 resp, err = c.receiveResponse() 518 if err != nil { 519 return nil, err 520 } 521 522 // 保存 session 523 c.rsession = resp.Header.Get(FieldSession) 524 525 // TODO: 代码臃肿,需要优化 526 // 再试一次 password md5的情况 527 if resp.StatusCode == StatusUnauthorized { 528 md5Digest := md5.Sum([]byte(c.password)) 529 c.md5password = hex.EncodeToString(md5Digest[:]) 530 531 pw := c.md5password 532 auth := resp.Header.Get(FieldWWWAuthenticate) 533 if len(auth) > len(digestAuthPrefix) && strings.EqualFold(auth[:len(digestAuthPrefix)], digestAuthPrefix) { 534 ok := false 535 c.realm, c.nonce, ok = resp.DigestAuth() 536 if !ok { 537 return resp, fmt.Errorf("WWW-Authenticate, %s", auth) 538 } 539 540 r.SetDigestAuth(r.URL, c.realm, c.nonce, c.userName, pw) 541 } else if len(auth) > len(basicAuthPrefix) && strings.EqualFold(auth[:len(basicAuthPrefix)], basicAuthPrefix) { 542 ok := false 543 c.realm, ok = resp.BasicAuth() 544 if !ok { 545 return resp, fmt.Errorf("WWW-Authenticate, %s", auth) 546 } 547 r.SetBasicAuth(c.userName, pw) 548 } else { 549 return resp, fmt.Errorf("WWW-Authenticate, %s", auth) 550 } 551 552 // 修改请求序号 553 r.Header.Set(FieldCSeq, strconv.FormatInt(atomic.AddInt64(&c.seq, 1), 10)) 554 555 err := c.request(r) 556 if err != nil { 557 return nil, err 558 } 559 560 resp, err = c.receiveResponse() 561 if err != nil { 562 return nil, err 563 } 564 565 // 保存 session 566 c.rsession = resp.Header.Get(FieldSession) 567 } 568 } 569 570 if !(resp.StatusCode >= 200 && resp.StatusCode <= 300) { 571 return resp, errors.New(resp.Status) 572 } 573 574 return resp, nil 575 } 576 577 func (c *PullClient) request(req *Request) error { 578 c.lockW.Lock() 579 err := req.Write(c.conn) 580 if err == nil { 581 _, err = c.conn.Flush() 582 } 583 c.lockW.Unlock() 584 585 if err != nil { 586 c.logger.Errorf("send request error = %v", err) 587 return err 588 } 589 590 if c.logger.LevelEnabled(xlog.DebugLevel) { 591 c.logger.Debugf("===>>>\r\n%s", strings.TrimSpace(req.String())) 592 } 593 return err 594 } 595 596 func (c *PullClient) response(resp *Response) error { 597 c.lockW.Lock() 598 err := resp.Write(c.conn) 599 if err == nil { 600 _, err = c.conn.Flush() 601 } 602 c.lockW.Unlock() 603 604 if err != nil { 605 c.logger.Errorf("send response error = %v", err) 606 return err 607 } 608 609 if c.logger.LevelEnabled(xlog.DebugLevel) { 610 c.logger.Debugf("===>>>\r\n%s", strings.TrimSpace(resp.String())) 611 } 612 return nil 613 } 614 615 func (c *PullClient) connect() error { 616 // 连接超时要更短 617 timeout := time.Duration(int64(config.NetTimeout()) / 3) 618 conn, err := net.DialTimeout("tcp", c.url.Host, timeout) 619 if err != nil { 620 c.logger.Errorf("connet remote server fail,err = %v", err) 621 return err 622 } 623 624 c.closed = false // 已经连接 625 c.conn = buffered.NewConn(conn, 626 buffered.FlushRate(config.NetFlushRate()), 627 buffered.BufferSize(config.NetBufferSize())) 628 629 c.logger.Infof("connect remote server success") 630 return nil 631 } 632 633 func (c *PullClient) disconnect() { 634 if c.closed { 635 return 636 } 637 638 c.closed = true 639 640 c.logger.Info("disconnec from remote server") 641 if c.conn != nil { 642 c.conn.Close() 643 } 644 645 c.rsession = "" 646 atomic.StoreInt64(&c.seq, 0) 647 c.realm = "" 648 c.sdp = nil 649 c.aControl = "" 650 c.vControl = "" 651 c.aCodec = "" 652 c.vCodec = "" 653 }