Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions actix-http/CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## Unreleased

- When configured, gracefully close HTTP/1 connections after early responses to unread request bodies. [#3967]

[#3967]: https://github.com/actix/actix-web/issues/3967

## 3.12.1

**Notice: This release contains a security fix. Users are encouraged to update to this version ASAP.**
Expand Down
219 changes: 177 additions & 42 deletions actix-http/src/h1/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use crate::{
config::ServiceConfig,
error::{DispatchError, ParseError, PayloadError},
service::HttpFlow,
Error, Extensions, HttpMessage, OnConnectData, Request, Response, StatusCode,
ConnectionType, Error, Extensions, HttpMessage, OnConnectData, Request, Response, StatusCode,
};

const LW_BUFFER_SIZE: usize = 1024;
Expand All @@ -58,6 +58,9 @@ bitflags! {

/// Set if write-half is disconnected.
const WRITE_DISCONNECT = 0b0010_0000;

/// Set while gracefully closing a connection after an early response.
const LINGER = 0b0100_0000;
}
}

Expand Down Expand Up @@ -361,6 +364,65 @@ where
io.poll_flush(cx)
}

fn enter_linger(mut self: Pin<&mut Self>) {
let this = self.as_mut().project();
this.flags.remove(Flags::KEEP_ALIVE);
this.flags.insert(Flags::LINGER | Flags::FINISHED);
}

fn ensure_linger_timer(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> bool {
let this = self.as_mut().project();

if matches!(this.shutdown_timer, TimerState::Active { .. }) {
return true;
}

if let Some(deadline) = this.config.client_disconnect_deadline() {
this.shutdown_timer
.set_and_init(cx, sleep_until(deadline.into()), line!());
true
} else {
false
}
}

fn poll_linger(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Result<Poll<()>, DispatchError> {
if self.as_mut().poll_flush(cx)?.is_pending() {
return Ok(Poll::Pending);
}

if !self.as_mut().ensure_linger_timer(cx) {
let this = self.as_mut().project();
this.flags.remove(Flags::LINGER);
this.flags.insert(Flags::SHUTDOWN);
return Ok(Poll::Ready(()));
}

loop {
let should_disconnect = self.as_mut().read_available(cx)?;
let this = self.as_mut().project();
let mut progressed = false;

if !this.read_buf.is_empty() {
this.read_buf.clear();
progressed = true;
}

if should_disconnect {
this.flags.remove(Flags::LINGER);
this.flags.insert(Flags::READ_DISCONNECT | Flags::SHUTDOWN);
return Ok(Poll::Ready(()));
}

if !progressed {
return Ok(Poll::Pending);
}
}
}

fn send_response_inner(
self: Pin<&mut Self>,
res: Response<()>,
Expand All @@ -385,54 +447,90 @@ where

fn send_response(
mut self: Pin<&mut Self>,
res: Response<()>,
mut res: Response<()>,
body: B,
) -> Result<(), DispatchError> {
let close_after_response = {
let this = self.as_mut().project();
should_close_after_response(this.payload.as_ref(), *this.payload_drainable)
};

if close_after_response {
res.head_mut().set_connection_type(ConnectionType::Close);
}

let size = self.as_mut().send_response_inner(res, &body)?;
let mut this = self.project();
this.state.set(match size {
match size {
BodySize::None | BodySize::Sized(0) => {
let payload_unfinished = this.payload.is_some();
let drain_payload = this.payload.as_ref().is_some_and(|pl| pl.is_dropped())
&& *this.payload_drainable;
let this = self.as_mut().project();

if payload_unfinished && !drain_payload {
this.flags.insert(Flags::SHUTDOWN | Flags::FINISHED);
if close_after_response {
if this.config.client_disconnect_deadline().is_some() {
drop(this);
self.as_mut().enter_linger();
} else {
self.as_mut()
.project()
.flags
.insert(Flags::SHUTDOWN | Flags::FINISHED);
}
} else {
this.flags.insert(Flags::FINISHED);
}

State::None
self.as_mut().project().state.set(State::None);
}
_ => State::SendPayload { body },
});
_ => self
.as_mut()
.project()
.state
.set(State::SendPayload { body }),
}

Ok(())
}

fn send_error_response(
mut self: Pin<&mut Self>,
res: Response<()>,
mut res: Response<()>,
body: BoxBody,
) -> Result<(), DispatchError> {
let close_after_response = {
let this = self.as_mut().project();
should_close_after_response(this.payload.as_ref(), *this.payload_drainable)
};

if close_after_response {
res.head_mut().set_connection_type(ConnectionType::Close);
}

let size = self.as_mut().send_response_inner(res, &body)?;
let mut this = self.project();
this.state.set(match size {
match size {
BodySize::None | BodySize::Sized(0) => {
let payload_unfinished = this.payload.is_some();
let drain_payload = this.payload.as_ref().is_some_and(|pl| pl.is_dropped())
&& *this.payload_drainable;
let this = self.as_mut().project();

if payload_unfinished && !drain_payload {
this.flags.insert(Flags::SHUTDOWN | Flags::FINISHED);
if close_after_response {
if this.config.client_disconnect_deadline().is_some() {
drop(this);
self.as_mut().enter_linger();
} else {
self.as_mut()
.project()
.flags
.insert(Flags::SHUTDOWN | Flags::FINISHED);
}
} else {
this.flags.insert(Flags::FINISHED);
}

State::None
self.as_mut().project().state.set(State::None);
}
_ => State::SendErrorPayload { body },
});
_ => self
.as_mut()
.project()
.state
.set(State::SendErrorPayload { body }),
}

Ok(())
}
Expand Down Expand Up @@ -534,18 +632,26 @@ where
// this.payload was the payload for the request we just finished
// responding to. We can check to see if we finished reading it
// yet, and if not, shutdown the connection.
let payload_unfinished = this.payload.is_some();
let drain_payload =
this.payload.as_ref().is_some_and(|pl| pl.is_dropped())
&& *this.payload_drainable;
let close_after_response = should_close_after_response(
this.payload.as_ref(),
*this.payload_drainable,
);
let not_pipelined = this.messages.is_empty();

// payload stream finished.
// set state to None and handle next message
this.state.set(State::None);

if not_pipelined && payload_unfinished && !drain_payload {
this.flags.insert(Flags::SHUTDOWN | Flags::FINISHED);
if not_pipelined && close_after_response {
if this.config.client_disconnect_deadline().is_some() {
drop(this);
self.as_mut().enter_linger();
} else {
self.as_mut()
.project()
.flags
.insert(Flags::SHUTDOWN | Flags::FINISHED);
}
} else {
this.flags.insert(Flags::FINISHED);
}
Expand Down Expand Up @@ -588,18 +694,26 @@ where
// this.payload was the payload for the request we just finished
// responding to. We can check to see if we finished reading it
// yet, and if not, shutdown the connection.
let payload_unfinished = this.payload.is_some();
let drain_payload =
this.payload.as_ref().is_some_and(|pl| pl.is_dropped())
&& *this.payload_drainable;
let close_after_response = should_close_after_response(
this.payload.as_ref(),
*this.payload_drainable,
);
let not_pipelined = this.messages.is_empty();

// payload stream finished.
// set state to None and handle next message
this.state.set(State::None);

if not_pipelined && payload_unfinished && !drain_payload {
this.flags.insert(Flags::SHUTDOWN | Flags::FINISHED);
if not_pipelined && close_after_response {
if this.config.client_disconnect_deadline().is_some() {
drop(this);
self.as_mut().enter_linger();
} else {
self.as_mut()
.project()
.flags
.insert(Flags::SHUTDOWN | Flags::FINISHED);
}
} else {
this.flags.insert(Flags::FINISHED);
}
Expand Down Expand Up @@ -960,14 +1074,20 @@ where
let this = self.as_mut().project();
if let TimerState::Active { timer } = this.shutdown_timer {
debug_assert!(
this.flags.contains(Flags::SHUTDOWN),
"shutdown flag should be set when timer is active",
this.flags.intersects(Flags::LINGER | Flags::SHUTDOWN),
"shutdown or linger flag should be set when timer is active",
);

// timed-out during shutdown; drop connection
if timer.as_mut().poll(cx).is_ready() {
trace!("timed-out during shutdown");
return Err(DispatchError::DisconnectTimeout);
if this.flags.contains(Flags::LINGER) {
trace!("timed-out during linger; shutting down connection");
this.flags.remove(Flags::LINGER);
this.flags.insert(Flags::SHUTDOWN);
this.shutdown_timer.clear(line!());
} else {
trace!("timed-out during shutdown");
return Err(DispatchError::DisconnectTimeout);
}
}
}

Expand Down Expand Up @@ -1133,7 +1253,15 @@ where

inner.as_mut().poll_timers(cx)?;

let poll = if inner.flags.contains(Flags::SHUTDOWN) {
let poll = if inner.flags.contains(Flags::LINGER) {
match inner.as_mut().poll_linger(cx)? {
Poll::Ready(()) => {
cx.waker().wake_by_ref();
Poll::Pending
}
Poll::Pending => Poll::Pending,
}
} else if inner.flags.contains(Flags::SHUTDOWN) {
if inner.flags.contains(Flags::WRITE_DISCONNECT) {
Poll::Ready(Ok(()))
} else {
Expand Down Expand Up @@ -1281,7 +1409,7 @@ where
inner_p.shutdown_timer,
);

if inner_p.flags.contains(Flags::SHUTDOWN) {
if inner_p.flags.intersects(Flags::LINGER | Flags::SHUTDOWN) {
cx.waker().wake_by_ref();
}
Poll::Pending
Expand All @@ -1295,6 +1423,13 @@ where
}
}

fn should_close_after_response(payload: Option<&PayloadSender>, payload_drainable: bool) -> bool {
let payload_unfinished = payload.is_some();
let drain_payload = payload.is_some_and(|pl| pl.is_dropped()) && payload_drainable;

payload_unfinished && !drain_payload
}

#[allow(dead_code)]
fn trace_timer_states(
label: &str,
Expand Down
Loading
Loading