#include "Main.h"

struct serverInfo
{
	std::string raw;
	time_t beat;
};

std::map<std::string, serverInfo> servers;
std::set<std::string> banlist;

void saveBanList()
{
	std::ofstream f("banlist.txt");
	if (f.is_open())
	{
		for (auto it : banlist)
		{
			f << it << ' ';
		}
		f.close();
		std::cout << "[INFO] Ban list saved (" << banlist.size() << " entries)" << std::endl;
	}
	else
	{
		std::cout << "[ERROR] Ban list cannot be saved" << std::endl;
	}
}

void loadBanList()
{
	std::ifstream f("banlist.txt");
	if (f.is_open())
	{
		while (!f.eof())
		{
			std::string s;
			f >> s;
			if (!s.empty())
			{
				banlist.insert(s);
			}
		}
		f.close();
		std::cout << "[INFO] Ban list loaded (" << banlist.size() << " entries)" << std::endl;
	}
	else
	{
		std::cout << "[INFO] Ban list not found" << std::endl;
	}
}

int main()
{
#ifdef WIN32
	WSADATA wsaData;

	if (WSAStartup(MAKEWORD(2, 2), &wsaData) != NO_ERROR)
	{
		errorLog("WSAStartup() failed");
		return 1;
	}
#endif

	_SOCKET_ auth;
	_SOCKET_ master;
	_SOCKET_ report;

	if ((auth = setup(20800)) == BADSOCKET)
	{
		errorLog("Setup() failed for auth");
		return 1;
	}

	std::cout << "[INFO] Auth running on port 20800" << std::endl;

	if ((master = setup(20810)) == BADSOCKET)
	{
		errorLog("Setup() failed for master");
		return 1;
	}

	std::cout << "[INFO] Master running on port 20810" << std::endl;

	if ((report = setup(6850)) == BADSOCKET)
	{
		errorLog("Setup() failed for report");
		return 1;
	}

	std::cout << "[INFO] Report running on port 6850" << std::endl;

	loadBanList();

#ifdef WIN32
	std::thread authThread(startRecv, &auth, STYPE_AUTH);
	std::thread masterThread(startRecv, &master, STYPE_MASTER);
	startRecv(&report, STYPE_REPORT);
#else
	std::thread authThread(startRecv, auth, STYPE_AUTH);
	std::thread masterThread(startRecv, master, STYPE_MASTER);
	startRecv(report, STYPE_REPORT);
#endif

	// WIN32: closesocket(auth) == SOCKET_ERROR
	// LINUX: close(auth);
}

std::string getServerAttr(std::map<std::string, serverInfo>::iterator iter, std::string sub)
{
	int pos = iter->second.raw.find(sub) + sub.length() + 1;
	return iter->second.raw.substr(pos, iter->second.raw.find("\\", pos) - pos);
}

void startRecv(_SOCKETP_ socket, sType type)
{
	sockaddr_in SenderAddr;
	socklen_t SenderAddrSize = sizeof(SenderAddr);

	while (true)
	{
		char RecvBuf[BUFLEN];
		memset(RecvBuf, 0, BUFLEN);
#ifdef WIN32
		if (recvfrom(*socket, RecvBuf, BUFLEN, 0, (sockaddr*)&SenderAddr, &SenderAddrSize) == SOCKET_ERROR)
#else
		if (recvfrom(socket, RecvBuf, BUFLEN, 0, (sockaddr*)&SenderAddr, &SenderAddrSize) == SOCKET_ERROR)
#endif
		{
			errorLog("Recvfrom() failed");
			continue;
		}

		int offset = type == STYPE_REPORT ? 0 : 4;

		char* pos = strchr(RecvBuf, 0);
		int pLen = pos - RecvBuf - offset;

		std::string packet = std::string(RecvBuf).substr(offset, pLen);
		if (type != STYPE_REPORT && packet[packet.length() - 1] == 0x0A)
		{
			packet = packet.substr(0, packet.length() - 1);
		}

		std::cout << "[RECEIVED] " << packet << std::endl;

		if (type == STYPE_MASTER)
		{
			if (startsWith(&packet, "heartbeat COD-4"))
			{
				if (packet == "heartbeat COD-4")
				{
					sendPacket(socket, SenderAddr, "getchallenge 958195462");
					sendPacket(socket, SenderAddr, "getstatus 958195462");
				}
				else if (packet.size() > 16)
				{
					std::string ip = getip(SenderAddr);
					servers[ip].raw = packet.substr(16);
					servers[ip].beat = time(0);
				}
			}
			else if (startsWith(&packet, "statusResponse"))
			{
				if (packet.size() > 16)
				{
					std::string ip = getip(SenderAddr);
					servers[ip].raw = packet.substr(16);
					servers[ip].beat = time(0);
				}
			}
			/*else if (startsWith(&packet, "heartbeat update")) // Not a real 'beat'
			{
				// Splitting
				std::string s = packet.substr(17);
				unsigned int n = s.length();
				std::vector<std::string> tok;
				std::string temp = "";
				for (unsigned int i = 0; i < n; ++i)
				{
					if (s[i] == '=')
					{
						tok.push_back(temp);
						temp = "";
					}
					else temp += s[i];
				}
				if (!temp.empty())
					tok.push_back(temp);

				temp = getip(SenderAddr);
				for (unsigned int i = 0; i < tok.size(); i += 2)
				{
					n = servers[temp].raw.find(tok[i]);
					servers[temp].raw = servers[temp].raw.substr(0, n) + tok[i] + "\\" + tok[i + 1] + servers[temp].raw.substr(servers[temp].raw.find('\\', servers[temp].raw.find('\\', n + 1) + 1));
					std::cout << "NEWRAW: " << servers[temp].raw << std::endl;
				}
			}*/
			else if (packet == "heartbeat flatline")
			{
				servers.erase(getip(SenderAddr));
			}
		}
		else if (type == STYPE_AUTH)
		{
			if (startsWith(&packet, "getIpAuthorize"))
			{
				std::vector<std::string> vec;
				splitStr(packet, vec, ' ');
				if (vec.size() == 7)
				{
					bool allow = vec[6].size() > 2;
					if (allow)
					{
						std::string guid = vec[6].substr(1, vec[6].size() - 2);
						if (banlist.find(guid) != banlist.end())
						{
							allow = false;
						}
					}
					sendPacket(socket, SenderAddr, ("ipAuthorize " + vec[1] + " " + (allow ? "accept KEY_IS_GOOD" : "deny INVALID_CDKEY") + " 0 " + vec[6]).c_str());
				}
			}
		}
		else if (startsWith(&packet, "Carrot:")) // Password
		{
			packet.erase(0, 7);
			if (startsWith(&packet, "report"))
			{
				if (packet.size() > 7)
				{
					std::string srv = packet.substr(7);
				
					if (srv == "all")
					{
						if (!servers.size())
						{
							sendPacket(socket, SenderAddr, "", false);
						}
						else
						{
							time_t t = time(0);
							std::map<std::string, serverInfo>::iterator iter;
							for (iter = servers.begin(); iter != servers.end(); ++iter)
							{
								if (iter->second.beat + 600 < t) // After 10 minutes of no response it is dead probably
									servers.erase(iter->first);
								else
									sendPacket(socket, SenderAddr, (iter->first + '=' + std::to_string(iter->second.beat) + '=' + iter->second.raw).c_str(), false);
							}
						}
					}
					else if (!servers[srv].raw.empty())
					{
						sendPacket(socket, SenderAddr, (std::to_string(servers[srv].beat) + '=' + servers[srv].raw).c_str(), false);
					}
				}
			}
			else if (startsWith(&packet, "ban"))
			{
				if (packet.size() > 4)
				{
					std::string guid = packet.substr(4);
					banlist.insert(guid);
					saveBanList();
					std::cout << "[INFO] Banned " << guid << std::endl;
				}
			}
			else if (startsWith(&packet, "unban"))
			{
				if (packet.size() > 6)
				{
					std::string guid = packet.substr(6);
					banlist.erase(guid);
					saveBanList();
					std::cout << "[INFO] Unbanned " << guid << std::endl;
				}
			}
			else if (packet == "reloadbanlist")
			{
				loadBanList();
			}
			else if (packet == "forcequit")
			{
				exit(EXIT_SUCCESS);
			}
		}
	}
}

std::string getip(sockaddr_in SenderAddr)
{
	return _ntop(&SenderAddr.sin_addr) + ":" + std::to_string((ntohs(SenderAddr.sin_port))); 
}

std::string _ntop(in_addr* addr)
{
/*#if _WIN32_WINNT >= 0x600
	char str[INET_ADDRSTRLEN];
	inet_ntop(AF_INET, addr, str, INET_ADDRSTRLEN);
	return str;
#else*/
	return inet_ntoa(*addr);
//#endif
}

void splitStr(std::string raw, std::vector<std::string>& vec, char delimiter)
{
	std::stringstream ss(raw);
	std::string item;
	while (std::getline(ss, item, delimiter))
	{
		vec.push_back(item);
	}
}

bool startsWith(std::string* str, std::string start)
{
	return (*str).size() >= start.size() && (*str).substr(0, start.size()) == start;
}

void sendPacket(_SOCKETP_ socket, sockaddr_in recvAddr, const char* buf, bool lead)
{
	int bLen = strlen(buf);
	int len;
	char* packet;
	if (lead)
	{
		len = bLen + 5;
		packet = new char[len];
		packet[0] = packet[1] = packet[2] = packet[3] = (unsigned char)0xFF;
		memcpy(packet + 4, buf, bLen + 1);
		packet[len - 1] = 0x00;
	}
	else
	{
		len = bLen;
		packet = new char[bLen];
		memcpy(packet, buf, bLen);
	}
#ifdef WIN32
	if (sendto(*socket, packet, len, 0, (sockaddr*)&recvAddr, sizeof(recvAddr)) == SOCKET_ERROR)
#else
	if (sendto(socket, packet, len, 0, (sockaddr*)&recvAddr, sizeof(recvAddr)) == SOCKET_ERROR)
#endif
	{
		errorLog("Sendto() failed");
		delete[] packet;
		return;
	}
	std::cout << "[SENT] " << buf << std::endl;
	delete[] packet;
}

_SOCKET_ setup(unsigned const short port)
{
	_SOCKET_ RecvSocket;
	if ((RecvSocket = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP)) == INVALID_SOCKET) {
		errorLog("Socket() failed");
		return BADSOCKET;
	}

	sockaddr_in RecvAddr;

#ifndef WIN32
	bzero(&RecvAddr, sizeof(RecvAddr));
#endif

	RecvAddr.sin_family = AF_INET;
	RecvAddr.sin_port = htons(port);
	RecvAddr.sin_addr.s_addr = htonl(INADDR_ANY);

	if (bind(RecvSocket, (sockaddr*)&RecvAddr, sizeof(RecvAddr)) == SOCKET_ERROR)
	{
		errorLog("Bind() failed");
		return BADSOCKET;
	}

	return RecvSocket;
}

void errorLog(const char* err)
{
	std::cout << "[ERROR] " << err << std::endl;
}