diff --git a/example/test.go b/example/test.go index 3a033fd..0fe5750 100644 --- a/example/test.go +++ b/example/test.go @@ -216,7 +216,7 @@ func main() { } w.PCR = sample.PCR bw := &bytes.Buffer{} - if err = ts.WritePES(bw, pes, sample.Data); err != nil { + if err = ts.WritePES(bw, pes, bytes.NewReader(sample.Data)); err != nil { return } if err = w.Write(bw.Bytes(), false); err != nil { @@ -253,5 +253,6 @@ func main() { } } } + } diff --git a/writer.go b/writer.go index c949603..3ec0d1f 100644 --- a/writer.go +++ b/writer.go @@ -274,11 +274,13 @@ func bswap32(v uint) uint { return (v>>24)|((v>>16)&0xff)<<8|((v>>8)&0xff)<<16|(v&0xff)<<24 } -func WritePES(w io.Writer, self PESHeader, data []byte) (err error) { +func WritePES(w io.Writer, self PESHeader, data io.ReadSeeker) (err error) { // http://dvd.sourceforge.net/dvdinfo/pes-hdr.html var pts_dts_flags, header_length, packet_length uint + dataLen := getSeekerLength(data) + // start code(24) 000001 // StreamId(8) // packet_length(16) @@ -315,7 +317,7 @@ func WritePES(w io.Writer, self PESHeader, data []byte) (err error) { if pts_dts_flags & DTS != 0 { header_length += 5 } - packet_length = 3+header_length+uint(len(data)) + packet_length = 3+header_length+uint(dataLen) if DebugWriter { fmt.Printf("pesw: packet_length=%d\n", packet_length) @@ -359,7 +361,7 @@ func WritePES(w io.Writer, self PESHeader, data []byte) (err error) { } // data - if _, err = w.Write(data); err != nil { + if _, err = io.Copy(w, data); err != nil { return } @@ -475,14 +477,124 @@ func WritePMT(w io.Writer, self PMT) (err error) { type SimpleH264Writer struct { W io.Writer - headerHasWritten bool + TimeScale int + + SPS []byte + PPS []byte + + tsw *TSWriter + pts uint64 + pcr uint64 + prepared bool } -func (self *SimpleH264Writer) WriteSample(data []byte) (err error) { +func (self *SimpleH264Writer) prepare() (err error) { + writePAT := func() (err error) { + w := &TSWriter{ + W: self.W, + PID: 0, + DisableHeaderPadding: true, + } + pat := PAT{ + Entries: []PATEntry{ + {ProgramNumber: 1, ProgramMapPID: 0x1000}, + }, + } + bw := &bytes.Buffer{} + if err = WritePAT(bw, pat); err != nil { + return + } + if err = w.Write(bw.Bytes(), false); err != nil { + return + } + return + } + + writePMT := func() (err error) { + w := &TSWriter{ + W: self.W, + PID: 0x1000, + DisableHeaderPadding: true, + } + pmt := PMT{ + PCRPID: 0x100, + ElementaryStreamInfos: []ElementaryStreamInfo{ + {StreamType: ElementaryStreamTypeH264, ElementaryPID: 0x100}, + }, + } + bw := &bytes.Buffer{} + if err = WritePMT(bw, pmt); err != nil { + return + } + if err = w.Write(bw.Bytes(), false); err != nil { + return + } + return + } + + if err = writePAT(); err != nil { + return + } + + if err = writePMT(); err != nil { + return + } + + self.tsw = &TSWriter{ + W: self.W, + PID: 0x100, + } + self.pts = PTS_HZ + self.pcr = PCR_HZ + return } -func (self *SimpleH264Writer) WriteNALU(data []byte) (err error) { +func (self *SimpleH264Writer) writeData(data io.ReadSeeker, duration int) (err error) { + pes := PESHeader{ + StreamId: StreamIdH264, + PTS: self.pts, + } + self.tsw.PCR = self.pcr + + self.pts += uint64(duration)*PTS_HZ/uint64(self.TimeScale) + self.pcr += uint64(duration)*PCR_HZ/uint64(self.TimeScale) + + bw := &bytes.Buffer{} + if err = WritePES(bw, pes, data); err != nil { + return + } + if err = self.tsw.Write(bw.Bytes(), false); err != nil { + return + } + return } +func (self *SimpleH264Writer) writeNALUs(nalus [][]byte, duration int) (err error) { + readers := []io.ReadSeeker{} + for _, nalu := range nalus { + startCode := bytes.NewReader([]byte{0,0,1}) + readers = append(readers, startCode) + readers = append(readers, bytes.NewReader(nalu)) + } + return self.writeData(&multiReadSeeker{readers: readers}, duration) +} + +func (self *SimpleH264Writer) WriteNALU(sync bool, duration int, nalu []byte) (err error) { + nalus := [][]byte{} + + if !self.prepared { + if err = self.prepare(); err != nil { + return + } + self.prepared = true + nalus = append(nalus, self.SPS) + nalus = append(nalus, self.PPS) + } + + nalus = append(nalus, nalu) + + return self.writeNALUs(nalus, duration) +} +