Skip to content

Commit be8b1c8

Browse files
Add ability to preserve PAT/PMT information (#74)
* Add ability to preserve PAT/PMT information * revert comment deletion * add test data
1 parent f593538 commit be8b1c8

3 files changed

Lines changed: 165 additions & 5 deletions

File tree

muxer.go

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ type Muxer struct {
2929
packetSize int
3030
tablesRetransmitPeriod int // period in PES packets
3131

32+
transportStreamID uint16
33+
pmtPID uint16
34+
3235
pm *programMap // pid -> programNumber
3336
pmUpdated bool
3437
pmt PMTData
@@ -68,6 +71,22 @@ func MuxerOptTablesRetransmitPeriod(newPeriod int) func(*Muxer) {
6871
}
6972
}
7073

74+
// WithTransportStreamID sets the transport stream ID written into PAT.
75+
// Default is 0.
76+
func WithTransportStreamID(id uint16) func(*Muxer) {
77+
return func(m *Muxer) {
78+
m.transportStreamID = id
79+
}
80+
}
81+
82+
// WithPMTPID sets the PID used for PMT packets and advertised in PAT.
83+
// Default is 0x1000.
84+
func WithPMTPID(pid uint16) func(*Muxer) {
85+
return func(m *Muxer) {
86+
m.pmtPID = pid
87+
}
88+
}
89+
7190
// TODO MuxerOptAutodetectPCRPID selecting first video PID for each PMT, falling back to first audio, falling back to any other
7291

7392
func NewMuxer(ctx context.Context, w io.Writer, opts ...func(*Muxer)) *Muxer {
@@ -78,6 +97,8 @@ func NewMuxer(ctx context.Context, w io.Writer, opts ...func(*Muxer)) *Muxer {
7897
packetSize: MpegTsPacketSize, // no 192-byte packet support yet
7998
tablesRetransmitPeriod: 40,
8099

100+
pmtPID: pmtStartPID,
101+
81102
pm: newProgramMap(),
82103
pmt: PMTData{
83104
ElementaryStreams: []*PMTElementaryStream{},
@@ -97,14 +118,14 @@ func NewMuxer(ctx context.Context, w io.Writer, opts ...func(*Muxer)) *Muxer {
97118
m.bufWriter = astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: &m.buf})
98119
m.bitsWriter = astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: m.w})
99120

100-
// TODO multiple programs support
101-
m.pm.setUnlocked(pmtStartPID, programNumberStart)
102-
m.pmUpdated = true
103-
104121
for _, opt := range opts {
105122
opt(m)
106123
}
107124

125+
// TODO multiple programs support
126+
m.pm.setUnlocked(m.pmtPID, programNumberStart)
127+
m.pmUpdated = true
128+
108129
// to output tables at the very start
109130
m.tablesRetransmitCounter = m.tablesRetransmitPeriod
110131

@@ -322,6 +343,7 @@ func (m *Muxer) WriteTables() (int, error) {
322343

323344
func (m *Muxer) generatePAT() error {
324345
d := m.pm.toPATDataUnlocked()
346+
d.TransportStreamID = m.transportStreamID
325347

326348
versionNumber := m.patVersion.get()
327349
if m.pmUpdated {
@@ -432,7 +454,7 @@ func (m *Muxer) generatePMT() error {
432454
Header: PacketHeader{
433455
HasPayload: true,
434456
PayloadUnitStartIndicator: true,
435-
PID: pmtStartPID, // FIXME multiple programs support
457+
PID: m.pmtPID, // FIXME multiple programs support
436458
ContinuityCounter: uint8(m.pmtCC.inc()),
437459
},
438460
Payload: m.buf.Bytes(),

roundtrip_test.go

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
package astits
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"errors"
7+
"os"
8+
"testing"
9+
10+
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
type pesRecord struct {
15+
pid uint16
16+
pes *PESData
17+
af *PacketAdaptationField
18+
}
19+
20+
func TestRoundTrip(t *testing.T) {
21+
originalBytes, err := os.ReadFile("testdata/ts/silent_audio.ts")
22+
require.NoError(t, err)
23+
24+
// Phase 1: Demux the original TS file
25+
dmx := NewDemuxer(context.Background(), bytes.NewReader(originalBytes), DemuxerOptPacketSize(MpegTsPacketSize))
26+
27+
var originalPAT *PATData
28+
var originalPMT *PMTData
29+
var originalPMTPID uint16 = 0xFFFF
30+
31+
for {
32+
d, err := dmx.NextData()
33+
if errors.Is(err, ErrNoMorePackets) {
34+
break
35+
}
36+
require.NoError(t, err)
37+
38+
if d.PAT != nil {
39+
originalPAT = d.PAT
40+
originalPMTPID = d.PAT.Programs[0].ProgramMapID
41+
}
42+
if d.PMT != nil {
43+
originalPMT = d.PMT
44+
}
45+
46+
if originalPMT != nil && originalPAT != nil {
47+
break
48+
}
49+
}
50+
require.NotNil(t, originalPAT)
51+
require.NotNil(t, originalPMT)
52+
require.NotEqual(t, 0xFFFF, originalPMTPID)
53+
54+
// Phase 2: Mux everything back into a new TS stream, preserving PAT/PMT identifiers
55+
var buf bytes.Buffer
56+
muxer := NewMuxer(context.Background(), &buf,
57+
WithTransportStreamID(originalPAT.TransportStreamID),
58+
WithPMTPID(originalPMTPID),
59+
)
60+
61+
for _, es := range originalPMT.ElementaryStreams {
62+
err := muxer.AddElementaryStream(PMTElementaryStream{
63+
ElementaryPID: es.ElementaryPID,
64+
StreamType: es.StreamType,
65+
ElementaryStreamDescriptors: es.ElementaryStreamDescriptors,
66+
})
67+
require.NoError(t, err)
68+
}
69+
muxer.SetPCRPID(originalPMT.PCRPID)
70+
muxer.pmt.ProgramDescriptors = originalPMT.ProgramDescriptors
71+
_, err = muxer.WriteTables()
72+
require.NoError(t, err)
73+
74+
// Phase 3: Demux the round-tripped output
75+
dmx2 := NewDemuxer(context.Background(), bytes.NewReader(buf.Bytes()), DemuxerOptPacketSize(MpegTsPacketSize))
76+
77+
var rtPAT *PATData
78+
var rtPMT *PMTData
79+
80+
for {
81+
d, err := dmx2.NextData()
82+
if errors.Is(err, ErrNoMorePackets) {
83+
break
84+
}
85+
require.NoError(t, err)
86+
87+
if d.PAT != nil {
88+
rtPAT = d.PAT
89+
}
90+
if d.PMT != nil {
91+
rtPMT = d.PMT
92+
}
93+
94+
if rtPAT != nil && rtPMT != nil {
95+
break
96+
}
97+
}
98+
require.NotNil(t, rtPAT)
99+
require.NotNil(t, rtPMT)
100+
101+
// Phase 4: Validate round-trip preserved all meaningful information
102+
// --- PAT ---
103+
assert.Equal(t, originalPAT.TransportStreamID, rtPAT.TransportStreamID, "PAT TransportStreamID mismatch")
104+
require.Equal(t, len(originalPAT.Programs), len(rtPAT.Programs), "PAT program count mismatch")
105+
for i, origProg := range originalPAT.Programs {
106+
assert.Equalf(t, origProg.ProgramNumber, rtPAT.Programs[i].ProgramNumber,
107+
"PAT Programs[%d].ProgramNumber mismatch", i)
108+
assert.Equalf(t, origProg.ProgramMapID, rtPAT.Programs[i].ProgramMapID,
109+
"PAT Programs[%d].ProgramMapID mismatch", i)
110+
}
111+
112+
// --- PMT ---
113+
assert.Equal(t, originalPMT.PCRPID, rtPMT.PCRPID)
114+
assert.Equal(t, originalPMT.ProgramNumber, rtPMT.ProgramNumber)
115+
require.Equal(t, len(originalPMT.ProgramDescriptors), len(rtPMT.ProgramDescriptors))
116+
for i, desc := range originalPMT.ProgramDescriptors {
117+
assert.Equalf(t, desc.Tag, rtPMT.ProgramDescriptors[i].Tag,
118+
"PMT ProgramDescriptors[%d].Tag mismatch", i)
119+
assert.Equalf(t, desc.Length, rtPMT.ProgramDescriptors[i].Length,
120+
"PMT ProgramDescriptors[%d].Length mismatch", i)
121+
}
122+
require.Equal(t, len(originalPMT.ElementaryStreams), len(rtPMT.ElementaryStreams))
123+
for i, es := range originalPMT.ElementaryStreams {
124+
rtES := rtPMT.ElementaryStreams[i]
125+
assert.Equalf(t, es.ElementaryPID, rtES.ElementaryPID,
126+
"PMT ElementaryStreams[%d].ElementaryPID mismatch", i)
127+
assert.Equalf(t, es.StreamType, rtES.StreamType,
128+
"PMT ElementaryStreams[%d].StreamType mismatch", i)
129+
require.Equalf(t, len(es.ElementaryStreamDescriptors), len(rtES.ElementaryStreamDescriptors),
130+
"PMT ElementaryStreams[%d].ElementaryStreamDescriptors count mismatch", i)
131+
for j, desc := range es.ElementaryStreamDescriptors {
132+
assert.Equalf(t, desc.Tag, rtES.ElementaryStreamDescriptors[j].Tag,
133+
"PMT ElementaryStreams[%d].Descriptors[%d].Tag mismatch", i, j)
134+
assert.Equalf(t, desc.Length, rtES.ElementaryStreamDescriptors[j].Length,
135+
"PMT ElementaryStreams[%d].Descriptors[%d].Length mismatch", i, j)
136+
}
137+
}
138+
}

testdata/ts/silent_audio.ts

752 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)