I am attempting to build a basic Proxy Server in java that can handle and forward both HTTP and HTTPS requests.
The current code below is based on https://stackoverflow.com/a/41368670/6346653:
package com.example;
import java.io.*;
import java.net.ServerSocket;
import java.net.Socket;
import java.util.Properties;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
public class ProxyServer extends Thread {
private int port;
private static final int HTTPS = 1;
private static final int HTTP = 2;
public static void main(String[] args) {
// load a properties file
Properties prop = new Properties();
try {
prop.load(ProxyServer.class.getClassLoader().getResourceAsStream("config.properties"));
//If port num is provided as argument
if (args.length > 0) {
int port = Integer.parseInt(args[0]);
(new ProxyServer(port)).run();
}
//Else use default properties
else {
int port = Integer.parseInt(prop.getProperty("port"));
(new ProxyServer(port)).run();
}
} catch (IOException e1) {
// TODO Auto-generated catch block
e1.printStackTrace();
}
}
public ProxyServer(int port) {
super("Proxy Server Thread");
this.port = port;
}
@Override
public void run() {
//Proxy ServerSocket listening for incoming connections on static port
try (ServerSocket serverSocket = new ServerSocket(this.port)) {
System.out.println("Proxy Server listening on port: " + this.port);
Socket socket;
try {
while ((socket = serverSocket.accept()) != null) {
(new Handler(socket)).start();
}
} catch (IOException e) {
System.out.println(e.toString());
//e.printStackTrace(); // TODO: implement catch
}
} catch (IOException e) {
System.out.println(e.toString());
//e.printStackTrace(); // TODO: implement catch
return;
}
}
//Handler class for incoming socket connections
public static class Handler extends Thread {
public static final Pattern CONNECT_PATTERN = Pattern.compile("CONNECT (.+):(.+) HTTP/(1\\.[01])",
Pattern.CASE_INSENSITIVE);
public static final Pattern HTTP_PATTERN = Pattern.compile("(HEAD|GET|POST|PUT) http:\\/\\/([^\\/]+)(\\/([^\\s]*))? HTTP\\/(1\\.[01])",
Pattern.CASE_INSENSITIVE);
private final Socket clientSocket;
private boolean previousWasR = false;
public Handler(Socket clientSocket) {
this.clientSocket = clientSocket;
}
@Override
public void run() {
try {
//Extract first line of request and use matchers to determine HTTPS (CONNECT pattern) or HTTP (GET/POST/HEAD/PUT etc)
String request = readLine(clientSocket);
System.out.println(request);
Matcher matcher = CONNECT_PATTERN.matcher(request);
Matcher httpMatcher = HTTP_PATTERN.matcher(request);
//Connection is HTTPS
if (matcher.matches()) {
connectRemote(HTTPS, null, clientSocket, matcher.group(1), Integer.parseInt(matcher.group(2)), matcher.group(3));
}
//Else connection is HTTP
else if (httpMatcher.matches()){
connectRemote(HTTP, request, clientSocket, httpMatcher.group(2), 80, "1.1");
}
} catch (IOException e) {
System.out.println(e.toString());
} finally {
try {
clientSocket.close();
} catch (IOException e) {
System.out.println(e.toString());
}
}
}
//Make socket connection to the remote host and pass data via input/output streams
private void connectRemote(int protocol, String firstLine, Socket clientSocket, String remoteHost, int remotePort, String version)
throws UnsupportedEncodingException, IOException {
//for HTTPS connections, purge the header data below CONNECT from client socket
if (protocol == HTTPS) {
String header;
do {
header = readLine(clientSocket);
} while (!"".equals(header));
}
OutputStreamWriter outputStreamWriter = new OutputStreamWriter(clientSocket.getOutputStream(),
"ISO-8859-1");
final Socket forwardSocket;
try {
forwardSocket = new Socket(remoteHost, remotePort);
} catch (IOException | NumberFormatException e) {
System.out.println(e.toString());
outputStreamWriter.write("HTTP/" + version + " 502 Bad Gateway\r\n");
outputStreamWriter.write("Proxy-agent: Simple/0.1\r\n");
outputStreamWriter.write("\r\n");
outputStreamWriter.flush();
return;
}
try {
if (protocol == HTTPS) {
outputStreamWriter.write("HTTP/" + version + " 200 Connection established\r\n");
outputStreamWriter.write("Proxy-agent: Simple/0.1\r\n");
outputStreamWriter.write("\r\n");
outputStreamWriter.flush();
}
Thread remoteToClient = new Thread() {
@Override
public void run() {
forwardData(null, forwardSocket, clientSocket);
}
};
remoteToClient.start();
try {
if (previousWasR) {
int read = clientSocket.getInputStream().read();
if (read != -1) {
if (read != '\n') {
forwardSocket.getOutputStream().write(read);
}
forwardData(firstLine, clientSocket, forwardSocket);
} else {
if (!forwardSocket.isOutputShutdown()) {
forwardSocket.shutdownOutput();
}
if (!clientSocket.isInputShutdown()) {
clientSocket.shutdownInput();
}
}
} else {
forwardData(firstLine, clientSocket, forwardSocket);
}
} finally {
try {
remoteToClient.join();
} catch (InterruptedException e) {
System.out.println(e.toString());
}
}
} finally {
forwardSocket.close();
}
}
//Data exchange (used for both client to proxy server & proxy server to remote)
private static void forwardData(String firstLine, Socket inputSocket, Socket outputSocket) {
try {
InputStream inputStream = inputSocket.getInputStream();
try {
OutputStream outputStream = outputSocket.getOutputStream();
if (firstLine != null) {
outputStream.write((firstLine).getBytes("ISO-8859-1"));
outputStream.write(System.getProperty("line.separator").getBytes("ISO-8859-1"));
}
try {
byte[] buffer = new byte[4096];
int read;
do {
read = inputStream.read(buffer);
if (read > 0) {
outputStream.write(buffer, 0, read);
if (inputStream.available() < 1) {
outputStream.flush();
}
}
} while (read >= 0);
} finally {
if (!outputSocket.isOutputShutdown()) {
outputSocket.shutdownOutput();
}
}
} finally {
if (!inputSocket.isInputShutdown()) {
inputSocket.shutdownInput();
}
}
} catch (IOException e) {
System.out.println(e.toString());
}
}
//Extract first line of client socket request - used to determine protocol and extract remote host/port
private String readLine(Socket socket) throws IOException {
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
int next;
readerLoop:
while ((next = socket.getInputStream().read()) != -1) {
if (previousWasR && next == '\n') {
previousWasR = false;
continue;
}
previousWasR = false;
switch (next) {
case '\r':
previousWasR = true;
break readerLoop;
case '\n':
break readerLoop;
default:
byteArrayOutputStream.write(next);
break;
}
}
return byteArrayOutputStream.toString("ISO-8859-1");
}
}
}
The idea is to read the first line from the Proxy Server's client connection inputstream to determine which protocol is used, and extract the host/port combination to forward to.
If HTTPS (pattern matching "CONNECT" is found), the remaining data from the client's input stream is forwarded via a remote socket connection, with the e.g. "HTTP 1.1 200 Connection established" line manually written back to the client's output stream. This seems to be working well.
For HTTP connections, since the first line is already read from the input stream to determine protocol and host/port, this is being manually added together with a new line separator before the remaining data is forwarded, as per the line below:
if (firstLine != null) {
outputStream.write((firstLine).getBytes("ISO-8859-1"));
outputStream.write(System.getProperty("line.separator").getBytes("ISO-8859-1"));
}
If I run this locally (e.g. on port 8084), and then set my browser's proxy settings to "http://127.0.0.1:8084" this seems to work fine as well, an attempt to access e.g. "http://www.ipdatabase.com" will work fine, and the below line is printed to the system.out:
GET http://www.ipdatabase.com HTTP/1.1
If I run this on a remote server "example.com" and then set browser's proxy settings to "http://example.com:8084", there is always a socket connection reset and socket connection error thrown immediately after the first line is forwarded. Below is what is observed in system.out:
GET http://www.ipdatabase.com HTTP/1.1
java.net.SocketException: Connection reset
(exception logged by the catch clause in forwardData method).
What is causing the connection reset when running on remote servers?