github.com/cnotch/ipchub@v1.1.0/service/rtsp/session.go (about)

     1  // Copyright (c) 2019,CAOHONGJU All rights reserved.
     2  // Use of this source code is governed by a MIT-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package rtsp
     6  
     7  import (
     8  	"bytes"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"net"
    13  	"net/url"
    14  	"runtime/debug"
    15  	"strings"
    16  	"sync"
    17  	"time"
    18  
    19  	"github.com/cnotch/ipchub/config"
    20  	"github.com/cnotch/ipchub/media"
    21  	"github.com/cnotch/ipchub/network/socket/buffered"
    22  	"github.com/cnotch/ipchub/network/websocket"
    23  	"github.com/cnotch/ipchub/provider/auth"
    24  	"github.com/cnotch/ipchub/provider/security"
    25  	"github.com/cnotch/ipchub/stats"
    26  	"github.com/cnotch/ipchub/utils"
    27  	"github.com/cnotch/xlog"
    28  	"github.com/pixelbender/go-sdp/sdp"
    29  )
    30  
    31  const (
    32  	realm = config.Name
    33  )
    34  
    35  const (
    36  	statusInit = iota
    37  	statusReady
    38  	statusPlaying
    39  	statusRecording
    40  )
    41  
    42  var buffers = sync.Pool{
    43  	New: func() interface{} {
    44  		return bytes.NewBuffer(make([]byte, 0, 1024*2))
    45  	},
    46  }
    47  
    48  // Session RTSP 会话
    49  type Session struct {
    50  	// 创建时设置
    51  	svr      *Server
    52  	logger   *xlog.Logger
    53  	closed   bool
    54  	lsession string // 本地会话标识
    55  	timeout  time.Duration
    56  	conn     *buffered.Conn
    57  	lockW    sync.Mutex
    58  
    59  	wsconn websocket.Conn
    60  
    61  	authMode auth.Mode
    62  	nonce    string
    63  	user     *auth.User
    64  
    65  	// DESCRIBE,或 ANNOUNCE 后设置
    66  	url      *url.URL
    67  	path     string
    68  	rawSdp   string
    69  	sdp      *sdp.Session
    70  	aControl string
    71  	vControl string
    72  	aCodec   string
    73  	vCodec   string
    74  	mode     SessionMode
    75  
    76  	// Setup 后设置
    77  	transport RTPTransport
    78  
    79  	// 启动流媒体传输后设置
    80  	status   int            // session状态
    81  	stream   mediaStream    // 媒体流
    82  	consumer media.Consumer // 消费者
    83  }
    84  
    85  func newSession(svr *Server, conn net.Conn) *Session {
    86  
    87  	session := &Session{
    88  		svr:      svr,
    89  		lsession: security.NewID().Base64(),
    90  		timeout:  config.NetTimeout(),
    91  		conn: buffered.NewConn(conn,
    92  			buffered.FlushRate(config.NetFlushRate()),
    93  			buffered.BufferSize(config.NetBufferSize())),
    94  		mode: UnknownSession,
    95  		transport: RTPTransport{
    96  			Mode: PlaySession, // 默认为播放
    97  			Type: RTPUnknownTrans,
    98  		},
    99  		authMode: config.RtspAuthMode(),
   100  		nonce:    security.NewID().MD5(),
   101  		status:   statusInit,
   102  		stream:   defaultStream,
   103  		consumer: defaultConsumer,
   104  	}
   105  
   106  	if wsc, ok := conn.(websocket.Conn); ok { // 如果是WebSocket,有http进行验证
   107  		session.authMode = auth.NoneAuth
   108  		session.wsconn = wsc
   109  		session.path = wsc.Path()
   110  		session.user = auth.Get(wsc.Username())
   111  	}
   112  
   113  	// ipaddr, _ := address.Parse(conn.RemoteAddr().String(), 80)
   114  	// // 如果是本机IP,不验证;以便ffmpeg本机rtsp->rtmp
   115  	// if network.IsLocalhostIP(ipaddr.IP) {
   116  	// 	session.authMode = auth.NoneAuth
   117  	// }
   118  
   119  	for i := rtpChannelMin; i < rtpChannelCount; i++ {
   120  		session.transport.Channels[i] = -1
   121  		session.transport.ClientPorts[i] = -1
   122  	}
   123  	session.logger = svr.logger.With(xlog.Fields(
   124  		xlog.F("session", session.lsession)))
   125  
   126  	return session
   127  }
   128  
   129  // Addr Session地址
   130  func (s *Session) Addr() string {
   131  	return s.conn.RemoteAddr().String()
   132  }
   133  
   134  // Consume 消费媒体包
   135  func (s *Session) Consume(p Pack) {
   136  	s.consumer.Consume(p)
   137  }
   138  
   139  // Close 关闭会话
   140  func (s *Session) Close() error {
   141  	if s.closed {
   142  		return nil
   143  	}
   144  
   145  	s.closed = true
   146  	s.conn.Close()
   147  	return nil
   148  }
   149  
   150  func (s *Session) process() {
   151  	defer func() {
   152  		if r := recover(); r != nil {
   153  			s.logger.Errorf("session panic; %v \n %s", r, debug.Stack())
   154  		}
   155  
   156  		stats.RtspConns.Release()
   157  		s.Close()
   158  		s.consumer.Close()
   159  		s.stream.Close()
   160  
   161  		// 重置到初始状态
   162  		s.conn = nil
   163  		s.status = statusInit
   164  		s.stream = defaultStream
   165  		s.consumer = defaultConsumer
   166  		s.logger.Infof("close rtsp session")
   167  	}()
   168  
   169  	s.logger.Infof("open rtsp session")
   170  	stats.RtspConns.Add() // 增加一个 RTSP 连接计数
   171  	reader := s.conn.Reader()
   172  
   173  	for !s.closed {
   174  		deadLine := time.Time{}
   175  		if s.timeout > 0 {
   176  			deadLine = time.Now().Add(s.timeout)
   177  		}
   178  		if err := s.conn.SetReadDeadline(deadLine); err != nil {
   179  			s.logger.Error(err.Error())
   180  			break
   181  		}
   182  
   183  		err := receive(s.logger, reader, s.transport.Channels[:], s)
   184  		if err != nil {
   185  			if err == io.EOF { // 如果客户端断开提醒
   186  				s.logger.Warn("The client actively disconnects")
   187  			} else if !s.closed { // 如果主动关闭,不提示
   188  				s.logger.Error(err.Error())
   189  			}
   190  			break
   191  		}
   192  	}
   193  }
   194  
   195  // receiveHandler.onPack
   196  func (s *Session) onPack(pack *RTPPack) (err error) {
   197  	return s.stream.WritePacket(pack)
   198  }
   199  
   200  // receiveHandler.onResponse
   201  func (s *Session) onResponse(resp *Response) (err error) {
   202  	// 忽略,服务器不会主动发起请求
   203  	return
   204  }
   205  
   206  // receiveHandler.onRequest
   207  func (s *Session) onRequest(req *Request) (err error) {
   208  	resp := s.newResponse(StatusOK, req)
   209  	// 预处理
   210  	continueProcess, err := s.onPreprocess(resp, req)
   211  	if !continueProcess {
   212  		return err
   213  	}
   214  
   215  	switch req.Method {
   216  	case MethodDescribe:
   217  		s.onDescribe(resp, req)
   218  	case MethodAnnounce:
   219  		s.onAnnounce(resp, req)
   220  	case MethodSetup:
   221  		s.onSetup(resp, req)
   222  	case MethodRecord:
   223  		s.onRecord(resp, req)
   224  	case MethodPlay:
   225  		return s.onPlay(resp, req) // play 发送流媒体不在当前 routine,需要先回复
   226  	default:
   227  		// 状态不支持的方法
   228  		resp.StatusCode = StatusMethodNotValidInThisState
   229  	}
   230  
   231  	// 发送响应
   232  	err = s.response(resp)
   233  	return err
   234  }
   235  
   236  func (s *Session) onDescribe(resp *Response, req *Request) {
   237  
   238  	// TODO: 检查 accept 中的类型是否包含 sdp
   239  	s.url = req.URL
   240  	if s.wsconn == nil { // websocket访问的路径有ws://路径表示
   241  		s.path = utils.CanonicalPath(req.URL.Path)
   242  	}
   243  
   244  	stream := media.GetOrCreate(s.path)
   245  	if stream == nil {
   246  		resp.StatusCode = StatusNotFound
   247  		return
   248  	}
   249  
   250  	if !s.checkPermission(auth.PullRight) {
   251  		resp.StatusCode = StatusForbidden
   252  		return
   253  	}
   254  
   255  	// 从流中取 sdp
   256  	sdpRaw := stream.Sdp()
   257  	if len(sdpRaw) == 0 {
   258  		resp.StatusCode = StatusNotFound
   259  		return
   260  	}
   261  	err := s.parseSdp(sdpRaw)
   262  	if err != nil { // TODO:需要更好的处理方式
   263  		resp.StatusCode = StatusNotFound
   264  		return
   265  	}
   266  
   267  	resp.Header.Set(FieldContentType, "application/sdp")
   268  	resp.Body = s.rawSdp
   269  	s.mode = PlaySession // 标记为播放会话
   270  }
   271  
   272  func (s *Session) onAnnounce(resp *Response, req *Request) {
   273  
   274  	// 检查 Content-Type: application/sdp
   275  	if req.Header.Get(FieldContentType) != "application/sdp" {
   276  		resp.StatusCode = StatusBadRequest // TODO:更合适的代码
   277  		return
   278  	}
   279  
   280  	s.url = req.URL
   281  	s.path = utils.CanonicalPath(req.URL.Path)
   282  
   283  	if !s.checkPermission(auth.PushRight) {
   284  		resp.StatusCode = StatusForbidden
   285  		return
   286  	}
   287  
   288  	// 从流中取 sdp
   289  	err := s.parseSdp(req.Body)
   290  	if err != nil {
   291  		resp.StatusCode = StatusBadRequest
   292  		return
   293  	}
   294  
   295  	s.mode = RecordSession // 标记为录像会话
   296  }
   297  
   298  func (s *Session) onSetup(resp *Response, req *Request) {
   299  	// a=control:streamid=1
   300  	// a=control:rtsp://192.168.1.165/trackID=1
   301  	// a=control:?ctype=video
   302  	setupURL := &url.URL{}
   303  	*setupURL = *req.URL
   304  	if setupURL.Port() == "" {
   305  		setupURL.Host = fmt.Sprintf("%s:554", setupURL.Host)
   306  	}
   307  	setupPath := setupURL.String()
   308  
   309  	//setupPath = setupPath[strings.LastIndex(setupPath, "/")+1:]
   310  	vPath, err := getControlPath(s.vControl)
   311  	if err != nil {
   312  		resp.StatusCode = StatusInternalServerError
   313  		resp.Status = "Invalid VControl"
   314  		return
   315  	}
   316  
   317  	aPath, err := getControlPath(s.aControl)
   318  	if err != nil {
   319  		resp.StatusCode = StatusInternalServerError
   320  		resp.Status = "Invalid AControl"
   321  		return
   322  	}
   323  
   324  	ts := req.Header.Get(FieldTransport)
   325  	resp.Header.Set(FieldTransport, ts) // 先回写transport
   326  
   327  	// 检查控制路径
   328  	chindex := -1
   329  	if setupPath == aPath || (aPath != "" && strings.LastIndex(setupPath, aPath) == len(setupPath)-len(aPath)) {
   330  		chindex = int(ChannelAudio)
   331  	} else if setupPath == vPath || (vPath != "" && strings.LastIndex(setupPath, vPath) == len(setupPath)-len(vPath)) {
   332  		chindex = int(ChannelVideo)
   333  	} else { // 找不到被 Setup 的资源
   334  		resp.StatusCode = StatusInternalServerError
   335  		resp.Status = fmt.Sprintf("SETUP Unkown control:%s", setupPath)
   336  		return
   337  	}
   338  
   339  	err = s.transport.ParseTransport(chindex, ts)
   340  	if err != nil {
   341  		resp.StatusCode = StatusInvalidParameter
   342  		resp.Status = err.Error()
   343  		return
   344  	}
   345  
   346  	// 检查和以前的命令是否一致
   347  	if s.mode == UnknownSession {
   348  		s.mode = s.transport.Mode
   349  	}
   350  
   351  	if s.mode != s.transport.Mode {
   352  		resp.StatusCode = StatusInvalidParameter
   353  		if s.mode == PlaySession {
   354  			resp.Status = "Current state can't setup as record"
   355  		} else {
   356  			resp.Status = "Current state can't setup as play"
   357  		}
   358  		return
   359  	}
   360  
   361  	// record 只支持 TCP 单播
   362  	if s.mode == RecordSession {
   363  		// 检查用户权限
   364  		if !s.checkPermission(auth.PushRight) {
   365  			resp.StatusCode = StatusForbidden
   366  			return
   367  		}
   368  
   369  		if s.transport.Type != RTPTCPUnicast {
   370  			resp.StatusCode = StatusUnsupportedTransport
   371  			resp.Status = "when mode = record,only support tcp unicast"
   372  		} else {
   373  			if s.status < statusReady { // 初始状态切换到Ready
   374  				s.status = statusReady
   375  			}
   376  		}
   377  		return
   378  	}
   379  
   380  	// 检查用户权限,播放
   381  	if !s.checkPermission(auth.PullRight) {
   382  		resp.StatusCode = StatusForbidden
   383  		return
   384  	}
   385  
   386  	if s.transport.Type == RTPMulticast { // 需要修改回复的transport
   387  		st := media.GetOrCreate(s.path)
   388  		if st == nil { // 没有找到源
   389  			resp.StatusCode = StatusNotFound
   390  			return
   391  		}
   392  		ma := st.Multicastable()
   393  		if ma == nil { // 不支持组播
   394  			resp.StatusCode = StatusUnsupportedTransport
   395  			return
   396  		}
   397  
   398  		ts = fmt.Sprintf("%s;destination=%s;port=%d-%d;source=%s;ttl=%d",
   399  			ts, ma.MulticastIP(),
   400  			ma.Port(chindex), ma.Port(chindex+1),
   401  			ma.SourceIP(), ma.TTL())
   402  		resp.Header.Set(FieldTransport, ts)
   403  	}
   404  
   405  	if s.status < statusReady { // 初始状态切换到Ready
   406  		s.status = statusReady
   407  	}
   408  }
   409  
   410  func (s *Session) onRecord(resp *Response, req *Request) {
   411  	if s.status == statusRecording {
   412  		return
   413  	}
   414  
   415  	// 传输模式、会话模式判断
   416  	if s.mode != RecordSession || s.transport.Type != RTPTCPUnicast {
   417  		resp.StatusCode = StatusMethodNotValidInThisState
   418  		return
   419  	}
   420  
   421  	if !s.checkPermission(auth.PushRight) {
   422  		resp.StatusCode = StatusForbidden
   423  		return
   424  	}
   425  
   426  	s.asTCPPusher()
   427  	s.status = statusRecording
   428  }
   429  
   430  func (s *Session) onPlay(resp *Response, req *Request) (err error) {
   431  	if s.status == statusPlaying {
   432  		return
   433  	}
   434  
   435  	// 传输模式、会话模式判断
   436  	if s.mode != PlaySession || s.transport.Type == RTPUnknownTrans {
   437  		resp.StatusCode = StatusMethodNotValidInThisState
   438  		return s.response(resp)
   439  	}
   440  
   441  	stream := media.GetOrCreate(s.path)
   442  	if stream == nil {
   443  		resp.StatusCode = StatusNotFound
   444  		return s.response(resp)
   445  	}
   446  
   447  	if !s.checkPermission(auth.PullRight) {
   448  		resp.StatusCode = StatusForbidden
   449  		return s.response(resp)
   450  	}
   451  
   452  	resp.Header.Set(FieldRange, req.Header.Get(FieldRange))
   453  	switch s.transport.Type {
   454  	case RTPTCPUnicast:
   455  		err = s.asTCPConsumer(stream, resp)
   456  	case RTPUDPUnicast:
   457  		err = s.asUDPConsumer(stream, resp)
   458  	default:
   459  		err = s.asMulticastConsumer(stream, resp)
   460  	}
   461  
   462  	if err == nil {
   463  		s.status = statusPlaying
   464  	}
   465  	return
   466  }
   467  
   468  func (s *Session) checkPermission(right auth.AccessRight) bool {
   469  	if s.authMode == auth.NoneAuth {
   470  		return true
   471  	}
   472  
   473  	if s.user == nil {
   474  		return false
   475  	}
   476  
   477  	return s.user.ValidatePermission(s.path, right)
   478  }
   479  
   480  func (s *Session) checkAuth(r *Request) (user *auth.User, err error) {
   481  	switch s.authMode {
   482  	case auth.BasicAuth:
   483  		username, password, has := r.BasicAuth()
   484  		if !has {
   485  			return nil, errors.New("require legal Authorization field")
   486  		}
   487  		user := auth.Get(username)
   488  		if user == nil {
   489  			return nil, errors.New("user not exist")
   490  		}
   491  		err = user.ValidatePassword(password)
   492  		if err != nil {
   493  			return nil, err
   494  		}
   495  		return user, nil
   496  
   497  	case auth.DigestAuth:
   498  		username, response, has := r.DigestAuth()
   499  		if !has {
   500  			return nil, errors.New("require legal Authorization field")
   501  		}
   502  		user := auth.Get(username)
   503  		if user == nil {
   504  			return nil, errors.New("user not exist")
   505  		}
   506  		resp2 := formatDigestAuthResponse(realm, s.nonce, r.Method, r.URL.String(), username, user.Password)
   507  		if resp2 == response {
   508  			return user, nil
   509  		}
   510  		resp2 = formatDigestAuthResponse(realm, s.nonce, r.Method, r.URL.String(), username, user.PasswordMD5())
   511  		if resp2 == response {
   512  			return user, nil
   513  		}
   514  		s.nonce = security.NewID().MD5()
   515  		return nil, errors.New("require legal Authorization field")
   516  	default: // 无需验证
   517  		return nil, nil
   518  	}
   519  }
   520  
   521  func (s *Session) onPreprocess(resp *Response, req *Request) (continueProcess bool, err error) {
   522  	// Options 方法无需验证,直接回复
   523  	if req.Method == MethodOptions {
   524  		resp.Header.Set(FieldPublic, "DESCRIBE, SETUP, TEARDOWN, PLAY, OPTIONS, ANNOUNCE, RECORD")
   525  		err = s.response(resp)
   526  		return false, err
   527  	}
   528  
   529  	// 关闭请求
   530  	if req.Method == MethodTeardown {
   531  		// 发送响应
   532  		err = s.response(resp)
   533  		s.Close()
   534  		return false, err
   535  	}
   536  
   537  	// 检查状态下的方法
   538  	switch s.status {
   539  	case statusReady:
   540  		continueProcess = req.Method == MethodSetup ||
   541  			req.Method == MethodPlay || req.Method == MethodRecord
   542  	case statusPlaying:
   543  		continueProcess = req.Method == MethodPlay
   544  	case statusRecording:
   545  		continueProcess = req.Method == MethodRecord
   546  	default:
   547  		continueProcess = !(req.Method == MethodPlay || req.Method == MethodRecord)
   548  	}
   549  	if !continueProcess {
   550  		resp.StatusCode = StatusMethodNotValidInThisState
   551  		err = s.response(resp)
   552  		return false, err
   553  	}
   554  
   555  	// 检查认证
   556  	user, err2 := s.checkAuth(req)
   557  	if err2 != nil {
   558  		resp.StatusCode = StatusUnauthorized
   559  		if err2 != nil {
   560  			resp.Status = err2.Error()
   561  		}
   562  		err = s.response(resp)
   563  		return false, err
   564  	}
   565  
   566  	s.user = user
   567  	return true, nil
   568  }
   569  
   570  func (s *Session) response(resp *Response) error {
   571  	s.lockW.Lock()
   572  
   573  	var err error
   574  
   575  	if s.wsconn != nil { // websocket 客户端
   576  		buf := buffers.Get().(*bytes.Buffer)
   577  		buf.Reset()
   578  		defer buffers.Put(buf)
   579  
   580  		err = resp.Write(buf) // 保证写入包的完整性,简化前端分包
   581  		_, err = s.wsconn.Write(buf.Bytes())
   582  	} else {
   583  		err = resp.Write(s.conn)
   584  		if err == nil {
   585  			_, err = s.conn.Flush()
   586  		}
   587  	}
   588  
   589  	s.lockW.Unlock()
   590  
   591  	if err != nil {
   592  		s.logger.Errorf("send response error = %v", err)
   593  		return err
   594  	}
   595  
   596  	if s.logger.LevelEnabled(xlog.DebugLevel) {
   597  		s.logger.Debugf("===>>>\r\n%s", strings.TrimSpace(resp.String()))
   598  	}
   599  
   600  	return nil
   601  }
   602  
   603  func (s *Session) newResponse(code int, req *Request) *Response {
   604  	resp := &Response{
   605  		StatusCode: code,
   606  		Header:     make(Header),
   607  		Request:    req,
   608  	}
   609  
   610  	resp.Header.Set(FieldCSeq, req.Header.Get(FieldCSeq))
   611  	resp.Header.Set(FieldSession, s.lsession)
   612  
   613  	// 根据认证模式增加认证所需的字段
   614  	switch s.authMode {
   615  	case auth.BasicAuth:
   616  		resp.SetBasicAuth(realm)
   617  	case auth.DigestAuth:
   618  		resp.SetDigestAuth(realm, s.nonce)
   619  	}
   620  	return resp
   621  }
   622  
   623  func (s *Session) parseSdp(rawSdp string) (err error) {
   624  	// 从流中取 sdp
   625  	s.rawSdp = rawSdp
   626  	// 解析
   627  	s.sdp, err = sdp.ParseString(s.rawSdp)
   628  	if err != nil {
   629  		return
   630  	}
   631  
   632  	for _, media := range s.sdp.Media {
   633  		switch media.Type {
   634  		case "video":
   635  			s.vControl = media.Attributes.Get("control")
   636  			s.vCodec = media.Format[0].Name
   637  		case "audio":
   638  			s.aControl = media.Attributes.Get("control")
   639  			s.aCodec = media.Format[0].Name
   640  		}
   641  	}
   642  	return
   643  }
   644  
   645  func getControlPath(ctrl string) (path string, err error) {
   646  	if len(ctrl) >= len(rtspURLPrefix) && strings.EqualFold(ctrl[:len(rtspURLPrefix)], rtspURLPrefix) {
   647  		var ctrlURL *url.URL
   648  		ctrlURL, err = url.Parse(ctrl)
   649  		if err != nil {
   650  			return "", err
   651  		}
   652  		if ctrlURL.Port() == "" {
   653  			ctrlURL.Host = fmt.Sprintf("%s:554", ctrlURL.Hostname())
   654  		}
   655  		return ctrlURL.String(), nil
   656  	}
   657  	return ctrl, nil
   658  }