github.com/bluenviron/mediacommon@v1.9.3/pkg/codecs/h264/dts_extractor.go (about)

     1  package h264
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"time"
     7  
     8  	"github.com/bluenviron/mediacommon/pkg/bits"
     9  )
    10  
    11  const (
    12  	maxReorderedFrames = 10
    13  	/*
    14  		(max_size(first_mb_in_slice) + max_size(slice_type) + max_size(pic_parameter_set_id) +
    15  		max_size(frame_num) + max_size(pic_order_cnt_lsb)) * 4 / 3 =
    16  		(3 * max_size(golomb) + (max(Log2MaxFrameNumMinus4) + 4) / 8 + (max(Log2MaxPicOrderCntLsbMinus4) + 4) / 8) * 4 / 3 =
    17  		(3 * 4 + 2 + 2) * 4 / 3 = 22
    18  	*/
    19  	maxBytesToGetPOC = 22
    20  )
    21  
    22  func getPictureOrderCount(buf []byte, sps *SPS) (uint32, error) {
    23  	buf = buf[1:]
    24  	lb := len(buf)
    25  
    26  	if lb > maxBytesToGetPOC {
    27  		lb = maxBytesToGetPOC
    28  	}
    29  
    30  	buf = EmulationPreventionRemove(buf[:lb])
    31  	pos := 0
    32  
    33  	_, err := bits.ReadGolombUnsigned(buf, &pos) // first_mb_in_slice
    34  	if err != nil {
    35  		return 0, err
    36  	}
    37  
    38  	_, err = bits.ReadGolombUnsigned(buf, &pos) // slice_type
    39  	if err != nil {
    40  		return 0, err
    41  	}
    42  
    43  	_, err = bits.ReadGolombUnsigned(buf, &pos) // pic_parameter_set_id
    44  	if err != nil {
    45  		return 0, err
    46  	}
    47  
    48  	_, err = bits.ReadBits(buf, &pos, int(sps.Log2MaxFrameNumMinus4+4)) // frame_num
    49  	if err != nil {
    50  		return 0, err
    51  	}
    52  
    53  	picOrderCntLsb, err := bits.ReadBits(buf, &pos, int(sps.Log2MaxPicOrderCntLsbMinus4+4))
    54  	if err != nil {
    55  		return 0, err
    56  	}
    57  
    58  	return uint32(picOrderCntLsb), nil
    59  }
    60  
    61  func getPictureOrderCountDiff(a uint32, b uint32, sps *SPS) int32 {
    62  	max := uint32(1 << (sps.Log2MaxPicOrderCntLsbMinus4 + 4))
    63  	d := (a - b) & (max - 1)
    64  	if d > (max / 2) {
    65  		return int32(d) - int32(max)
    66  	}
    67  	return int32(d)
    68  }
    69  
    70  // DTSExtractor allows to extract DTS from PTS.
    71  type DTSExtractor struct {
    72  	sps             []byte
    73  	spsp            *SPS
    74  	prevDTSFilled   bool
    75  	prevDTS         time.Duration
    76  	expectedPOC     uint32
    77  	reorderedFrames int
    78  	pauseDTS        int
    79  	pocIncrement    int
    80  }
    81  
    82  // NewDTSExtractor allocates a DTSExtractor.
    83  func NewDTSExtractor() *DTSExtractor {
    84  	return &DTSExtractor{
    85  		pocIncrement: 2,
    86  	}
    87  }
    88  
    89  func (d *DTSExtractor) extractInner(au [][]byte, pts time.Duration) (time.Duration, bool, error) {
    90  	var idr []byte
    91  	var nonIDR []byte
    92  
    93  	for _, nalu := range au {
    94  		typ := NALUType(nalu[0] & 0x1F)
    95  		switch typ {
    96  		case NALUTypeSPS:
    97  			if !bytes.Equal(d.sps, nalu) {
    98  				var spsp SPS
    99  				err := spsp.Unmarshal(nalu)
   100  				if err != nil {
   101  					return 0, false, fmt.Errorf("invalid SPS: %w", err)
   102  				}
   103  				d.sps = nalu
   104  				d.spsp = &spsp
   105  
   106  				// reset state
   107  				d.reorderedFrames = 0
   108  				d.pocIncrement = 2
   109  			}
   110  
   111  		case NALUTypeIDR:
   112  			idr = nalu
   113  
   114  		case NALUTypeNonIDR:
   115  			nonIDR = nalu
   116  		}
   117  	}
   118  
   119  	if d.spsp == nil {
   120  		return 0, false, fmt.Errorf("SPS not received yet")
   121  	}
   122  
   123  	if d.spsp.PicOrderCntType == 2 || !d.spsp.FrameMbsOnlyFlag {
   124  		return pts, false, nil
   125  	}
   126  
   127  	if d.spsp.PicOrderCntType == 1 {
   128  		return 0, false, fmt.Errorf("pic_order_cnt_type = 1 is not supported yet")
   129  	}
   130  
   131  	switch {
   132  	case idr != nil:
   133  		d.expectedPOC = 0
   134  		d.pauseDTS = 0
   135  
   136  		if !d.prevDTSFilled || d.reorderedFrames == 0 {
   137  			return pts, false, nil
   138  		}
   139  
   140  		return d.prevDTS + (pts-d.prevDTS)/time.Duration(d.reorderedFrames+1), false, nil
   141  
   142  	case nonIDR != nil:
   143  		d.expectedPOC += uint32(d.pocIncrement)
   144  		d.expectedPOC &= ((1 << (d.spsp.Log2MaxPicOrderCntLsbMinus4 + 4)) - 1)
   145  
   146  		if d.pauseDTS > 0 {
   147  			d.pauseDTS--
   148  			return d.prevDTS + 1*time.Millisecond, true, nil
   149  		}
   150  
   151  		poc, err := getPictureOrderCount(nonIDR, d.spsp)
   152  		if err != nil {
   153  			return 0, false, err
   154  		}
   155  
   156  		if d.pocIncrement == 2 && (poc%2) != 0 {
   157  			d.pocIncrement = 1
   158  			d.expectedPOC /= 2
   159  		}
   160  
   161  		pocDiff := int(getPictureOrderCountDiff(poc, d.expectedPOC, d.spsp)) / d.pocIncrement
   162  		limit := -(d.reorderedFrames + 1)
   163  
   164  		// this happens when there are B-frames immediately following an IDR frame
   165  		if pocDiff < limit {
   166  			increase := limit - pocDiff
   167  			if (d.reorderedFrames + increase) > maxReorderedFrames {
   168  				return 0, false, fmt.Errorf("too many reordered frames (%d)", d.reorderedFrames+increase)
   169  			}
   170  
   171  			d.reorderedFrames += increase
   172  			d.pauseDTS = increase
   173  			return d.prevDTS + 1*time.Millisecond, true, nil
   174  		}
   175  
   176  		if pocDiff == limit {
   177  			return pts, false, nil
   178  		}
   179  
   180  		if pocDiff > d.reorderedFrames {
   181  			increase := pocDiff - d.reorderedFrames
   182  			if (d.reorderedFrames + increase) > maxReorderedFrames {
   183  				return 0, false, fmt.Errorf("too many reordered frames (%d)", d.reorderedFrames+increase)
   184  			}
   185  
   186  			d.reorderedFrames += increase
   187  			d.pauseDTS = increase - 1
   188  			return d.prevDTS + 1*time.Millisecond, false, nil
   189  		}
   190  
   191  		return d.prevDTS + (pts-d.prevDTS)/time.Duration(pocDiff+d.reorderedFrames+1), false, nil
   192  
   193  	default:
   194  		return 0, false, fmt.Errorf("access unit doesn't contain an IDR or non-IDR NALU")
   195  	}
   196  }
   197  
   198  // Extract extracts the DTS of an access unit.
   199  func (d *DTSExtractor) Extract(au [][]byte, pts time.Duration) (time.Duration, error) {
   200  	dts, skipChecks, err := d.extractInner(au, pts)
   201  	if err != nil {
   202  		return 0, err
   203  	}
   204  
   205  	if !skipChecks && dts > pts {
   206  		return 0, fmt.Errorf("DTS is greater than PTS")
   207  	}
   208  
   209  	if d.prevDTSFilled && dts <= d.prevDTS {
   210  		return 0, fmt.Errorf("DTS is not monotonically increasing, was %v, now is %v",
   211  			d.prevDTS, dts)
   212  	}
   213  
   214  	d.prevDTS = dts
   215  	d.prevDTSFilled = true
   216  
   217  	return dts, err
   218  }