github.com/bluenviron/mediacommon@v1.9.3/pkg/formats/fmp4/init.go (about)

     1  package fmp4
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  
     7  	"github.com/abema/go-mp4"
     8  
     9  	"github.com/bluenviron/mediacommon/pkg/codecs/av1"
    10  	"github.com/bluenviron/mediacommon/pkg/codecs/h265"
    11  	"github.com/bluenviron/mediacommon/pkg/codecs/mpeg4audio"
    12  )
    13  
    14  // Specification: ISO 14496-1, Table 5
    15  const (
    16  	objectTypeIndicationVisualISO14496part2    = 0x20
    17  	objectTypeIndicationAudioISO14496part3     = 0x40
    18  	objectTypeIndicationVisualISO1318part2Main = 0x61
    19  	objectTypeIndicationAudioISO11172part3     = 0x6B
    20  	objectTypeIndicationVisualISO10918part1    = 0x6C
    21  )
    22  
    23  // Specification: ISO 14496-1, Table 6
    24  const (
    25  	streamTypeVisualStream = 0x04
    26  	streamTypeAudioStream  = 0x05
    27  )
    28  
    29  func av1FindSequenceHeader(bs []byte) ([]byte, error) {
    30  	tu, err := av1.BitstreamUnmarshal(bs, true)
    31  	if err != nil {
    32  		return nil, err
    33  	}
    34  
    35  	for _, obu := range tu {
    36  		var h av1.OBUHeader
    37  		err := h.Unmarshal(obu)
    38  		if err != nil {
    39  			return nil, err
    40  		}
    41  
    42  		if h.Type == av1.OBUTypeSequenceHeader {
    43  			return obu, nil
    44  		}
    45  	}
    46  
    47  	return nil, fmt.Errorf("sequence header not found")
    48  }
    49  
    50  func h265FindParams(params []mp4.HEVCNaluArray) ([]byte, []byte, []byte, error) {
    51  	var vps []byte
    52  	var sps []byte
    53  	var pps []byte
    54  
    55  	for _, arr := range params {
    56  		switch h265.NALUType(arr.NaluType) {
    57  		case h265.NALUType_VPS_NUT, h265.NALUType_SPS_NUT, h265.NALUType_PPS_NUT:
    58  			if arr.NumNalus != 1 {
    59  				return nil, nil, nil, fmt.Errorf("multiple VPS/SPS/PPS are not supported")
    60  			}
    61  		}
    62  
    63  		switch h265.NALUType(arr.NaluType) {
    64  		case h265.NALUType_VPS_NUT:
    65  			vps = arr.Nalus[0].NALUnit
    66  
    67  		case h265.NALUType_SPS_NUT:
    68  			sps = arr.Nalus[0].NALUnit
    69  
    70  		case h265.NALUType_PPS_NUT:
    71  			pps = arr.Nalus[0].NALUnit
    72  		}
    73  	}
    74  
    75  	if vps == nil {
    76  		return nil, nil, nil, fmt.Errorf("VPS not provided")
    77  	}
    78  
    79  	if sps == nil {
    80  		return nil, nil, nil, fmt.Errorf("SPS not provided")
    81  	}
    82  
    83  	if pps == nil {
    84  		return nil, nil, nil, fmt.Errorf("PPS not provided")
    85  	}
    86  
    87  	return vps, sps, pps, nil
    88  }
    89  
    90  func h264FindParams(avcc *mp4.AVCDecoderConfiguration) ([]byte, []byte, error) {
    91  	if len(avcc.SequenceParameterSets) > 1 {
    92  		return nil, nil, fmt.Errorf("multiple SPS are not supported")
    93  	}
    94  
    95  	var sps []byte
    96  	if len(avcc.SequenceParameterSets) == 1 {
    97  		sps = avcc.SequenceParameterSets[0].NALUnit
    98  	}
    99  
   100  	if len(avcc.PictureParameterSets) > 1 {
   101  		return nil, nil, fmt.Errorf("multiple PPS are not supported")
   102  	}
   103  
   104  	var pps []byte
   105  	if len(avcc.PictureParameterSets) == 1 {
   106  		pps = avcc.PictureParameterSets[0].NALUnit
   107  	}
   108  
   109  	return sps, pps, nil
   110  }
   111  
   112  func esdsFindDecoderConf(descriptors []mp4.Descriptor) *mp4.DecoderConfigDescriptor {
   113  	for _, desc := range descriptors {
   114  		if desc.Tag == mp4.DecoderConfigDescrTag {
   115  			return desc.DecoderConfigDescriptor
   116  		}
   117  	}
   118  	return nil
   119  }
   120  
   121  func esdsFindDecoderSpecificInfo(descriptors []mp4.Descriptor) []byte {
   122  	for _, desc := range descriptors {
   123  		if desc.Tag == mp4.DecSpecificInfoTag {
   124  			return desc.Data
   125  		}
   126  	}
   127  	return nil
   128  }
   129  
   130  // Init is a fMP4 initialization block.
   131  type Init struct {
   132  	Tracks []*InitTrack
   133  }
   134  
   135  // Unmarshal decodes a fMP4 initialization block.
   136  func (i *Init) Unmarshal(r io.ReadSeeker) error {
   137  	type readState int
   138  
   139  	const (
   140  		waitingTrak readState = iota
   141  		waitingTkhd
   142  		waitingMdhd
   143  		waitingCodec
   144  		waitingAv1C
   145  		waitingVpcC
   146  		waitingHvcC
   147  		waitingAvcC
   148  		waitingVideoEsds
   149  		waitingAudioEsds
   150  		waitingDOps
   151  		waitingDac3
   152  		waitingPcmC
   153  	)
   154  
   155  	state := waitingTrak
   156  	var curTrack *InitTrack
   157  	var width int
   158  	var height int
   159  	var sampleRate int
   160  	var channelCount int
   161  
   162  	_, err := mp4.ReadBoxStructure(r, func(h *mp4.ReadHandle) (interface{}, error) {
   163  		if !h.BoxInfo.IsSupportedType() {
   164  			if state != waitingTrak {
   165  				i.Tracks = i.Tracks[:len(i.Tracks)-1]
   166  				state = waitingTrak
   167  			}
   168  		} else {
   169  			switch h.BoxInfo.Type.String() {
   170  			case "moov":
   171  				return h.Expand()
   172  
   173  			case "trak":
   174  				if state != waitingTrak {
   175  					return nil, fmt.Errorf("unexpected box '%v'", h.BoxInfo.Type)
   176  				}
   177  
   178  				curTrack = &InitTrack{}
   179  				i.Tracks = append(i.Tracks, curTrack)
   180  				state = waitingTkhd
   181  				return h.Expand()
   182  
   183  			case "tkhd":
   184  				if state != waitingTkhd {
   185  					return nil, fmt.Errorf("unexpected box '%v'", h.BoxInfo.Type)
   186  				}
   187  
   188  				box, _, err := h.ReadPayload()
   189  				if err != nil {
   190  					return nil, err
   191  				}
   192  				tkhd := box.(*mp4.Tkhd)
   193  
   194  				curTrack.ID = int(tkhd.TrackID)
   195  				state = waitingMdhd
   196  
   197  			case "mdia":
   198  				return h.Expand()
   199  
   200  			case "mdhd":
   201  				if state != waitingMdhd {
   202  					return nil, fmt.Errorf("unexpected box '%v'", h.BoxInfo.Type)
   203  				}
   204  
   205  				box, _, err := h.ReadPayload()
   206  				if err != nil {
   207  					return nil, err
   208  				}
   209  				mdhd := box.(*mp4.Mdhd)
   210  
   211  				curTrack.TimeScale = mdhd.Timescale
   212  				state = waitingCodec
   213  
   214  			case "minf", "stbl", "stsd":
   215  				return h.Expand()
   216  
   217  			case "avc1":
   218  				if state != waitingCodec {
   219  					return nil, fmt.Errorf("unexpected box '%v'", h.BoxInfo.Type)
   220  				}
   221  				state = waitingAvcC
   222  				return h.Expand()
   223  
   224  			case "avcC":
   225  				if state != waitingAvcC {
   226  					return nil, fmt.Errorf("unexpected box '%v'", h.BoxInfo.Type)
   227  				}
   228  
   229  				box, _, err := h.ReadPayload()
   230  				if err != nil {
   231  					return nil, err
   232  				}
   233  				avcc := box.(*mp4.AVCDecoderConfiguration)
   234  
   235  				sps, pps, err := h264FindParams(avcc)
   236  				if err != nil {
   237  					return nil, err
   238  				}
   239  
   240  				curTrack.Codec = &CodecH264{
   241  					SPS: sps,
   242  					PPS: pps,
   243  				}
   244  				state = waitingTrak
   245  
   246  			case "vp09":
   247  				if state != waitingCodec {
   248  					return nil, fmt.Errorf("unexpected box '%v'", h.BoxInfo.Type)
   249  				}
   250  
   251  				box, _, err := h.ReadPayload()
   252  				if err != nil {
   253  					return nil, err
   254  				}
   255  				vp09 := box.(*mp4.VisualSampleEntry)
   256  
   257  				width = int(vp09.Width)
   258  				height = int(vp09.Height)
   259  				state = waitingVpcC
   260  				return h.Expand()
   261  
   262  			case "vpcC":
   263  				if state != waitingVpcC {
   264  					return nil, fmt.Errorf("unexpected box '%v'", h.BoxInfo.Type)
   265  				}
   266  
   267  				box, _, err := h.ReadPayload()
   268  				if err != nil {
   269  					return nil, err
   270  				}
   271  				vpcc := box.(*mp4.VpcC)
   272  
   273  				curTrack.Codec = &CodecVP9{
   274  					Width:             width,
   275  					Height:            height,
   276  					Profile:           vpcc.Profile,
   277  					BitDepth:          vpcc.BitDepth,
   278  					ChromaSubsampling: vpcc.ChromaSubsampling,
   279  					ColorRange:        vpcc.VideoFullRangeFlag != 0,
   280  				}
   281  				state = waitingTrak
   282  
   283  			case "vp08": // VP8, not supported yet
   284  				return nil, nil
   285  
   286  			case "hev1", "hvc1":
   287  				if state != waitingCodec {
   288  					return nil, fmt.Errorf("unexpected box '%v'", h.BoxInfo.Type)
   289  				}
   290  				state = waitingHvcC
   291  				return h.Expand()
   292  
   293  			case "hvcC":
   294  				if state != waitingHvcC {
   295  					return nil, fmt.Errorf("unexpected box '%v'", h.BoxInfo.Type)
   296  				}
   297  
   298  				box, _, err := h.ReadPayload()
   299  				if err != nil {
   300  					return nil, err
   301  				}
   302  				hvcc := box.(*mp4.HvcC)
   303  
   304  				vps, sps, pps, err := h265FindParams(hvcc.NaluArrays)
   305  				if err != nil {
   306  					return nil, err
   307  				}
   308  
   309  				curTrack.Codec = &CodecH265{
   310  					VPS: vps,
   311  					SPS: sps,
   312  					PPS: pps,
   313  				}
   314  				state = waitingTrak
   315  
   316  			case "av01":
   317  				if state != waitingCodec {
   318  					return nil, fmt.Errorf("unexpected box '%v'", h.BoxInfo.Type)
   319  				}
   320  				state = waitingAv1C
   321  				return h.Expand()
   322  
   323  			case "av1C":
   324  				if state != waitingAv1C {
   325  					return nil, fmt.Errorf("unexpected box '%v'", h.BoxInfo.Type)
   326  				}
   327  
   328  				box, _, err := h.ReadPayload()
   329  				if err != nil {
   330  					return nil, err
   331  				}
   332  				av1c := box.(*mp4.Av1C)
   333  
   334  				sequenceHeader, err := av1FindSequenceHeader(av1c.ConfigOBUs)
   335  				if err != nil {
   336  					return nil, err
   337  				}
   338  
   339  				curTrack.Codec = &CodecAV1{
   340  					SequenceHeader: sequenceHeader,
   341  				}
   342  				state = waitingTrak
   343  
   344  			case "Opus":
   345  				if state != waitingCodec {
   346  					return nil, fmt.Errorf("unexpected box '%v'", h.BoxInfo.Type)
   347  				}
   348  				state = waitingDOps
   349  				return h.Expand()
   350  
   351  			case "dOps":
   352  				if state != waitingDOps {
   353  					return nil, fmt.Errorf("unexpected box '%v'", h.BoxInfo.Type)
   354  				}
   355  
   356  				box, _, err := h.ReadPayload()
   357  				if err != nil {
   358  					return nil, err
   359  				}
   360  				dops := box.(*mp4.DOps)
   361  
   362  				curTrack.Codec = &CodecOpus{
   363  					ChannelCount: int(dops.OutputChannelCount),
   364  				}
   365  				state = waitingTrak
   366  
   367  			case "mp4v":
   368  				if state != waitingCodec {
   369  					return nil, fmt.Errorf("unexpected box '%v'", h.BoxInfo.Type)
   370  				}
   371  
   372  				box, _, err := h.ReadPayload()
   373  				if err != nil {
   374  					return nil, err
   375  				}
   376  				mp4v := box.(*mp4.VisualSampleEntry)
   377  
   378  				width = int(mp4v.Width)
   379  				height = int(mp4v.Height)
   380  				state = waitingVideoEsds
   381  				return h.Expand()
   382  
   383  			case "mp4a":
   384  				if state != waitingCodec {
   385  					return nil, fmt.Errorf("unexpected box '%v'", h.BoxInfo.Type)
   386  				}
   387  
   388  				box, _, err := h.ReadPayload()
   389  				if err != nil {
   390  					return nil, err
   391  				}
   392  				mp4a := box.(*mp4.AudioSampleEntry)
   393  
   394  				sampleRate = int(mp4a.SampleRate / 65536)
   395  				channelCount = int(mp4a.ChannelCount)
   396  				state = waitingAudioEsds
   397  				return h.Expand()
   398  
   399  			case "esds":
   400  				box, _, err := h.ReadPayload()
   401  				if err != nil {
   402  					return nil, err
   403  				}
   404  				esds := box.(*mp4.Esds)
   405  
   406  				conf := esdsFindDecoderConf(esds.Descriptors)
   407  				if conf == nil {
   408  					return nil, fmt.Errorf("unable to find decoder config")
   409  				}
   410  
   411  				switch state {
   412  				case waitingVideoEsds:
   413  					switch conf.ObjectTypeIndication {
   414  					case objectTypeIndicationVisualISO14496part2:
   415  						spec := esdsFindDecoderSpecificInfo(esds.Descriptors)
   416  						if spec == nil {
   417  							return nil, fmt.Errorf("unable to find decoder specific info")
   418  						}
   419  
   420  						curTrack.Codec = &CodecMPEG4Video{
   421  							Config: spec,
   422  						}
   423  
   424  					case objectTypeIndicationVisualISO1318part2Main:
   425  						spec := esdsFindDecoderSpecificInfo(esds.Descriptors)
   426  						if spec == nil {
   427  							return nil, fmt.Errorf("unable to find decoder specific info")
   428  						}
   429  
   430  						curTrack.Codec = &CodecMPEG1Video{
   431  							Config: spec,
   432  						}
   433  
   434  					case objectTypeIndicationVisualISO10918part1:
   435  						curTrack.Codec = &CodecMJPEG{
   436  							Width:  width,
   437  							Height: height,
   438  						}
   439  
   440  					default:
   441  						return nil, fmt.Errorf("unsupported object type indication: 0x%.2x", conf.ObjectTypeIndication)
   442  					}
   443  
   444  					state = waitingTrak
   445  
   446  				case waitingAudioEsds:
   447  					switch conf.ObjectTypeIndication {
   448  					case objectTypeIndicationAudioISO14496part3:
   449  						spec := esdsFindDecoderSpecificInfo(esds.Descriptors)
   450  						if spec == nil {
   451  							return nil, fmt.Errorf("unable to find decoder specific info")
   452  						}
   453  
   454  						var c mpeg4audio.Config
   455  						err := c.Unmarshal(spec)
   456  						if err != nil {
   457  							return nil, fmt.Errorf("invalid MPEG-4 Audio configuration: %w", err)
   458  						}
   459  
   460  						curTrack.Codec = &CodecMPEG4Audio{
   461  							Config: c,
   462  						}
   463  
   464  					case objectTypeIndicationAudioISO11172part3:
   465  						curTrack.Codec = &CodecMPEG1Audio{
   466  							SampleRate:   sampleRate,
   467  							ChannelCount: channelCount,
   468  						}
   469  
   470  					default:
   471  						return nil, fmt.Errorf("unsupported object type indication: 0x%.2x", conf.ObjectTypeIndication)
   472  					}
   473  
   474  				default:
   475  					return nil, fmt.Errorf("unexpected box '%v'", h.BoxInfo.Type)
   476  				}
   477  
   478  				state = waitingTrak
   479  
   480  			case "ac-3":
   481  				if state != waitingCodec {
   482  					return nil, fmt.Errorf("unexpected box '%v'", h.BoxInfo.Type)
   483  				}
   484  
   485  				box, _, err := h.ReadPayload()
   486  				if err != nil {
   487  					return nil, err
   488  				}
   489  				ac3 := box.(*mp4.AudioSampleEntry)
   490  
   491  				sampleRate = int(ac3.SampleRate / 65536)
   492  				channelCount = int(ac3.ChannelCount)
   493  				state = waitingDac3
   494  				return h.Expand()
   495  
   496  			case "dac3":
   497  				if state != waitingDac3 {
   498  					return nil, fmt.Errorf("unexpected box '%v'", h.BoxInfo.Type)
   499  				}
   500  
   501  				box, _, err := h.ReadPayload()
   502  				if err != nil {
   503  					return nil, err
   504  				}
   505  				dac3 := box.(*mp4.Dac3)
   506  
   507  				curTrack.Codec = &CodecAC3{
   508  					SampleRate:   sampleRate,
   509  					ChannelCount: channelCount,
   510  					Fscod:        dac3.Fscod,
   511  					Bsid:         dac3.Bsid,
   512  					Bsmod:        dac3.Bsmod,
   513  					Acmod:        dac3.Acmod,
   514  					LfeOn:        dac3.LfeOn != 0,
   515  					BitRateCode:  dac3.BitRateCode,
   516  				}
   517  				state = waitingTrak
   518  
   519  			case "ipcm":
   520  				if state != waitingCodec {
   521  					return nil, fmt.Errorf("unexpected box '%v'", h.BoxInfo.Type)
   522  				}
   523  
   524  				box, _, err := h.ReadPayload()
   525  				if err != nil {
   526  					return nil, err
   527  				}
   528  				ac3 := box.(*mp4.AudioSampleEntry)
   529  
   530  				sampleRate = int(ac3.SampleRate / 65536)
   531  				channelCount = int(ac3.ChannelCount)
   532  				state = waitingPcmC
   533  				return h.Expand()
   534  
   535  			case "pcmC":
   536  				if state != waitingPcmC {
   537  					return nil, fmt.Errorf("unexpected box '%v'", h.BoxInfo.Type)
   538  				}
   539  
   540  				box, _, err := h.ReadPayload()
   541  				if err != nil {
   542  					return nil, err
   543  				}
   544  				pcmc := box.(*mp4.PcmC)
   545  
   546  				curTrack.Codec = &CodecLPCM{
   547  					LittleEndian: (pcmc.FormatFlags & 0x01) != 0,
   548  					BitDepth:     int(pcmc.PCMSampleSize),
   549  					SampleRate:   sampleRate,
   550  					ChannelCount: channelCount,
   551  				}
   552  				state = waitingTrak
   553  			}
   554  		}
   555  
   556  		return nil, nil
   557  	})
   558  	if err != nil {
   559  		return err
   560  	}
   561  
   562  	if state != waitingTrak {
   563  		return fmt.Errorf("parse error")
   564  	}
   565  
   566  	if len(i.Tracks) == 0 {
   567  		return fmt.Errorf("no tracks found")
   568  	}
   569  
   570  	return nil
   571  }
   572  
   573  // Marshal encodes a fMP4 initialization file.
   574  func (i *Init) Marshal(w io.WriteSeeker) error {
   575  	/*
   576  		|ftyp|
   577  		|moov|
   578  		|    |mvhd|
   579  		|    |trak|
   580  		|    |trak|
   581  		|    |....|
   582  		|    |mvex|
   583  		|    |    |trex|
   584  		|    |    |trex|
   585  		|    |    |....|
   586  	*/
   587  
   588  	mw := newMP4Writer(w)
   589  
   590  	_, err := mw.writeBox(&mp4.Ftyp{ // <ftyp/>
   591  		MajorBrand:   [4]byte{'m', 'p', '4', '2'},
   592  		MinorVersion: 1,
   593  		CompatibleBrands: []mp4.CompatibleBrandElem{
   594  			{CompatibleBrand: [4]byte{'m', 'p', '4', '1'}},
   595  			{CompatibleBrand: [4]byte{'m', 'p', '4', '2'}},
   596  			{CompatibleBrand: [4]byte{'i', 's', 'o', 'm'}},
   597  			{CompatibleBrand: [4]byte{'h', 'l', 's', 'f'}},
   598  		},
   599  	})
   600  	if err != nil {
   601  		return err
   602  	}
   603  
   604  	_, err = mw.writeBoxStart(&mp4.Moov{}) // <moov>
   605  	if err != nil {
   606  		return err
   607  	}
   608  
   609  	_, err = mw.writeBox(&mp4.Mvhd{ // <mvhd/>
   610  		Timescale:   1000,
   611  		Rate:        65536,
   612  		Volume:      256,
   613  		Matrix:      [9]int32{0x00010000, 0, 0, 0, 0x00010000, 0, 0, 0, 0x40000000},
   614  		NextTrackID: 4294967295,
   615  	})
   616  	if err != nil {
   617  		return err
   618  	}
   619  
   620  	for _, track := range i.Tracks {
   621  		err := track.marshal(mw)
   622  		if err != nil {
   623  			return err
   624  		}
   625  	}
   626  
   627  	_, err = mw.writeBoxStart(&mp4.Mvex{}) // <mvex>
   628  	if err != nil {
   629  		return err
   630  	}
   631  
   632  	for _, track := range i.Tracks {
   633  		_, err = mw.writeBox(&mp4.Trex{ // <trex/>
   634  			TrackID:                       uint32(track.ID),
   635  			DefaultSampleDescriptionIndex: 1,
   636  		})
   637  		if err != nil {
   638  			return err
   639  		}
   640  	}
   641  
   642  	err = mw.writeBoxEnd() // </mvex>
   643  	if err != nil {
   644  		return err
   645  	}
   646  
   647  	err = mw.writeBoxEnd() // </moov>
   648  	if err != nil {
   649  		return err
   650  	}
   651  
   652  	return nil
   653  }