diff --git a/client.go b/client.go index 60add71..7b98380 100644 --- a/client.go +++ b/client.go @@ -38,6 +38,8 @@ type Client struct { setupMap []int playCalled bool + authHeaders func(method string) []string + url *url.URL conn *connWithTimeout rconn io.Reader @@ -45,7 +47,6 @@ type Client struct { cseq uint streams []*Stream session string - authorization []string body io.Reader pktque *pktqueue.Queue } @@ -113,14 +114,6 @@ func (self *Client) Streams() (streams []av.CodecData, err error) { return } -func (self *Client) writeLine(line string) (err error) { - if self.DebugConn { - fmt.Print("> ", line) - } - _, err = fmt.Fprint(self.conn, line) - return -} - func (self *Client) sendRtpKeepalive() (err error) { if self.RtpKeepAliveTimeout > 0 { if self.rtpKeepaliveTimer.IsZero() { @@ -140,24 +133,40 @@ func (self *Client) sendRtpKeepalive() (err error) { func (self *Client) WriteRequest(req Request) (err error) { self.conn.Timeout = self.RtspTimeout - self.cseq++ - req.Header = append(req.Header, fmt.Sprintf("CSeq: %d", self.cseq)) - for _, s := range self.authorization { - req.Header = append(req.Header, "Authorization: "+s) - } - req.Header = append(req.Header, self.Headers...) - if err = self.writeLine(fmt.Sprintf("%s %s RTSP/1.0\r\n", req.Method, req.Uri)); err != nil { - return - } - for _, v := range req.Header { - if err = self.writeLine(fmt.Sprint(v, "\r\n")); err != nil { - return + + buf := &bytes.Buffer{} + + fmt.Fprintf(buf, "%s %s RTSP/1.0\r\n", req.Method, req.Uri) + fmt.Fprintf(buf, "CSeq: %d\r\n", self.cseq) + + if self.authHeaders != nil { + headers := self.authHeaders(req.Method) + for _, s := range headers { + io.WriteString(buf, s) + io.WriteString(buf, "\r\n") } } - if err = self.writeLine("\r\n"); err != nil { + for _, s := range req.Header { + io.WriteString(buf, s) + io.WriteString(buf, "\r\n") + } + for _, s := range self.Headers { + io.WriteString(buf, s) + io.WriteString(buf, "\r\n") + } + io.WriteString(buf, "\r\n") + + bufout := buf.Bytes() + + if self.DebugConn { + fmt.Print("> ", string(bufout)) + } + + if _, err = self.conn.Write(bufout); err != nil { return } + return } @@ -268,7 +277,10 @@ func (self *Client) ReadResponse() (res Response, err error) { } if self.DebugConn { - fmt.Println("<", header) + for k, s := range header { + fmt.Println(k, s) + } + fmt.Println() } if res.StatusCode != 200 && res.StatusCode != 401 { @@ -316,12 +328,15 @@ func (self *Client) ReadResponse() (res Response, err error) { return } hs1 := md5hash(username+":"+realm+":"+password) - hs2 := md5hash("DESCRIBE:"+self.requestUri) - response := md5hash(hs1+":"+nonce+":"+hs2) - self.authorization = []string{ - fmt.Sprintf(`Digest username="%s", realm="%s", nonce="%s", uri="%s", response="%s"`, - username, realm, nonce, self.requestUri, response), - fmt.Sprintf(`Basic %s`, base64.StdEncoding.EncodeToString([]byte(username+":"+password))), + + self.authHeaders = func(method string) []string { + hs2 := md5hash(method+":"+self.requestUri) + response := md5hash(hs1+":"+nonce+":"+hs2) + return []string{ + fmt.Sprintf(`Authorization: Digest username="%s", realm="%s", nonce="%s", uri="%s", response="%s"`, + username, realm, nonce, self.requestUri, response), + fmt.Sprintf(`Authorization: Basic %s`, base64.StdEncoding.EncodeToString([]byte(username+":"+password))), + } } } } @@ -651,6 +666,7 @@ func (self *Stream) handlePacket(timestamp uint32, packet []byte) (err error) { self.pkt.Data = packet self.timestamp = timestamp } + return }