// Copyright © 2017-2018 Atomic Software, LLC. All Rights Reserved. // See LICENSE.md for full license information. using Atom.Core.Collections; using Atom.Core.Diagnostics; using Atom.Core.Mathematics; using Atom.Core.Networking.Messages; using Atom.Core.Networking.Messages.Structures.Misc; using System; using System.IO; using System.Net; using System.Net.Sockets; using System.Threading; using Atom.Core.Game; namespace Atom.Core.Networking { public class Client { public static List ActiveClients = new List(); public bool IsConnected => Socket.Connected; public EncryptionTable EncryptionTable { get; set; } public ushort Handle { get; set; } public Account Account { get; set; } private readonly Socket Socket; private byte[] ReceiveBuffer; private MemoryStream ReceiveStream; public Client(Socket socket) { ActiveClients.Add(this); Socket = socket; ReceiveBuffer = new byte[NetworkMessage.MaxMessageSize]; ReceiveStream = new MemoryStream(); Handle = (ushort) MathUtils.Rnd.Next(ushort.MaxValue); SetEncryptionTableIndex((ushort) MathUtils.Rnd.Next(498)); new PROTO_NC_MISC_SEED_ACK(EncryptionTable.Index).Send(this); WaitForBytes(); } public Client(string targetIP, int targetPort) { ActiveClients.Add(this); Socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); ReceiveBuffer = new byte[NetworkMessage.MaxMessageSize]; ReceiveStream = new MemoryStream(); Handle = (ushort) MathUtils.Rnd.Next(ushort.MaxValue); TryConnect(targetIP, targetPort); } public void SetEncryptionTableIndex(ushort seed) { if (EncryptionTable == null) { EncryptionTable = new EncryptionTable(); } EncryptionTable.SetIndex(seed); } private void TryConnect(string targetIP, int targetPort) { Log.Info($"Attempting to connect to the remote socket at {targetIP}:{targetPort}"); Socket.BeginConnect(new IPEndPoint(IPAddress.Parse(targetIP), targetPort), Connected, new object[] { targetIP, targetPort }); } private void Connected(IAsyncResult e) { try { Socket.EndConnect(e); WaitForBytes(); Log.Info("Connection opened to the remote socket"); } catch { Log.Warning("Remote socket connection attempt failed. Trying again"); Thread.Sleep(3000); TryConnect((string)((object[])e.AsyncState)[0], (int)((object[])e.AsyncState)[1]); } } public void SendBytes(byte[] bytes) { if (!IsConnected) { return; } var ByteLength = bytes.Length; var SentBytes = 0; if (ByteLength > NetworkMessage.MaxMessageSize) { Log.CriticalError("Exceeded max message size while sending data to the client."); Disconnect(); return; } while (SentBytes < ByteLength) // While there are still bytes to send. { SentBytes += Socket.Send(bytes, SentBytes, ByteLength - SentBytes, SocketFlags.None); if (SentBytes > ByteLength) { Log.Warning("Socket", $"BUFFER OVERFLOW OCCURRED - Sent {SentBytes - ByteLength} bytes more than expected."); break; } } } private void WaitForBytes() { if (!IsConnected) { return; } // Flush out the buffer to prevent overflowing and switching. ReceiveBuffer = new byte[NetworkMessage.MaxMessageSize]; Socket.BeginReceive(ReceiveBuffer, 0, ReceiveBuffer.Length, SocketFlags.None, ReceivedBytes, Socket); } private void ReceivedBytes(IAsyncResult e) { if (!IsConnected) { return; } int BytesReceivedCount; try { BytesReceivedCount = Socket.EndReceive(e); } catch { Disconnect(); return; } if (BytesReceivedCount <= 0) { Disconnect(); return; } if (BytesReceivedCount > NetworkMessage.MaxMessageSize) { Log.CriticalError("Received too many bytes from the client."); Disconnect(); return; } var recvData = new byte[BytesReceivedCount]; Array.Copy(ReceiveBuffer, 0, recvData, 0, BytesReceivedCount); CreateMessages(recvData); WaitForBytes(); } private void CreateMessages(byte[] bytes) { ReceiveStream.Write(bytes, 0, bytes.Length); while (TryParseMessage()) { // Keep trying to create messages out of the data while we have until we run out. } } private bool TryParseMessage() { byte[] MsgBytes; ReceiveStream.Position = 0; if (ReceiveStream.Length < 1) { return false; } var SizeBytes = new byte[1]; ReceiveStream.Read(SizeBytes, 0, 1); if (SizeBytes[0] != 0) { var MsgSize = SizeBytes[0]; // If the stream has less bytes than the size of the message. if (ReceiveStream.Length - ReceiveStream.Position < MsgSize) { return false; } MsgBytes = new byte[MsgSize]; ReceiveStream.Read(MsgBytes, 0, MsgSize); } else { if (ReceiveStream.Length - ReceiveStream.Position < 2) { return false; } SizeBytes = new byte[2]; ReceiveStream.Read(SizeBytes, 0, 2); var MsgSize = BitConverter.ToUInt16(SizeBytes, 0); if (ReceiveStream.Length - ReceiveStream.Position < MsgSize) { return false; } MsgBytes = new byte[MsgSize]; ReceiveStream.Read(MsgBytes, 0, MsgSize); } EncryptionTable?.CryptBytes(MsgBytes, 0, MsgBytes.Length); var NewMessage = new NetworkMessage(MsgBytes) {Client = this}; if (!MessageHandlerLoader.CanGetHandler(NewMessage.Protocol, out MessageHandlerDelegate Handler)) { Log.Warning($"Unk: {NewMessage}"); } else { Log.Info($"Received: {NewMessage}"); Handler(NewMessage); } TrimReceiveStream(); return true; } private void TrimReceiveStream() { if (ReceiveStream.Position == ReceiveStream.Length) { ReceiveStream = new MemoryStream(); return; } var RemainingByteCount = new byte[ReceiveStream.Length - ReceiveStream.Position]; ReceiveStream.Read(RemainingByteCount, 0, RemainingByteCount.Length); ReceiveStream = new MemoryStream(); ReceiveStream.Write(RemainingByteCount, 0, RemainingByteCount.Length); } public void Disconnect() { if (!IsConnected) { return; } Log.Info("Disconnecting client"); ActiveClients.Remove(this); Socket.Close(); ReceiveBuffer = new byte[] { }; ReceiveStream.Close(); ReceiveStream.Dispose(); Account?.Characters?.Clear(); GC.SuppressFinalize(this); Log.Info("Client disconnected"); } } }