home /
connvers /
src /
main /
java /
avividi /
connvers /
server /
ClientConnectionHandler.java
ClientConnectionHandler.java
package avividi.connvers.server;
import avividi.connvers.protocol.Message;
import avividi.connvers.protocol.MessageEncoder;
import avividi.connvers.protocol.ValidationException;
import avividi.connvers.protocol.MessageTopic;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.EOFException;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.time.Instant;
import java.util.Objects;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import static avividi.connvers.Connvers.HANDSHAKE_TIMEOUT_MS;
import static avividi.connvers.Connvers.HEARTBEAT_TIMEOUT_MS;
import static avividi.connvers.Connvers.SERVER_USER;
import static avividi.connvers.Connvers.printSystemLine;
public class ClientConnectionHandler implements Runnable {
private final String ip;
private final Socket clientSocket;
private final ConnectionPool pool;
private final MessageEncoder messageEncoder;
private String user;
DataInputStream inStream;
DataOutputStream outStream;
public ClientConnectionHandler(Socket clientSocket, ConnectionPool pool, MessageEncoder messageEncoder) {
this.clientSocket = Objects.requireNonNull(clientSocket);
this.ip = ((InetSocketAddress) clientSocket.getRemoteSocketAddress()).getHostName();
this.pool = Objects.requireNonNull(pool);
this.messageEncoder = messageEncoder;
}
@Override
public void run() {
printSystemLine(String.format("New client connecting (%s) ...", this.ip));
try {
clientSocket.setSoTimeout(HEARTBEAT_TIMEOUT_MS);
inStream = new DataInputStream(clientSocket.getInputStream());
outStream = new DataOutputStream(clientSocket.getOutputStream());
if (checkBanned()) return;
handshake();
while (true) {
String rawMessage = inStream.readUTF();
Message message = messageEncoder.decode(rawMessage);
if (message.topic() == MessageTopic.heartbeat) {
// printSystemLine("received heartbeat from <" + message.user() + ">");
}
else if (message.topic() == MessageTopic.say) {
pool.broadcast(message.withTime(Instant.now()));
}
else {
throw new RuntimeException("illegal client message topic '" + message.topic() + "'");
}
}
}
catch (ValidationException e) {
cast(new Message(SERVER_USER, Instant.now(), MessageTopic.error, e.getMessage()));
reportParseError(e);
}
catch (EOFException e) {
// socket closed normally
}
catch (IOException e) {
reportDisconnect();
}
catch (Exception e) {
e.printStackTrace();
}
finally {
closeAndUnregister();
}
}
private boolean checkBanned() throws IOException {
if (pool.isBannedIp(ip)) {
cast(new Message(
SERVER_USER,
Instant.now(),
MessageTopic.error,
"banned")
);
printSystemLine("... User banned.");
reportDisconnect();
close();
return true;
}
return false;
}
private void handshake() throws IOException, ValidationException {
final AtomicBoolean isShook = new AtomicBoolean(false);
timeoutHandshake(isShook);
String rawMessage = inStream.readUTF();
isShook.set(true);
Message message = messageEncoder.decode(rawMessage);
if (message.topic() != MessageTopic.join) {
cast(new Message(
SERVER_USER,
Instant.now(),
MessageTopic.error,
"not joined")
);
close();
return;
}
else if (pool.has(message.user())) {
cast(new Message(
SERVER_USER,
Instant.now(),
MessageTopic.error,
String.format("user <%s> already exists", message.user()))
);
close();
return;
}
user = message.user();
cast(new Message(SERVER_USER, Instant.now(), MessageTopic.welcome, null));
pool.register(this);
}
private void timeoutHandshake(AtomicBoolean isShook) {
Executors.newSingleThreadScheduledExecutor().schedule(() -> {
try {
if (!isShook.get()) {
printSystemLine("... Handshake timeout.");
clientSocket.close();
}
} catch (IOException e) {
e.printStackTrace();
}
}, HANDSHAKE_TIMEOUT_MS, TimeUnit.MILLISECONDS);
}
public synchronized void cast(Message message) {
try {
String raw = messageEncoder.encode(message);
outStream.writeUTF(raw);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public String getUser() {
return Objects.requireNonNull(user);
}
public String getIp() {
return ip;
}
public boolean isClosed() {
return clientSocket.isClosed();
}
public void kill() {
if (!isClosed()) {
closeAndUnregister();
}
}
private void reportParseError(ValidationException e) {
printSystemLine(String.format("<%s> (%s) error '%s'", user == null ? "unregistered user" : user, ip, e.getMessage()));
}
private void reportDisconnect() {
printSystemLine(String.format("<%s> (%s) disconnected abruptly.", user == null ? "unregistered user" : user, ip));
}
private void close() {
try {
clientSocket.close();
}
catch (IOException e) {
e.printStackTrace();
}
}
private void closeAndUnregister() {
pool.unregister(this);
close();
}
}