github.com/puellanivis/breton@v0.2.16/lib/mpeg/ts/mux.go (about)

     1  package ts
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"sort"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/pkg/errors"
    11  	"github.com/puellanivis/breton/lib/glog"
    12  	"github.com/puellanivis/breton/lib/io/bufpipe"
    13  	"github.com/puellanivis/breton/lib/mpeg/ts/dvb"
    14  	"github.com/puellanivis/breton/lib/mpeg/ts/packet"
    15  	"github.com/puellanivis/breton/lib/mpeg/ts/pcr"
    16  	"github.com/puellanivis/breton/lib/mpeg/ts/pes"
    17  	"github.com/puellanivis/breton/lib/mpeg/ts/psi"
    18  )
    19  
    20  var _ = glog.Info
    21  
    22  // Mux defines an MPEG Transport Stream, which will take multiple elementary streams,
    23  // and compose them into a single byte stream.
    24  type Mux struct {
    25  	TransportStream
    26  
    27  	pcrSrc *pcr.Source
    28  
    29  	closed chan struct{}
    30  	ready  chan struct{}
    31  	chain  chan struct{} // chain is used to ensure that the first packet of each stream is in order.
    32  
    33  	mu          sync.Mutex
    34  	outstanding sync.WaitGroup
    35  }
    36  
    37  // NewMux returns a new TransportStream Mux that composes the streams into the given io.Writer.
    38  func NewMux(wr io.Writer, opts ...Option) *Mux {
    39  	chain := make(chan struct{})
    40  	close(chain)
    41  
    42  	m := &Mux{
    43  		TransportStream: TransportStream{
    44  			sink:   wr,
    45  			ticker: make(chan struct{}),
    46  		},
    47  
    48  		pcrSrc: pcr.NewSource(),
    49  
    50  		closed: make(chan struct{}),
    51  		ready:  make(chan struct{}),
    52  		chain:  chain,
    53  	}
    54  
    55  	m.TransportStream.m = m
    56  
    57  	for _, opt := range opts {
    58  		_ = opt(&m.TransportStream)
    59  	}
    60  
    61  	return m
    62  }
    63  
    64  // NewProgram allocates and returns a new Program within the Mux assigned to the given stream id.
    65  func (m *Mux) NewProgram(ctx context.Context, streamID uint16) (*Program, error) {
    66  	if streamID == 0 {
    67  		return nil, errors.Errorf("stream_id 0x%04X is invalid", streamID)
    68  	}
    69  
    70  	select {
    71  	case <-m.closed:
    72  		return nil, errors.New("Mux is closed")
    73  	default:
    74  	}
    75  
    76  	p, err := m.TransportStream.NewProgram(streamID)
    77  	if err != nil {
    78  		return nil, err
    79  	}
    80  
    81  	wr, err := m.WriterByPID(ctx, p.PID(), false)
    82  	if err != nil {
    83  		return nil, err
    84  	}
    85  	p.wr = wr
    86  
    87  	return p, nil
    88  }
    89  
    90  // Writer allocates and returns a new Stream in a new Program.
    91  func (m *Mux) Writer(ctx context.Context, streamID uint16, typ ProgramType) (io.WriteCloser, error) {
    92  	p, err := m.NewProgram(ctx, streamID)
    93  	if err != nil {
    94  		return nil, err
    95  	}
    96  
    97  	return p.NewWriter(ctx, typ)
    98  }
    99  
   100  const (
   101  	maxLengthAllowingStuffing = packet.MaxPayload - packet.AdaptationFieldMinLength
   102  )
   103  
   104  func (m *Mux) packetizer(rd io.ReadCloser, s *stream, isPES bool, next chan struct{}) {
   105  	defer func() {
   106  		if err := rd.Close(); err != nil {
   107  			glog.Errorf("packetizer: 0x%04X: rd.Close: %+v", s.pid, err)
   108  		}
   109  	}()
   110  
   111  	var continuity byte
   112  
   113  	// PSI table length is limited to 1021 bytes. This is significantly less than 0x10000 bytes.
   114  	// PES packet limited to payload length 0xFFFF, but a header of at least 6, so must be > 0x10000 bytes.
   115  	// So, we use 0x20000 bytes just to be sure we get a whole packet sequence.
   116  	//
   117  	// N.B. It is the Write/Reader’s responsibility to ensure that a Read completes only on full packets,
   118  	// and that said packet sequence will not be > 0x20000 bytes.
   119  	buf := make([]byte, 0x20000)
   120  
   121  	for {
   122  		n, err := rd.Read(buf)
   123  		if err != nil {
   124  			if err != io.EOF {
   125  				glog.Errorf("packetizer: 0x%04X: %+v", s.pid, err)
   126  			}
   127  
   128  			return
   129  		}
   130  
   131  		if n == len(buf) {
   132  			glog.Warningf("packetizer: 0x%04X: unexpected full read of packet buffer", s.pid)
   133  		}
   134  
   135  		// trunc the buffer to only what was read.
   136  		data := buf[:n]
   137  
   138  		pusi := true
   139  
   140  		for len(data) > 0 {
   141  			var af *packet.AdaptationField
   142  
   143  			switch {
   144  			case !isPES:
   145  				// Don’t do anything, PSI tables don’t get an AdaptationField.
   146  
   147  			case pusi:
   148  				af = &packet.AdaptationField{
   149  					Discontinuity: s.getDiscontinuity(),
   150  					RandomAccess:  true, // TODO: make this configurable.
   151  					PCR:           new(pcr.PCR),
   152  				}
   153  
   154  				m.pcrSrc.Read(af.PCR)
   155  
   156  			case len(data) < maxLengthAllowingStuffing:
   157  				// If the remaining payload is small enough to add stuffing and finish this sequence.
   158  				af = &packet.AdaptationField{
   159  					Stuffing: maxLengthAllowingStuffing - len(data),
   160  				}
   161  
   162  			case len(data) < packet.MaxPayload:
   163  				// We don’t have enough room to add stuffing and finish this sequence.
   164  				// So, we add an empty AdaptationField here with 0-bytes of stuffing,
   165  				// which adds 2-bytes to the header, and overflows the last byte
   166  				// of payload into the next packet,  where we will surely have enough room
   167  				// to actually add stuffing.
   168  				// TODO: check if we can just say adaptation_field_length is 0, which would add only one-byte instead of two?
   169  				af = &packet.AdaptationField{}
   170  			}
   171  
   172  			l := packet.MaxPayload - af.Len()
   173  
   174  			if l > len(data) {
   175  				switch {
   176  				case isPES:
   177  					if af != nil {
   178  						af.Stuffing -= len(data) - l
   179  						l = packet.MaxPayload - af.Len()
   180  					}
   181  
   182  					if l > len(data) {
   183  						glog.Errorf("calculated bad payload space: %d > %d", l, len(data))
   184  					}
   185  
   186  				default:
   187  					l = len(data)
   188  				}
   189  			}
   190  
   191  			pkt := &packet.Packet{
   192  				PID:             s.pid,
   193  				PUSI:            pusi,
   194  				Continuity:      continuity,
   195  				AdaptationField: af,
   196  				Payload:         data[:l],
   197  			}
   198  
   199  			pusi = false
   200  			continuity = (continuity + 1) & 0x0F
   201  
   202  			if _, err := m.writePackets(pkt); err != nil {
   203  				glog.Errorf("m.writePackets: 0x%04X: %+v", s.pid, err)
   204  				return
   205  			}
   206  
   207  			if next != nil {
   208  				close(next)
   209  				next = nil
   210  			}
   211  
   212  			data = data[l:]
   213  		}
   214  	}
   215  
   216  }
   217  
   218  func (m *Mux) chainLink() (<-chan struct{}, chan struct{}) {
   219  	m.mu.Lock()
   220  	defer m.mu.Unlock()
   221  
   222  	wait, next := m.chain, make(chan struct{})
   223  	m.chain = next
   224  
   225  	return wait, next
   226  }
   227  
   228  // WriterByPID creates a new io.WriteCloser on the specified PID.
   229  func (m *Mux) WriterByPID(ctx context.Context, pid uint16, isPES bool) (io.WriteCloser, error) {
   230  	glog.Infof("pid:x%04X, isPES:%v", pid, isPES)
   231  
   232  	if pid == pidNULL {
   233  		return nil, errors.Errorf("pid 0x%04X is invalid", pid)
   234  	}
   235  
   236  	select {
   237  	case <-m.closed:
   238  		return nil, errors.Errorf("pid 0x%04X: mux is closed", pid)
   239  	default:
   240  	}
   241  
   242  	pipe := bufpipe.New(ctx)
   243  
   244  	var rd io.ReadCloser = pipe
   245  	// if !isPES: bufpipe.Pipe -> Packetizer
   246  	// if  isPES: bufpipe.Pipe -> ReadAll -> pes.Writer -> io.Pipe -> Packetizer
   247  
   248  	var wait <-chan struct{}
   249  	var next chan struct{}
   250  
   251  	if isPES {
   252  		wait, next = m.chainLink()
   253  
   254  		var wr io.WriteCloser
   255  
   256  		rd, wr = io.Pipe() // synchronous pipe, don’t write to it without a Reader available.
   257  
   258  		pesWR := pes.NewWriter(0xC0, wr) // TODO: don’t hardcode a value for audio.
   259  		//pesWR.Stream.Header.PTS = new(uint64) // we would need to extract this from the input stream…
   260  
   261  		pesHdrLen, err := pesWR.HeaderLength()
   262  		if err != nil {
   263  			return nil, err
   264  		}
   265  
   266  		// 176                       : first payload size (MaxPayload - len(AF{PCR:xxx}))
   267  		// pes.HeaderLength          : PES header size
   268  		// 14 * packet.MaxPayload    : 14 packets of full payload
   269  		// maxLengthAllowingStuffing : 1 packet with enough room for len(AF{Stuffing:xxx})
   270  		// = |PUSI:payload[176]|, 14 × |payload[184]|, |AF{Stuffing}:payload[<182]|
   271  
   272  		// TODO: don’t hard code thise.
   273  		bufpipe.WithMaxOutstanding(176 - pesHdrLen + 14*packet.MaxPayload + maxLengthAllowingStuffing)(pipe)
   274  		bufpipe.WithNoAutoFlush()(pipe)
   275  
   276  		go func() {
   277  			defer wr.Close()
   278  
   279  			for {
   280  				data, err := pipe.ReadAll()
   281  				if err != nil {
   282  					if err != io.EOF {
   283  						glog.Errorf("pipe.ReadAll: %+v", err)
   284  					}
   285  					return
   286  				}
   287  
   288  				if _, err := pesWR.Write(data); err != nil {
   289  					glog.Errorf("mpeg/ts/pes.Writer.Write: %+v", err)
   290  					return
   291  				}
   292  			}
   293  		}()
   294  	}
   295  
   296  	if isPES {
   297  		// We only want to wg.Wait on PES streams.
   298  		m.outstanding.Add(1)
   299  	}
   300  
   301  	ready := make(chan struct{})
   302  	close(ready)
   303  
   304  	s := &stream{
   305  		ready: ready,
   306  
   307  		pid: pid,
   308  
   309  		wr: pipe,
   310  		closer: func() error {
   311  			return pipe.Close()
   312  		},
   313  	}
   314  
   315  	go func() {
   316  		if isPES {
   317  			defer m.outstanding.Done()
   318  
   319  			// Here, we wait until we’ve written the initial PAT and PMTs.
   320  			<-m.ready
   321  
   322  			// Here, we wait for our turn in the chain,
   323  			// to ensure a deterministic order of first packets.
   324  			<-wait
   325  		}
   326  
   327  		m.packetizer(rd, s, isPES, next)
   328  	}()
   329  
   330  	return s, nil
   331  }
   332  
   333  // Close closes and ends the TransportStream finishing all writes.
   334  //
   335  // It returns a channel of errors that is closed when all streams and programs are complete.
   336  func (m *Mux) Close() <-chan error {
   337  	errch := make(chan error)
   338  
   339  	go func() {
   340  		defer close(errch)
   341  
   342  		close(m.closed)
   343  
   344  		m.outstanding.Wait()
   345  	}()
   346  
   347  	return errch
   348  }
   349  
   350  func (m *Mux) markReady() {
   351  	m.mu.Lock()
   352  	defer m.mu.Unlock()
   353  
   354  	select {
   355  	case <-m.ready:
   356  	default:
   357  		close(m.ready)
   358  	}
   359  }
   360  
   361  func (m *Mux) marshalPAT() ([]byte, error) {
   362  	pat := m.GetPAT()
   363  
   364  	var keys []uint16
   365  	for key := range pat {
   366  		keys = append(keys, key)
   367  	}
   368  	sort.Slice(keys, func(i, j int) bool { return keys[i] < keys[j] })
   369  
   370  	pmap := make([]psi.ProgramMap, len(keys))
   371  	for i, key := range keys {
   372  		pmap[i].Set(key, pat[key])
   373  	}
   374  
   375  	tbl := &psi.PAT{
   376  		Syntax: &psi.SectionSyntax{
   377  			TableIDExtension: 0x1,
   378  			Current:          true,
   379  		},
   380  		Map: pmap,
   381  	}
   382  
   383  	return tbl.Marshal()
   384  }
   385  
   386  // this needs to be moved into sink, as right now it violates sink’s internal details.
   387  func (m *Mux) preamble(continuity byte) error {
   388  	var pkts []*packet.Packet
   389  
   390  	continuity = continuity & 0x0F
   391  
   392  	if sdt := m.TransportStream.getDVBSDT(); sdt != nil {
   393  		if payload, err := sdt.Marshal(); err == nil {
   394  			// In this specific case, if we get an error, just ignore the packet entirely.
   395  			pkts = append(pkts, &packet.Packet{
   396  				PID:        dvb.ServiceDescriptionPID,
   397  				PUSI:       true,
   398  				Continuity: continuity,
   399  				Payload:    payload,
   400  			})
   401  		} else {
   402  			glog.Warningf("dvb.ServiceDescriptorTable.Marshal: %+v", err)
   403  		}
   404  	}
   405  
   406  	payload, err := m.marshalPAT()
   407  	if err != nil {
   408  		return err
   409  	}
   410  
   411  	pkts = append(pkts, &packet.Packet{
   412  		PID:        pidPAT,
   413  		PUSI:       true,
   414  		Continuity: continuity,
   415  		Payload:    payload,
   416  	})
   417  
   418  	for _, p := range m.GetPMTs() {
   419  		pkt, err := p.packet(continuity)
   420  		if err != nil {
   421  			return err
   422  		}
   423  
   424  		pkts = append(pkts, pkt)
   425  	}
   426  
   427  	if _, err := m.writePackets(pkts...); err != nil {
   428  		return err
   429  	}
   430  
   431  	m.markReady()
   432  
   433  	return nil
   434  }
   435  
   436  // Serve handles the composition of the various Streams into Programs into the Transport Stream.
   437  //
   438  // It returns a channel of errors that is closed when service has completely finished.
   439  func (m *Mux) Serve(ctx context.Context) <-chan error {
   440  	wrPAT, err := m.WriterByPID(ctx, pidPAT, false)
   441  	if err != nil {
   442  		return retError(err)
   443  	}
   444  
   445  	var continuity byte
   446  	if err := m.preamble(continuity); err != nil {
   447  		return retError(err)
   448  	}
   449  	continuity++
   450  
   451  	errch := make(chan error)
   452  
   453  	go func() {
   454  		defer func() {
   455  			if err := wrPAT.Close(); err != nil {
   456  				errch <- err
   457  			}
   458  
   459  			close(errch)
   460  		}()
   461  
   462  		// TODO: what do the specifications say?
   463  		timer := time.NewTimer(m.getUpdateRate())
   464  		defer timer.Stop()
   465  
   466  		for {
   467  			timer.Reset(m.getUpdateRate())
   468  
   469  			select {
   470  			case <-ctx.Done():
   471  				return
   472  			case <-m.closed:
   473  				return
   474  			case <-timer.C:
   475  			case <-m.ticker:
   476  			}
   477  
   478  			if err := m.preamble(continuity); err != nil {
   479  				errch <- err
   480  				return
   481  			}
   482  			continuity++
   483  		}
   484  	}()
   485  
   486  	return errch
   487  }