github.com/puellanivis/breton@v0.2.16/lib/mpeg/ts/demux.go (about)

     1  package ts
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"sync"
     7  
     8  	"github.com/pkg/errors"
     9  	"github.com/puellanivis/breton/lib/glog"
    10  	"github.com/puellanivis/breton/lib/io/bufpipe"
    11  	"github.com/puellanivis/breton/lib/mpeg/ts/packet"
    12  	"github.com/puellanivis/breton/lib/mpeg/ts/pes"
    13  	"github.com/puellanivis/breton/lib/mpeg/ts/psi"
    14  )
    15  
    16  var _ = glog.Info
    17  
    18  // Demux defines an MPEG Transport Stream, which will take a composed byte stream,
    19  // and decompose it into multiple elementary streams.
    20  type Demux struct {
    21  	TransportStream
    22  
    23  	src io.Reader
    24  
    25  	closed chan struct{}
    26  
    27  	mu       sync.Mutex
    28  	programs map[uint16]*bufpipe.Pipe
    29  
    30  	complete  chan struct{}
    31  	pending   map[uint16]*bufpipe.Pipe
    32  	pendingWG sync.WaitGroup
    33  }
    34  
    35  // NewDemux returns a new Demux that is using the given io.Reader as a byte stream.
    36  func NewDemux(rd io.Reader, opts ...Option) *Demux {
    37  	d := &Demux{
    38  		src: rd,
    39  
    40  		closed: make(chan struct{}),
    41  
    42  		programs: make(map[uint16]*bufpipe.Pipe),
    43  
    44  		complete: make(chan struct{}),
    45  		pending:  make(map[uint16]*bufpipe.Pipe),
    46  	}
    47  
    48  	for _, opt := range opts {
    49  		_ = opt(&d.TransportStream)
    50  	}
    51  
    52  	return d
    53  }
    54  
    55  const (
    56  	pidPAT  uint16 = 0
    57  	pidNULL uint16 = 0x1FFF
    58  )
    59  
    60  func (d *Demux) getPipe(ctx context.Context, pid uint16) (*bufpipe.Pipe, error) {
    61  	d.mu.Lock()
    62  	defer d.mu.Unlock()
    63  
    64  	if _, exists := d.programs[pid]; exists {
    65  		return nil, errors.Errorf("pid 0x%04X is already assigned", pid)
    66  	}
    67  
    68  	pipe := d.pending[pid]
    69  	if pipe == nil {
    70  		// We assign a context closer below, so don’t assign it here.
    71  		pipe = bufpipe.New(nil, bufpipe.WithNoAutoFlush())
    72  	}
    73  	delete(d.pending, pid)
    74  
    75  	// Here we set the context closer for both paths.
    76  	// This way, if a pending pipe were made in the Serve goroutine,
    77  	// we properly tie it to _this_ context, and not the Serve context.
    78  	pipe.CloseOnContext(ctx)
    79  
    80  	d.programs[pid] = pipe
    81  	return pipe, nil
    82  }
    83  
    84  // Reader locates the specified stream id, and returns an io.ReadCloser that corresponds to that Stream only.
    85  func (d *Demux) Reader(ctx context.Context, streamID uint16) (io.ReadCloser, error) {
    86  	if streamID == 0 {
    87  		return nil, errors.Errorf("stream_id 0x%04X is invalid", streamID)
    88  	}
    89  
    90  	select {
    91  	case <-d.closed:
    92  		return nil, errors.New("Demux is closed")
    93  	default:
    94  	}
    95  
    96  	s := &stream{
    97  		ready: make(chan struct{}),
    98  	}
    99  
   100  	d.pendingWG.Add(1)
   101  	go func() {
   102  		defer d.pendingWG.Done()
   103  		defer s.makeReady()
   104  
   105  		pat := d.GetPAT()
   106  
   107  		pmtPID, ok := pat[streamID]
   108  		if !ok {
   109  			s.err = errors.Errorf("no PMT found for stream_id 0x%04X", streamID)
   110  			return
   111  		}
   112  
   113  		pmtRD, err := d.ReaderByPID(ctx, pmtPID, false)
   114  		if err != nil {
   115  			s.err = err
   116  			return
   117  		}
   118  		defer pmtRD.Close()
   119  
   120  		b := make([]byte, 1024)
   121  		n, err := pmtRD.Read(b)
   122  		if err != nil {
   123  			s.err = err
   124  			return
   125  		}
   126  		b = b[:n]
   127  
   128  		if n < 1 {
   129  			s.err = errors.Errorf("zero-length read for pmt on pid 0x%04X", pmtPID)
   130  			return
   131  		}
   132  
   133  		tbl, err := psi.Unmarshal(b)
   134  		if err != nil {
   135  			s.err = err
   136  			return
   137  		}
   138  
   139  		pmt, ok := tbl.(*psi.PMT)
   140  		if !ok {
   141  			s.err = errors.Errorf("unexpected table on pid 0x%04X: %v", pmtPID, tbl.TableID())
   142  		}
   143  
   144  		var pid uint16
   145  		for _, s := range pmt.Streams {
   146  			pid = s.PID
   147  			break
   148  		}
   149  
   150  		pipe, err := d.getPipe(ctx, pid)
   151  		if err != nil {
   152  			s.err = err
   153  			return
   154  		}
   155  
   156  		s.pid = pid
   157  		s.rd = pes.NewReader(pipe)
   158  		s.closer = func() error {
   159  			d.mu.Lock()
   160  			defer d.mu.Unlock()
   161  
   162  			return d.closePID(pid)
   163  		}
   164  	}()
   165  
   166  	return s, nil
   167  }
   168  
   169  // ReaderByPID returns a new io.ReadCloser that reads a raw Program based on the given Program ID.
   170  func (d *Demux) ReaderByPID(ctx context.Context, pid uint16, isPES bool) (io.ReadCloser, error) {
   171  	if pid == pidNULL {
   172  		return nil, errors.Errorf("pid 0x%04X is invalid", pid)
   173  	}
   174  
   175  	select {
   176  	case <-d.closed:
   177  		return nil, errors.New("Demux is closed")
   178  	default:
   179  	}
   180  
   181  	pipe, err := d.getPipe(ctx, pid)
   182  	if err != nil {
   183  		return nil, err
   184  	}
   185  
   186  	var rd io.Reader = pipe
   187  	if isPES {
   188  		rd = pes.NewReader(rd)
   189  	}
   190  
   191  	ready := make(chan struct{})
   192  	close(ready)
   193  
   194  	return &stream{
   195  		ready: ready,
   196  
   197  		pid: pid,
   198  
   199  		rd: rd,
   200  		closer: func() error {
   201  			d.mu.Lock()
   202  			defer d.mu.Unlock()
   203  
   204  			return d.closePID(pid)
   205  		},
   206  	}, nil
   207  }
   208  
   209  func (d *Demux) closePending(pid uint16) {
   210  	pipe := d.pending[pid]
   211  	if pipe == nil {
   212  		return
   213  	}
   214  
   215  	delete(d.pending, pid)
   216  	pipe.Close()
   217  }
   218  
   219  func (d *Demux) closePID(pid uint16) error {
   220  	pipe := d.programs[pid]
   221  	if pipe == nil {
   222  		return nil
   223  	}
   224  
   225  	delete(d.programs, pid)
   226  	return pipe.Close()
   227  }
   228  
   229  // Close closes and ends the TransportStream finishing all reads.
   230  //
   231  // It returns a channel of errors that is closed when all streams and programs are complete.
   232  func (d *Demux) Close() <-chan error {
   233  	errch := make(chan error)
   234  
   235  	go func() {
   236  		defer close(errch)
   237  
   238  		d.mu.Lock()
   239  		defer d.mu.Unlock()
   240  
   241  		var pids []uint16
   242  		for pid := range d.programs {
   243  			pids = append(pids, pid)
   244  		}
   245  
   246  		for _, pid := range pids {
   247  			if err := d.closePID(pid); err != nil {
   248  				errch <- err
   249  			}
   250  		}
   251  
   252  		select {
   253  		case <-d.closed:
   254  		default:
   255  			close(d.closed)
   256  		}
   257  	}()
   258  
   259  	return errch
   260  }
   261  
   262  func (d *Demux) get(pkt *packet.Packet) (wr *bufpipe.Pipe, debug func(*packet.Packet)) {
   263  	d.mu.Lock()
   264  	defer d.mu.Unlock()
   265  
   266  	debug = d.getDebug()
   267  
   268  	wr = d.programs[pkt.PID]
   269  	if wr != nil {
   270  		return wr, debug
   271  	}
   272  
   273  	select {
   274  	case <-d.complete:
   275  		return nil, debug
   276  	default:
   277  	}
   278  
   279  	wr = d.pending[pkt.PID]
   280  	if wr != nil {
   281  		return wr, debug
   282  	}
   283  
   284  	// Make a new bufpipe.Pipe with no context closer.
   285  	// A context closer will be attached,
   286  	// only if this is transformed from pending.
   287  	wr = bufpipe.New(nil, bufpipe.WithNoAutoFlush())
   288  
   289  	d.pending[pkt.PID] = wr
   290  
   291  	return wr, debug
   292  }
   293  
   294  func (d *Demux) readOne(b []byte) (bool, error) {
   295  	if _, err := d.src.Read(b); err != nil {
   296  		return true, err
   297  	}
   298  
   299  	pkt := new(packet.Packet)
   300  	if err := pkt.Unmarshal(b); err != nil {
   301  		return false, err
   302  	}
   303  
   304  	wr, debug := d.get(pkt)
   305  
   306  	if debug != nil {
   307  		debug(pkt)
   308  	}
   309  
   310  	if wr == nil {
   311  		return false, nil
   312  	}
   313  
   314  	if pkt.PUSI {
   315  		if err := wr.Sync(); err != nil {
   316  			d.mu.Lock()
   317  			defer d.mu.Unlock()
   318  
   319  			return false, d.closePID(pkt.PID)
   320  		}
   321  	}
   322  
   323  	if _, err := wr.Write(pkt.Bytes()); err != nil {
   324  		d.mu.Lock()
   325  		defer d.mu.Unlock()
   326  
   327  		return false, d.closePID(pkt.PID)
   328  	}
   329  
   330  	return false, nil
   331  }
   332  
   333  func retError(err error) <-chan error {
   334  	errch := make(chan error, 1)
   335  	errch <- err
   336  	close(errch)
   337  	return errch
   338  }
   339  
   340  // Serve handles the decomposition of the byte stream into the various Programs and Streams.
   341  //
   342  // It returns a channel of errors that is closed when service has completely finished.
   343  func (d *Demux) Serve(ctx context.Context) <-chan error {
   344  	rdPAT, err := d.ReaderByPID(ctx, pidPAT, false)
   345  	if err != nil {
   346  		return retError(err)
   347  	}
   348  
   349  	errch := make(chan error)
   350  	done := make(chan struct{})
   351  
   352  	go func() {
   353  		d.pendingWG.Wait()
   354  
   355  		close(d.complete)
   356  
   357  		d.mu.Lock()
   358  		defer d.mu.Unlock()
   359  
   360  		var pids []uint16
   361  		for pid := range d.programs {
   362  			pids = append(pids, pid)
   363  		}
   364  
   365  		for _, pid := range pids {
   366  			d.closePending(pid)
   367  		}
   368  	}()
   369  
   370  	go func() {
   371  		//defer d.setPAT(nil)
   372  
   373  		b := make([]byte, 1024)
   374  
   375  		var ver byte = 0xFF
   376  
   377  		for {
   378  			n, err := rdPAT.Read(b)
   379  			if err != nil {
   380  				if err != io.EOF {
   381  					select {
   382  					case <-done:
   383  					default:
   384  						errch <- err
   385  					}
   386  				}
   387  
   388  				select {
   389  				case <-d.patReady:
   390  					return
   391  				default:
   392  				}
   393  
   394  				continue
   395  			}
   396  
   397  			if n < 1 {
   398  				// empty-reads are real possibilities.
   399  				continue
   400  			}
   401  
   402  			tbl, err := psi.Unmarshal(b)
   403  			if err != nil {
   404  				select {
   405  				case <-done:
   406  				default:
   407  					errch <- err
   408  				}
   409  
   410  				select {
   411  				case <-d.patReady:
   412  					return
   413  				default:
   414  				}
   415  
   416  				continue
   417  			}
   418  
   419  			pat, ok := tbl.(*psi.PAT)
   420  			if !ok {
   421  				errch <- errors.Errorf("unexpected table on pid 0x0000: %v", tbl.TableID())
   422  				continue
   423  			}
   424  
   425  			if pat.Syntax != nil {
   426  				if ver == pat.Syntax.Version {
   427  					continue
   428  				}
   429  				ver = pat.Syntax.Version
   430  			}
   431  
   432  			newPAT := make(map[uint16]uint16)
   433  			for _, m := range pat.Map {
   434  				newPAT[m.ProgramNumber] = m.PID
   435  			}
   436  
   437  			d.setPAT(newPAT)
   438  		}
   439  	}()
   440  
   441  	go func() {
   442  		defer func() {
   443  			for err := range d.Close() {
   444  				errch <- err
   445  			}
   446  
   447  			close(done)
   448  			close(errch)
   449  		}()
   450  
   451  		b := make([]byte, packet.Length)
   452  
   453  		for {
   454  			select {
   455  			case <-ctx.Done():
   456  				return
   457  			default:
   458  			}
   459  
   460  			isFatal, err := d.readOne(b)
   461  			if err != nil {
   462  				if err == io.EOF {
   463  					return
   464  				}
   465  
   466  				errch <- err
   467  			}
   468  
   469  			if isFatal {
   470  				return
   471  			}
   472  		}
   473  	}()
   474  
   475  	return errch
   476  }