I', len(serialized_data))) self.socket.sendall(serialized_data) def receive_data(self): data_length = struct.unpack('>I', self.socket.recv(4))[0] received_data = b"" while len(received_data) < data_length: p"> I', len(serialized_data))) self.socket.sendall(serialized_data) def receive_data(self): data_length = struct.unpack('>I', self.socket.recv(4))[0] received_data = b"" while len(received_data) < data_length: p"> I', len(serialized_data))) self.socket.sendall(serialized_data) def receive_data(self): data_length = struct.unpack('>I', self.socket.recv(4))[0] received_data = b"" while len(received_data) < data_length: p">
import socket
import pickle
import torch.nn as nn
import struct
from shared_model import SimpleCNN
learning_rate = 1
class FederatedClient:
def __init__(self, host='localhost', port=12723):
self.host = host
self.port = port
self.model = SimpleCNN()
self.criterion = nn.CrossEntropyLoss()
self.socket = None
def connect(self):
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.connect((self.host, self.port))
def disconnect(self):
if self.socket:
self.send_data("disconnect")
self.socket.close()
def send_data(self, data):
serialized_data = pickle.dumps(data)
self.socket.sendall(struct.pack('>I', len(serialized_data)))
self.socket.sendall(serialized_data)
def receive_data(self):
data_length = struct.unpack('>I', self.socket.recv(4))[0]
received_data = b""
while len(received_data) < data_length:
packet = self.socket.recv(data_length - len(received_data))
if not packet:
break
received_data += packet
return pickle.loads(received_data)
def get_model_params_and_data(self):
self.send_data("get_model")
received = self.receive_data()
model_state = received["model"]
self.model.load_state_dict(model_state)
return received["data"], received["target"]
def train_and_send_gradients(self):
self.model.train()
# Process one batch
while True:
data, target = self.get_model_params_and_data() # Get model and training data
output = self.model(data) # Forward pass: Get predictions
loss = self.criterion(output, target) # Compute the loss
loss.backward() # Backpropagation: Compute gradients
print(loss) # Print the loss for debugging
gradients = [learning_rate * param.grad.clone() for param in self.model.parameters()] # Scale gradients
self.send_data("update_model") # Notify server to accept gradient updates
self.send_data(gradients) # Send gradients to the server
response = self.receive_data() # Wait for server confirmation
assert response == "update_complete" # Ensure update was successful
self.model.zero_grad() # Reset gradients for the next iteration
def run_client(host='localhost', port=12723):
client = FederatedClient(host, port)
client.connect()
try:
client.train_and_send_gradients()
finally:
client.disconnect()
if __name__ == "__main__":
run_client()