github.com/bluenviron/mediacommon@v1.9.3/pkg/formats/mpegts/reader.go (about)

     1  package mpegts
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"strings"
     9  
    10  	"github.com/asticode/go-astits"
    11  
    12  	"github.com/bluenviron/mediacommon/pkg/codecs/ac3"
    13  	"github.com/bluenviron/mediacommon/pkg/codecs/h264"
    14  	"github.com/bluenviron/mediacommon/pkg/codecs/mpeg1audio"
    15  	"github.com/bluenviron/mediacommon/pkg/codecs/mpeg4audio"
    16  )
    17  
    18  // ReaderOnDecodeErrorFunc is the prototype of the callback passed to OnDecodeError.
    19  type ReaderOnDecodeErrorFunc func(err error)
    20  
    21  // ReaderOnDataH26xFunc is the prototype of the callback passed to OnDataH26x.
    22  type ReaderOnDataH26xFunc func(pts int64, dts int64, au [][]byte) error
    23  
    24  // ReaderOnDataMPEGxVideoFunc is the prototype of the callback passed to OnDataMPEGxVideo.
    25  type ReaderOnDataMPEGxVideoFunc func(pts int64, frame []byte) error
    26  
    27  // ReaderOnDataOpusFunc is the prototype of the callback passed to OnDataOpus.
    28  type ReaderOnDataOpusFunc func(pts int64, packets [][]byte) error
    29  
    30  // ReaderOnDataMPEG4AudioFunc is the prototype of the callback passed to OnDataMPEG4Audio.
    31  type ReaderOnDataMPEG4AudioFunc func(pts int64, aus [][]byte) error
    32  
    33  // ReaderOnDataMPEG1AudioFunc is the prototype of the callback passed to OnDataMPEG1Audio.
    34  type ReaderOnDataMPEG1AudioFunc func(pts int64, frames [][]byte) error
    35  
    36  // ReaderOnDataAC3Func is the prototype of the callback passed to OnDataAC3.
    37  type ReaderOnDataAC3Func func(pts int64, frame []byte) error
    38  
    39  func findPMT(dem *astits.Demuxer) (*astits.PMTData, error) {
    40  	for {
    41  		data, err := dem.NextData()
    42  		if err != nil {
    43  			return nil, err
    44  		}
    45  
    46  		if data.PMT != nil {
    47  			return data.PMT, nil
    48  		}
    49  	}
    50  }
    51  
    52  // Reader is a MPEG-TS reader.
    53  type Reader struct {
    54  	tracks        []*Track
    55  	dem           *astits.Demuxer
    56  	onDecodeError ReaderOnDecodeErrorFunc
    57  	onData        map[uint16]func(int64, int64, []byte) error
    58  }
    59  
    60  // NewReader allocates a Reader.
    61  func NewReader(br io.Reader) (*Reader, error) {
    62  	rr := &recordReader{r: br}
    63  
    64  	dem := astits.NewDemuxer(
    65  		context.Background(),
    66  		rr,
    67  		astits.DemuxerOptPacketSize(188))
    68  
    69  	pmt, err := findPMT(dem)
    70  	if err != nil {
    71  		return nil, err
    72  	}
    73  
    74  	var tracks []*Track //nolint:prealloc
    75  
    76  	for _, es := range pmt.ElementaryStreams {
    77  		var track Track
    78  		err := track.unmarshal(dem, es)
    79  		if err != nil {
    80  			if errors.Is(err, errUnsupportedCodec) {
    81  				continue
    82  			}
    83  			return nil, err
    84  		}
    85  
    86  		tracks = append(tracks, &track)
    87  	}
    88  
    89  	if tracks == nil {
    90  		return nil, fmt.Errorf("no tracks with supported codecs found")
    91  	}
    92  
    93  	// rewind demuxer
    94  	dem = astits.NewDemuxer(
    95  		context.Background(),
    96  		&playbackReader{r: br, buf: rr.buf},
    97  		astits.DemuxerOptPacketSize(188))
    98  
    99  	return &Reader{
   100  		tracks:        tracks,
   101  		dem:           dem,
   102  		onDecodeError: func(error) {},
   103  		onData:        make(map[uint16]func(int64, int64, []byte) error),
   104  	}, nil
   105  }
   106  
   107  // Tracks returns detected tracks.
   108  func (r *Reader) Tracks() []*Track {
   109  	return r.tracks
   110  }
   111  
   112  // OnDecodeError sets a callback that is called when a non-fatal decode error occurs.
   113  func (r *Reader) OnDecodeError(cb ReaderOnDecodeErrorFunc) {
   114  	r.onDecodeError = cb
   115  }
   116  
   117  // OnDataH26x sets a callback that is called when data from an H26x track is received.
   118  func (r *Reader) OnDataH26x(track *Track, cb ReaderOnDataH26xFunc) {
   119  	r.onData[track.PID] = func(pts int64, dts int64, data []byte) error {
   120  		au, err := h264.AnnexBUnmarshal(data)
   121  		if err != nil {
   122  			r.onDecodeError(err)
   123  			return nil
   124  		}
   125  
   126  		return cb(pts, dts, au)
   127  	}
   128  }
   129  
   130  // OnDataMPEGxVideo sets a callback that is called when data from an MPEG-1/2/4 Video track is received.
   131  func (r *Reader) OnDataMPEGxVideo(track *Track, cb ReaderOnDataMPEGxVideoFunc) {
   132  	r.onData[track.PID] = func(pts int64, _ int64, data []byte) error {
   133  		return cb(pts, data)
   134  	}
   135  }
   136  
   137  // OnDataOpus sets a callback that is called when data from an Opus track is received.
   138  func (r *Reader) OnDataOpus(track *Track, cb ReaderOnDataOpusFunc) {
   139  	r.onData[track.PID] = func(pts int64, dts int64, data []byte) error {
   140  		if pts != dts {
   141  			r.onDecodeError(fmt.Errorf("PTS is not equal to DTS"))
   142  			return nil
   143  		}
   144  
   145  		pos := 0
   146  		var packets [][]byte
   147  
   148  		for {
   149  			var au opusAccessUnit
   150  			n, err := au.unmarshal(data[pos:])
   151  			if err != nil {
   152  				r.onDecodeError(err)
   153  				return nil
   154  			}
   155  			pos += n
   156  
   157  			packets = append(packets, au.Packet)
   158  
   159  			if len(data[pos:]) == 0 {
   160  				break
   161  			}
   162  		}
   163  
   164  		return cb(pts, packets)
   165  	}
   166  }
   167  
   168  // OnDataMPEG4Audio sets a callback that is called when data from an MPEG-4 Audio track is received.
   169  func (r *Reader) OnDataMPEG4Audio(track *Track, cb ReaderOnDataMPEG4AudioFunc) {
   170  	r.onData[track.PID] = func(pts int64, dts int64, data []byte) error {
   171  		if pts != dts {
   172  			r.onDecodeError(fmt.Errorf("PTS is not equal to DTS"))
   173  			return nil
   174  		}
   175  
   176  		var pkts mpeg4audio.ADTSPackets
   177  		err := pkts.Unmarshal(data)
   178  		if err != nil {
   179  			r.onDecodeError(fmt.Errorf("invalid ADTS: %w", err))
   180  			return nil
   181  		}
   182  
   183  		aus := make([][]byte, len(pkts))
   184  		for i, pkt := range pkts {
   185  			aus[i] = pkt.AU
   186  		}
   187  
   188  		return cb(pts, aus)
   189  	}
   190  }
   191  
   192  // OnDataMPEG1Audio sets a callback that is called when data from an MPEG-1 Audio track is received.
   193  func (r *Reader) OnDataMPEG1Audio(track *Track, cb ReaderOnDataMPEG1AudioFunc) {
   194  	r.onData[track.PID] = func(pts int64, dts int64, data []byte) error {
   195  		if pts != dts {
   196  			r.onDecodeError(fmt.Errorf("PTS is not equal to DTS"))
   197  			return nil
   198  		}
   199  
   200  		var frames [][]byte
   201  
   202  		for len(data) > 0 {
   203  			var h mpeg1audio.FrameHeader
   204  			err := h.Unmarshal(data)
   205  			if err != nil {
   206  				r.onDecodeError(err)
   207  				return nil
   208  			}
   209  
   210  			fl := h.FrameLen()
   211  			if len(data) < fl {
   212  				r.onDecodeError(fmt.Errorf("buffer is too short"))
   213  				return nil
   214  			}
   215  
   216  			var frame []byte
   217  			frame, data = data[:fl], data[fl:]
   218  
   219  			frames = append(frames, frame)
   220  		}
   221  
   222  		return cb(pts, frames)
   223  	}
   224  }
   225  
   226  // OnDataAC3 sets a callback that is called when data from an AC-3 track is received.
   227  func (r *Reader) OnDataAC3(track *Track, cb ReaderOnDataAC3Func) {
   228  	r.onData[track.PID] = func(pts int64, dts int64, data []byte) error {
   229  		if pts != dts {
   230  			r.onDecodeError(fmt.Errorf("PTS is not equal to DTS"))
   231  			return nil
   232  		}
   233  
   234  		var syncInfo ac3.SyncInfo
   235  		err := syncInfo.Unmarshal(data)
   236  		if err != nil {
   237  			r.onDecodeError(err)
   238  			return nil
   239  		}
   240  		size := syncInfo.FrameSize()
   241  
   242  		if size != len(data) {
   243  			r.onDecodeError(fmt.Errorf("unexpected frame size: got %d, expected %d", len(data), size))
   244  			return nil
   245  		}
   246  
   247  		return cb(pts, data)
   248  	}
   249  }
   250  
   251  // Read reads data.
   252  func (r *Reader) Read() error {
   253  	for {
   254  		data, err := r.dem.NextData()
   255  		if err != nil {
   256  			// https://github.com/asticode/go-astits/blob/b0b19247aa31633650c32638fb55f597fa6e2468/packet_buffer.go#L133C1-L133C5
   257  			if errors.Is(err, astits.ErrNoMorePackets) || strings.Contains(err.Error(), "astits: reading ") {
   258  				return err
   259  			}
   260  			r.onDecodeError(err)
   261  			continue
   262  		}
   263  
   264  		if data.PES == nil {
   265  			return nil
   266  		}
   267  
   268  		if data.PES.Header.OptionalHeader == nil ||
   269  			data.PES.Header.OptionalHeader.PTSDTSIndicator == astits.PTSDTSIndicatorNoPTSOrDTS ||
   270  			data.PES.Header.OptionalHeader.PTSDTSIndicator == astits.PTSDTSIndicatorIsForbidden {
   271  			r.onDecodeError(fmt.Errorf("PTS is missing"))
   272  			return nil
   273  		}
   274  
   275  		pts := data.PES.Header.OptionalHeader.PTS.Base
   276  
   277  		var dts int64
   278  		if data.PES.Header.OptionalHeader.PTSDTSIndicator == astits.PTSDTSIndicatorBothPresent {
   279  			dts = data.PES.Header.OptionalHeader.DTS.Base
   280  		} else {
   281  			dts = pts
   282  		}
   283  
   284  		onData, ok := r.onData[data.PID]
   285  		if !ok {
   286  			return nil
   287  		}
   288  
   289  		return onData(pts, dts, data.PES.Data)
   290  	}
   291  }