I'm attempting to write a simple async https proxy server in c#.
I would like to know how I should detect/handle when the request is complete, and how to exit my bActive loop, assuming a loop like this is appropriate.
Would really appreciate some pointers on if my approach is correct and what I could do to improve the logic.
The issue I seem to be running into is that the time it takes for an endpoint to respond along with the network delay means I DataAvailable
doenst always have data but there may still be some sending. Requiring a sleep and another attmempt which in turn causes the long completion time in requests.
Listen for TCP connection
Extract CONNECT header and open a connection to the requested server
Copy the requestStream to proxyStream
Copy the proxyStream to the requestStream
Sleep waiting for data and repeat 3 - 4 until no data is avaiable on both streams. Then break out of the loop and close connection.
public async Task Start()
{
listener.Start();
while (listen)
{
if (listener.Pending())
{
HandleClient(await listener.AcceptTcpClientAsync());
}
else
{
await Task.Delay(100); //<--- timeout
}
}
}
private static async Task HandleClient(TcpClient clt)
{
var bytes = new byte[clt.ReceiveBufferSize];
var hostHeaderAvailable = 0;
NetworkStream requestStream = null;
int count;
const string connectText = "connect";
const string hostText = "Host: ";
bool bActive = true;
List<Task> tasks = new List<Task>();
try
{
using (NetworkStream proxyStream = clt.GetStream())
using (TcpClient requestClient = new TcpClient())
{
proxyStream.ReadTimeout = 100;
proxyStream.WriteTimeout = 100;
while (bActive)
{
if (proxyStream.DataAvailable && hostHeaderAvailable == 0)
{
count = await proxyStream.ReadAsync(bytes, 0, bytes.Length);
var text = Encoding.UTF8.GetString(bytes);
Console.WriteLine(text);
if (text.ToLower().StartsWith(connectText))
{
// extract the url and port
var host = text.Remove(0, connectText.Length + 1);
var hostIndex = host.IndexOf(" ", StringComparison.Ordinal);
var hostEntry = host.Remove(hostIndex).Split(new[] { ":" }, StringSplitOptions.None);
// connect to the url and prot supplied
await requestClient.ConnectAsync(hostEntry[0], Convert.ToInt32(hostEntry[1]));
requestStream = requestClient.GetStream();
requestStream.ReadTimeout = 100;
requestStream.WriteTimeout = 100;
// send 200 response to proxyStream
const string sslResponse = "HTTP/1.0 200 Connection established\r\n\r\n";
var sslResponseBytes = Encoding.UTF8.GetBytes(sslResponse);
await proxyStream.WriteAsync(sslResponseBytes, 0, sslResponseBytes.Length);
// delay here seems to prevent the following proxyStream.read from failing as data is not yet avaiable
// without it the loop runs and has to timeout before running again
await Task.Delay(1);
}
}
hostHeaderAvailable++;
if (requestStream == null || !requestClient.Connected || !clt.Connected)
{
bActive = false;
break;
}
Console.WriteLine(proxyStream.DataAvailable || requestStream.DataAvailable);
if (proxyStream.DataAvailable || requestStream.DataAvailable)
{
Task task = proxyStream.CopyToAsync(requestStream);
Task task2 = requestStream.CopyToAsync(proxyStream);
tasks.Add(task);
tasks.Add(task2);
await Task.WhenAll(tasks).ConfigureAwait(false);
bActive = false;
break;
}
await Task.Delay(10);
}
}
}
catch (Exception e)
{
Console.WriteLine(e.ToString());
}
clt.Close();
}
An older attempt that used ReadAsync
/WriteAsync
too longer to response and still had the timeout issue.
Listen for TCP connection
Extract CONNECT header and open a connection to the requested server
Read data from requestStream and copy to proxyStream
Wait checking if data is avaiable on either stream
If data avaiable read from proxyStream and write to requestStream
If data avaiable read from requestStream and write to proxyStream
Sleep waiting for data and repeat 5 - 6 until no data is avaiable on eitboth streams. Then break out of the loop and close connection.
private static TcpListener listener = new TcpListener(IPAddress.Parse("192.168.0.25"), 13000);
private static bool listen = true;
public async Task Start()
{
listener.Start();
while (listen)
{
if (listener.Pending())
{
await HandleClient(await listener.AcceptTcpClientAsync());
}
else
{
await Task.Delay(100);
}
}
}
private static async Task HandleClient(TcpClient clt)
{
var bytes = new byte[clt.ReceiveBufferSize];
var hostHeaderAvailable = 0;
NetworkStream requestStream = null;
int count;
const string connectText = "connect";
const string hostText = "Host: ";
bool bActive = true;
try
{
using (NetworkStream proxyStream = clt.GetStream())
using (TcpClient requestClient = new TcpClient())
{
while (bActive)
{
while (proxyStream.DataAvailable)
{
// handle connect
if (hostHeaderAvailable == 0)
{
count = await proxyStream.ReadAsync(bytes, 0, bytes.Length);
var text = Encoding.UTF8.GetString(bytes);
Console.WriteLine(text);
if (text.ToLower().StartsWith(connectText))
{
// extract the url and port
var host = text.Remove(0, connectText.Length + 1);
var hostIndex = host.IndexOf(" ", StringComparison.Ordinal);
var hostEntry = host.Remove(hostIndex).Split(new[] { ":" }, StringSplitOptions.None);
// connect to the url and prot supplied
await requestClient.ConnectAsync(hostEntry[0], Convert.ToInt32(hostEntry[1]));
requestStream = requestClient.GetStream();
// send 200 response to proxyStream
const string sslResponse = "HTTP/1.0 200 Connection established\r\n\r\n";
var sslResponseBytes = Encoding.UTF8.GetBytes(sslResponse);
await proxyStream.WriteAsync(sslResponseBytes, 0, sslResponseBytes.Length);
// delay here seems to prevent the following proxyStream.read from failing as data is not yet avaiable
// without it the loop runs and has to timeout before running again
await Task.Delay(20);
}
}
hostHeaderAvailable++;
if (requestClient.Connected && hostHeaderAvailable > 1)
{
count = await proxyStream.ReadAsync(bytes, 0, bytes.Length);
await requestStream.WriteAsync(bytes, 0, count);
}
}
while (requestStream.DataAvailable)
{
count = await requestStream.ReadAsync(bytes, 0, bytes.Length);
await proxyStream.WriteAsync(bytes, 0, count);
}
// attempt to detect a timeout / end of data avaiable
var timeout = 0;
while (!proxyStream.DataAvailable && !requestStream.DataAvailable)
{
if (timeout > 5)
{
bActive = false;
break;
}
await Task.Delay(10);
timeout++;
}
}
}
}
catch (Exception e)
{
Console.WriteLine(e.ToString());
}
}
UPDATE
As per AgentFire's answer I have now come to the following working code:
public static async Task HandleDisconnect(TcpClient tcp, TcpClient tcp2, CancellationToken cancellationToken)
{
while (true)
{
if (tcp.Client.Poll(0, SelectMode.SelectRead))
{
byte[] buff = new byte[1];
if (tcp.Client.Receive(buff, SocketFlags.Peek) == 0)
{
// Client disconnected
Console.WriteLine("The requesting client has dropped its connection.");
cancellationToken = new CancellationToken(true);
break;
}
}
if (tcp2.Client.Poll(0, SelectMode.SelectRead))
{
byte[] buff = new byte[1];
if (tcp2.Client.Receive(buff, SocketFlags.Peek) == 0)
{
// Server disconnected
Console.WriteLine("The destination client has dropped its connection.");
cancellationToken = new CancellationToken(true);
break;
}
}
await Task.Delay(1);
}
}
private static async Task HandleClient(TcpClient clt)
{
List<Task> tasks = new List<Task>();
var bytes = new byte[clt.ReceiveBufferSize];
var hostHeaderAvailable = 0;
NetworkStream requestStream = null;
const string connectText = "connect";
try
{
using (NetworkStream proxyStream = clt.GetStream())
using (TcpClient requestClient = new TcpClient())
{
proxyStream.ReadTimeout = 100;
proxyStream.WriteTimeout = 100;
if (proxyStream.DataAvailable && hostHeaderAvailable == 0)
{
await proxyStream.ReadAsync(bytes, 0, bytes.Length);
var text = Encoding.UTF8.GetString(bytes);
Console.WriteLine(text);
if (text.ToLower().StartsWith(connectText))
{
// extract the url and port
var host = text.Remove(0, connectText.Length + 1);
var hostIndex = host.IndexOf(" ", StringComparison.Ordinal);
var hostEntry = host.Remove(hostIndex).Split(new[] { ":" }, StringSplitOptions.None);
// connect to the url and prot supplied
await requestClient.ConnectAsync(hostEntry[0], Convert.ToInt32(hostEntry[1]));
requestStream = requestClient.GetStream();
requestStream.ReadTimeout = 100;
requestStream.WriteTimeout = 100;
// send 200 response to proxyStream
const string sslResponse = "HTTP/1.0 200 Connection established\r\n\r\n";
var sslResponseBytes = Encoding.UTF8.GetBytes(sslResponse);
await proxyStream.WriteAsync(sslResponseBytes, 0, sslResponseBytes.Length);
}
}
hostHeaderAvailable++;
CancellationToken cancellationToken = new CancellationToken(false);
Task task = proxyStream.CopyToAsync(requestStream, cancellationToken);
Task task2 = requestStream.CopyToAsync(proxyStream, cancellationToken);
Task handleConnection = HandleDisconnect(clt, requestClient, cancellationToken);
tasks.Add(task);
tasks.Add(task2);
tasks.Add(handleConnection);
await Task.WhenAll(tasks).ConfigureAwait(false);
// close conenctions
clt.Close();
clt.Dispose();
requestClient.Close();
requestClient.Dispose();
}
}
catch (Exception e)
{
Console.WriteLine(e.ToString());
}
}
UPDATE
Attempt at using CancellationTokenSource
CancellationTokenSource source = new CancellationTokenSource();
CancellationToken cancellationToken = source.Token;
TaskFactory factory = new TaskFactory(cancellationToken);
tasks.Add(factory.StartNew(() => {proxyStream.CopyToAsync(requestStream);}, cancellationToken));
tasks.Add(factory.StartNew(() => {requestStream.CopyToAsync(proxyStream);}, cancellationToken));
tasks.Add(factory.StartNew(async () => {
//wait for this to retur, then cancel the token
await HandleDisconnect(clt, requestClient);
source.Cancel();
}, cancellationToken));
try
{
await factory.ContinueWhenAll(tasks.ToArray(),
(results) =>
{
Console.WriteLine("Tasks complete");
}, cancellationToken);
}
catch (AggregateException ae)
{
foreach (Exception e in ae.InnerExceptions)
{
if (e is TaskCanceledException)
Console.WriteLine("Unable to compute mean: {0}",
((TaskCanceledException)e).Message);
else
Console.WriteLine("Exception: " + e.GetType().Name);
}
}
finally
{
source.Dispose();
}
UPDATE
public static class extensionTcpClient{
public static bool CheckIfDisconnected(this TcpClient tcp)
{
if (tcp.Client.Poll(0, SelectMode.SelectRead))
{
byte[] buff = new byte[1];
if (tcp.Client.Receive(buff, SocketFlags.Peek) == 0)
{
// Client disconnected
return false;
}
}
return true;
}
}
class ProxyMaintainer
{
private static TcpListener listener = new TcpListener(IPAddress.Parse("192.168.0.25"), 13000);
public ProxyMaintainer()
{
}
public async Task Start()
{
Console.WriteLine("###############################");
Console.WriteLine("Listening on 192.168.0.25:13000");
Console.WriteLine("###############################\n");
listener.Start();
while (listen)
{
if (listener.Pending())
{
HandleClient(await listener.AcceptTcpClientAsync());
}
else
{
await Task.Delay(100); //<--- timeout
}
}
}
private static async Task Transport(NetworkStream from, NetworkStream to, Func<bool> isAlivePoller, CancellationToken token)
{
byte[] buffer = new byte[4096];
while (isAlivePoller())
{
while (from.DataAvailable)
{
int read = await from.ReadAsync(buffer, 0, buffer.Length, token).ConfigureAwait(false);
await to.WriteAsync(buffer, 0, read, token);
}
// Relieve the CPU a bit.
await Task.Delay(10, token).ConfigureAwait(false);
}
}
private static async Task HandleClient(TcpClient clientFrom)
{
var hostHeaderAvailable = 0;
int count;
var bytes = new byte[clientFrom.ReceiveBufferSize];
const string connectText = "connect";
NetworkStream toStream = null;
using (var fromStream = clientFrom.GetStream())
using(TcpClient clientTo = new TcpClient())
using (var manualStopper = new CancellationTokenSource())
{
count = await fromStream.ReadAsync(bytes, 0, bytes.Length);
var text = Encoding.UTF8.GetString(bytes);
Console.WriteLine(text);
if (text.ToLower().StartsWith(connectText))
{
// extract the url and port
var host = text.Remove(0, connectText.Length + 1);
var hostIndex = host.IndexOf(" ", StringComparison.Ordinal);
var hostEntry = host.Remove(hostIndex).Split(new[] { ":" }, StringSplitOptions.None);
// connect to the url and prot supplied
await clientTo.ConnectAsync(hostEntry[0], Convert.ToInt32(hostEntry[1]));
toStream = clientTo.GetStream();
// send 200 response to proxyStream
const string sslResponse = "HTTP/1.0 200 Connection established\r\n\r\n";
var sslResponseBytes = Encoding.UTF8.GetBytes(sslResponse);
await fromStream.WriteAsync(sslResponseBytes, 0, sslResponseBytes.Length);
}
bool Poller() => clientFrom.CheckIfDisconnected() && clientTo.CheckIfDisconnected();
Task one = Transport(fromStream, toStream, Poller, manualStopper.Token);
Task two = Transport(toStream, fromStream, Poller, manualStopper.Token);
await Task.WhenAll(one, two).ConfigureAwait(false);
//await one; await two; // To get exceptions if you want them and there are any.
// Alternatively, you can use Task.WhenAll to get exceptions aggregated for you.
}
Console.WriteLine("Closing connection");
}
}