github.com/bluenviron/mediacommon@v1.9.3/pkg/codecs/h265/dts_extractor.go (about)

     1  package h265
     2  
     3  import (
     4  	"fmt"
     5  	"math"
     6  	"time"
     7  
     8  	"github.com/bluenviron/mediacommon/pkg/bits"
     9  	"github.com/bluenviron/mediacommon/pkg/codecs/h264"
    10  )
    11  
    12  const (
    13  	maxBytesToGetPOC = 12
    14  )
    15  
    16  func getPTSDTSDiff(buf []byte, sps *SPS, pps *PPS) (uint32, error) {
    17  	typ := NALUType((buf[0] >> 1) & 0b111111)
    18  
    19  	buf = buf[1:]
    20  	lb := len(buf)
    21  
    22  	if lb > maxBytesToGetPOC {
    23  		lb = maxBytesToGetPOC
    24  	}
    25  
    26  	buf = h264.EmulationPreventionRemove(buf[:lb])
    27  	pos := 8
    28  
    29  	firstSliceSegmentInPicFlag, err := bits.ReadFlag(buf, &pos)
    30  	if err != nil {
    31  		return 0, err
    32  	}
    33  
    34  	if !firstSliceSegmentInPicFlag {
    35  		return 0, fmt.Errorf("first_slice_segment_in_pic_flag = 0 is not supported")
    36  	}
    37  
    38  	if typ >= NALUType_BLA_W_LP && typ <= NALUType_RSV_IRAP_VCL23 {
    39  		_, err := bits.ReadFlag(buf, &pos) // no_output_of_prior_pics_flag
    40  		if err != nil {
    41  			return 0, err
    42  		}
    43  	}
    44  
    45  	_, err = bits.ReadGolombUnsigned(buf, &pos) // slice_pic_parameter_set_id
    46  	if err != nil {
    47  		return 0, err
    48  	}
    49  
    50  	if pps.NumExtraSliceHeaderBits > 0 {
    51  		err := bits.HasSpace(buf, pos, int(pps.NumExtraSliceHeaderBits))
    52  		if err != nil {
    53  			return 0, err
    54  		}
    55  		pos += int(pps.NumExtraSliceHeaderBits)
    56  	}
    57  
    58  	sliceType, err := bits.ReadGolombUnsigned(buf, &pos) // slice_type
    59  	if err != nil {
    60  		return 0, err
    61  	}
    62  
    63  	if pps.OutputFlagPresentFlag {
    64  		_, err := bits.ReadFlag(buf, &pos) // pic_output_flag
    65  		if err != nil {
    66  			return 0, err
    67  		}
    68  	}
    69  
    70  	if sps.SeparateColourPlaneFlag {
    71  		_, err := bits.ReadBits(buf, &pos, 2) // colour_plane_id
    72  		if err != nil {
    73  			return 0, err
    74  		}
    75  	}
    76  
    77  	_, err = bits.ReadBits(buf, &pos, int(sps.Log2MaxPicOrderCntLsbMinus4+4)) // pic_order_cnt_lsb
    78  	if err != nil {
    79  		return 0, err
    80  	}
    81  
    82  	shortTermRefPicSetSpsFlag, err := bits.ReadFlag(buf, &pos)
    83  	if err != nil {
    84  		return 0, err
    85  	}
    86  
    87  	var rps *SPS_ShortTermRefPicSet
    88  
    89  	if !shortTermRefPicSetSpsFlag {
    90  		rps = &SPS_ShortTermRefPicSet{}
    91  		err = rps.unmarshal(buf, &pos, uint32(len(sps.ShortTermRefPicSets)), uint32(len(sps.ShortTermRefPicSets)), nil)
    92  		if err != nil {
    93  			return 0, err
    94  		}
    95  	} else {
    96  		if len(sps.ShortTermRefPicSets) == 0 {
    97  			return 0, fmt.Errorf("invalid short_term_ref_pic_set_idx")
    98  		}
    99  
   100  		b := int(math.Ceil(math.Log2(float64(len(sps.ShortTermRefPicSets)))))
   101  		tmp, err := bits.ReadBits(buf, &pos, b)
   102  		if err != nil {
   103  			return 0, err
   104  		}
   105  		shortTermRefPicSetIdx := int(tmp)
   106  
   107  		if len(sps.ShortTermRefPicSets) <= shortTermRefPicSetIdx {
   108  			return 0, fmt.Errorf("invalid short_term_ref_pic_set_idx")
   109  		}
   110  
   111  		rps = sps.ShortTermRefPicSets[shortTermRefPicSetIdx]
   112  	}
   113  
   114  	var v uint32
   115  
   116  	if sliceType == 0 { // B-frame
   117  		if typ == NALUType_TRAIL_N || typ == NALUType_RASL_N {
   118  			v = sps.MaxNumReorderPics[0] - uint32(len(rps.DeltaPocS1Minus1))
   119  		} else if typ == NALUType_TRAIL_R || typ == NALUType_RASL_R {
   120  			if len(rps.DeltaPocS0Minus1) == 0 {
   121  				return 0, fmt.Errorf("invalid delta_poc_s0_minus1")
   122  			}
   123  			v = rps.DeltaPocS0Minus1[0] + sps.MaxNumReorderPics[0] - 1
   124  		}
   125  	} else { // I or P-frame
   126  		if len(rps.DeltaPocS0Minus1) == 0 {
   127  			return 0, fmt.Errorf("invalid delta_poc_s0_minus1")
   128  		}
   129  		v = rps.DeltaPocS0Minus1[0] + sps.MaxNumReorderPics[0]
   130  	}
   131  
   132  	return v, nil
   133  }
   134  
   135  // DTSExtractor allows to extract DTS from PTS.
   136  type DTSExtractor struct {
   137  	spsp          *SPS
   138  	ppsp          *PPS
   139  	prevDTSFilled bool
   140  	prevDTS       time.Duration
   141  }
   142  
   143  // NewDTSExtractor allocates a DTSExtractor.
   144  func NewDTSExtractor() *DTSExtractor {
   145  	return &DTSExtractor{}
   146  }
   147  
   148  func (d *DTSExtractor) extractInner(au [][]byte, pts time.Duration) (time.Duration, error) {
   149  	var idr []byte
   150  	var nonIDR []byte
   151  
   152  	for _, nalu := range au {
   153  		typ := NALUType((nalu[0] >> 1) & 0b111111)
   154  		switch typ {
   155  		case NALUType_SPS_NUT:
   156  			var spsp SPS
   157  			err := spsp.Unmarshal(nalu)
   158  			if err != nil {
   159  				return 0, fmt.Errorf("invalid SPS: %w", err)
   160  			}
   161  			d.spsp = &spsp
   162  
   163  		case NALUType_PPS_NUT:
   164  			var ppsp PPS
   165  			err := ppsp.Unmarshal(nalu)
   166  			if err != nil {
   167  				return 0, fmt.Errorf("invalid PPS: %w", err)
   168  			}
   169  			d.ppsp = &ppsp
   170  
   171  		case NALUType_IDR_W_RADL, NALUType_IDR_N_LP:
   172  			idr = nalu
   173  
   174  		case NALUType_TRAIL_N, NALUType_TRAIL_R, NALUType_CRA_NUT, NALUType_RASL_N, NALUType_RASL_R:
   175  			nonIDR = nalu
   176  		}
   177  	}
   178  
   179  	if d.spsp == nil {
   180  		return 0, fmt.Errorf("SPS not received yet")
   181  	}
   182  
   183  	if d.ppsp == nil {
   184  		return 0, fmt.Errorf("PPS not received yet")
   185  	}
   186  
   187  	if len(d.spsp.MaxNumReorderPics) != 1 || d.spsp.MaxNumReorderPics[0] == 0 {
   188  		return pts, nil
   189  	}
   190  
   191  	if d.spsp.VUI == nil || d.spsp.VUI.TimingInfo == nil {
   192  		return pts, nil
   193  	}
   194  
   195  	var samplesDiff uint32
   196  
   197  	switch {
   198  	case idr != nil:
   199  		samplesDiff = d.spsp.MaxNumReorderPics[0]
   200  
   201  	case nonIDR != nil:
   202  		var err error
   203  		samplesDiff, err = getPTSDTSDiff(nonIDR, d.spsp, d.ppsp)
   204  		if err != nil {
   205  			return 0, err
   206  		}
   207  
   208  	default:
   209  		return 0, fmt.Errorf("access unit doesn't contain an IDR or non-IDR NALU")
   210  	}
   211  
   212  	timeDiff := time.Duration(samplesDiff) * time.Second *
   213  		time.Duration(d.spsp.VUI.TimingInfo.NumUnitsInTick) / time.Duration(d.spsp.VUI.TimingInfo.TimeScale)
   214  	dts := pts - timeDiff
   215  
   216  	return dts, nil
   217  }
   218  
   219  // Extract extracts the DTS of a access unit.
   220  func (d *DTSExtractor) Extract(au [][]byte, pts time.Duration) (time.Duration, error) {
   221  	dts, err := d.extractInner(au, pts)
   222  	if err != nil {
   223  		return 0, err
   224  	}
   225  
   226  	if dts > pts {
   227  		return 0, fmt.Errorf("DTS is greater than PTS")
   228  	}
   229  
   230  	if d.prevDTSFilled && dts <= d.prevDTS {
   231  		return 0, fmt.Errorf("DTS is not monotonically increasing, was %v, now is %v",
   232  			d.prevDTS, dts)
   233  	}
   234  
   235  	d.prevDTSFilled = true
   236  	d.prevDTS = dts
   237  
   238  	return dts, err
   239  }