diff --git a/client.go b/client.go index ae6d663..b5c8509 100644 --- a/client.go +++ b/client.go @@ -29,15 +29,15 @@ type Client struct { Headers []string RtspTimeout time.Duration - RtpFirstReadTimeout time.Duration - RtpReadTimeout time.Duration + RtpTimeout time.Duration RtpKeepAliveTimeout time.Duration + rtpKeepaliveTimer time.Time setupCalled bool playCalled bool url *url.URL - conn net.Conn + conn *connWithTimeout rconn io.Reader requestUri string cseq uint @@ -84,9 +84,11 @@ func DialTimeout(uri string, timeout time.Duration) (self *Client, err error) { u2 := *URL u2.User = nil + connt := &connWithTimeout{Conn: conn} + self = &Client{ - conn: conn, - rconn: conn, + conn: connt, + rconn: connt, url: URL, requestUri: u2.String(), } @@ -112,7 +114,21 @@ func (self *Client) writeLine(line string) (err error) { return } +func (self *Client) sendRtpKeepalive() (err error) { + if self.RtpKeepAliveTimeout > 0 { + if !self.rtpKeepaliveTimer.IsZero() && time.Now().Sub(self.rtpKeepaliveTimer) > self.RtpKeepAliveTimeout { + if err = self.Options(); err != nil { + return + } + } + self.rtpKeepaliveTimer = time.Now() + } + return +} + func (self *Client) WriteRequest(req Request) (err error) { + self.conn.Timeout = self.RtspTimeout + self.cseq++ req.Header = append(req.Header, self.Headers...) req.Header = append(req.Header, fmt.Sprintf("CSeq: %d", self.cseq)) @@ -142,6 +158,8 @@ func (self *Client) ReadResponse() (res Response, err error) { self.rconn = io.MultiReader(bytes.NewReader(buf), self.rconn) } if res.StatusCode == 200 { + self.conn.Timeout = self.RtspTimeout + if res.ContentLength > 0 { res.Body = make([]byte, res.ContentLength) if _, err = io.ReadFull(self.rconn, res.Body); err != nil { @@ -149,6 +167,10 @@ func (self *Client) ReadResponse() (res Response, err error) { } } } else if res.BlockLength > 0 { + if err = self.sendRtpKeepalive(); err != nil { + return + } + self.conn.Timeout = self.RtpTimeout res.Block = make([]byte, res.BlockLength) if _, err = io.ReadFull(self.rconn, res.Block); err != nil { return @@ -156,6 +178,8 @@ func (self *Client) ReadResponse() (res Response, err error) { } }() + self.conn.Timeout = self.RtspTimeout + var h [4]byte if _, err = io.ReadFull(self.rconn, h[:]); err != nil { return @@ -173,6 +197,8 @@ func (self *Client) ReadResponse() (res Response, err error) { // RTSP 200 OK self.rconn = io.MultiReader(bytes.NewReader(h[:]), self.rconn) } else { + self.conn.Timeout = self.RtpTimeout + for { for { var b [1]byte diff --git a/conn.go b/conn.go new file mode 100644 index 0000000..621895d --- /dev/null +++ b/conn.go @@ -0,0 +1,26 @@ +package rtsp + +import ( + "net" + "time" +) + +type connWithTimeout struct { + Timeout time.Duration + net.Conn +} + +func (self connWithTimeout) Read(p []byte) (n int, err error) { + if self.Timeout > 0 { + self.Conn.SetReadDeadline(time.Now().Add(self.Timeout)) + } + return self.Conn.Read(p) +} + +func (self connWithTimeout) Write(p []byte) (n int, err error) { + if self.Timeout > 0 { + self.Conn.SetWriteDeadline(time.Now().Add(self.Timeout)) + } + return self.Conn.Write(p) +} +