home / connvers / src / main / java / avividi / connvers / server / ClientConnectionHandler.java

ClientConnectionHandler.java



package avividi.connvers.server;

import avividi.connvers.protocol.Message;
import avividi.connvers.protocol.MessageEncoder;
import avividi.connvers.protocol.ValidationException;
import avividi.connvers.protocol.MessageTopic;

import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.EOFException;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.time.Instant;
import java.util.Objects;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;

import static avividi.connvers.Connvers.HANDSHAKE_TIMEOUT_MS;
import static avividi.connvers.Connvers.HEARTBEAT_TIMEOUT_MS;
import static avividi.connvers.Connvers.SERVER_USER;
import static avividi.connvers.Connvers.printSystemLine;

public class ClientConnectionHandler implements Runnable {

  private final String ip;
  private final Socket clientSocket;
  private final ConnectionPool pool;
  private final MessageEncoder messageEncoder;
  private String user;
  DataInputStream inStream;
  DataOutputStream outStream;

  public ClientConnectionHandler(Socket clientSocket, ConnectionPool pool, MessageEncoder messageEncoder) {
    this.clientSocket = Objects.requireNonNull(clientSocket);
    this.ip = ((InetSocketAddress) clientSocket.getRemoteSocketAddress()).getHostName();
    this.pool = Objects.requireNonNull(pool);
    this.messageEncoder = messageEncoder;
  }

  @Override
  public void run() {
    printSystemLine(String.format("New client connecting (%s) ...",  this.ip));

    try {
      clientSocket.setSoTimeout(HEARTBEAT_TIMEOUT_MS);
      inStream = new DataInputStream(clientSocket.getInputStream());
      outStream = new DataOutputStream(clientSocket.getOutputStream());

      if (checkBanned()) return;

      handshake();

      while (true) {
        String rawMessage = inStream.readUTF();
        Message message = messageEncoder.decode(rawMessage);

        if (message.topic() == MessageTopic.heartbeat) {
//          printSystemLine("received heartbeat from <" + message.user() + ">");
        }
        else if (message.topic() == MessageTopic.say) {
          pool.broadcast(message.withTime(Instant.now()));
        }
        else {
          throw new RuntimeException("illegal client message topic '" + message.topic() + "'");
        }
      }
    }
    catch (ValidationException e) {
      cast(new Message(SERVER_USER, Instant.now(), MessageTopic.error, e.getMessage()));
      reportParseError(e);
    }
    catch (EOFException e) {
      // socket closed normally
    }
    catch (IOException e) {
      reportDisconnect();
    }
    catch (Exception e) {
      e.printStackTrace();
    }
    finally {
      closeAndUnregister();
    }
  }

  private boolean checkBanned() throws IOException {
    if (pool.isBannedIp(ip)) {
      cast(new Message(
          SERVER_USER,
          Instant.now(),
          MessageTopic.error,
          "banned")
      );
      printSystemLine("... User banned.");
      reportDisconnect();
      close();
      return true;
    }
    return false;
  }

  private void handshake() throws IOException, ValidationException {
    final AtomicBoolean isShook = new AtomicBoolean(false);
    timeoutHandshake(isShook);

    String rawMessage = inStream.readUTF();
    isShook.set(true);
    Message message = messageEncoder.decode(rawMessage);

    if (message.topic() != MessageTopic.join) {
      cast(new Message(
          SERVER_USER,
          Instant.now(),
          MessageTopic.error,
          "not joined")
      );
      close();
      return;
    }
    else if (pool.has(message.user())) {
      cast(new Message(
          SERVER_USER,
          Instant.now(),
          MessageTopic.error,
          String.format("user <%s> already exists", message.user()))
      );
      close();
      return;
    }
    user = message.user();
    cast(new Message(SERVER_USER, Instant.now(), MessageTopic.welcome, null));
    pool.register(this);
  }

  private void timeoutHandshake(AtomicBoolean isShook) {
    Executors.newSingleThreadScheduledExecutor().schedule(() -> {
      try {
        if (!isShook.get()) {
          printSystemLine("... Handshake timeout.");
          clientSocket.close();
        }
      } catch (IOException e) {
        e.printStackTrace();
      }
    }, HANDSHAKE_TIMEOUT_MS, TimeUnit.MILLISECONDS);
  }

  public synchronized void cast(Message message) {
    try {
      String raw = messageEncoder.encode(message);
      outStream.writeUTF(raw);
    } catch (IOException e) {
      throw new RuntimeException(e);
    }
  }

  public String getUser() {
    return Objects.requireNonNull(user);
  }

  public String getIp() {
    return ip;
  }

  public boolean isClosed() {
    return clientSocket.isClosed();
  }

  public void kill() {
    if (!isClosed()) {
      closeAndUnregister();
    }
  }


  private void reportParseError(ValidationException e) {
    printSystemLine(String.format("<%s> (%s) error '%s'", user == null ? "unregistered user" : user, ip, e.getMessage()));
  }

  private void reportDisconnect() {
    printSystemLine(String.format("<%s> (%s) disconnected abruptly.", user == null ? "unregistered user" : user, ip));
  }

  private void close() {
    try {
      clientSocket.close();
    }
    catch (IOException e) {
      e.printStackTrace();
    }
  }

  private void closeAndUnregister() {
    pool.unregister(this);
    close();
  }
}