# server.py
import socket
import pickle
import torch
import torch.optim as optim
import torch.nn as nn
from torchvision import datasets, transforms
import struct
from shared_model import SimpleCNN
from dataloader import FLDataset
from torch.utils.data import DataLoader, Dataset
#federated server object
class FederatedServer:
def __init__(self, host='localhost', port=12723):
self.host = host
self.port = port
self.model = SimpleCNN()
self.optimizer = optim.SGD(self.model.parameters(), lr=0.01)
self.criterion = nn.CrossEntropyLoss()
self.test_loss_list = []
self.test_acc_list = []
self.test_step_list = []
#send byte stream data
def send_data(self, conn, data):
serialized_data = pickle.dumps(data)
conn.sendall(struct.pack('>I', len(serialized_data)))
conn.sendall(serialized_data)
#receive byte stream data
def receive_data(self, conn):
data_length = struct.unpack('>I', conn.recv(4))[0]
received_data = b""
while len(received_data) < data_length:
packet = conn.recv(data_length - len(received_data))
if not packet:
break
received_data += packet
return pickle.loads(received_data)
def evaluate(self, test_dataloader, step_count):
self.model.eval() # Switch model to evaluation mode
total_loss = 0.0
correct = 0
total = 0
with torch.no_grad(): # Disable gradient computation for evaluation
for data, target in test_dataloader:
outputs = self.model(data) # Get model predictions
loss = self.criterion(outputs, target)
total_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
self.model.train()
accuracy = 100 * correct / total
average_loss = total_loss / len(test_dataloader)
self.test_loss_list.append(average_loss)
self.test_acc_list.append(accuracy)
self.test_step_list.append(step_count)
print(f"{step_count}: Test Loss: {average_loss:.4f}, Test Accuracy: {accuracy:.2f}%")
def start(self, dataloader, test_dataloader, num_epochs=5):
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.bind((self.host, self.port))
server_socket.listen(5)
print(f"Server listening on {self.host}:{self.port}")
data_iter = iter(dataloader)
cur_epoch = 0
step_count = 0
while True:
conn, addr = server_socket.accept()
print(f"Connected by {addr}")
try:
while True:
command = self.receive_data(conn)
if command == "get_model":
data, target = None, None
while True:
try:
data, target = next(data_iter)
break
except Exception as e:
data_iter = iter(dataloader)
cur_epoch += 1
print(f"Getting data failed, reload data, epoch: {cur_epoch}")
if cur_epoch >= num_epochs:
break
model_state = self.model.state_dict()
self.send_data(conn, {"model": model_state, "data": data, "target": target})
elif command == "update_model":
if step_count % 200 == 0:
print(f"Step {step_count}: Evaluating test accuracy...")
self.evaluate(test_dataloader, step_count)
gradients = self.receive_data(conn)
for param, grad in zip(self.model.parameters(), gradients):
if param.grad is None:
param.grad = grad
else:
param.grad += grad
self.optimizer.step()
self.optimizer.zero_grad()
step_count += 1
self.send_data(conn, "update_complete")
elif command == "disconnect":
print(f"Client {addr} disconnecting")
break
except Exception as e:
print(f"Error with client {addr}: {e}")
finally:
print(self.test_step_list)
print(self.test_acc_list)
print(self.test_loss_list)
conn.close()
print(f"Connection closed with {addr}")
def run_server():
train_dataset = FLDataset()
dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataset = FLDataset(split="test")
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)
server = FederatedServer()
server.start(dataloader, test_dataloader, num_epochs=5)
if __name__ == "__main__":
run_server()