From bced7bd91538a56a8fdf729952b1b2a74c9e068a Mon Sep 17 00:00:00 2001 From: nareix Date: Wed, 22 Jun 2016 12:56:28 +0800 Subject: [PATCH] change block search mechanism again, improve relocate --- client.go | 505 +++++++++++++++++++++++++++++++++--------------------- stream.go | 3 + 2 files changed, 312 insertions(+), 196 deletions(-) diff --git a/client.go b/client.go index 0f804e6..bc4a3aa 100644 --- a/client.go +++ b/client.go @@ -45,7 +45,7 @@ type Client struct { url *url.URL conn *connWithTimeout - rconn io.Reader + brconn *bufio.Reader requestUri string cseq uint streams []*Stream @@ -61,14 +61,12 @@ type Request struct { } type Response struct { - BlockLength int - Block []byte - BlockNo int - StatusCode int - Header textproto.MIMEHeader + Headers textproto.MIMEHeader ContentLength int Body []byte + + Block []byte } func DialTimeout(uri string, timeout time.Duration) (self *Client, err error) { @@ -94,7 +92,7 @@ func DialTimeout(uri string, timeout time.Duration) (self *Client, err error) { self = &Client{ conn: connt, - rconn: connt, + brconn: bufio.NewReaderSize(connt, 256), url: URL, requestUri: u2.String(), } @@ -116,11 +114,7 @@ func (self *Client) Streams() (streams []av.CodecData, err error) { } func (self *Client) SendRtpKeepalive() (err error) { - if self.RtpKeepAliveTimeout > 0 && self.rtpKeepaliveEnterCnt == 0 { - self.rtpKeepaliveEnterCnt++ - defer func() { - self.rtpKeepaliveEnterCnt-- - }() + if self.RtpKeepAliveTimeout > 0 { if self.rtpKeepaliveTimer.IsZero() { self.rtpKeepaliveTimer = time.Now() } else if time.Now().Sub(self.rtpKeepaliveTimer) > self.RtpKeepAliveTimeout { @@ -128,7 +122,11 @@ func (self *Client) SendRtpKeepalive() (err error) { if self.DebugRtsp { fmt.Println("rtp: keep alive") } - if err = self.Options(); err != nil { + req := Request{ + Method: "OPTIONS", + Uri: self.requestUri, + } + if err = self.WriteRequest(req); err != nil { return } } @@ -175,15 +173,12 @@ func (self *Client) WriteRequest(req Request) (err error) { return } -func (self *Client) probeBlockHeader(h []byte) (length int, no int, valid bool) { +func (self *Client) parseBlockHeader(h []byte) (length int, no int, valid bool) { length = int(h[2])<<8 + int(h[3]) no = int(h[1]) if no/2 >= len(self.streams) { return } - if no%2 != 0 { - return - } if length < 8 { return } @@ -195,53 +190,50 @@ func (self *Client) probeBlockHeader(h []byte) (length int, no int, valid bool) return } -func (self *Client) readBlock(h []byte) (res Response, err error) { - self.conn.Timeout = self.RtpTimeout +func (self *Client) parseHeaders(b []byte) (statusCode int, headers textproto.MIMEHeader, err error) { + var line string + r := textproto.NewReader(bufio.NewReader(bytes.NewReader(b))) + if line, err = r.ReadLine(); err != nil { + err = fmt.Errorf("rtsp: header invalid") + return + } - for { - var valid bool - if res.BlockLength, res.BlockNo, valid = self.probeBlockHeader(h); valid { - break - } - if self.DebugRtp { - fmt.Println("rtp: block: relocate try") - } - - for { - if _, err = self.rconn.Read(h[:1]); err != nil { - return - } - if h[0] == 36 { - break - } - } - - if _, err = io.ReadFull(self.rconn, h[1:]); err != nil { + if codes := strings.Split(line, " "); len(codes) >= 2 { + if statusCode, err = strconv.Atoi(codes[1]); err != nil { + err = fmt.Errorf("rtsp: header invalid: %s", err) return } } - if self.DebugRtp { - fmt.Println("rtp: block: len", res.BlockLength, "no", res.BlockNo) - } - - res.Block = make([]byte, res.BlockLength) - copy(res.Block[:len(h)-4], h[4:]) - if _, err = io.ReadFull(self.rconn, res.Block[len(h)-4:]); err != nil { + if headers, err = r.ReadMIMEHeader(); err != nil { return } return } -func (self *Client) handle401(header textproto.MIMEHeader) (err error) { +func (self *Client) handleResp(res *Response) (err error) { + if sess := res.Headers.Get("Session"); sess != "" && self.session == "" { + if fields := strings.Split(sess, ";"); len(fields) > 0 { + self.session = fields[0] + } + } + if res.StatusCode == 401 { + if err = self.handle401(res); err != nil { + return + } + } + return +} + +func (self *Client) handle401(res *Response) (err error) { /* RTSP/1.0 401 Unauthorized CSeq: 2 Date: Wed, May 04 2016 10:10:51 GMT WWW-Authenticate: Digest realm="LIVE555 Streaming Media", nonce="c633aaf8b83127633cbe98fac1d20d87" */ - authval := header.Get("WWW-Authenticate") + authval := res.Headers.Get("WWW-Authenticate") hdrval := strings.SplitN(authval, " ", 2) var realm, nonce string @@ -291,87 +283,187 @@ func (self *Client) handle401(header textproto.MIMEHeader) (err error) { return } -func (self *Client) ReadResponse() (res Response, err error) { - if err = self.SendRtpKeepalive(); err != nil { - return - } +func (self *Client) findRTSP() (block []byte, data []byte, err error) { + const ( + R = iota+1 + T + S + Header + Dollar + ) + var _peek [8]byte + peek := _peek[0:0] + stat := 0 - self.conn.Timeout = self.RtspTimeout - - h := make([]byte, 16) - if _, err = io.ReadFull(self.rconn, h); err != nil { - return - } - - if h[0] == 82 && h[1] == 84 && h[2] == 83 && h[3] == 80 { - // RTSP 200 OK - self.rconn = io.MultiReader(bytes.NewReader(h), self.rconn) - } else { - return self.readBlock(h) - } - - br := bufio.NewReader(self.rconn) - tp := textproto.NewReader(br) - - var line string - if line, err = tp.ReadLine(); err != nil { - return - } - if self.DebugRtsp { - fmt.Println("<", line) - } - - fline := strings.SplitN(line, " ", 3) - if len(fline) < 2 { - err = fmt.Errorf("rtsp: malformed response line") - return - } - - if res.StatusCode, err = strconv.Atoi(fline[1]); err != nil { - return - } - var header textproto.MIMEHeader - if header, err = tp.ReadMIMEHeader(); err != nil { - return - } - - if self.DebugRtsp { - for k, s := range header { - fmt.Println(k, s) + for { + var b byte + if b, err = self.brconn.ReadByte(); err != nil { + return + } + switch b { + case 'R': + if stat == 0 { + stat = R + } + case 'T': + if stat == R { + stat = T + } + case 'S': + if stat == T { + stat = S + } + case 'P': + if stat == S { + stat = Header + } + case '$': + stat = Dollar + peek = _peek[0:0] + default: + if stat != Dollar { + stat = 0 + peek = _peek[0:0] + } } - fmt.Println() - } - switch res.StatusCode { - case 401: - if err = self.handle401(header); err != nil { + if stat != 0 { + peek = append(peek, b) + } + if stat == Header { + data = peek return } - case 200: - - default: - err = fmt.Errorf("rtsp: StatusCode=%d invalid", res.StatusCode) - return - } - - res.ContentLength, _ = strconv.Atoi(header.Get("Content-Length")) - if sess := header.Get("Session"); sess != "" && self.session == "" { - if fields := strings.Split(sess, ";"); len(fields) > 0 { - self.session = fields[0] + if stat == Dollar && len(peek) >= 8 { + if blocklen, _, ok := self.parseBlockHeader(peek); ok { + left := blocklen+4-len(peek) + block = append(peek, make([]byte, left)...) + if _, err = io.ReadFull(self.brconn, block[len(peek):]); err != nil { + return + } + return + } + stat = 0 + peek = _peek[0:0] } } - buf, _ := br.Peek(br.Buffered()) - self.rconn = io.MultiReader(bytes.NewReader(buf), self.rconn) + return +} +func (self *Client) readLFLF() (block []byte, data []byte, err error) { + const ( + LF = iota+1 + LFLF + ) + peek := []byte{} + stat := 0 + dollarpos := -1 + lpos := 0 + pos := 0 + + for { + var b byte + if b, err = self.brconn.ReadByte(); err != nil { + return + } + switch b { + case '\n': + if stat == 0 { + stat = LF + lpos = pos + } else if stat == LF { + if pos - lpos <= 2 { + stat = LFLF + } else { + lpos = pos + } + } + case '$': + dollarpos = pos + } + peek = append(peek, b) + + if stat == LFLF { + data = peek + return + } else if dollarpos != -1 && dollarpos - pos >= 8 { + hdrlen := dollarpos-pos + start := len(peek)-hdrlen + if blocklen, _, ok := self.parseBlockHeader(peek[start:]); ok { + block = append(peek[start:], make([]byte, blocklen+4-hdrlen)...) + if _, err = io.ReadFull(self.brconn, block[hdrlen:]); err != nil { + return + } + return + } + dollarpos = -1 + } + + pos++ + } + + return +} + +func (self *Client) readResp(b []byte) (res Response, err error) { + if res.StatusCode, res.Headers, err = self.parseHeaders(b); err != nil { + return + } + res.ContentLength, _ = strconv.Atoi(res.Headers.Get("Content-Length")) if res.ContentLength > 0 { res.Body = make([]byte, res.ContentLength) - if _, err = io.ReadFull(self.rconn, res.Body); err != nil { + if _, err = io.ReadFull(self.brconn, res.Body); err != nil { return } } + if err = self.handleResp(&res); err != nil { + return + } + return +} +func (self *Client) poll() (res Response, err error) { + var block []byte + var rtsp []byte + var headers []byte + + self.conn.Timeout = self.RtspTimeout + for { + if block, rtsp, err = self.findRTSP(); err != nil { + return + } + if len(block) > 0 { + res.Block = block + return + } else { + if block, headers, err = self.readLFLF(); err != nil { + return + } + if len(block) > 0 { + res.Block = block + return + } + if res, err = self.readResp(append(rtsp, headers...)); err != nil { + return + } + } + return + } + + return +} + +func (self *Client) ReadResponse() (res Response, err error) { + for { + if res, err = self.poll(); err != nil { + return + } + if res.StatusCode > 0 { + return + } + } return } @@ -616,6 +708,11 @@ func (self *Stream) handleBuggyAnnexbH264Packet(timestamp uint32, packet []byte) } func (self *Stream) handleH264Payload(timestamp uint32, packet []byte) (err error) { + if len(packet) < 2 { + err = fmt.Errorf("rtp: h264 packet too short") + return + } + var isBuggy bool if isBuggy, err = self.handleBuggyAnnexbH264Packet(timestamp, packet); isBuggy { return @@ -786,45 +883,21 @@ func (self *Stream) handleH264Payload(timestamp uint32, packet []byte) (err erro return } -func (self *Stream) handlePacket(timestamp uint32, packet []byte) (err error) { +func (self *Stream) handleRtpPacket(packet []byte) (err error) { if self.isCodecDataChange() { err = ErrCodecDataChange return } - switch self.Type() { - case av.H264: - if err = self.handleH264Payload(timestamp, packet); err != nil { - return + if self.client != nil && self.client.DebugRtp { + fmt.Println("rtp: packet len", len(packet)) + dumpsize := len(packet) + if dumpsize > 32 { + dumpsize = 32 } - - case av.AAC: - self.gotpkt = true - self.pkt.Data = packet[4:] - self.timestamp = timestamp - - default: - self.gotpkt = true - self.pkt.Data = packet - self.timestamp = timestamp + fmt.Print(hex.Dump(packet[:dumpsize])) } - return -} - -func (self *Client) parseBlock(blockNo int, packet []byte) (streamIndex int, err error) { - if blockNo%2 != 0 { - // rtcp block - return - } - - streamIndex = blockNo / 2 - if streamIndex >= len(self.streams) { - err = fmt.Errorf("rtsp: parseBlock: streamIndex=%d invalid", streamIndex) - return - } - stream := self.streams[streamIndex] - /* 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 @@ -839,17 +912,15 @@ func (self *Client) parseBlock(blockNo int, packet []byte) (streamIndex int, err | .... | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ */ - if len(packet) < 8 { err = fmt.Errorf("rtp: packet too short") return } payloadOffset := 12 + int(packet[0]&0xf)*4 - if payloadOffset+2 > len(packet) { + if payloadOffset > len(packet) { err = fmt.Errorf("rtp: packet too short") return } - timestamp := binary.BigEndian.Uint32(packet[4:8]) payload := packet[payloadOffset:] @@ -897,15 +968,21 @@ func (self *Client) parseBlock(blockNo int, packet []byte) (streamIndex int, err */ //payloadType := packet[1]&0x7f - if self.DebugRtp { - //fmt.Println("packet:", stream.Type(), "offset", payloadOffset, "pt", payloadType) - if len(packet) > 24 { - fmt.Println(hex.Dump(packet[:24])) + switch self.Type() { + case av.H264: + if err = self.handleH264Payload(timestamp, payload); err != nil { + return } - } - if err = stream.handlePacket(timestamp, payload); err != nil { - return + case av.AAC: + self.gotpkt = true + self.pkt.Data = payload[4:] // TODO: remove this hack + self.timestamp = timestamp + + default: + self.gotpkt = true + self.pkt.Data = packet + self.timestamp = timestamp } return @@ -940,6 +1017,68 @@ func (self *Client) Close() (err error) { return self.conn.Conn.Close() } +func (self *Client) handleBlock(block []byte) (pkt av.Packet, ok bool, err error) { + _, blockno, _ := self.parseBlockHeader(block) + if blockno%2 != 0 { + return + } + + i := blockno/2 + if i >= len(self.streams) { + err = fmt.Errorf("rtsp: block no=%d invalid", blockno) + return + } + stream := self.streams[i] + + if err = stream.handleRtpPacket(block[4:]); err != nil { + return + } + + if stream.gotpkt { + timeScale := stream.Sdp.TimeScale + + /* + TODO: https://tools.ietf.org/html/rfc3550 + A receiver can then synchronize presentation of the audio and video packets by relating + their RTP timestamps using the timestamp pairs in RTCP SR packets. + */ + if stream.firsttimestamp == 0 { + stream.firsttimestamp = stream.timestamp + } + stream.timestamp -= stream.firsttimestamp + + if timeScale == 0 { + /* + https://tools.ietf.org/html/rfc5391 + The RTP timestamp clock frequency is the same as the default sampling frequency: 16 kHz. + */ + timeScale = 16000 + } + + ok = true + pkt = stream.pkt + pkt.Time = time.Duration(stream.timestamp)*time.Second / time.Duration(timeScale) + pkt.Idx = int8(self.setupMap[i]) + + if pkt.Time < stream.lasttime { + err = fmt.Errorf("rtp: stream#%d time=%v < lasttime=%v", pkt.Time, stream.lasttime) + } + stream.lasttime = pkt.Time + + if self.DebugRtp { + fmt.Println("rtp: pktin", pkt.Idx, pkt.Time, len(pkt.Data)) + } + if self.DebugRtp { + fmt.Println("rtp: pktout", pkt.Idx, pkt.Time, len(pkt.Data)) + } + + stream.pkt = av.Packet{} + stream.gotpkt = false + } + + return +} + func (self *Client) ReadPacket() (pkt av.Packet, err error) { if !self.setupCalled { if err = self.setupAll(); err != nil { @@ -952,53 +1091,27 @@ func (self *Client) ReadPacket() (pkt av.Packet, err error) { } } + if err = self.SendRtpKeepalive(); err != nil { + return + } + for { var res Response - if res, err = self.ReadResponse(); err != nil { + for { + if res, err = self.poll(); err != nil { + return + } + if len(res.Block) > 0 { + break + } + } + + var ok bool + if pkt, ok, err = self.handleBlock(res.Block); err != nil { return } - if res.BlockLength > 0 { - var i int - if i, err = self.parseBlock(res.BlockNo, res.Block); err != nil { - return - } - stream := self.streams[i] - if stream.gotpkt { - timeScale := stream.Sdp.TimeScale - - /* - TODO: https://tools.ietf.org/html/rfc3550 - A receiver can then synchronize presentation of the audio and video packets by relating - their RTP timestamps using the timestamp pairs in RTCP SR packets. - */ - if stream.firsttimestamp == 0 { - stream.firsttimestamp = stream.timestamp - } - stream.timestamp -= stream.firsttimestamp - - if timeScale == 0 { - /* - https://tools.ietf.org/html/rfc5391 - The RTP timestamp clock frequency is the same as the default sampling frequency: 16 kHz. - */ - timeScale = 16000 - } - - pkt = stream.pkt - pkt.Time = time.Duration(stream.timestamp)*time.Second / time.Duration(timeScale) - pkt.Idx = int8(self.setupMap[i]) - - if self.DebugRtp { - fmt.Println("rtp: pktin", pkt.Idx, pkt.Time, len(pkt.Data)) - } - if self.DebugRtp { - fmt.Println("rtp: pktout", pkt.Idx, pkt.Time, len(pkt.Data)) - } - - stream.pkt = av.Packet{} - stream.gotpkt = false - return - } + if ok { + return } } diff --git a/stream.go b/stream.go index 44f8cc6..4f43d7a 100644 --- a/stream.go +++ b/stream.go @@ -3,6 +3,7 @@ package rtsp import ( "github.com/nareix/av" "github.com/nareix/rtsp/sdp" + "time" ) type Stream struct { @@ -21,5 +22,7 @@ type Stream struct { pkt av.Packet timestamp uint32 firsttimestamp uint32 + + lasttime time.Duration }