home / connvers / src / main / java / avividi / connvers / server / ConnectionPool.java

ConnectionPool.java



package avividi.connvers.server;

import avividi.connvers.Broadcaster;
import avividi.connvers.protocol.Message;
import avividi.connvers.protocol.MessageTopic;

import java.time.Instant;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static avividi.connvers.Connvers.SERVER_USER;

public class ConnectionPool implements Broadcaster {

  private final String adminUser;
  private final Storage storage;
  private final Map<String, ClientConnectionHandler> connections = Collections.synchronizedMap(new HashMap<>());
  private final Map<String, String> ipBlacklist = new HashMap<>();


  public ConnectionPool(String adminUser, Storage storage) {
    this.adminUser = adminUser;
    this.storage = storage;
  }

  public void broadcast(Message message) {
    message.print();
    connections.values().forEach(c -> c.cast(message));
    storage.push(message);
  }


  void register(ClientConnectionHandler connection) {
    if (connections.containsKey(connection.getUser())) {
      throw new IllegalStateException();
    }
    connections.put(connection.getUser(), connection);

    storage.getMessages().forEach(connection::cast);
    Message message = new Message(
        SERVER_USER,
        Instant.now(),
        MessageTopic.say,
        String.format("User <%s> joined.", connection.getUser())
    );
    broadcast(message);
  }

  void unregister(ClientConnectionHandler connection) {
    if (!connections.containsKey(connection.getUser())) {
      throw new IllegalStateException();
    }
    connections.remove(connection.getUser());
    Message message = new Message(
        SERVER_USER,
        Instant.now(),
        MessageTopic.say,
        String.format("User <%s> left.", connection.getUser())
    );
    broadcast(message);
  }

  public boolean has(String user) {
    return connections.containsKey(user) || user.equals(adminUser);
  }

  public record UserAndIp(String user, String ip) {}
  public List<UserAndIp> listUsers() {
    return Stream.concat(
        Stream.of(new UserAndIp(adminUser, "admin")),
        connections.values().stream()
            .map(conn ->new UserAndIp(conn.getUser(), conn.getIp()))
            .sorted(Comparator.comparing(i -> i.user)))
        .collect(Collectors.toList());
  }


  public boolean isBannedIp(String ip) {
    return ipBlacklist.containsValue(ip);
  }

  public boolean banUser(String user) {
    var connection = connections.get(user);
    if (connection == null || connection.isClosed()) return false;
    ipBlacklist.put(user, connection.getIp());
    connection.kill();
    return true;
  }

  public boolean unbanUser(String user) {
    return ipBlacklist.remove(user) != null;
  }

  public List<UserAndIp> listBanned() {
    return ipBlacklist.entrySet().stream()
        .map(entry -> new UserAndIp(entry.getKey(), entry.getValue()))
        .sorted(Comparator.comparing(i -> i.user))
        .collect(Collectors.toList());
  }

  public boolean kickUser(String user) {
    var connection = connections.get(user);
    if (connection == null || connection.isClosed()) return false;
    connection.kill();
    return true;
  }
}