add message support

This commit is contained in:
nareix 2016-06-23 19:00:33 +08:00
parent 17f1aca838
commit 2385aafae2
2 changed files with 354 additions and 92 deletions

View File

@ -106,6 +106,7 @@ func parseChal(b []byte, peerKey []byte, key []byte) (dig []byte, err int) {
ver := b[5:9] ver := b[5:9]
l.Printf("handshake: epoch %v ver %v", epoch, ver) l.Printf("handshake: epoch %v ver %v", epoch, ver)
// random
var offs int var offs int
if offs = findDigest(b[1:], peerKey, 772); offs == -1 { if offs = findDigest(b[1:], peerKey, 772); offs == -1 {
if offs = findDigest(b[1:], peerKey, 8); offs == -1 { if offs = findDigest(b[1:], peerKey, 8); offs == -1 {
@ -123,7 +124,7 @@ func parseChal(b []byte, peerKey []byte, key []byte) (dig []byte, err int) {
func handShake(rw io.ReadWriter) { func handShake(rw io.ReadWriter) {
b := ReadBuf(rw, 1537) b := ReadBuf(rw, 1537) // C0+C1
l.Printf("handshake: got client chal") l.Printf("handshake: got client chal")
dig, err := parseChal(b, clientKey2, serverKey) dig, err := parseChal(b, clientKey2, serverKey)
if err != 0 { if err != 0 {
@ -132,14 +133,14 @@ func handShake(rw io.ReadWriter) {
createChal(b, serverVersion, serverKey2) createChal(b, serverVersion, serverKey2)
l.Printf("handshake: send server chal") l.Printf("handshake: send server chal")
rw.Write(b) rw.Write(b) // S0+S1
b = make([]byte, 1536) b = make([]byte, 1536)
createResp(b, dig) createResp(b, dig)
l.Printf("handshake: send server resp") l.Printf("handshake: send server resp")
rw.Write(b) rw.Write(b) // S2
b = ReadBuf(rw, 1536) b = ReadBuf(rw, 1536) // C2
l.Printf("handshake: got client resp") l.Printf("handshake: got client resp")
} }

437
new.go
View File

@ -8,6 +8,7 @@ import (
"encoding/hex" "encoding/hex"
"io" "io"
"github.com/nareix/pio" "github.com/nareix/pio"
"github.com/nareix/flv/flvio"
) )
type Publisher struct { type Publisher struct {
@ -23,16 +24,10 @@ type Server struct {
} }
func (self *Server) handleConn(conn *Conn) (err error) { func (self *Server) handleConn(conn *Conn) (err error) {
if err = conn.Handshake(); err != nil { if err = conn.determineType(); err != nil {
fmt.Println("rtmp: conn closed:", err)
return return
} }
for {
if err = conn.ReadChunk(); err != nil {
return
}
}
return return
} }
@ -58,13 +53,7 @@ func (self *Server) ListenAndServe() (err error) {
return return
} }
conn := &Conn{} conn := newConn(netconn)
conn.csmap = make(map[uint32]*chunkStream)
conn.maxChunkSize = 128
conn.bufr = bufio.NewReaderSize(netconn, 512)
conn.bufw = bufio.NewWriterSize(netconn, 512)
conn.br = pio.NewReader(conn.bufr)
conn.bw = pio.NewWriter(conn.bufw)
go self.handleConn(conn) go self.handleConn(conn)
} }
} }
@ -74,32 +63,271 @@ type Conn struct {
bw *pio.Writer bw *pio.Writer
bufr *bufio.Reader bufr *bufio.Reader
bufw *bufio.Writer bufw *bufio.Writer
intw *pio.Writer
maxChunkSize int writeMaxChunkSize int
readMaxChunkSize int
lastcsid uint32 lastcsid uint32
lastcs *chunkStream lastcs *chunkStream
csmap map[uint32]*chunkStream csmap map[uint32]*chunkStream
publishing, playing bool
gotcommand bool
command string
commandr *pio.Reader
commandobj flvio.AMFMap
commandtransid float64
gotmsg bool
msgdata []byte
msgtypeid uint8
}
func newConn(netconn net.Conn) *Conn {
conn := &Conn{}
conn.csmap = make(map[uint32]*chunkStream)
conn.readMaxChunkSize = 128
conn.writeMaxChunkSize = 128
conn.bufr = bufio.NewReaderSize(netconn, 512)
conn.bufw = bufio.NewWriterSize(netconn, 512)
conn.br = pio.NewReader(conn.bufr)
conn.bw = pio.NewWriter(conn.bufw)
conn.intw = pio.NewWriter(nil)
return conn
} }
type chunkStream struct { type chunkStream struct {
TimestampNow uint32 timenow uint32
TimestampDelta uint32 timedelta uint32
HasTimestampExt bool hastimeext bool
Msgsid uint32 msgsid uint32
Msgtypeid uint8 msgtypeid uint8
Msglen uint32 msgdatalen uint32
Msgleft uint32 msgdataleft uint32
Msghdrtype uint8 msghdrtype uint8
Msgdata []byte msgdata []byte
} }
func (self *chunkStream) Start() { func (self *chunkStream) Start() {
self.Msgleft = self.Msglen self.msgdataleft = self.msgdatalen
self.Msgdata = make([]byte, self.Msglen) self.msgdata = make([]byte, self.msgdatalen)
} }
func (self *Conn) ReadChunk() ( err error) { const (
msgtypeidUserControl = 4
msgtypeidWindowAckSize = 5
msgtypeidSetPeerBandwidth = 6
msgtypeidSetChunkSize = 1
msgtypeidCommandMsgAMF0 = 20
msgtypeidCommandMsgAMF3 = 17
)
const (
eventtypeStreamBegin = 0
)
func (self *Conn) pollCommand() (err error) {
for {
if err = self.readChunk(); err != nil {
return
}
if self.gotcommand {
self.gotcommand = false
return
}
}
}
func (self *Conn) pollMsg() (err error) {
for {
if err = self.readChunk(); err != nil {
return
}
if self.gotmsg {
self.gotmsg = false
return
}
}
}
func (self *Conn) determineType() (err error) {
if err = self.handshake(); err != nil {
return
}
// < connect
if err = self.pollCommand(); err != nil {
return
}
if self.command != "connect" {
err = fmt.Errorf("rtmp: first command is not connect")
return
}
// > WindowAckSize
if err = self.writeWindowAckSize(5000000); err != nil {
return
}
// > SetPeerBandwidth
if err = self.writeSetPeerBandwidth(5000000, 2); err != nil {
return
}
// > SetChunkSize
if err = self.writeSetChunkSize(uint32(self.writeMaxChunkSize)); err != nil {
return
}
// > _result("NetConnection.Connect.Success")
w := self.writeCommandMsgStart()
flvio.WriteAMF0Val(w, "_result")
flvio.WriteAMF0Val(w, self.commandtransid)
flvio.WriteAMF0Val(w, flvio.AMFMap{
"fmtVer": "FMS/3,0,1,123",
"capabilities": 31,
})
flvio.WriteAMF0Val(w, flvio.AMFMap{
"level": "status",
"code": "NetConnection.Connect.Success",
"description": "Connection Success.",
"objectEncoding": 0,
})
self.writeCommandMsgEnd()
if err = self.pollCommand(); err != nil {
return
}
if err = self.pollCommand(); err != nil {
return
}
return
}
func (self *Conn) writeSetChunkSize(size uint32) (err error) {
w := self.writeProtoCtrlMsgStart()
w.WriteU32BE(size)
return self.writeProtoCtrlMsgEnd(msgtypeidSetChunkSize)
}
func (self *Conn) writeWindowAckSize(size uint32) (err error) {
w := self.writeProtoCtrlMsgStart()
w.WriteU32BE(size)
return self.writeProtoCtrlMsgEnd(msgtypeidWindowAckSize)
}
func (self *Conn) writeSetPeerBandwidth(acksize uint32, limittype uint8) (err error) {
w := self.writeProtoCtrlMsgStart()
w.WriteU32BE(acksize)
w.WriteU8(limittype)
return self.writeProtoCtrlMsgEnd(msgtypeidSetPeerBandwidth)
}
func (self *Conn) writeProtoCtrlMsgStart() *pio.Writer {
self.intw.SaveToVecOn()
return self.intw
}
func (self *Conn) writeProtoCtrlMsgEnd(msgtypeid uint8) (err error) {
msgdatav := self.intw.SaveToVecOff()
return self.writeChunks(2, 0, msgtypeid, 0, msgdatav)
}
func (self *Conn) writeCommandMsgStart() *pio.Writer {
self.intw.SaveToVecOn()
return self.intw
}
func (self *Conn) writeCommandMsgEnd() (err error) {
msgdatav := self.intw.SaveToVecOff()
return self.writeChunks(3, 0, msgtypeidCommandMsgAMF0, 0, msgdatav)
}
func (self *Conn) writeUserControlMsgStart(eventtype uint16) *pio.Writer {
self.intw.SaveToVecOn()
self.intw.WriteU16BE(eventtype)
return self.intw
}
func (self *Conn) writeUserControlMsgEnd() (err error) {
msgdatav := self.intw.SaveToVecOff()
return self.writeChunks(2, 0, msgtypeidUserControl, 0, msgdatav)
}
func (self *Conn) writeStreamBegin(msgcsid uint32) (err error) {
w := self.writeUserControlMsgStart(eventtypeStreamBegin)
w.WriteU32BE(msgcsid)
return self.writeUserControlMsgEnd()
}
func (self *Conn) writeChunks(csid uint32, timestamp uint32, msgtypeid uint8, msgcsid uint32, msgdatav [][]byte) (err error) {
msgdatalen := pio.VecLen(msgdatav)
// [Type 0][Type 3][Type 3][Type 3]
// 0 1 2 3
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | timestamp |message length |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | message length (cont) |message type id| msg stream id |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | message stream id (cont) |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
//
// Figure 9 Chunk Message Header Type 0
if err = self.bw.WriteU8(byte(csid)&0x3f); err != nil {
return
}
if err = self.bw.WriteU24BE(timestamp); err != nil {
return
}
if err = self.bw.WriteU24BE(uint32(msgdatalen)); err != nil {
return
}
if err = self.bw.WriteU8(msgtypeid); err != nil {
return
}
if err = self.bw.WriteU32BE(msgcsid); err != nil {
return
}
msgdataoff := 0
for {
size := msgdatalen - msgdataoff
if size > self.writeMaxChunkSize {
size = self.writeMaxChunkSize
}
write := pio.VecSlice(msgdatav, msgdataoff, msgdataoff+size)
for _, b := range write {
if _, err = self.bw.Write(b); err != nil {
return
}
}
msgdataoff += size
if msgdataoff == msgdatalen {
break
}
// Type 3
if err = self.bw.WriteU8(byte(csid)&0x3f|3<<6); err != nil {
return
}
}
fmt.Printf("rtmp: write chunk msgdatalen=%d\n", msgdatalen)
if err = self.bufw.Flush(); err != nil {
return
}
return
}
func (self *Conn) readChunk() (err error) {
var msghdrtype uint8 var msghdrtype uint8
var csid uint32 var csid uint32
var header uint8 var header uint8
@ -150,8 +378,8 @@ func (self *Conn) ReadChunk() ( err error) {
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// //
// Figure 9 Chunk Message Header Type 0 // Figure 9 Chunk Message Header Type 0
if cs.Msgleft != 0 { if cs.msgdataleft != 0 {
err = fmt.Errorf("rtmp: chunk msgleft=%d invalid", cs.Msgleft) err = fmt.Errorf("rtmp: chunk msgdataleft=%d invalid", cs.msgdataleft)
return return
} }
var h[]byte var h[]byte
@ -159,19 +387,19 @@ func (self *Conn) ReadChunk() ( err error) {
return return
} }
timestamp = pio.GetU24BE(h[0:3]) timestamp = pio.GetU24BE(h[0:3])
cs.Msghdrtype = msghdrtype cs.msghdrtype = msghdrtype
cs.Msglen = pio.GetU24BE(h[3:6]) cs.msgdatalen = pio.GetU24BE(h[3:6])
cs.Msgtypeid = h[6] cs.msgtypeid = h[6]
cs.Msgsid = pio.GetU32BE(h[7:11]) cs.msgsid = pio.GetU32BE(h[7:11])
if timestamp == 0xffffff { if timestamp == 0xffffff {
if timestamp, err = self.br.ReadU32BE(); err != nil { if timestamp, err = self.br.ReadU32BE(); err != nil {
return return
} }
cs.HasTimestampExt = true cs.hastimeext = true
} else { } else {
cs.HasTimestampExt = false cs.hastimeext = false
} }
cs.TimestampNow = timestamp cs.timenow = timestamp
cs.Start() cs.Start()
case 1: case 1:
@ -184,8 +412,8 @@ func (self *Conn) ReadChunk() ( err error) {
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// //
// Figure 10 Chunk Message Header Type 1 // Figure 10 Chunk Message Header Type 1
if cs.Msgleft != 0 { if cs.msgdataleft != 0 {
err = fmt.Errorf("rtmp: chunk msgleft=%d invalid", cs.Msgleft) err = fmt.Errorf("rtmp: chunk msgdataleft=%d invalid", cs.msgdataleft)
return return
} }
var h[]byte var h[]byte
@ -193,19 +421,19 @@ func (self *Conn) ReadChunk() ( err error) {
return return
} }
timestamp = pio.GetU24BE(h[0:3]) timestamp = pio.GetU24BE(h[0:3])
cs.Msghdrtype = msghdrtype cs.msghdrtype = msghdrtype
cs.Msglen = pio.GetU24BE(h[3:6]) cs.msgdatalen = pio.GetU24BE(h[3:6])
cs.Msgtypeid = h[6] cs.msgtypeid = h[6]
if timestamp == 0xffffff { if timestamp == 0xffffff {
if timestamp, err = self.br.ReadU32BE(); err != nil { if timestamp, err = self.br.ReadU32BE(); err != nil {
return return
} }
cs.HasTimestampExt = true cs.hastimeext = true
} else { } else {
cs.HasTimestampExt = false cs.hastimeext = false
} }
cs.TimestampDelta = timestamp cs.timedelta = timestamp
cs.TimestampNow += timestamp cs.timenow += timestamp
cs.Start() cs.Start()
case 2: case 2:
@ -216,47 +444,47 @@ func (self *Conn) ReadChunk() ( err error) {
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// //
// Figure 11 Chunk Message Header Type 2 // Figure 11 Chunk Message Header Type 2
if cs.Msgleft != 0 { if cs.msgdataleft != 0 {
err = fmt.Errorf("rtmp: chunk msgleft=%d invalid", cs.Msgleft) err = fmt.Errorf("rtmp: chunk msgdataleft=%d invalid", cs.msgdataleft)
return return
} }
var h[]byte var h[]byte
if h, err = self.br.ReadBytes(3); err != nil { if h, err = self.br.ReadBytes(3); err != nil {
return return
} }
cs.Msghdrtype = msghdrtype cs.msghdrtype = msghdrtype
timestamp = pio.GetU24BE(h[0:3]) timestamp = pio.GetU24BE(h[0:3])
if timestamp == 0xffffff { if timestamp == 0xffffff {
if timestamp, err = self.br.ReadU32BE(); err != nil { if timestamp, err = self.br.ReadU32BE(); err != nil {
return return
} }
cs.HasTimestampExt = true cs.hastimeext = true
} else { } else {
cs.HasTimestampExt = false cs.hastimeext = false
} }
cs.TimestampDelta = timestamp cs.timedelta = timestamp
cs.TimestampNow += timestamp cs.timenow += timestamp
cs.Start() cs.Start()
case 3: case 3:
if cs.Msgleft == 0 { if cs.msgdataleft == 0 {
switch cs.Msghdrtype { switch cs.msghdrtype {
case 0: case 0:
if cs.HasTimestampExt { if cs.hastimeext {
if timestamp, err = self.br.ReadU32BE(); err != nil { if timestamp, err = self.br.ReadU32BE(); err != nil {
return return
} }
cs.TimestampNow = timestamp cs.timenow = timestamp
} }
case 1, 2: case 1, 2:
if cs.HasTimestampExt { if cs.hastimeext {
if timestamp, err = self.br.ReadU32BE(); err != nil { if timestamp, err = self.br.ReadU32BE(); err != nil {
return return
} }
} else { } else {
timestamp = cs.TimestampDelta timestamp = cs.timedelta
} }
cs.TimestampNow += timestamp cs.timenow += timestamp
} }
cs.Start() cs.Start()
} }
@ -266,36 +494,75 @@ func (self *Conn) ReadChunk() ( err error) {
return return
} }
size := int(cs.Msgleft) size := int(cs.msgdataleft)
if size > self.maxChunkSize { if size > self.readMaxChunkSize {
size = self.maxChunkSize size = self.readMaxChunkSize
} }
off := cs.Msglen-cs.Msgleft off := cs.msgdatalen-cs.msgdataleft
buf := cs.Msgdata[off:int(off)+size] buf := cs.msgdata[off:int(off)+size]
if _, err = io.ReadFull(self.br, buf); err != nil { if _, err = io.ReadFull(self.br, buf); err != nil {
return return
} }
cs.Msgleft -= uint32(size) cs.msgdataleft -= uint32(size)
if true { if true {
fmt.Printf("rtmp: chunk csid=%d msgsid=%d msgtypeid=%d msghdrtype=%d len=%d left=%d\n", fmt.Printf("rtmp: chunk csid=%d msgsid=%d msgtypeid=%d msghdrtype=%d len=%d left=%d\n",
csid, cs.Msgsid, cs.Msgtypeid, cs.Msghdrtype, cs.Msglen, cs.Msgleft) csid, cs.msgsid, cs.msgtypeid, cs.msghdrtype, cs.msgdatalen, cs.msgdataleft)
} }
if cs.Msgleft == 0 { if cs.msgdataleft == 0 {
if true { if true {
fmt.Println("rtmp: chunk data") fmt.Println("rtmp: chunk data")
fmt.Print(hex.Dump(cs.Msgdata)) fmt.Print(hex.Dump(cs.msgdata))
fmt.Printf("%x\n", cs.Msgdata) fmt.Printf("%x\n", cs.msgdata)
}
if err = self.handleMsg(cs.msgtypeid, cs.msgdata); err != nil {
return
} }
} }
return return
} }
func (self *Conn) Handshake() (err error) { func (self *Conn) handleMsg(msgtypeid uint8, msgdata []byte) (err error) {
// C0 switch msgtypeid {
case msgtypeidCommandMsgAMF0:
r := pio.NewReaderBytes(msgdata)
command, _ := flvio.ReadAMF0Val(r)
commandtransid, _ := flvio.ReadAMF0Val(r)
commandobj, _ := flvio.ReadAMF0Val(r)
var ok bool
if self.command, ok = command.(string); !ok {
err = fmt.Errorf("rtmp: CommandMsgAMF0 command is not string")
return
}
self.commandobj, _ = commandobj.(flvio.AMFMap)
self.commandtransid, _ = commandtransid.(float64)
self.commandr = r
self.gotcommand = true
case msgtypeidSetPeerBandwidth:
case msgtypeidSetChunkSize:
case msgtypeidWindowAckSize:
self.msgdata = msgdata
self.msgtypeid = msgtypeid
self.gotmsg = true
}
return
}
func (self *Conn) handshake() (err error) {
var time uint32
var version uint8 var version uint8
random := make([]byte, 1528)
// C0
if version, err = self.br.ReadU8(); err != nil { if version, err = self.br.ReadU8(); err != nil {
return return
} }
@ -303,14 +570,21 @@ func (self *Conn) Handshake() (err error) {
err = fmt.Errorf("rtmp: handshake c0: version=%d invalid", version) err = fmt.Errorf("rtmp: handshake c0: version=%d invalid", version)
return return
} }
// C1
if time, err = self.br.ReadU32BE(); err != nil {
return
}
if _, err = self.br.ReadU32BE(); err != nil {
return
}
if _, err = io.ReadFull(self.br, random); err != nil {
return
}
// S0 // S0
if err = self.bw.WriteU8(0x3); err != nil { if err = self.bw.WriteU8(0x3); err != nil {
return return
} }
random := make([]byte, 1528)
// S1 // S1
if err = self.bw.WriteU32BE(0); err != nil { if err = self.bw.WriteU32BE(0); err != nil {
return return
@ -324,19 +598,6 @@ func (self *Conn) Handshake() (err error) {
if err = self.bufw.Flush(); err != nil { if err = self.bufw.Flush(); err != nil {
return return
} }
// C1
var time uint32
if time, err = self.br.ReadU32BE(); err != nil {
return
}
if _, err = self.br.ReadU32BE(); err != nil {
return
}
if _, err = io.ReadFull(self.br, random); err != nil {
return
}
// S2 // S2
if err = self.bw.WriteU32BE(0); err != nil { if err = self.bw.WriteU32BE(0); err != nil {
return return