Skip to content
Snippets Groups Projects
Unverified Commit 8f3a2659 authored by Alice Ryhl's avatar Alice Ryhl Committed by GitHub
Browse files

net: introduce owned split on TcpStream (#2270)

parent 800574b4
No related branches found
No related tags found
Loading
...@@ -9,5 +9,8 @@ pub use incoming::Incoming; ...@@ -9,5 +9,8 @@ pub use incoming::Incoming;
mod split; mod split;
pub use split::{ReadHalf, WriteHalf}; pub use split::{ReadHalf, WriteHalf};
mod split_owned;
pub use split_owned::{OwnedReadHalf, OwnedWriteHalf, ReuniteError};
pub(crate) mod stream; pub(crate) mod stream;
pub(crate) use stream::TcpStream; pub(crate) use stream::TcpStream;
...@@ -25,8 +25,8 @@ pub struct ReadHalf<'a>(&'a TcpStream); ...@@ -25,8 +25,8 @@ pub struct ReadHalf<'a>(&'a TcpStream);
/// Write half of a `TcpStream`. /// Write half of a `TcpStream`.
/// ///
/// Note that in the `AsyncWrite` implemenation of `TcpStreamWriteHalf`, /// Note that in the `AsyncWrite` implemenation of this type, `poll_shutdown` will
/// `poll_shutdown` actually shuts down the TCP stream in the write direction. /// shut down the TCP stream in the write direction.
#[derive(Debug)] #[derive(Debug)]
pub struct WriteHalf<'a>(&'a TcpStream); pub struct WriteHalf<'a>(&'a TcpStream);
......
//! `TcpStream` owned split support.
//!
//! A `TcpStream` can be split into an `OwnedReadHalf` and a `OwnedWriteHalf`
//! with the `TcpStream::into_split` method. `OwnedReadHalf` implements
//! `AsyncRead` while `OwnedWriteHalf` implements `AsyncWrite`.
//!
//! Compared to the generic split of `AsyncRead + AsyncWrite`, this specialized
//! split has no associated overhead and enforces all invariants at the type
//! level.
use crate::future::poll_fn;
use crate::io::{AsyncRead, AsyncWrite};
use crate::net::TcpStream;
use bytes::Buf;
use std::error::Error;
use std::mem::MaybeUninit;
use std::net::Shutdown;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::{fmt, io};
/// Owned read half of a [`TcpStream`], created by [`into_split`].
///
/// [`TcpStream`]: TcpStream
/// [`into_split`]: TcpStream::into_split()
#[derive(Debug)]
pub struct OwnedReadHalf {
inner: Arc<TcpStream>,
}
/// Owned write half of a [`TcpStream`], created by [`into_split`].
///
/// Note that in the `AsyncWrite` implemenation of this type, `poll_shutdown` will
/// shut down the TCP stream in the write direction.
///
/// Dropping the write half will close the TCP stream in both directions.
///
/// [`TcpStream`]: TcpStream
/// [`into_split`]: TcpStream::into_split()
#[derive(Debug)]
pub struct OwnedWriteHalf {
inner: Arc<TcpStream>,
shutdown_on_drop: bool,
}
pub(crate) fn split_owned(stream: TcpStream) -> (OwnedReadHalf, OwnedWriteHalf) {
let arc = Arc::new(stream);
let read = OwnedReadHalf {
inner: Arc::clone(&arc),
};
let write = OwnedWriteHalf {
inner: arc,
shutdown_on_drop: true,
};
(read, write)
}
pub(crate) fn reunite(
read: OwnedReadHalf,
write: OwnedWriteHalf,
) -> Result<TcpStream, ReuniteError> {
if Arc::ptr_eq(&read.inner, &write.inner) {
write.forget();
// This unwrap cannot fail as the api does not allow creating more than two Arcs,
// and we just dropped the other half.
Ok(Arc::try_unwrap(read.inner).expect("Too many handles to Arc"))
} else {
Err(ReuniteError(read, write))
}
}
/// Error indicating two halves were not from the same socket, and thus could
/// not be reunited.
#[derive(Debug)]
pub struct ReuniteError(pub OwnedReadHalf, pub OwnedWriteHalf);
impl fmt::Display for ReuniteError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"tried to reunite halves that are not from the same socket"
)
}
}
impl Error for ReuniteError {}
impl OwnedReadHalf {
/// Attempts to put the two halves of a `TcpStream` back together and
/// recover the original socket. Succeeds only if the two halves
/// originated from the same call to [`into_split`].
///
/// [`into_split`]: TcpStream::into_split()
pub fn reunite(self, other: OwnedWriteHalf) -> Result<TcpStream, ReuniteError> {
reunite(self, other)
}
/// Attempt to receive data on the socket, without removing that data from
/// the queue, registering the current task for wakeup if data is not yet
/// available.
///
/// See the [`TcpStream::poll_peek`] level documenation for more details.
///
/// # Examples
///
/// ```no_run
/// use tokio::io;
/// use tokio::net::TcpStream;
///
/// use futures::future::poll_fn;
///
/// #[tokio::main]
/// async fn main() -> io::Result<()> {
/// let stream = TcpStream::connect("127.0.0.1:8000").await?;
/// let (mut read_half, _) = stream.into_split();
/// let mut buf = [0; 10];
///
/// poll_fn(|cx| {
/// read_half.poll_peek(cx, &mut buf)
/// }).await?;
///
/// Ok(())
/// }
/// ```
///
/// [`TcpStream::poll_peek`]: TcpStream::poll_peek
pub fn poll_peek(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
self.inner.poll_peek2(cx, buf)
}
/// Receives data on the socket from the remote address to which it is
/// connected, without removing that data from the queue. On success,
/// returns the number of bytes peeked.
///
/// See the [`TcpStream::peek`] level documenation for more details.
///
/// # Examples
///
/// ```no_run
/// use tokio::net::TcpStream;
/// use tokio::prelude::*;
/// use std::error::Error;
///
/// #[tokio::main]
/// async fn main() -> Result<(), Box<dyn Error>> {
/// // Connect to a peer
/// let stream = TcpStream::connect("127.0.0.1:8080").await?;
/// let (mut read_half, _) = stream.into_split();
///
/// let mut b1 = [0; 10];
/// let mut b2 = [0; 10];
///
/// // Peek at the data
/// let n = read_half.peek(&mut b1).await?;
///
/// // Read the data
/// assert_eq!(n, read_half.read(&mut b2[..n]).await?);
/// assert_eq!(&b1[..n], &b2[..n]);
///
/// Ok(())
/// }
/// ```
///
/// [`TcpStream::peek`]: TcpStream::peek
pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> {
poll_fn(|cx| self.poll_peek(cx, buf)).await
}
}
impl AsyncRead for OwnedReadHalf {
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit<u8>]) -> bool {
false
}
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
self.inner.poll_read_priv(cx, buf)
}
}
impl OwnedWriteHalf {
/// Attempts to put the two halves of a `TcpStream` back together and
/// recover the original socket. Succeeds only if the two halves
/// originated from the same call to [`into_split`].
///
/// [`into_split`]: TcpStream::into_split()
pub fn reunite(self, other: OwnedReadHalf) -> Result<TcpStream, ReuniteError> {
reunite(other, self)
}
/// Destroy the write half, but don't close the stream until the read half
/// is dropped. If the read half has already been dropped, this closes the
/// stream.
pub fn forget(mut self) {
self.shutdown_on_drop = false;
drop(self);
}
}
impl Drop for OwnedWriteHalf {
fn drop(&mut self) {
if self.shutdown_on_drop {
let _ = self.inner.shutdown(Shutdown::Both);
}
}
}
impl AsyncWrite for OwnedWriteHalf {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.inner.poll_write_priv(cx, buf)
}
fn poll_write_buf<B: Buf>(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>> {
self.inner.poll_write_buf_priv(cx, buf)
}
#[inline]
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
// tcp flush is a no-op
Poll::Ready(Ok(()))
}
// `poll_shutdown` on a write half shutdowns the stream in the "write" direction.
fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
self.inner.shutdown(Shutdown::Write).into()
}
}
impl AsRef<TcpStream> for OwnedReadHalf {
fn as_ref(&self) -> &TcpStream {
&*self.inner
}
}
impl AsRef<TcpStream> for OwnedWriteHalf {
fn as_ref(&self) -> &TcpStream {
&*self.inner
}
}
use crate::future::poll_fn; use crate::future::poll_fn;
use crate::io::{AsyncRead, AsyncWrite, PollEvented}; use crate::io::{AsyncRead, AsyncWrite, PollEvented};
use crate::net::tcp::split::{split, ReadHalf, WriteHalf}; use crate::net::tcp::split::{split, ReadHalf, WriteHalf};
use crate::net::tcp::split_owned::{split_owned, OwnedReadHalf, OwnedWriteHalf};
use crate::net::ToSocketAddrs; use crate::net::ToSocketAddrs;
use bytes::Buf; use bytes::Buf;
...@@ -614,10 +615,26 @@ impl TcpStream { ...@@ -614,10 +615,26 @@ impl TcpStream {
/// Splits a `TcpStream` into a read half and a write half, which can be used /// Splits a `TcpStream` into a read half and a write half, which can be used
/// to read and write the stream concurrently. /// to read and write the stream concurrently.
///
/// This method is more efficient than [`into_split`], but the halves cannot be
/// moved into independently spawned tasks.
///
/// [`into_split`]: TcpStream::into_split()
pub fn split(&mut self) -> (ReadHalf<'_>, WriteHalf<'_>) { pub fn split(&mut self) -> (ReadHalf<'_>, WriteHalf<'_>) {
split(self) split(self)
} }
/// Splits a `TcpStream` into a read half and a write half, which can be used
/// to read and write the stream concurrently.
///
/// Unlike [`split`], the owned halves can be moved to separate tasks, however
/// this comes at the cost of a heap allocation.
///
/// [`split`]: TcpStream::split()
pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) {
split_owned(self)
}
// == Poll IO functions that takes `&self` == // == Poll IO functions that takes `&self` ==
// //
// They are not public because (taken from the doc of `PollEvented`): // They are not public because (taken from the doc of `PollEvented`):
......
#![warn(rust_2018_idioms)]
#![cfg(feature = "full")]
use std::io::{Error, ErrorKind, Result};
use std::io::{Read, Write};
use std::sync::{Arc, Barrier};
use std::{net, thread};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::try_join;
#[tokio::test]
async fn split() -> Result<()> {
const MSG: &[u8] = b"split";
let mut listener = TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let (stream1, (mut stream2, _)) = try_join! {
TcpStream::connect(&addr),
listener.accept(),
}?;
let (mut read_half, mut write_half) = stream1.into_split();
let ((), (), ()) = try_join! {
async {
let len = stream2.write(MSG).await?;
assert_eq!(len, MSG.len());
let mut read_buf = vec![0u8; 32];
let read_len = stream2.read(&mut read_buf).await?;
assert_eq!(&read_buf[..read_len], MSG);
Result::Ok(())
},
async {
let len = write_half.write(MSG).await?;
assert_eq!(len, MSG.len());
Ok(())
},
async {
let mut read_buf = vec![0u8; 32];
let peek_len1 = read_half.peek(&mut read_buf[..]).await?;
let peek_len2 = read_half.peek(&mut read_buf[..]).await?;
assert_eq!(peek_len1, peek_len2);
let read_len = read_half.read(&mut read_buf[..]).await?;
assert_eq!(peek_len1, read_len);
assert_eq!(&read_buf[..read_len], MSG);
Ok(())
},
}?;
Ok(())
}
#[tokio::test]
async fn reunite() -> Result<()> {
let listener = net::TcpListener::bind("127.0.0.1:0")?;
let addr = listener.local_addr()?;
let handle = thread::spawn(move || {
drop(listener.accept().unwrap());
drop(listener.accept().unwrap());
});
let stream1 = TcpStream::connect(&addr).await?;
let (read1, write1) = stream1.into_split();
let stream2 = TcpStream::connect(&addr).await?;
let (_, write2) = stream2.into_split();
let read1 = match read1.reunite(write2) {
Ok(_) => panic!("Reunite should not succeed"),
Err(err) => err.0,
};
read1.reunite(write1).expect("Reunite should succeed");
handle.join().unwrap();
Ok(())
}
/// Test that dropping the write half actually closes the stream.
#[tokio::test]
async fn drop_write() -> Result<()> {
const MSG: &[u8] = b"split";
let listener = net::TcpListener::bind("127.0.0.1:0")?;
let addr = listener.local_addr()?;
let barrier = Arc::new(Barrier::new(2));
let barrier2 = barrier.clone();
let handle = thread::spawn(move || {
let (mut stream, _) = listener.accept().unwrap();
stream.write(MSG).unwrap();
let mut read_buf = [0u8; 32];
let res = match stream.read(&mut read_buf) {
Ok(0) => Ok(()),
Ok(len) => Err(Error::new(
ErrorKind::Other,
format!("Unexpected read: {} bytes.", len),
)),
Err(err) => Err(err),
};
barrier2.wait();
drop(stream);
res
});
let stream = TcpStream::connect(&addr).await?;
let (mut read_half, write_half) = stream.into_split();
let mut read_buf = [0u8; 32];
let read_len = read_half.read(&mut read_buf[..]).await?;
assert_eq!(&read_buf[..read_len], MSG);
// drop it while the read is in progress
std::thread::spawn(move || {
thread::sleep(std::time::Duration::from_millis(50));
drop(write_half);
});
match read_half.read(&mut read_buf[..]).await {
Ok(0) => {}
Ok(len) => panic!("Unexpected read: {} bytes.", len),
Err(err) => panic!("Unexpected error: {}.", err),
}
barrier.wait();
handle.join().unwrap().unwrap();
Ok(())
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment