package info.u_team.u_team_core.impl;

import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;

import info.u_team.u_team_core.api.Platform.Environment;
import info.u_team.u_team_core.api.network.NetworkContext;
import info.u_team.u_team_core.api.network.NetworkEnvironment;
import info.u_team.u_team_core.api.network.NetworkHandler;
import info.u_team.u_team_core.util.CastUtil;
import info.u_team.u_team_core.util.EnvironmentUtil;
import net.fabricmc.fabric.api.client.networking.v1.ClientLoginNetworking;
import net.fabricmc.fabric.api.client.networking.v1.ClientPlayNetworking;
import net.fabricmc.fabric.api.networking.v1.PacketByteBufs;
import net.fabricmc.fabric.api.networking.v1.ServerLoginConnectionEvents;
import net.fabricmc.fabric.api.networking.v1.ServerLoginNetworking;
import net.fabricmc.fabric.api.networking.v1.ServerPlayNetworking;
import net.minecraft.class_1255;
import net.minecraft.class_1657;
import net.minecraft.class_2540;
import net.minecraft.class_2561;
import net.minecraft.class_2960;
import net.minecraft.class_3222;

public class FabricNetworkHandler implements NetworkHandler {
	
	public static final String NOT_ON_CLIENT = "\u1F640\u1F640MissingVersion";
	
	private final String protocolVersion;
	
	private Predicate<String> clientAcceptedVersions;
	private Predicate<String> serverAcceptedVersions;
	
	private final class_2960 channel;
	private final Map<Class<?>, MessagePacket<?>> messages;
	
	FabricNetworkHandler(String protocolVersion, class_2960 channel) {
		this.protocolVersion = protocolVersion;
		clientAcceptedVersions = protocolVersion::equals;
		serverAcceptedVersions = protocolVersion::equals;
		
		this.channel = channel;
		messages = new HashMap<>();
		
		EnvironmentUtil.runWhen(Environment.CLIENT, () -> () -> Client.registerLoginReceiver(this, channel));
		ServerLoginNetworking.registerGlobalReceiver(channel, (server, packetListener, understood, buffer, synchronizer, responseSender) -> {
			if (!understood) {
				buffer = PacketByteBufs.create().method_10814(NOT_ON_CLIENT);
			}
			acceptProtocolVersion(buffer, serverAcceptedVersions, packetListener::method_14380);
		});
		ServerLoginConnectionEvents.QUERY_START.register((packetListener, server, sender, synchronizer) -> {
			sender.sendPacket(channel, PacketByteBufs.create().method_10814(protocolVersion));
		});
	}
	
	@Override
	public <M> void registerMessage(int index, Class<M> clazz, BiConsumer<M, class_2540> encoder, Function<class_2540, M> decoder, BiConsumer<M, NetworkContext> messageConsumer, Optional<NetworkEnvironment> handlerEnvironment) {
		final class_2960 location = channel.method_48331("/" + index);
		final MessagePacket<?> oldPacket = messages.put(clazz, new MessagePacket<>(location, encoder, handlerEnvironment));
		if (oldPacket != null) {
			throw new IllegalArgumentException("Packet class " + clazz + " was already registered");
		}
		
		if (validNetworkEnvironment(NetworkEnvironment.SERVER, handlerEnvironment)) {
			// Register client -> server handler
			ServerPlayNetworking.registerGlobalReceiver(location, (server, player, packetListener, byteBuf, responseSender) -> {
				messageConsumer.accept(decodeMessage(decoder, byteBuf), new FabricNetworkContext(NetworkEnvironment.SERVER, player, server));
			});
		}
		
		if (validNetworkEnvironment(NetworkEnvironment.CLIENT, handlerEnvironment)) {
			// Register server -> client handler
			EnvironmentUtil.runWhen(Environment.CLIENT, () -> () -> Client.registerReceiver(this, location, decoder, messageConsumer));
		}
	}
	
	@Override
	public <M> void sendToPlayer(class_3222 player, M message) {
		final EncodedMessage encodedMessage = encodeMessage(message, NetworkEnvironment.CLIENT);
		ServerPlayNetworking.send(player, encodedMessage.location, encodedMessage.byteBuf);
	}
	
	@Override
	public <M> void sendToServer(M message) {
		EnvironmentUtil.runWhen(Environment.CLIENT, () -> () -> Client.send(this, message));
	}
	
	@Override
	public String getProtocolVersion() {
		return protocolVersion;
	}
	
	@Override
	public void setProtocolAcceptor(Predicate<String> clientAcceptedVersions, Predicate<String> serverAcceptedVersions) {
		this.clientAcceptedVersions = clientAcceptedVersions;
		this.serverAcceptedVersions = serverAcceptedVersions;
	}
	
	private <M> EncodedMessage encodeMessage(M message, NetworkEnvironment expectedHandler) {
		final MessagePacket<M> packet = CastUtil.uncheckedCast(messages.get(message.getClass()));
		if (packet == null) {
			throw new IllegalArgumentException("Message " + message.getClass() + " was not registred");
		}
		if (!validNetworkEnvironment(expectedHandler, packet.handlerEnvironment)) {
			throw new IllegalArgumentException("Message " + message.getClass() + " cannot be used to send to " + expectedHandler);
		}
		final class_2540 buffer = PacketByteBufs.create();
		packet.encoder.accept(message, buffer);
		return new EncodedMessage(packet.location, buffer);
	}
	
	private <M> M decodeMessage(Function<class_2540, M> decoder, class_2540 buffer) {
		return decoder.apply(buffer);
	}
	
	private boolean validNetworkEnvironment(NetworkEnvironment expected, Optional<NetworkEnvironment> handlerEnvironment) {
		final NetworkEnvironment environment = handlerEnvironment.orElse(null);
		return environment == null || environment == expected;
	}
	
	private boolean acceptProtocolVersion(class_2540 buffer, Predicate<String> predicate, Consumer<class_2561> disconnectMessage) {
		final String receivedProtocolVersion = buffer.method_19772();
		if (!predicate.test(receivedProtocolVersion)) {
			disconnectMessage.accept(class_2561.method_43470("Protocol version for channel " + channel + " does not match. Expected: " + protocolVersion + ", received: " + receivedProtocolVersion));
			return false;
		}
		return true;
	}
	
	private static class Client {
		
		public static void registerLoginReceiver(FabricNetworkHandler handler, class_2960 location) {
			ClientLoginNetworking.registerGlobalReceiver(location, (client, packetListener, buffer, listenerAdder) -> {
				if (handler.acceptProtocolVersion(buffer, handler.clientAcceptedVersions, packetListener.field_3707::method_10747)) {
					return CompletableFuture.completedFuture(PacketByteBufs.create().method_10814(handler.protocolVersion));
				} else {
					return CompletableFuture.completedFuture(null);
				}
			});
		}
		
		public static <M> void send(FabricNetworkHandler handler, M message) {
			final EncodedMessage encodedMessage = handler.encodeMessage(message, NetworkEnvironment.SERVER);
			ClientPlayNetworking.send(encodedMessage.location, encodedMessage.byteBuf);
		}
		
		public static <M> void registerReceiver(FabricNetworkHandler handler, class_2960 location, Function<class_2540, M> decoder, BiConsumer<M, NetworkContext> messageConsumer) {
			ClientPlayNetworking.registerGlobalReceiver(location, (client, packetListener, byteBuf, responseSender) -> {
				messageConsumer.accept(handler.decodeMessage(decoder, byteBuf), new FabricNetworkContext(NetworkEnvironment.CLIENT, client.field_1724, client));
			});
		}
	}
	
	private record MessagePacket<M> (class_2960 location, BiConsumer<M, class_2540> encoder, Optional<NetworkEnvironment> handlerEnvironment) {
	}
	
	private record EncodedMessage(class_2960 location, class_2540 byteBuf) {
	}
	
	public static class FabricNetworkContext implements NetworkContext {
		
		private final NetworkEnvironment environment;
		private final class_1657 player;
		private final class_1255<?> executor;
		
		FabricNetworkContext(NetworkEnvironment environment, class_1657 player, class_1255<?> executor) {
			this.environment = environment;
			this.player = player;
			this.executor = executor;
		}
		
		@Override
		public NetworkEnvironment getEnvironment() {
			return environment;
		}
		
		@Override
		public class_1657 getPlayer() {
			return player;
		}
		
		@Override
		public void executeOnMainThread(Runnable runnable) {
			if (!executor.method_18854()) {
				executor.method_5382(runnable);
			} else {
				runnable.run();
			}
		}
	}
	
	public static class Factory implements NetworkHandler.Factory {
		
		@Override
		public NetworkHandler create(String protocolVersion, class_2960 location) {
			return new FabricNetworkHandler(protocolVersion, location);
		}
	}
}
