React on WindowAckSize, set PeerBandwidth hard

The SetPeerBandwidth command is now sent with the Hard(0) options
in order to force this value.

The WindowAckSize command from the peer is now respected if it is
not part of the handshake. Before this lead to no window size set
and therefore no ACKs sent to the peer. For clients that expect
such ACK (like Blackmagic ATEM Mini Pro) this caused aborting the
connection.
This commit is contained in:
Ingo Oppermann 2022-07-28 18:13:04 +02:00
parent 2102a8289c
commit 7098ea1efd

View File

@ -10,16 +10,17 @@ import (
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
"github.com/datarhei/joy4/utils/bits/pio"
"github.com/datarhei/joy4/av"
"github.com/datarhei/joy4/av/avutil"
"github.com/datarhei/joy4/format/flv"
"github.com/datarhei/joy4/format/flv/flvio"
"io" "io"
"net" "net"
"net/url" "net/url"
"strings" "strings"
"time" "time"
"github.com/datarhei/joy4/av"
"github.com/datarhei/joy4/av/avutil"
"github.com/datarhei/joy4/format/flv"
"github.com/datarhei/joy4/format/flv/flvio"
"github.com/datarhei/joy4/utils/bits/pio"
) )
var Debug bool var Debug bool
@ -55,7 +56,7 @@ func DialTimeout(uri string, timeout time.Duration) (conn *Conn, err error) {
return return
} }
var ErrServerClosed = errors.New("rtmp: Server closed") var ErrServerClosed = errors.New("server closed")
type Server struct { type Server struct {
Addr string Addr string
@ -98,7 +99,7 @@ func (self *Server) ListenAndServe() error {
listener, err := net.Listen("tcp", addr) listener, err := net.Listen("tcp", addr)
if err != nil { if err != nil {
return fmt.Errorf("rtmp: %w", err) return err
} }
return self.Serve(listener) return self.Serve(listener)
@ -112,7 +113,7 @@ func (self *Server) ListenAndServeTLS(certFile, keyFile string) error {
listener, err := net.Listen("tcp", addr) listener, err := net.Listen("tcp", addr)
if err != nil { if err != nil {
return fmt.Errorf("rtmp: %w", err) return err
} }
return self.ServeTLS(listener, certFile, keyFile) return self.ServeTLS(listener, certFile, keyFile)
@ -135,7 +136,7 @@ func (self *Server) ServeTLS(listener net.Listener, certFile, keyFile string) er
var err error var err error
config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
if err != nil { if err != nil {
return fmt.Errorf("rtmp: %w", err) return err
} }
} }
@ -179,8 +180,6 @@ func (self *Server) Serve(listener net.Listener) error {
} }
}() }()
} }
return nil
} }
func (self *Server) Close() { func (self *Server) Close() {
@ -429,11 +428,11 @@ func (self *Conn) writeBasicConf() (err error) {
return return
} }
// > WindowAckSize // > WindowAckSize
if err = self.writeWindowAckSize(5000000); err != nil { if err = self.writeWindowAckSize(1024 * 1024 * 2); err != nil {
return return
} }
// > SetPeerBandwidth // > SetPeerBandwidth
if err = self.writeSetPeerBandwidth(5000000, 2); err != nil { if err = self.writeSetPeerBandwidth(1024*1024*2, 0); err != nil {
return return
} }
return return
@ -447,18 +446,18 @@ func (self *Conn) readConnect() (err error) {
return return
} }
if self.commandname != "connect" { if self.commandname != "connect" {
err = fmt.Errorf("rtmp: first command is not connect") err = fmt.Errorf("first command is not connect")
return return
} }
if self.commandobj == nil { if self.commandobj == nil {
err = fmt.Errorf("rtmp: connect command params invalid") err = fmt.Errorf("connect command params invalid")
return return
} }
var ok bool var ok bool
var _app, _tcurl interface{} var _app, _tcurl interface{}
if _app, ok = self.commandobj["app"]; !ok { if _app, ok = self.commandobj["app"]; !ok {
err = fmt.Errorf("rtmp: `connect` params missing `app`") err = fmt.Errorf("the `connect` params missing `app`")
return return
} }
connectpath, _ = _app.(string) connectpath, _ = _app.(string)
@ -521,7 +520,7 @@ func (self *Conn) readConnect() (err error) {
} }
if len(self.commandparams) < 1 { if len(self.commandparams) < 1 {
err = fmt.Errorf("rtmp: publish params invalid") err = fmt.Errorf("publish params invalid")
return return
} }
publishpath, _ := self.commandparams[0].(string) publishpath, _ := self.commandparams[0].(string)
@ -547,7 +546,7 @@ func (self *Conn) readConnect() (err error) {
} }
if cberr != nil { if cberr != nil {
err = fmt.Errorf("rtmp: OnPlayOrPublish check failed") err = fmt.Errorf("OnPlayOrPublish check failed")
return return
} }
@ -564,7 +563,7 @@ func (self *Conn) readConnect() (err error) {
} }
if len(self.commandparams) < 1 { if len(self.commandparams) < 1 {
err = fmt.Errorf("rtmp: command play params invalid") err = fmt.Errorf("command play params invalid")
return return
} }
playpath, _ := self.commandparams[0].(string) playpath, _ := self.commandparams[0].(string)
@ -606,8 +605,6 @@ func (self *Conn) readConnect() (err error) {
} }
} }
return
} }
func (self *Conn) checkConnectResult() (ok bool, errmsg string) { func (self *Conn) checkConnectResult() (ok bool, errmsg string) {
@ -657,7 +654,7 @@ func (self *Conn) probe() (err error) {
} }
if err = self.prober.PushTag(tag, int32(self.timestamp)); err != nil { if err = self.prober.PushTag(tag, int32(self.timestamp)); err != nil {
if Debug { if Debug {
fmt.Printf("error probing tag: %s\n", err.Error()) fmt.Printf("rtmp: error probing tag: %s\n", err.Error())
} }
} }
} }
@ -705,7 +702,7 @@ func (self *Conn) writeConnect(path string) (err error) {
var ok bool var ok bool
var errmsg string var errmsg string
if ok, errmsg = self.checkConnectResult(); !ok { if ok, errmsg = self.checkConnectResult(); !ok {
err = fmt.Errorf("rtmp: command connect failed: %s", errmsg) err = fmt.Errorf("command connect failed: %s", errmsg)
return return
} }
if Debug { if Debug {
@ -718,9 +715,9 @@ func (self *Conn) writeConnect(path string) (err error) {
if len(self.msgdata) == 4 { if len(self.msgdata) == 4 {
self.readAckSize = pio.U32BE(self.msgdata) self.readAckSize = pio.U32BE(self.msgdata)
} }
if err = self.writeWindowAckSize(0xffffffff); err != nil { //if err = self.writeWindowAckSize(0xffffffff); err != nil {
return // return
} //}
} }
} }
} }
@ -759,7 +756,7 @@ func (self *Conn) connectPublish() (err error) {
if self.commandname == "_result" { if self.commandname == "_result" {
var ok bool var ok bool
if ok, self.avmsgsid = self.checkCreateStreamResult(); !ok { if ok, self.avmsgsid = self.checkCreateStreamResult(); !ok {
err = fmt.Errorf("rtmp: createStream command failed") err = fmt.Errorf("createStream command failed")
return return
} }
break break
@ -819,7 +816,7 @@ func (self *Conn) connectPlay() (err error) {
if self.commandname == "_result" { if self.commandname == "_result" {
var ok bool var ok bool
if ok, self.avmsgsid = self.checkCreateStreamResult(); !ok { if ok, self.avmsgsid = self.checkCreateStreamResult(); !ok {
err = fmt.Errorf("rtmp: createStream command failed") err = fmt.Errorf("createStream command failed")
return return
} }
break break
@ -862,11 +859,9 @@ func (self *Conn) ReadPacket() (pkt av.Packet, err error) {
var ok bool var ok bool
if pkt, ok = self.prober.TagToPacket(tag, int32(self.timestamp)); ok { if pkt, ok = self.prober.TagToPacket(tag, int32(self.timestamp)); ok {
return return pkt, nil
} }
} }
return
} }
func (self *Conn) Prepare() (err error) { func (self *Conn) Prepare() (err error) {
@ -910,7 +905,7 @@ func (self *Conn) prepare(stage int, flags int) (err error) {
return return
} }
} else { } else {
err = fmt.Errorf("rtmp: call WriteHeader() before WritePacket()") err = fmt.Errorf("call WriteHeader() before WritePacket()")
return return
} }
} }
@ -1203,13 +1198,13 @@ func (self *Conn) readChunk() (err error) {
default: // Chunk basic header 1 default: // Chunk basic header 1
case 0: // Chunk basic header 2 case 0: // Chunk basic header 2
if _, err = io.ReadFull(self.bufr, b[:1]); err != nil { if _, err = io.ReadFull(self.bufr, b[:1]); err != nil {
return return fmt.Errorf("chunk basic header 2: %w", err)
} }
n += 1 n += 1
csid = uint32(b[0]) + 64 csid = uint32(b[0]) + 64
case 1: // Chunk basic header 3 case 1: // Chunk basic header 3
if _, err = io.ReadFull(self.bufr, b[:2]); err != nil { if _, err = io.ReadFull(self.bufr, b[:2]); err != nil {
return return fmt.Errorf("chunk basic header 3: %w", err)
} }
n += 2 n += 2
csid = uint32(pio.U16BE(b)) + 64 csid = uint32(pio.U16BE(b)) + 64
@ -1237,7 +1232,7 @@ func (self *Conn) readChunk() (err error) {
// //
// Figure 9 Chunk Message Header Type 0 // Figure 9 Chunk Message Header Type 0
if cs.msgdataleft != 0 { if cs.msgdataleft != 0 {
err = fmt.Errorf("rtmp: chunk msgdataleft=%d invalid", cs.msgdataleft) err = fmt.Errorf("chunk msgdataleft=%d invalid", cs.msgdataleft)
return return
} }
h := b[:11] h := b[:11]
@ -1274,7 +1269,7 @@ func (self *Conn) readChunk() (err error) {
// //
// Figure 10 Chunk Message Header Type 1 // Figure 10 Chunk Message Header Type 1
if cs.msgdataleft != 0 { if cs.msgdataleft != 0 {
err = fmt.Errorf("rtmp: chunk msgdataleft=%d invalid", cs.msgdataleft) err = fmt.Errorf("chunk msgdataleft=%d invalid", cs.msgdataleft)
return return
} }
h := b[:7] h := b[:7]
@ -1309,7 +1304,7 @@ func (self *Conn) readChunk() (err error) {
// //
// Figure 11 Chunk Message Header Type 2 // Figure 11 Chunk Message Header Type 2
if cs.msgdataleft != 0 { if cs.msgdataleft != 0 {
err = fmt.Errorf("rtmp: chunk msgdataleft=%d invalid", cs.msgdataleft) err = fmt.Errorf("chunk msgdataleft=%d invalid", cs.msgdataleft)
return return
} }
h := b[:3] h := b[:3]
@ -1361,7 +1356,7 @@ func (self *Conn) readChunk() (err error) {
} }
default: default:
err = fmt.Errorf("rtmp: invalid chunk msg header type=%d", msghdrtype) err = fmt.Errorf("invalid chunk msg header type=%d", msghdrtype)
return return
} }
@ -1389,15 +1384,17 @@ func (self *Conn) readChunk() (err error) {
} }
if err = self.handleMsg(cs.timenow, cs.msgsid, cs.msgtypeid, cs.msgdata); err != nil { if err = self.handleMsg(cs.timenow, cs.msgsid, cs.msgtypeid, cs.msgdata); err != nil {
return return fmt.Errorf("handleMsg: %w", err)
} }
} }
self.ackn += uint32(n) self.ackn += uint32(n)
if self.readAckSize != 0 && self.ackn > self.readAckSize { if self.readAckSize != 0 && self.ackn > self.readAckSize {
if err = self.writeAck(self.ackn); err != nil { if err = self.writeAck(self.ackn); err != nil {
return return fmt.Errorf("writeACK: %w", err)
} }
self.flushWrite()
self.ackn = 0 self.ackn = 0
} }
@ -1423,7 +1420,7 @@ func (self *Conn) handleCommandMsgAMF0(b []byte) (n int, err error) {
var ok bool var ok bool
if self.commandname, ok = name.(string); !ok { if self.commandname, ok = name.(string); !ok {
err = fmt.Errorf("rtmp: CommandMsgAMF0 command is not string") err = fmt.Errorf("CommandMsgAMF0 command is not string")
return return
} }
self.commandtransid, _ = transid.(float64) self.commandtransid, _ = transid.(float64)
@ -1438,7 +1435,7 @@ func (self *Conn) handleCommandMsgAMF0(b []byte) (n int, err error) {
self.commandparams = append(self.commandparams, obj) self.commandparams = append(self.commandparams, obj)
} }
if n < len(b) { if n < len(b) {
err = fmt.Errorf("rtmp: CommandMsgAMF0 left bytes=%d", len(b)-n) err = fmt.Errorf("CommandMsgAMF0 left bytes=%d", len(b)-n)
return return
} }
@ -1459,7 +1456,7 @@ func (self *Conn) handleMsg(timestamp uint32, msgsid uint32, msgtypeid uint8, ms
case msgtypeidCommandMsgAMF3: case msgtypeidCommandMsgAMF3:
if len(msgdata) < 1 { if len(msgdata) < 1 {
err = fmt.Errorf("rtmp: short packet of CommandMsgAMF3") err = fmt.Errorf("short packet of CommandMsgAMF3")
return return
} }
// skip first byte // skip first byte
@ -1469,7 +1466,7 @@ func (self *Conn) handleMsg(timestamp uint32, msgsid uint32, msgtypeid uint8, ms
case msgtypeidUserControl: case msgtypeidUserControl:
if len(msgdata) < 2 { if len(msgdata) < 2 {
err = fmt.Errorf("rtmp: short packet of UserControl") err = fmt.Errorf("short packet of UserControl")
return return
} }
self.eventtype = pio.U16BE(msgdata) self.eventtype = pio.U16BE(msgdata)
@ -1487,7 +1484,7 @@ func (self *Conn) handleMsg(timestamp uint32, msgsid uint32, msgtypeid uint8, ms
self.datamsgvals = append(self.datamsgvals, obj) self.datamsgvals = append(self.datamsgvals, obj)
} }
if n < len(b) { if n < len(b) {
err = fmt.Errorf("rtmp: DataMsgAMF0 left bytes=%d", len(b)-n) err = fmt.Errorf("DataMsgAMF0 left bytes=%d", len(b)-n)
return return
} }
@ -1499,7 +1496,7 @@ func (self *Conn) handleMsg(timestamp uint32, msgsid uint32, msgtypeid uint8, ms
switch x.(type) { switch x.(type) {
case string: case string:
if x.(string) == "onMetaData" { if x.(string) == "onMetaData" {
metaindex = i+1 metaindex = i + 1
} }
} }
} }
@ -1537,11 +1534,17 @@ func (self *Conn) handleMsg(timestamp uint32, msgsid uint32, msgtypeid uint8, ms
case msgtypeidSetChunkSize: case msgtypeidSetChunkSize:
if len(msgdata) < 4 { if len(msgdata) < 4 {
err = fmt.Errorf("rtmp: short packet of SetChunkSize") err = fmt.Errorf("short packet of SetChunkSize")
return return
} }
self.readMaxChunkSize = int(pio.U32BE(msgdata)) self.readMaxChunkSize = int(pio.U32BE(msgdata))
return return
case msgtypeidWindowAckSize:
if len(self.msgdata) != 4 {
return fmt.Errorf("invalid packet of WindowAckSize")
}
self.readAckSize = pio.U32BE(self.msgdata)
return
} }
self.gotmsg = true self.gotmsg = true
@ -1660,7 +1663,7 @@ func (self *Conn) handshakeClient() (err error) {
} }
if Debug { if Debug {
fmt.Println("rtmp: handshakeClient: server version", S1[4], S1[5], S1[6], S1[7]) fmt.Println("handshakeClient: server version", S1[4], S1[5], S1[6], S1[7])
} }
if ver := pio.U32BE(S1[4:8]); ver != 0 { if ver := pio.U32BE(S1[4:8]); ver != 0 {
@ -1698,7 +1701,7 @@ func (self *Conn) handshakeServer() (err error) {
return return
} }
if C0[0] != 3 { if C0[0] != 3 {
err = fmt.Errorf("rtmp: handshake version=%d invalid", C0[0]) err = fmt.Errorf("handshake version=%d invalid", C0[0])
return return
} }
@ -1713,7 +1716,7 @@ func (self *Conn) handshakeServer() (err error) {
var ok bool var ok bool
var digest []byte var digest []byte
if ok, digest = hsParse1(C1, hsClientPartialKey, hsServerFullKey); !ok { if ok, digest = hsParse1(C1, hsClientPartialKey, hsServerFullKey); !ok {
err = fmt.Errorf("rtmp: handshake server: C1 invalid") err = fmt.Errorf("handshake server: C1 invalid")
return return
} }
hsCreate01(S0S1, srvtime, srvver, hsServerPartialKey) hsCreate01(S0S1, srvtime, srvver, hsServerPartialKey)