rewrite ugly ReadResponse

This commit is contained in:
nareix 2016-06-21 16:59:43 +08:00
parent d62f2cef96
commit 6ba0534fb9
2 changed files with 140 additions and 136 deletions

209
client.go
View File

@ -9,7 +9,6 @@ import (
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"github.com/nareix/av" "github.com/nareix/av"
"github.com/nareix/av/pktque"
"github.com/nareix/codec" "github.com/nareix/codec"
"github.com/nareix/codec/aacparser" "github.com/nareix/codec/aacparser"
"github.com/nareix/codec/h264parser" "github.com/nareix/codec/h264parser"
@ -53,7 +52,6 @@ type Client struct {
streamsintf []av.CodecData streamsintf []av.CodecData
session string session string
body io.Reader body io.Reader
corrector *pktque.TimeCorrector
} }
type Request struct { type Request struct {
@ -177,126 +175,66 @@ func (self *Client) WriteRequest(req Request) (err error) {
return return
} }
func (self *Client) ReadResponse() (res Response, err error) { func (self *Client) probeBlockHeader(h []byte) (length int, no int, valid bool) {
var br *bufio.Reader length = int(h[2])<<8 + int(h[3])
no = int(h[1])
if no/2 >= len(self.streams) {
return
}
if no%2 != 0 {
return
}
if length < 8 {
return
}
stream := self.streams[no/2]
if int(h[5]&0x7f) != stream.Sdp.PayloadType {
return
}
valid = true
return
}
defer func() { func (self *Client) readBlock(h []byte) (res Response, err error) {
if br != nil {
buf, _ := br.Peek(br.Buffered())
self.rconn = io.MultiReader(bytes.NewReader(buf), self.rconn)
}
if res.StatusCode == 200 {
self.conn.Timeout = self.RtspTimeout
if res.ContentLength > 0 {
res.Body = make([]byte, res.ContentLength)
if _, err = io.ReadFull(self.rconn, res.Body); err != nil {
return
}
}
} else if res.BlockLength > 0 {
self.conn.Timeout = self.RtpTimeout
res.Block = make([]byte, res.BlockLength)
if _, err = io.ReadFull(self.rconn, res.Block); err != nil {
return
}
if err = self.SendRtpKeepalive(); err != nil {
return
}
}
}()
self.conn.Timeout = self.RtspTimeout
var h [4]byte
if _, err = io.ReadFull(self.rconn, h[:]); err != nil {
return
}
if h[0] == 36 {
// $
res.BlockLength = int(h[2])<<8 + int(h[3])
res.BlockNo = int(h[1])
if self.DebugRtp {
fmt.Println("rtp: block: len", res.BlockLength, "no", res.BlockNo)
}
// TODO: if invalid need relocate also
return
} else if h[0] == 82 && h[1] == 84 && h[2] == 83 && h[3] == 80 {
// RTSP 200 OK
self.rconn = io.MultiReader(bytes.NewReader(h[:]), self.rconn)
} else {
self.conn.Timeout = self.RtpTimeout self.conn.Timeout = self.RtpTimeout
for { for {
var valid bool
if res.BlockLength, res.BlockNo, valid = self.probeBlockHeader(h); valid {
break
}
if self.DebugRtp { if self.DebugRtp {
fmt.Println("rtp: block: relocate try") fmt.Println("rtp: block: relocate try")
} }
for { for {
var b [1]byte if _, err = self.rconn.Read(h[:1]); err != nil {
if _, err = self.rconn.Read(b[:]); err != nil {
return return
} }
if b[0] == 36 { if h[0] == 36 {
break break
} }
} }
if _, err = io.ReadFull(self.rconn, h[1:4]); err != nil {
return
}
res.BlockLength = int(h[2])<<8 + int(h[3]) if _, err = io.ReadFull(self.rconn, h[1:]); err != nil {
res.BlockNo = int(h[1]) return
if res.BlockNo/2 < len(self.streams) {
break
} }
} }
if self.DebugRtp { if self.DebugRtp {
fmt.Println("rtp: block: relocate done")
fmt.Println("rtp: block: len", res.BlockLength, "no", res.BlockNo) fmt.Println("rtp: block: len", res.BlockLength, "no", res.BlockNo)
} }
res.Block = make([]byte, res.BlockLength)
copy(res.Block[:len(h)-4], h[4:])
if _, err = io.ReadFull(self.rconn, res.Block[len(h)-4:]); err != nil {
return return
} }
br = bufio.NewReader(self.rconn)
tp := textproto.NewReader(br)
var line string
if line, err = tp.ReadLine(); err != nil {
return return
} }
if self.DebugRtsp {
fmt.Println("<", line)
}
fline := strings.SplitN(line, " ", 3) func (self *Client) handle401(header textproto.MIMEHeader) (err error) {
if len(fline) < 2 {
err = fmt.Errorf("rtsp: malformed response line")
return
}
if res.StatusCode, err = strconv.Atoi(fline[1]); err != nil {
return
}
var header textproto.MIMEHeader
if header, err = tp.ReadMIMEHeader(); err != nil {
return
}
if self.DebugRtsp {
for k, s := range header {
fmt.Println(k, s)
}
fmt.Println()
}
if res.StatusCode != 200 && res.StatusCode != 401 {
err = fmt.Errorf("rtsp: StatusCode=%d invalid", res.StatusCode)
return
}
if res.StatusCode == 401 {
/* /*
RTSP/1.0 401 Unauthorized RTSP/1.0 401 Unauthorized
CSeq: 2 CSeq: 2
@ -349,15 +287,90 @@ func (self *Client) ReadResponse() (res Response, err error) {
} }
} }
} }
return
}
func (self *Client) ReadResponse() (res Response, err error) {
if err = self.SendRtpKeepalive(); err != nil {
return
} }
self.conn.Timeout = self.RtspTimeout
h := make([]byte, 16)
if _, err = io.ReadFull(self.rconn, h); err != nil {
return
}
if h[0] == 82 && h[1] == 84 && h[2] == 83 && h[3] == 80 {
// RTSP 200 OK
self.rconn = io.MultiReader(bytes.NewReader(h), self.rconn)
} else {
return self.readBlock(h)
}
br := bufio.NewReader(self.rconn)
tp := textproto.NewReader(br)
var line string
if line, err = tp.ReadLine(); err != nil {
return
}
if self.DebugRtsp {
fmt.Println("<", line)
}
fline := strings.SplitN(line, " ", 3)
if len(fline) < 2 {
err = fmt.Errorf("rtsp: malformed response line")
return
}
if res.StatusCode, err = strconv.Atoi(fline[1]); err != nil {
return
}
var header textproto.MIMEHeader
if header, err = tp.ReadMIMEHeader(); err != nil {
return
}
if self.DebugRtsp {
for k, s := range header {
fmt.Println(k, s)
}
fmt.Println()
}
switch res.StatusCode {
case 401:
if err = self.handle401(header); err != nil {
return
}
case 200:
default:
err = fmt.Errorf("rtsp: StatusCode=%d invalid", res.StatusCode)
return
}
res.ContentLength, _ = strconv.Atoi(header.Get("Content-Length"))
if sess := header.Get("Session"); sess != "" && self.session == "" { if sess := header.Get("Session"); sess != "" && self.session == "" {
if fields := strings.Split(sess, ";"); len(fields) > 0 { if fields := strings.Split(sess, ";"); len(fields) > 0 {
self.session = fields[0] self.session = fields[0]
} }
} }
res.ContentLength, _ = strconv.Atoi(header.Get("Content-Length")) buf, _ := br.Peek(br.Buffered())
self.rconn = io.MultiReader(bytes.NewReader(buf), self.rconn)
if res.ContentLength > 0 {
res.Body = make([]byte, res.ContentLength)
if _, err = io.ReadFull(self.rconn, res.Body); err != nil {
return
}
}
return return
} }
@ -425,7 +438,6 @@ func (self *Client) initstructs() {
for i := range self.setupIdx { for i := range self.setupIdx {
self.streamsintf[i] = self.streams[self.setupIdx[i]].CodecData self.streamsintf[i] = self.streams[self.setupIdx[i]].CodecData
} }
self.corrector = pktque.NewTimeCorrector(self.streamsintf)
} }
func (self *Client) Describe() (streams []av.CodecData, err error) { func (self *Client) Describe() (streams []av.CodecData, err error) {
@ -979,7 +991,6 @@ func (self *Client) ReadPacket() (pkt av.Packet, err error) {
if self.DebugRtp { if self.DebugRtp {
fmt.Println("rtp: pktin", pkt.Idx, pkt.Time, len(pkt.Data)) fmt.Println("rtp: pktin", pkt.Idx, pkt.Time, len(pkt.Data))
} }
self.corrector.Correct(&pkt)
if self.DebugRtp { if self.DebugRtp {
fmt.Println("rtp: pktout", pkt.Idx, pkt.Time, len(pkt.Data)) fmt.Println("rtp: pktout", pkt.Idx, pkt.Time, len(pkt.Data))
} }

View File

@ -23,10 +23,3 @@ type Stream struct {
firsttimestamp uint32 firsttimestamp uint32
} }
func (self Stream) IsAudio() bool {
return self.Sdp.AVType == "audio"
}
func (self Stream) IsVideo() bool {
return self.Sdp.AVType == "video"
}