diff --git a/LaunchServer/source/response/ResponseThread.java b/LaunchServer/source/response/ResponseThread.java index 5f6d9a5..bfcf87f 100644 --- a/LaunchServer/source/response/ResponseThread.java +++ b/LaunchServer/source/response/ResponseThread.java @@ -33,6 +33,7 @@ // Instance private final LaunchServer server; private final Socket socket; + private boolean handshakePassed = false; public ResponseThread(LaunchServer server, Socket socket) throws SocketException { this.server = server; @@ -46,6 +47,8 @@ try (InputStream is = socket.getInputStream(); OutputStream os = socket.getOutputStream(); HInput input = new HInput(is); HOutput output = new HOutput(os)) { readHandshake(input, output); + + // Start response try { respond(input, output); } catch (RequestException e) { @@ -54,8 +57,12 @@ } catch (Exception e) { LogHelper.error(e); } finally { - server.serverSocketHandler.onDisconnected(socket); IOHelper.close(socket); + + // Invoke disconnect listener + if (handshakePassed) { + server.serverSocketHandler.onDisconnected(socket); + } } } @@ -87,6 +94,10 @@ LogHelper.subDebug("Type: " + type.name()); } + // Invoke connect listener + server.serverSocketHandler.onConnected(socket, type); + handshakePassed = true; + // Choose response based on type Response response; switch (type) { diff --git a/LaunchServer/source/response/ServerSocketHandler.java b/LaunchServer/source/response/ServerSocketHandler.java index e1db91b..a4c380d 100644 --- a/LaunchServer/source/response/ServerSocketHandler.java +++ b/LaunchServer/source/response/ServerSocketHandler.java @@ -10,14 +10,15 @@ import java.util.concurrent.Executors; import java.util.concurrent.ThreadFactory; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiConsumer; import java.util.function.Consumer; -import java.util.function.Predicate; import launcher.LauncherAPI; import launcher.helper.CommonHelper; import launcher.helper.IOHelper; import launcher.helper.LogHelper; import launcher.helper.VerifyHelper; +import launcher.request.Request; import launcher.serialize.HInput; import launcher.serialize.HOutput; import launchserver.LaunchServer; @@ -32,7 +33,7 @@ // API private final Map customResponses = new ConcurrentHashMap<>(2); - private volatile Predicate connectListener; + private volatile BiConsumer connectListener; private volatile Consumer disconnectListener; public ServerSocketHandler(LaunchServer server) { @@ -69,14 +70,7 @@ // Listen for incoming connections while (serverSocket.isBound()) { - Socket socket = serverSocket.accept(); - if (connectListener != null && !connectListener.test(socket)) { - IOHelper.close(socket); - continue; - } - - // Filter passed - threadPool.execute(new ResponseThread(server, socket)); + threadPool.execute(new ResponseThread(server, serverSocket.accept())); } } catch (IOException e) { // Ignore error after close/rebind @@ -101,7 +95,7 @@ } @LauncherAPI - public void setConnectListener(Predicate connectListener) { + public void setConnectListener(BiConsumer connectListener) { this.connectListener = connectListener; } @@ -110,6 +104,12 @@ this.disconnectListener = disconnectListener; } + /*package*/ void onConnected(Socket socket, Request.Type type) { + if (connectListener != null) { + connectListener.accept(socket, type); + } + } + /*package*/ void onDisconnected(Socket socket) { if (disconnectListener != null) { disconnectListener.accept(socket);