React on WindowAckSize, set PeerBandwidth hard

The SetPeerBandwidth command is now sent with the Hard(0) options
in order to force this value.

The WindowAckSize command from the peer is now respected if it is
not part of the handshake. Before this lead to no window size set
and therefore no ACKs sent to the peer. For clients that expect
such ACK (like Blackmagic ATEM Mini Pro) this caused aborting the
connection.
This commit is contained in:
Ingo Oppermann 2022-07-28 18:13:04 +02:00
parent 2102a8289c
commit 7098ea1efd

View File

@ -10,16 +10,17 @@ import (
"encoding/hex"
"errors"
"fmt"
"github.com/datarhei/joy4/utils/bits/pio"
"github.com/datarhei/joy4/av"
"github.com/datarhei/joy4/av/avutil"
"github.com/datarhei/joy4/format/flv"
"github.com/datarhei/joy4/format/flv/flvio"
"io"
"net"
"net/url"
"strings"
"time"
"github.com/datarhei/joy4/av"
"github.com/datarhei/joy4/av/avutil"
"github.com/datarhei/joy4/format/flv"
"github.com/datarhei/joy4/format/flv/flvio"
"github.com/datarhei/joy4/utils/bits/pio"
)
var Debug bool
@ -55,7 +56,7 @@ func DialTimeout(uri string, timeout time.Duration) (conn *Conn, err error) {
return
}
var ErrServerClosed = errors.New("rtmp: Server closed")
var ErrServerClosed = errors.New("server closed")
type Server struct {
Addr string
@ -98,7 +99,7 @@ func (self *Server) ListenAndServe() error {
listener, err := net.Listen("tcp", addr)
if err != nil {
return fmt.Errorf("rtmp: %w", err)
return err
}
return self.Serve(listener)
@ -112,7 +113,7 @@ func (self *Server) ListenAndServeTLS(certFile, keyFile string) error {
listener, err := net.Listen("tcp", addr)
if err != nil {
return fmt.Errorf("rtmp: %w", err)
return err
}
return self.ServeTLS(listener, certFile, keyFile)
@ -135,7 +136,7 @@ func (self *Server) ServeTLS(listener net.Listener, certFile, keyFile string) er
var err error
config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return fmt.Errorf("rtmp: %w", err)
return err
}
}
@ -179,8 +180,6 @@ func (self *Server) Serve(listener net.Listener) error {
}
}()
}
return nil
}
func (self *Server) Close() {
@ -429,11 +428,11 @@ func (self *Conn) writeBasicConf() (err error) {
return
}
// > WindowAckSize
if err = self.writeWindowAckSize(5000000); err != nil {
if err = self.writeWindowAckSize(1024 * 1024 * 2); err != nil {
return
}
// > SetPeerBandwidth
if err = self.writeSetPeerBandwidth(5000000, 2); err != nil {
if err = self.writeSetPeerBandwidth(1024*1024*2, 0); err != nil {
return
}
return
@ -447,18 +446,18 @@ func (self *Conn) readConnect() (err error) {
return
}
if self.commandname != "connect" {
err = fmt.Errorf("rtmp: first command is not connect")
err = fmt.Errorf("first command is not connect")
return
}
if self.commandobj == nil {
err = fmt.Errorf("rtmp: connect command params invalid")
err = fmt.Errorf("connect command params invalid")
return
}
var ok bool
var _app, _tcurl interface{}
if _app, ok = self.commandobj["app"]; !ok {
err = fmt.Errorf("rtmp: `connect` params missing `app`")
err = fmt.Errorf("the `connect` params missing `app`")
return
}
connectpath, _ = _app.(string)
@ -521,7 +520,7 @@ func (self *Conn) readConnect() (err error) {
}
if len(self.commandparams) < 1 {
err = fmt.Errorf("rtmp: publish params invalid")
err = fmt.Errorf("publish params invalid")
return
}
publishpath, _ := self.commandparams[0].(string)
@ -547,7 +546,7 @@ func (self *Conn) readConnect() (err error) {
}
if cberr != nil {
err = fmt.Errorf("rtmp: OnPlayOrPublish check failed")
err = fmt.Errorf("OnPlayOrPublish check failed")
return
}
@ -564,7 +563,7 @@ func (self *Conn) readConnect() (err error) {
}
if len(self.commandparams) < 1 {
err = fmt.Errorf("rtmp: command play params invalid")
err = fmt.Errorf("command play params invalid")
return
}
playpath, _ := self.commandparams[0].(string)
@ -606,8 +605,6 @@ func (self *Conn) readConnect() (err error) {
}
}
return
}
func (self *Conn) checkConnectResult() (ok bool, errmsg string) {
@ -657,7 +654,7 @@ func (self *Conn) probe() (err error) {
}
if err = self.prober.PushTag(tag, int32(self.timestamp)); err != nil {
if Debug {
fmt.Printf("error probing tag: %s\n", err.Error())
fmt.Printf("rtmp: error probing tag: %s\n", err.Error())
}
}
}
@ -705,7 +702,7 @@ func (self *Conn) writeConnect(path string) (err error) {
var ok bool
var errmsg string
if ok, errmsg = self.checkConnectResult(); !ok {
err = fmt.Errorf("rtmp: command connect failed: %s", errmsg)
err = fmt.Errorf("command connect failed: %s", errmsg)
return
}
if Debug {
@ -718,9 +715,9 @@ func (self *Conn) writeConnect(path string) (err error) {
if len(self.msgdata) == 4 {
self.readAckSize = pio.U32BE(self.msgdata)
}
if err = self.writeWindowAckSize(0xffffffff); err != nil {
return
}
//if err = self.writeWindowAckSize(0xffffffff); err != nil {
// return
//}
}
}
}
@ -759,7 +756,7 @@ func (self *Conn) connectPublish() (err error) {
if self.commandname == "_result" {
var ok bool
if ok, self.avmsgsid = self.checkCreateStreamResult(); !ok {
err = fmt.Errorf("rtmp: createStream command failed")
err = fmt.Errorf("createStream command failed")
return
}
break
@ -819,7 +816,7 @@ func (self *Conn) connectPlay() (err error) {
if self.commandname == "_result" {
var ok bool
if ok, self.avmsgsid = self.checkCreateStreamResult(); !ok {
err = fmt.Errorf("rtmp: createStream command failed")
err = fmt.Errorf("createStream command failed")
return
}
break
@ -862,11 +859,9 @@ func (self *Conn) ReadPacket() (pkt av.Packet, err error) {
var ok bool
if pkt, ok = self.prober.TagToPacket(tag, int32(self.timestamp)); ok {
return
return pkt, nil
}
}
return
}
func (self *Conn) Prepare() (err error) {
@ -910,7 +905,7 @@ func (self *Conn) prepare(stage int, flags int) (err error) {
return
}
} else {
err = fmt.Errorf("rtmp: call WriteHeader() before WritePacket()")
err = fmt.Errorf("call WriteHeader() before WritePacket()")
return
}
}
@ -1203,13 +1198,13 @@ func (self *Conn) readChunk() (err error) {
default: // Chunk basic header 1
case 0: // Chunk basic header 2
if _, err = io.ReadFull(self.bufr, b[:1]); err != nil {
return
return fmt.Errorf("chunk basic header 2: %w", err)
}
n += 1
csid = uint32(b[0]) + 64
case 1: // Chunk basic header 3
if _, err = io.ReadFull(self.bufr, b[:2]); err != nil {
return
return fmt.Errorf("chunk basic header 3: %w", err)
}
n += 2
csid = uint32(pio.U16BE(b)) + 64
@ -1237,7 +1232,7 @@ func (self *Conn) readChunk() (err error) {
//
// Figure 9 Chunk Message Header Type 0
if cs.msgdataleft != 0 {
err = fmt.Errorf("rtmp: chunk msgdataleft=%d invalid", cs.msgdataleft)
err = fmt.Errorf("chunk msgdataleft=%d invalid", cs.msgdataleft)
return
}
h := b[:11]
@ -1274,7 +1269,7 @@ func (self *Conn) readChunk() (err error) {
//
// Figure 10 Chunk Message Header Type 1
if cs.msgdataleft != 0 {
err = fmt.Errorf("rtmp: chunk msgdataleft=%d invalid", cs.msgdataleft)
err = fmt.Errorf("chunk msgdataleft=%d invalid", cs.msgdataleft)
return
}
h := b[:7]
@ -1309,7 +1304,7 @@ func (self *Conn) readChunk() (err error) {
//
// Figure 11 Chunk Message Header Type 2
if cs.msgdataleft != 0 {
err = fmt.Errorf("rtmp: chunk msgdataleft=%d invalid", cs.msgdataleft)
err = fmt.Errorf("chunk msgdataleft=%d invalid", cs.msgdataleft)
return
}
h := b[:3]
@ -1361,7 +1356,7 @@ func (self *Conn) readChunk() (err error) {
}
default:
err = fmt.Errorf("rtmp: invalid chunk msg header type=%d", msghdrtype)
err = fmt.Errorf("invalid chunk msg header type=%d", msghdrtype)
return
}
@ -1389,15 +1384,17 @@ func (self *Conn) readChunk() (err error) {
}
if err = self.handleMsg(cs.timenow, cs.msgsid, cs.msgtypeid, cs.msgdata); err != nil {
return
return fmt.Errorf("handleMsg: %w", err)
}
}
self.ackn += uint32(n)
if self.readAckSize != 0 && self.ackn > self.readAckSize {
if err = self.writeAck(self.ackn); err != nil {
return
return fmt.Errorf("writeACK: %w", err)
}
self.flushWrite()
self.ackn = 0
}
@ -1423,7 +1420,7 @@ func (self *Conn) handleCommandMsgAMF0(b []byte) (n int, err error) {
var ok bool
if self.commandname, ok = name.(string); !ok {
err = fmt.Errorf("rtmp: CommandMsgAMF0 command is not string")
err = fmt.Errorf("CommandMsgAMF0 command is not string")
return
}
self.commandtransid, _ = transid.(float64)
@ -1438,7 +1435,7 @@ func (self *Conn) handleCommandMsgAMF0(b []byte) (n int, err error) {
self.commandparams = append(self.commandparams, obj)
}
if n < len(b) {
err = fmt.Errorf("rtmp: CommandMsgAMF0 left bytes=%d", len(b)-n)
err = fmt.Errorf("CommandMsgAMF0 left bytes=%d", len(b)-n)
return
}
@ -1459,7 +1456,7 @@ func (self *Conn) handleMsg(timestamp uint32, msgsid uint32, msgtypeid uint8, ms
case msgtypeidCommandMsgAMF3:
if len(msgdata) < 1 {
err = fmt.Errorf("rtmp: short packet of CommandMsgAMF3")
err = fmt.Errorf("short packet of CommandMsgAMF3")
return
}
// skip first byte
@ -1469,7 +1466,7 @@ func (self *Conn) handleMsg(timestamp uint32, msgsid uint32, msgtypeid uint8, ms
case msgtypeidUserControl:
if len(msgdata) < 2 {
err = fmt.Errorf("rtmp: short packet of UserControl")
err = fmt.Errorf("short packet of UserControl")
return
}
self.eventtype = pio.U16BE(msgdata)
@ -1487,7 +1484,7 @@ func (self *Conn) handleMsg(timestamp uint32, msgsid uint32, msgtypeid uint8, ms
self.datamsgvals = append(self.datamsgvals, obj)
}
if n < len(b) {
err = fmt.Errorf("rtmp: DataMsgAMF0 left bytes=%d", len(b)-n)
err = fmt.Errorf("DataMsgAMF0 left bytes=%d", len(b)-n)
return
}
@ -1499,7 +1496,7 @@ func (self *Conn) handleMsg(timestamp uint32, msgsid uint32, msgtypeid uint8, ms
switch x.(type) {
case string:
if x.(string) == "onMetaData" {
metaindex = i+1
metaindex = i + 1
}
}
}
@ -1537,11 +1534,17 @@ func (self *Conn) handleMsg(timestamp uint32, msgsid uint32, msgtypeid uint8, ms
case msgtypeidSetChunkSize:
if len(msgdata) < 4 {
err = fmt.Errorf("rtmp: short packet of SetChunkSize")
err = fmt.Errorf("short packet of SetChunkSize")
return
}
self.readMaxChunkSize = int(pio.U32BE(msgdata))
return
case msgtypeidWindowAckSize:
if len(self.msgdata) != 4 {
return fmt.Errorf("invalid packet of WindowAckSize")
}
self.readAckSize = pio.U32BE(self.msgdata)
return
}
self.gotmsg = true
@ -1660,7 +1663,7 @@ func (self *Conn) handshakeClient() (err error) {
}
if Debug {
fmt.Println("rtmp: handshakeClient: server version", S1[4], S1[5], S1[6], S1[7])
fmt.Println("handshakeClient: server version", S1[4], S1[5], S1[6], S1[7])
}
if ver := pio.U32BE(S1[4:8]); ver != 0 {
@ -1698,7 +1701,7 @@ func (self *Conn) handshakeServer() (err error) {
return
}
if C0[0] != 3 {
err = fmt.Errorf("rtmp: handshake version=%d invalid", C0[0])
err = fmt.Errorf("handshake version=%d invalid", C0[0])
return
}
@ -1713,7 +1716,7 @@ func (self *Conn) handshakeServer() (err error) {
var ok bool
var digest []byte
if ok, digest = hsParse1(C1, hsClientPartialKey, hsServerFullKey); !ok {
err = fmt.Errorf("rtmp: handshake server: C1 invalid")
err = fmt.Errorf("handshake server: C1 invalid")
return
}
hsCreate01(S0S1, srvtime, srvver, hsServerPartialKey)