I', len(serialized_data))) self.socket.sendall(serialized_data) def receive_data(self): data_length = struct.unpack('>I', self.socket.recv(4))[0] received_data = b"" while len(received_data) < data_length: p"> I', len(serialized_data))) self.socket.sendall(serialized_data) def receive_data(self): data_length = struct.unpack('>I', self.socket.recv(4))[0] received_data = b"" while len(received_data) < data_length: p"> I', len(serialized_data))) self.socket.sendall(serialized_data) def receive_data(self): data_length = struct.unpack('>I', self.socket.recv(4))[0] received_data = b"" while len(received_data) < data_length: p">
import socket
import pickle
import torch.nn as nn
import struct
from shared_model import SimpleCNN

learning_rate = 1

class FederatedClient:
    def __init__(self, host='localhost', port=12723):
        self.host = host
        self.port = port
        self.model = SimpleCNN()
        self.criterion = nn.CrossEntropyLoss()
        self.socket = None

    def connect(self):
        self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.socket.connect((self.host, self.port))

    def disconnect(self):
        if self.socket:
            self.send_data("disconnect")
            self.socket.close()

    def send_data(self, data):
        serialized_data = pickle.dumps(data)
        self.socket.sendall(struct.pack('>I', len(serialized_data)))
        self.socket.sendall(serialized_data)

    def receive_data(self):
        data_length = struct.unpack('>I', self.socket.recv(4))[0]
        received_data = b""
        while len(received_data) < data_length:
            packet = self.socket.recv(data_length - len(received_data))
            if not packet:
                break
            received_data += packet
        return pickle.loads(received_data)

    def get_model_params_and_data(self):
        self.send_data("get_model")
        received = self.receive_data()
        model_state = received["model"]
        self.model.load_state_dict(model_state)
        return received["data"], received["target"]

    def train_and_send_gradients(self):
        self.model.train()
        
        # Process one batch
        while True:
	        data, target = self.get_model_params_and_data()  # Get model and training data
	        
	        output = self.model(data)  # Forward pass: Get predictions
	        loss = self.criterion(output, target)  # Compute the loss
	        loss.backward()  # Backpropagation: Compute gradients
	        print(loss)  # Print the loss for debugging
	        
	        gradients = [learning_rate * param.grad.clone() for param in self.model.parameters()]  # Scale gradients
	        
	        self.send_data("update_model")  # Notify server to accept gradient updates
	        self.send_data(gradients)  # Send gradients to the server
	        
	        response = self.receive_data()  # Wait for server confirmation
	        assert response == "update_complete"  # Ensure update was successful
	        
	        self.model.zero_grad()  # Reset gradients for the next iteration
        
def run_client(host='localhost', port=12723):

    client = FederatedClient(host, port)
    client.connect()

    try:
        client.train_and_send_gradients()
    finally:
        client.disconnect()

if __name__ == "__main__":
    run_client()