# server.py
import socket
import pickle
import torch
import torch.optim as optim
import torch.nn as nn
from torchvision import datasets, transforms
import struct
from shared_model import SimpleCNN
from dataloader import FLDataset
from torch.utils.data import DataLoader, Dataset

#federated server object
class FederatedServer:
    def __init__(self, host='localhost', port=12723):
        self.host = host
        self.port = port
        self.model = SimpleCNN()
        self.optimizer = optim.SGD(self.model.parameters(), lr=0.01)
        self.criterion = nn.CrossEntropyLoss()

        self.test_loss_list = [] 
        self.test_acc_list = [] 
        self.test_step_list = [] 

#send byte stream data 
    def send_data(self, conn, data):
        serialized_data = pickle.dumps(data)
        conn.sendall(struct.pack('>I', len(serialized_data)))
        conn.sendall(serialized_data)

#receive byte stream data
    def receive_data(self, conn):
        data_length = struct.unpack('>I', conn.recv(4))[0]
        received_data = b""
        while len(received_data) < data_length:
            packet = conn.recv(data_length - len(received_data))
            if not packet:
                break
            received_data += packet
        return pickle.loads(received_data)

    def evaluate(self, test_dataloader, step_count):
        self.model.eval() # Switch model to evaluation mode
        total_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad(): # Disable gradient computation for evaluation
            for data, target in test_dataloader:
                outputs = self.model(data) # Get model predictions
                loss = self.criterion(outputs, target) 
                total_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += target.size(0)
                correct += (predicted == target).sum().item()
        self.model.train()
        accuracy = 100 * correct / total
        average_loss = total_loss / len(test_dataloader)
        self.test_loss_list.append(average_loss)
        self.test_acc_list.append(accuracy)
        self.test_step_list.append(step_count)
        print(f"{step_count}: Test Loss: {average_loss:.4f}, Test Accuracy: {accuracy:.2f}%")

    def start(self, dataloader, test_dataloader, num_epochs=5):
        server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        server_socket.bind((self.host, self.port))
        server_socket.listen(5)
        print(f"Server listening on {self.host}:{self.port}")
        data_iter = iter(dataloader)
        cur_epoch = 0
        step_count = 0

        while True:
            conn, addr = server_socket.accept()
            print(f"Connected by {addr}")
            try:
                while True:
                    command = self.receive_data(conn)
                    
                    if command == "get_model":
                        data, target = None, None
                        while True:
                            try:
                                data, target = next(data_iter)
                                break
                            except Exception as e:
                                data_iter = iter(dataloader)
                                cur_epoch += 1
                                print(f"Getting data failed, reload data, epoch: {cur_epoch}")
                                if cur_epoch >= num_epochs:
                                    break
                        model_state = self.model.state_dict()
                        self.send_data(conn, {"model": model_state, "data": data, "target": target})
                        
                    elif command == "update_model":

                        if step_count % 200 == 0:
                            print(f"Step {step_count}: Evaluating test accuracy...")
                            self.evaluate(test_dataloader, step_count)
                        
                        gradients = self.receive_data(conn)
                        
                        for param, grad in zip(self.model.parameters(), gradients):
                            if param.grad is None:
                                param.grad = grad
                            else:
                                param.grad += grad
                        
                        self.optimizer.step()
                        self.optimizer.zero_grad()
                        step_count += 1

                        self.send_data(conn, "update_complete")
                        
                    elif command == "disconnect":
                        print(f"Client {addr} disconnecting")
                        break
                    
            except Exception as e:
                print(f"Error with client {addr}: {e}")
            finally:
                print(self.test_step_list)
                print(self.test_acc_list)
                print(self.test_loss_list)
                conn.close()
                print(f"Connection closed with {addr}")
                

def run_server():
    train_dataset = FLDataset()
    dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)

    test_dataset = FLDataset(split="test")
    test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)

    server = FederatedServer()
    server.start(dataloader, test_dataloader, num_epochs=5)

if __name__ == "__main__":
    run_server()