From 2432a072da1055ab9ecab74dcb85dd646484a50c Mon Sep 17 00:00:00 2001 From: nareix Date: Mon, 27 Jun 2016 01:32:19 +0800 Subject: [PATCH] add publishing handle --- server.go | 158 ++++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 124 insertions(+), 34 deletions(-) diff --git a/server.go b/server.go index a351b00..8261bab 100644 --- a/server.go +++ b/server.go @@ -49,26 +49,40 @@ func DialTimeout(uri string, timeout time.Duration) (conn *Conn, err error) { } type Server struct { + Debug bool + DebugConn bool Addr string HandlePublish func(*Conn) HandlePlay func(*Conn) } func (self *Server) handleConn(conn *Conn) (err error) { - if err = conn.handshake(); err != nil { + if err = conn.handshakeServer(); err != nil { return } - if err = conn.determineType(); err != nil { - fmt.Println("rtmp: conn closed:", err) return } if conn.playing { if self.HandlePlay != nil { self.HandlePlay(conn) - conn.Close() } + } else if conn.publishing { + if self.HandlePublish != nil { + self.HandlePublish(conn) + } + + for { + conn.pollMsg() + if conn.msgtypeid == msgtypeidAudioMsg || conn.msgtypeid == msgtypeidVideoMsg { + break + } + } + } + + if err = conn.Close(); err != nil { + return } return @@ -96,7 +110,12 @@ func (self *Server) ListenAndServe() (err error) { return } + if self.Debug { + fmt.Println("rtmp: server: accepted") + } + conn := NewConn(netconn) + conn.Debug = self.DebugConn conn.isserver = true go self.handleConn(conn) } @@ -120,8 +139,6 @@ type Conn struct { writeMaxChunkSize int readMaxChunkSize int - lastcsid uint32 - lastcs *chunkStream csmap map[uint32]*chunkStream isserver bool @@ -152,8 +169,8 @@ func NewConn(netconn net.Conn) *Conn { conn.csmap = make(map[uint32]*chunkStream) conn.readMaxChunkSize = 128 conn.writeMaxChunkSize = 128 - conn.bufr = bufio.NewReaderSize(netconn, 4096) - conn.bufw = bufio.NewWriterSize(netconn, 4096) + conn.bufr = bufio.NewReaderSize(netconn, 2048) + conn.bufw = bufio.NewWriterSize(netconn, 2048) conn.br = pio.NewReader(conn.bufr) conn.bw = pio.NewWriter(conn.bufw) conn.intw = pio.NewWriter(nil) @@ -227,7 +244,7 @@ func (self *Conn) pollMsg() (err error) { } func (self *Conn) determineType() (err error) { - var connectpath, playpath string + var connectpath string // < connect("app") if err = self.pollCommand(); err != nil { @@ -290,13 +307,46 @@ func (self *Conn) determineType() (err error) { flvio.WriteAMF0Val(w, self.avmsgsid) // streamid=1 self.writeCommandMsgEnd(3, 0) - // < play("path") - case "play": + // < publish("path") + case "publish": + if self.Debug { + fmt.Println("rtmp: < publish") + } + if len(self.commandparams) < 1 { - err = fmt.Errorf("rtmp: play params invalid") + err = fmt.Errorf("rtmp: publish params invalid") return } - playpath, _ = self.commandparams[0].(string) + publishpath, _ := self.commandparams[0].(string) + + // > onStatus() + w := self.writeCommandMsgStart() + flvio.WriteAMF0Val(w, "onStatus") + flvio.WriteAMF0Val(w, self.commandtransid) + flvio.WriteAMF0Val(w, nil) + flvio.WriteAMF0Val(w, flvio.AMFMap{ + "level": "status", + "code": "NetStream.Publish.Start", + "description": "Start publishing", + }) + self.writeCommandMsgEnd(5, self.avmsgsid) + + self.Path = fmt.Sprintf("/%s/%s", connectpath, publishpath) + self.publishing = true + self.reading = true + return + + // < play("path") + case "play": + if self.Debug { + fmt.Println("rtmp: < play") + } + + if len(self.commandparams) < 1 { + err = fmt.Errorf("rtmp: command play params invalid") + return + } + playpath, _ := self.commandparams[0].(string) // > streamBegin(streamid) self.writeStreamBegin(self.avmsgsid) @@ -320,10 +370,6 @@ func (self *Conn) determineType() (err error) { flvio.WriteAMF0Val(w, true) self.writeDataMsgEnd(5, self.avmsgsid) - if self.Debug { - fmt.Println("rtmp: playing") - } - self.Path = fmt.Sprintf("/%s/%s", connectpath, playpath) self.playing = true self.writing = true @@ -592,6 +638,12 @@ func (self *Conn) ReadPacket() (pkt av.Packet, err error) { pkt.Data = tag.Data pkt.Idx = int8(self.audiostreamidx) break poll + + case msgtypeidUserControl: + + default: + err = fmt.Errorf("debug %d %v", self.msgtypeid, self.msgdata) + return } } @@ -601,7 +653,7 @@ func (self *Conn) ReadPacket() (pkt av.Packet, err error) { func (self *Conn) ReadHeader() (err error) { if !self.reading && !self.writing { - if err = self.handshake(); err != nil { + if err = self.handshakeClient(); err != nil { return } if err = self.connectPlay(); err != nil { @@ -897,6 +949,11 @@ func (self *Conn) writeChunks(csid uint32, timestamp uint32, msgtypeid uint8, ms if self.Debug { fmt.Printf("rtmp: write chunk msgdatalen=%d msgsid=%d\n", msgdatalen, msgsid) + b := []byte{} + for _, a := range msgdatav { + b = append(b, a...) + } + fmt.Print(hex.Dump(b)) } if err = self.bufw.Flush(); err != nil { @@ -932,15 +989,11 @@ func (self *Conn) readChunk() (err error) { csid = uint32(i)+64 } - var cs *chunkStream - if self.lastcs != nil && self.lastcsid == csid { - cs = self.lastcs - } else { + cs := self.csmap[csid] + if cs == nil { cs = &chunkStream{} self.csmap[csid] = cs } - self.lastcs = cs - self.lastcsid = csid var timestamp uint32 @@ -1085,8 +1138,8 @@ func (self *Conn) readChunk() (err error) { cs.msgdataleft -= uint32(size) if self.Debug { - fmt.Printf("rtmp: chunk csid=%d msgsid=%d msgtypeid=%d msghdrtype=%d len=%d left=%d\n", - csid, cs.msgsid, cs.msgtypeid, cs.msghdrtype, cs.msgdatalen, cs.msgdataleft) + fmt.Printf("rtmp: chunk msgsid=%d msgtypeid=%d msghdrtype=%d len=%d left=%d\n", + cs.msgsid, cs.msgtypeid, cs.msghdrtype, cs.msgdatalen, cs.msgdataleft) } if cs.msgdataleft == 0 { @@ -1245,14 +1298,6 @@ func hsCreateC1(p []byte) { copy(p[gap:], digest) } -func (self *Conn) handshake() (err error) { - if self.isserver { - return self.handshakeServer() - } else { - return self.handshakeClient() - } -} - func (self *Conn) handshakeClient() (err error) { var random [(1+1536*2)*2]byte @@ -1289,6 +1334,9 @@ func (self *Conn) handshakeClient() (err error) { } if S1[4] >= 3 { + // TODO + err = fmt.Errorf("rtmp: newstyle handshake unspported") + return } else { C2 = S1 } @@ -1305,6 +1353,48 @@ func (self *Conn) handshakeClient() (err error) { } func (self *Conn) handshakeServer() (err error) { + var random [(1+1536*2)*2]byte + + C0C1C2 := random[:1536*2+1] + C0 := C0C1C2[:1] + C1 := C0C1C2[1:1536+1] + C0C1 := C0C1C2[:1536+1] + C2 := C0C1C2[1536+1:] + + S0S1S2 := random[1536*2+1:] + S0 := S0S1S2[:1] + S1 := S0S1S2[1:1536+1] + //S0S1 := S0S1S2[:1536+1] + S2 := S0S1S2[1536+1:] + + // < C0C1 + if _, err = io.ReadFull(self.br, C0C1); err != nil { + return + } + if C0[0] != 3 { + err = fmt.Errorf("rtmp: handshake version=%d invalid", C0[0]) + return + } + + S0[0] = 3 + copy(S1[0:4], C1[0:4]) + rand.Read(S1[8:]) + copy(S2[0:4], C1[0:4]) + copy(S2[8:], C1[8:]) + + // > S0S1S2 + if _, err = self.bw.Write(S0S1S2); err != nil { + return + } + if err = self.bufw.Flush(); err != nil { + return + } + + // < C2 + if _, err = io.ReadFull(self.br, C2); err != nil { + return + } + return }