From 6ba0534fb9893b035d47c0cab102a612671ec0fa Mon Sep 17 00:00:00 2001 From: nareix Date: Tue, 21 Jun 2016 16:59:43 +0800 Subject: [PATCH] rewrite ugly ReadResponse --- client.go | 269 ++++++++++++++++++++++++++++-------------------------- stream.go | 7 -- 2 files changed, 140 insertions(+), 136 deletions(-) diff --git a/client.go b/client.go index 12a98ea..0f804e6 100644 --- a/client.go +++ b/client.go @@ -9,7 +9,6 @@ import ( "encoding/hex" "fmt" "github.com/nareix/av" - "github.com/nareix/av/pktque" "github.com/nareix/codec" "github.com/nareix/codec/aacparser" "github.com/nareix/codec/h264parser" @@ -53,7 +52,6 @@ type Client struct { streamsintf []av.CodecData session string body io.Reader - corrector *pktque.TimeCorrector } type Request struct { @@ -177,89 +175,142 @@ func (self *Client) WriteRequest(req Request) (err error) { return } -func (self *Client) ReadResponse() (res Response, err error) { - var br *bufio.Reader - - defer func() { - if br != nil { - buf, _ := br.Peek(br.Buffered()) - self.rconn = io.MultiReader(bytes.NewReader(buf), self.rconn) - } - if res.StatusCode == 200 { - self.conn.Timeout = self.RtspTimeout - - if res.ContentLength > 0 { - res.Body = make([]byte, res.ContentLength) - if _, err = io.ReadFull(self.rconn, res.Body); err != nil { - return - } - } - } else if res.BlockLength > 0 { - self.conn.Timeout = self.RtpTimeout - res.Block = make([]byte, res.BlockLength) - if _, err = io.ReadFull(self.rconn, res.Block); err != nil { - return - } - if err = self.SendRtpKeepalive(); err != nil { - return - } - } - }() - - self.conn.Timeout = self.RtspTimeout - var h [4]byte - if _, err = io.ReadFull(self.rconn, h[:]); err != nil { +func (self *Client) probeBlockHeader(h []byte) (length int, no int, valid bool) { + length = int(h[2])<<8 + int(h[3]) + no = int(h[1]) + if no/2 >= len(self.streams) { return } - - if h[0] == 36 { - // $ - res.BlockLength = int(h[2])<<8 + int(h[3]) - res.BlockNo = int(h[1]) - if self.DebugRtp { - fmt.Println("rtp: block: len", res.BlockLength, "no", res.BlockNo) - } - // TODO: if invalid need relocate also + if no%2 != 0 { return - } else if h[0] == 82 && h[1] == 84 && h[2] == 83 && h[3] == 80 { - // RTSP 200 OK - self.rconn = io.MultiReader(bytes.NewReader(h[:]), self.rconn) - } else { - self.conn.Timeout = self.RtpTimeout + } + if length < 8 { + return + } + stream := self.streams[no/2] + if int(h[5]&0x7f) != stream.Sdp.PayloadType { + return + } + valid = true + return +} + +func (self *Client) readBlock(h []byte) (res Response, err error) { + self.conn.Timeout = self.RtpTimeout + + for { + var valid bool + if res.BlockLength, res.BlockNo, valid = self.probeBlockHeader(h); valid { + break + } + if self.DebugRtp { + fmt.Println("rtp: block: relocate try") + } for { - if self.DebugRtp { - fmt.Println("rtp: block: relocate try") - } - - for { - var b [1]byte - if _, err = self.rconn.Read(b[:]); err != nil { - return - } - if b[0] == 36 { - break - } - } - if _, err = io.ReadFull(self.rconn, h[1:4]); err != nil { + if _, err = self.rconn.Read(h[:1]); err != nil { return } - - res.BlockLength = int(h[2])<<8 + int(h[3]) - res.BlockNo = int(h[1]) - if res.BlockNo/2 < len(self.streams) { + if h[0] == 36 { break } } - if self.DebugRtp { - fmt.Println("rtp: block: relocate done") - fmt.Println("rtp: block: len", res.BlockLength, "no", res.BlockNo) + if _, err = io.ReadFull(self.rconn, h[1:]); err != nil { + return } + } + + if self.DebugRtp { + fmt.Println("rtp: block: len", res.BlockLength, "no", res.BlockNo) + } + + res.Block = make([]byte, res.BlockLength) + copy(res.Block[:len(h)-4], h[4:]) + if _, err = io.ReadFull(self.rconn, res.Block[len(h)-4:]); err != nil { return } - br = bufio.NewReader(self.rconn) + return +} + +func (self *Client) handle401(header textproto.MIMEHeader) (err error) { + /* + RTSP/1.0 401 Unauthorized + CSeq: 2 + Date: Wed, May 04 2016 10:10:51 GMT + WWW-Authenticate: Digest realm="LIVE555 Streaming Media", nonce="c633aaf8b83127633cbe98fac1d20d87" + */ + authval := header.Get("WWW-Authenticate") + hdrval := strings.SplitN(authval, " ", 2) + var realm, nonce string + + if len(hdrval) == 2 { + for _, field := range strings.Split(hdrval[1], ",") { + field = strings.Trim(field, ", ") + if keyval := strings.Split(field, "="); len(keyval) == 2 { + key := keyval[0] + val := strings.Trim(keyval[1], `"`) + switch key { + case "realm": + realm = val + case "nonce": + nonce = val + } + } + } + + if realm != "" { + var username string + var password string + + if self.url.User == nil { + err = fmt.Errorf("rtsp: no username") + return + } + username = self.url.User.Username() + password, _ = self.url.User.Password() + + self.authHeaders = func(method string) []string { + headers := []string{ + fmt.Sprintf(`Authorization: Basic %s`, base64.StdEncoding.EncodeToString([]byte(username+":"+password))), + } + if nonce != "" { + hs1 := md5hash(username + ":" + realm + ":" + password) + hs2 := md5hash(method + ":" + self.requestUri) + response := md5hash(hs1 + ":" + nonce + ":" + hs2) + headers = append(headers, fmt.Sprintf( + `Authorization: Digest username="%s", realm="%s", nonce="%s", uri="%s", response="%s"`, + username, realm, nonce, self.requestUri, response)) + } + return headers + } + } + } + + return +} + +func (self *Client) ReadResponse() (res Response, err error) { + if err = self.SendRtpKeepalive(); err != nil { + return + } + + self.conn.Timeout = self.RtspTimeout + + h := make([]byte, 16) + if _, err = io.ReadFull(self.rconn, h); err != nil { + return + } + + if h[0] == 82 && h[1] == 84 && h[2] == 83 && h[3] == 80 { + // RTSP 200 OK + self.rconn = io.MultiReader(bytes.NewReader(h), self.rconn) + } else { + return self.readBlock(h) + } + + br := bufio.NewReader(self.rconn) tp := textproto.NewReader(br) var line string @@ -291,73 +342,35 @@ func (self *Client) ReadResponse() (res Response, err error) { fmt.Println() } - if res.StatusCode != 200 && res.StatusCode != 401 { + switch res.StatusCode { + case 401: + if err = self.handle401(header); err != nil { + return + } + + case 200: + + default: err = fmt.Errorf("rtsp: StatusCode=%d invalid", res.StatusCode) return } - if res.StatusCode == 401 { - /* - RTSP/1.0 401 Unauthorized - CSeq: 2 - Date: Wed, May 04 2016 10:10:51 GMT - WWW-Authenticate: Digest realm="LIVE555 Streaming Media", nonce="c633aaf8b83127633cbe98fac1d20d87" - */ - authval := header.Get("WWW-Authenticate") - hdrval := strings.SplitN(authval, " ", 2) - var realm, nonce string - - if len(hdrval) == 2 { - for _, field := range strings.Split(hdrval[1], ",") { - field = strings.Trim(field, ", ") - if keyval := strings.Split(field, "="); len(keyval) == 2 { - key := keyval[0] - val := strings.Trim(keyval[1], `"`) - switch key { - case "realm": - realm = val - case "nonce": - nonce = val - } - } - } - - if realm != "" { - var username string - var password string - - if self.url.User == nil { - err = fmt.Errorf("rtsp: no username") - return - } - username = self.url.User.Username() - password, _ = self.url.User.Password() - - self.authHeaders = func(method string) []string { - headers := []string{ - fmt.Sprintf(`Authorization: Basic %s`, base64.StdEncoding.EncodeToString([]byte(username+":"+password))), - } - if nonce != "" { - hs1 := md5hash(username + ":" + realm + ":" + password) - hs2 := md5hash(method + ":" + self.requestUri) - response := md5hash(hs1 + ":" + nonce + ":" + hs2) - headers = append(headers, fmt.Sprintf( - `Authorization: Digest username="%s", realm="%s", nonce="%s", uri="%s", response="%s"`, - username, realm, nonce, self.requestUri, response)) - } - return headers - } - } - } - } - + res.ContentLength, _ = strconv.Atoi(header.Get("Content-Length")) if sess := header.Get("Session"); sess != "" && self.session == "" { if fields := strings.Split(sess, ";"); len(fields) > 0 { self.session = fields[0] } } - res.ContentLength, _ = strconv.Atoi(header.Get("Content-Length")) + buf, _ := br.Peek(br.Buffered()) + self.rconn = io.MultiReader(bytes.NewReader(buf), self.rconn) + + if res.ContentLength > 0 { + res.Body = make([]byte, res.ContentLength) + if _, err = io.ReadFull(self.rconn, res.Body); err != nil { + return + } + } return } @@ -425,7 +438,6 @@ func (self *Client) initstructs() { for i := range self.setupIdx { self.streamsintf[i] = self.streams[self.setupIdx[i]].CodecData } - self.corrector = pktque.NewTimeCorrector(self.streamsintf) } func (self *Client) Describe() (streams []av.CodecData, err error) { @@ -979,7 +991,6 @@ func (self *Client) ReadPacket() (pkt av.Packet, err error) { if self.DebugRtp { fmt.Println("rtp: pktin", pkt.Idx, pkt.Time, len(pkt.Data)) } - self.corrector.Correct(&pkt) if self.DebugRtp { fmt.Println("rtp: pktout", pkt.Idx, pkt.Time, len(pkt.Data)) } diff --git a/stream.go b/stream.go index 5715a24..44f8cc6 100644 --- a/stream.go +++ b/stream.go @@ -23,10 +23,3 @@ type Stream struct { firsttimestamp uint32 } -func (self Stream) IsAudio() bool { - return self.Sdp.AVType == "audio" -} - -func (self Stream) IsVideo() bool { - return self.Sdp.AVType == "video" -}