github.com/cnotch/ipchub@v1.1.0/service/rtsp/pull_client.go (about)

     1  // Copyright (c) 2019,CAOHONGJU All rights reserved.
     2  // Use of this source code is governed by a MIT-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package rtsp
     6  
     7  import (
     8  	"crypto/md5"
     9  	"encoding/hex"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"net"
    14  	"net/url"
    15  	"runtime/debug"
    16  	"strconv"
    17  	"strings"
    18  	"sync"
    19  	"sync/atomic"
    20  	"time"
    21  
    22  	"github.com/cnotch/ipchub/config"
    23  	"github.com/cnotch/ipchub/media"
    24  	"github.com/cnotch/ipchub/network/socket/buffered"
    25  	"github.com/cnotch/ipchub/stats"
    26  	"github.com/cnotch/ipchub/utils"
    27  	"github.com/cnotch/xlog"
    28  	"github.com/pixelbender/go-sdp/sdp"
    29  )
    30  
    31  const (
    32  	defaultUserAgent = config.Name + "-rstp-client/1.0"
    33  )
    34  
    35  // PullClient 负责拉流到服务器
    36  type PullClient struct {
    37  	// 打开前设置
    38  	closed      bool
    39  	url         *url.URL
    40  	userName    string
    41  	password    string
    42  	md5password string
    43  	path        string
    44  	rtpChannels [rtpChannelCount]int
    45  	logger      *xlog.Logger
    46  
    47  	// 添加到流媒体中心后设置
    48  	stream *media.Stream
    49  
    50  	// 打开连接后设置
    51  	conn     *buffered.Conn
    52  	lockW    sync.Mutex
    53  	realm    string
    54  	nonce    string
    55  	rsession string
    56  	seq      int64
    57  
    58  	rawSdp   string
    59  	sdp      *sdp.Session
    60  	aControl string
    61  	vControl string
    62  	aCodec   string
    63  	vCodec   string
    64  }
    65  
    66  // NewPullClient 创建拉流客户端
    67  func NewPullClient(localPath, remoteURL string) (*PullClient, error) {
    68  	// 检查远端路径
    69  	url, err := url.Parse(remoteURL)
    70  	if err != nil {
    71  		return nil, err
    72  	}
    73  	if strings.ToLower(url.Scheme) != "rtsp" {
    74  		return nil, fmt.Errorf("RemoteURL '%s' is not RTSP url", remoteURL)
    75  	}
    76  	if strings.ToLower(url.Hostname()) == "" {
    77  		return nil, fmt.Errorf("RemoteURL '%s' is not RTSP url", remoteURL)
    78  	}
    79  	// 如果没有 port,补上默认端口
    80  	port := url.Port()
    81  	if len(port) == 0 {
    82  		url.Host = url.Hostname() + ":554"
    83  	}
    84  
    85  	// 提取用户名和密码
    86  	var userName, password string
    87  	if url.User != nil {
    88  		userName = url.User.Username()
    89  		password, _ = url.User.Password()
    90  		url.User = nil
    91  	}
    92  
    93  	// 检查发布路径
    94  	path := utils.CanonicalPath(localPath)
    95  
    96  	if path == "" {
    97  		path = utils.CanonicalPath(url.Path)
    98  	} else {
    99  		_, err := url.Parse("rtsp://localhost" + path)
   100  		if err != nil {
   101  			return nil, fmt.Errorf("Path '%s' 不合法", localPath)
   102  		}
   103  	}
   104  
   105  	client := &PullClient{
   106  		closed:   true,
   107  		url:      url,
   108  		userName: userName,
   109  		password: password,
   110  		path:     path,
   111  	}
   112  
   113  	for i := rtpChannelMin; i < rtpChannelCount; i++ {
   114  		client.rtpChannels[i] = int(i)
   115  	}
   116  
   117  	client.logger = xlog.L().With(xlog.Fields(
   118  		xlog.F("path", client.path),
   119  		xlog.F("rurl", client.url.String()),
   120  		xlog.F("type", "pull")))
   121  
   122  	return client, nil
   123  }
   124  
   125  // Ping 测试网络和服务器
   126  func (c *PullClient) Ping() error {
   127  	if !c.closed {
   128  		return nil
   129  	}
   130  
   131  	defer func() {
   132  		c.disconnect()
   133  		c.conn = nil
   134  		c.stream = nil
   135  	}()
   136  
   137  	err := c.connect()
   138  	if err != nil {
   139  		return err
   140  	}
   141  
   142  	// OPTIONS 尝试握手
   143  	err = c.requestHandshake()
   144  	if err != nil {
   145  		return err
   146  	}
   147  
   148  	// DESCRIBE 获取 sdp,看是否存在指定媒体
   149  	return c.requestSDP()
   150  }
   151  
   152  // Open 打开拉流客户端
   153  // 依次发生请求:OPTIONS、DESCRIBE、SETUP、PLAY
   154  // 全部成功,启动接收 RTP流 go routine
   155  func (c *PullClient) Open() (err error) {
   156  	if !c.closed {
   157  		return nil
   158  	}
   159  
   160  	defer func() {
   161  		if err != nil { // 出现任何错误执行断链操作
   162  			c.disconnect()
   163  			c.conn = nil
   164  			c.stream = nil
   165  		}
   166  	}()
   167  
   168  	// 连接
   169  	err = c.connect()
   170  	if err != nil {
   171  		return err
   172  	}
   173  
   174  	// 请求握手
   175  	err = c.requestHandshake()
   176  	if err != nil {
   177  		return err
   178  	}
   179  
   180  	// 获取流信息
   181  	err = c.requestSDP()
   182  	if err != nil {
   183  		return err
   184  	}
   185  
   186  	// 设置通讯通道
   187  	err = c.requestSetup()
   188  	if err != nil {
   189  		return err
   190  	}
   191  
   192  	// 请求播放
   193  	err = c.requestPlay()
   194  	if err != nil {
   195  		return err
   196  	}
   197  
   198  	return err
   199  }
   200  
   201  // Close 关闭客户端
   202  func (c *PullClient) Close() error {
   203  	c.disconnect()
   204  	return nil
   205  }
   206  
   207  func (c *PullClient) requestHandshake() (err error) {
   208  	// 使用 OPTIONS 尝试握手
   209  	r := c.newRequest(MethodOptions, c.url)
   210  	r.Header.Set(FieldRequire, "implicit-play")
   211  	_, err = c.requestWithResponse(r)
   212  	return err
   213  }
   214  
   215  func (c *PullClient) requestSDP() (err error) {
   216  	// DESCRIBE 获取 sdp
   217  	r := c.newRequest(MethodDescribe, c.url)
   218  	r.Header.Set(FieldAccept, "application/sdp")
   219  	resp, err := c.requestWithResponse(r)
   220  	if err != nil {
   221  		return err
   222  	}
   223  
   224  	// 解析
   225  	c.rawSdp = resp.Body
   226  	c.sdp, err = sdp.ParseString(c.rawSdp)
   227  	if err != nil {
   228  		return err
   229  	}
   230  
   231  	for _, media := range c.sdp.Media {
   232  		switch media.Type {
   233  		case "video":
   234  			c.vControl = media.Attributes.Get("control")
   235  			c.vCodec = media.Format[0].Name
   236  
   237  		case "audio":
   238  			c.aControl = media.Attributes.Get("control")
   239  			c.aCodec = media.Format[0].Name
   240  		}
   241  	}
   242  	return err
   243  }
   244  
   245  func (c *PullClient) requestSetup() (err error) {
   246  	var respVS, respAS *Response
   247  	// 视频通道设置
   248  	if len(c.vControl) > 0 {
   249  		var setupURL *url.URL
   250  		setupURL, err = c.getSetupURL(c.vControl)
   251  
   252  		r := c.newRequest(MethodSetup, setupURL)
   253  		r.Header.Set(FieldTransport,
   254  			fmt.Sprintf("RTP/AVP/TCP;unicast;interleaved=%d-%d", c.rtpChannels[ChannelVideo], c.rtpChannels[ChannelVideoControl]))
   255  		respVS, err = c.requestWithResponse(r)
   256  		if err != nil {
   257  			return err
   258  		}
   259  	}
   260  
   261  	// 音频通道设置
   262  	if len(c.aControl) > 0 {
   263  		var setupURL *url.URL
   264  		setupURL, err = c.getSetupURL(c.aControl)
   265  
   266  		r := c.newRequest(MethodSetup, setupURL)
   267  		r.Header.Set(FieldTransport,
   268  			fmt.Sprintf("RTP/AVP/TCP;unicast;interleaved=%d-%d", c.rtpChannels[ChannelAudio], c.rtpChannels[ChannelAudioControl]))
   269  
   270  		respAS, err = c.requestWithResponse(r)
   271  		if err != nil {
   272  			return err
   273  		}
   274  	}
   275  	_ = respVS
   276  	_ = respAS
   277  	return
   278  }
   279  
   280  func (c *PullClient) requestPlay() (err error) {
   281  	r := c.newRequest(MethodPlay, c.url)
   282  
   283  	resp, err := c.requestWithResponse(r)
   284  	if err != nil {
   285  		return err
   286  	}
   287  	_ = resp
   288  	mproxy := &multicastProxy{
   289  		path:        c.path,
   290  		bufferSize:  config.NetBufferSize(),
   291  		multicastIP: utils.Multicast.NextIP(), // 设置组播IP
   292  		ttl:         config.MulticastTTL(),
   293  		logger:      c.logger,
   294  	}
   295  
   296  	for i := rtpChannelMin; i < rtpChannelCount; i++ {
   297  		mproxy.ports[i] = utils.Multicast.NextPort()
   298  	}
   299  
   300  	c.stream = media.NewStream(c.path, c.rawSdp,
   301  		media.Attr("addr", c.url.String()),
   302  		media.Multicast(mproxy))
   303  	go c.playStream()
   304  
   305  	return nil
   306  }
   307  
   308  func (c *PullClient) playStream() {
   309  	defer func() {
   310  		if r := recover(); r != nil {
   311  			c.logger.Errorf("pull stream panic; %v \n %s", r, debug.Stack())
   312  		}
   313  
   314  		stats.RtspConns.Release() // 减少RTSP连接计数
   315  		media.Unregist(c.stream)  // 从媒体中心取消注册
   316  		c.disconnect()            // 确保网络关闭
   317  		c.conn = nil              // 通知GC,尽早释放资源
   318  		c.stream = nil
   319  		c.logger.Infof("close pull stream")
   320  	}()
   321  
   322  	c.logger.Infof("open pull stream")
   323  	media.Regist(c.stream) // 向媒体中心注册流
   324  	stats.RtspConns.Add()  // 增加一个 RTSP 连接计数
   325  
   326  	lastHeartbeat := time.Now()
   327  	reader := c.conn.Reader()
   328  	heartbeatInterval := config.NetHeartbeatInterval()
   329  	timeout := config.NetTimeout()
   330  
   331  	for !c.closed {
   332  		deadLine := time.Time{}
   333  		if timeout > 0 {
   334  			deadLine = time.Now().Add(timeout)
   335  		}
   336  		if err := c.conn.SetReadDeadline(deadLine); err != nil {
   337  			c.logger.Error(err.Error())
   338  			break
   339  		}
   340  
   341  		err := receive(c.logger, reader, c.rtpChannels[:], c)
   342  		if err != nil {
   343  			if err == io.EOF { // 如果对方断开
   344  				c.logger.Warn("The remote RTSP server is actively disconnected.")
   345  			} else if !c.closed { // 如果非主动关闭
   346  				c.logger.Error(err.Error())
   347  			}
   348  			break
   349  		}
   350  
   351  		if heartbeatInterval > 0 && time.Now().Sub(lastHeartbeat) > heartbeatInterval {
   352  			lastHeartbeat = time.Now()
   353  			// 心跳包
   354  			r := c.newRequest(MethodOptions, c.url)
   355  			err := c.request(r)
   356  			if err != nil {
   357  				c.logger.Error(err.Error())
   358  				break
   359  			}
   360  		}
   361  	}
   362  	reader = nil
   363  }
   364  
   365  func (c *PullClient) onPack(p *RTPPack) error {
   366  	return c.stream.WriteRtpPacket(p)
   367  }
   368  
   369  func (c *PullClient) onRequest(r *Request) (err error) {
   370  	// 只处理 Options 方法
   371  	switch r.Method {
   372  	case MethodOptions:
   373  		resp := &Response{
   374  			StatusCode: 200,
   375  			Header:     r.Header,
   376  		}
   377  		resp.Header.Del(FieldUserAgent)
   378  		resp.Header.Set(FieldPublic, MethodOptions)
   379  		err = c.response(resp)
   380  		if err != nil {
   381  			return err
   382  		}
   383  	default:
   384  		resp := &Response{
   385  			StatusCode: StatusMethodNotAllowed,
   386  			Header:     r.Header,
   387  		}
   388  		resp.Header.Del(FieldUserAgent)
   389  		err = c.response(resp)
   390  		if err != nil {
   391  			return err
   392  		}
   393  	}
   394  	return nil
   395  }
   396  
   397  func (c *PullClient) onResponse(resp *Response) (err error) {
   398  	// 忽略
   399  	return
   400  }
   401  
   402  func (c *PullClient) getSetupURL(ctrl string) (setupURL *url.URL, err error) {
   403  	if len(ctrl) >= len(rtspURLPrefix) && strings.EqualFold(ctrl[:len(rtspURLPrefix)], rtspURLPrefix) {
   404  		return url.Parse(ctrl)
   405  	}
   406  
   407  	setupURL = new(url.URL)
   408  	*setupURL = *c.url
   409  	if setupURL.Path[len(setupURL.Path)-1] == '/' {
   410  		setupURL.Path = setupURL.Path + ctrl
   411  	} else {
   412  		setupURL.Path = setupURL.Path + "/" + ctrl
   413  	}
   414  
   415  	return
   416  }
   417  
   418  func (c *PullClient) newRequest(method string, url *url.URL) *Request {
   419  	r := &Request{
   420  		Method: method,
   421  		Header: make(Header),
   422  	}
   423  
   424  	r.URL = url
   425  	if url == nil {
   426  		r.URL = c.url
   427  	}
   428  
   429  	r.Header.Set(FieldUserAgent, defaultUserAgent)
   430  	r.Header.Set(FieldCSeq, strconv.FormatInt(atomic.AddInt64(&c.seq, 1), 10))
   431  	if len(c.rsession) > 0 {
   432  		r.Header.Set(FieldSession, c.rsession)
   433  	}
   434  
   435  	// 和安全相关,已经收到安全作用域信息
   436  	if len(c.realm) > 0 {
   437  		pw := c.password
   438  		if len(c.md5password) > 0 {
   439  			pw = c.md5password
   440  		}
   441  
   442  		if len(c.nonce) > 0 {
   443  			// Digest 认证
   444  			r.SetDigestAuth(r.URL, c.realm, c.nonce, c.userName, pw)
   445  		} else {
   446  			// Basic 认证
   447  			r.SetBasicAuth(c.userName, pw)
   448  		}
   449  	}
   450  
   451  	return r
   452  }
   453  
   454  func (c *PullClient) receiveResponse() (resp *Response, err error) {
   455  	resp, err = ReadResponse(c.conn.Reader())
   456  	if err != nil {
   457  		return nil, err
   458  	}
   459  
   460  	if c.logger.LevelEnabled(xlog.DebugLevel) {
   461  		c.logger.Debugf("<<<===\r\n%s", strings.TrimSpace(resp.String()))
   462  	}
   463  
   464  	return
   465  }
   466  
   467  func (c *PullClient) requestWithResponse(r *Request) (*Response, error) {
   468  	err := c.request(r)
   469  	if err != nil {
   470  		return nil, err
   471  	}
   472  
   473  	resp, err := c.receiveResponse()
   474  	if err != nil {
   475  		return nil, err
   476  	}
   477  
   478  	// 保存 session
   479  	c.rsession = resp.Header.Get(FieldSession)
   480  
   481  	// 如果需要安全信息,增加安全信息并再次请求
   482  	if resp.StatusCode == StatusUnauthorized {
   483  
   484  		if len(c.userName) == 0 {
   485  			return resp, errors.New("require username and password")
   486  		}
   487  
   488  		pw := c.password
   489  		auth := resp.Header.Get(FieldWWWAuthenticate)
   490  		if len(auth) > len(digestAuthPrefix) && strings.EqualFold(auth[:len(digestAuthPrefix)], digestAuthPrefix) {
   491  			ok := false
   492  			c.realm, c.nonce, ok = resp.DigestAuth()
   493  			if !ok {
   494  				return resp, fmt.Errorf("WWW-Authenticate, %s", auth)
   495  			}
   496  
   497  			r.SetDigestAuth(r.URL, c.realm, c.nonce, c.userName, pw)
   498  		} else if len(auth) > len(basicAuthPrefix) && strings.EqualFold(auth[:len(basicAuthPrefix)], basicAuthPrefix) {
   499  			ok := false
   500  			c.realm, ok = resp.BasicAuth()
   501  			if !ok {
   502  				return resp, fmt.Errorf("WWW-Authenticate, %s", auth)
   503  			}
   504  			r.SetBasicAuth(c.userName, pw)
   505  		} else {
   506  			return resp, fmt.Errorf("WWW-Authenticate, %s", auth)
   507  		}
   508  
   509  		// 修改请求序号
   510  		r.Header.Set(FieldCSeq, strconv.FormatInt(atomic.AddInt64(&c.seq, 1), 10))
   511  
   512  		err := c.request(r)
   513  		if err != nil {
   514  			return nil, err
   515  		}
   516  
   517  		resp, err = c.receiveResponse()
   518  		if err != nil {
   519  			return nil, err
   520  		}
   521  
   522  		// 保存 session
   523  		c.rsession = resp.Header.Get(FieldSession)
   524  
   525  		// TODO: 代码臃肿,需要优化
   526  		// 再试一次 password md5的情况
   527  		if resp.StatusCode == StatusUnauthorized {
   528  			md5Digest := md5.Sum([]byte(c.password))
   529  			c.md5password = hex.EncodeToString(md5Digest[:])
   530  
   531  			pw := c.md5password
   532  			auth := resp.Header.Get(FieldWWWAuthenticate)
   533  			if len(auth) > len(digestAuthPrefix) && strings.EqualFold(auth[:len(digestAuthPrefix)], digestAuthPrefix) {
   534  				ok := false
   535  				c.realm, c.nonce, ok = resp.DigestAuth()
   536  				if !ok {
   537  					return resp, fmt.Errorf("WWW-Authenticate, %s", auth)
   538  				}
   539  
   540  				r.SetDigestAuth(r.URL, c.realm, c.nonce, c.userName, pw)
   541  			} else if len(auth) > len(basicAuthPrefix) && strings.EqualFold(auth[:len(basicAuthPrefix)], basicAuthPrefix) {
   542  				ok := false
   543  				c.realm, ok = resp.BasicAuth()
   544  				if !ok {
   545  					return resp, fmt.Errorf("WWW-Authenticate, %s", auth)
   546  				}
   547  				r.SetBasicAuth(c.userName, pw)
   548  			} else {
   549  				return resp, fmt.Errorf("WWW-Authenticate, %s", auth)
   550  			}
   551  
   552  			// 修改请求序号
   553  			r.Header.Set(FieldCSeq, strconv.FormatInt(atomic.AddInt64(&c.seq, 1), 10))
   554  
   555  			err := c.request(r)
   556  			if err != nil {
   557  				return nil, err
   558  			}
   559  
   560  			resp, err = c.receiveResponse()
   561  			if err != nil {
   562  				return nil, err
   563  			}
   564  
   565  			// 保存 session
   566  			c.rsession = resp.Header.Get(FieldSession)
   567  		}
   568  	}
   569  
   570  	if !(resp.StatusCode >= 200 && resp.StatusCode <= 300) {
   571  		return resp, errors.New(resp.Status)
   572  	}
   573  
   574  	return resp, nil
   575  }
   576  
   577  func (c *PullClient) request(req *Request) error {
   578  	c.lockW.Lock()
   579  	err := req.Write(c.conn)
   580  	if err == nil {
   581  		_, err = c.conn.Flush()
   582  	}
   583  	c.lockW.Unlock()
   584  
   585  	if err != nil {
   586  		c.logger.Errorf("send request error = %v", err)
   587  		return err
   588  	}
   589  
   590  	if c.logger.LevelEnabled(xlog.DebugLevel) {
   591  		c.logger.Debugf("===>>>\r\n%s", strings.TrimSpace(req.String()))
   592  	}
   593  	return err
   594  }
   595  
   596  func (c *PullClient) response(resp *Response) error {
   597  	c.lockW.Lock()
   598  	err := resp.Write(c.conn)
   599  	if err == nil {
   600  		_, err = c.conn.Flush()
   601  	}
   602  	c.lockW.Unlock()
   603  
   604  	if err != nil {
   605  		c.logger.Errorf("send response error = %v", err)
   606  		return err
   607  	}
   608  
   609  	if c.logger.LevelEnabled(xlog.DebugLevel) {
   610  		c.logger.Debugf("===>>>\r\n%s", strings.TrimSpace(resp.String()))
   611  	}
   612  	return nil
   613  }
   614  
   615  func (c *PullClient) connect() error {
   616  	// 连接超时要更短
   617  	timeout := time.Duration(int64(config.NetTimeout()) / 3)
   618  	conn, err := net.DialTimeout("tcp", c.url.Host, timeout)
   619  	if err != nil {
   620  		c.logger.Errorf("connet remote server fail,err = %v", err)
   621  		return err
   622  	}
   623  
   624  	c.closed = false // 已经连接
   625  	c.conn = buffered.NewConn(conn,
   626  		buffered.FlushRate(config.NetFlushRate()),
   627  		buffered.BufferSize(config.NetBufferSize()))
   628  
   629  	c.logger.Infof("connect remote server success")
   630  	return nil
   631  }
   632  
   633  func (c *PullClient) disconnect() {
   634  	if c.closed {
   635  		return
   636  	}
   637  
   638  	c.closed = true
   639  
   640  	c.logger.Info("disconnec from remote server")
   641  	if c.conn != nil {
   642  		c.conn.Close()
   643  	}
   644  
   645  	c.rsession = ""
   646  	atomic.StoreInt64(&c.seq, 0)
   647  	c.realm = ""
   648  	c.sdp = nil
   649  	c.aControl = ""
   650  	c.vControl = ""
   651  	c.aCodec = ""
   652  	c.vCodec = ""
   653  }