Federated averaging is a technique used to train machine learning models where data is spread across many different servers or devices. It ensures data privacy and security and maintains data locality by enabling model training without sharing the raw data.
In the conventional approach of centralized machine learning, data is collected and stored in a central server, and a single model is trained on this consolidated data. Conversely, in federated averaging, data remains decentralized, residing on numerous devices or edge devices, and the model's training is conducted locally on each specific device.
The process of federated averaging involves the following steps:
Initialization: The
Client selection: To take part in the training round, a subset of clients is chosen. This choice may be arbitrary or determined by a set of standards.
Model distribution: The chosen clients receive the global model. A duplicate of the model is given to each client.
Local training: The model is trained using the local data on each client device. To increase the model's performance, this training procedure can contain numerous iterations or epochs.
Model aggregation: The updated models from each client are sent back to the central server after local training.
Model averaging: The central server aggregates the models received from the clients by averaging the model parameters. This averaging process ensures that the global model benefits from the knowledge learned on different clients while preserving privacy.
Repeat: Steps 2 (Client selection) to 6 (Model averaging) are repeated for multiple training rounds until
Note: The flowchart above provides a general workflow. The actual implementation of federated averaging may vary based on the specific federated learning framework or algorithm used.
import numpy as np# Define a sample model for demonstrationclass Model:def __init__(self):self.fc = np.random.randn(10, 1)def forward(self, x):return np.dot(x, self.fc)# Client-side training functiondef train_local_model(client_data, model):num_epochs = 10learning_rate = 0.1for epoch in range(num_epochs):inputs, labels = client_datainputs = inputs.reshape(1, -1)outputs = model.forward(inputs)loss = np.mean((outputs - labels) ** 2)grad = 2 * np.dot(inputs.T, outputs - labels) / inputs.shape[1]model.fc -= learning_rate * gradreturn model# Server-side federated averagingdef federated_averaging(global_model, client_data, num_rounds):for round in range(num_rounds):# Client selectionselected_clients = np.random.choice(range(len(client_data)), size=3, replace=False) # Select 3 client indices# Model distributionclient_models = [global_model] * len(selected_clients) # Provide duplicate global model to each client# Local trainingfor i, client_index in enumerate(selected_clients):client_model = train_local_model(client_data[client_index], client_models[i])client_models[i] = client_model# Model aggregationaggregated_model = Model()for client_model in client_models:aggregated_model.fc += client_model.fc# Model averagingglobal_model.fc = aggregated_model.fc / len(client_models)print(f"Round {round+1} - Global Model Parameters:")print(global_model.fc)print()# Example usageglobal_model = Model()client_data = [(np.random.randn(10), np.random.randn(1)),(np.random.randn(10), np.random.randn(1)),(np.random.randn(10), np.random.randn(1))] # Dummy client datanum_rounds = 2federated_averaging(global_model, client_data, num_rounds)
Lines 4–9: Defines a sample model called Model
. The model has a single fully connected layer represented by the fc
attribute, initialized with random values. The forward
method performs a forward pass by taking input x
and multiplying it with the fc
weights using the dot product.
Lines 12–24: The function train_local_model
performs local training on the client side.
Takes client_data
which consists of input samples and corresponding labels, and model
as inputs.
Trains the model for a fixed number of epochs using gradient descent with a fixed learning rate.
Calculates the loss by comparing the model's outputs with the labels and updates the model's weights using the gradient of the loss with respect to the weights.
Lines 27–50: This function federated_averaging
performs server-side federated averaging. It takes global_model
as the initial model, client_data
as a list of client data, and num_rounds
as the number of federated learning rounds. In each round, the following steps are performed:
Randomly selects a subset of clients from the available client data.
Provides each selected client with a copy of the global model.
Each selected client trains its local model using the train_local_model
function.
Creates an empty aggregated_model
and aggregates the weights of each client's model by summing them.
Updates the global model's weights by averaging the aggregated weights across all clients.
Displays the global model's weights for each round.
Lines 53–60: Initializes a global_model
, creates a list of dummy client_data
consisting of input samples and labels, and defines the number of federated learning num_rounds
. It then calls the federated_averaging
function with these inputs to perform federated learning.
The federated averaging technique has the following benefits:
Allows collaboration without providing raw data while maintaining privacy.
Reduces communication costs because only model updates are sent between clients and the server.
Scalable for large-scale machine learning applications.
Minimizes the need for data transfer, reducing network bandwidth requirements and latency.
Optimizes resource utilization, as it distributes the computational load across multiple devices or servers.
The federated averaging has the following drawbacks:
Data distribution heterogeneity between devices.
Centralized control of the training process is lacking.
Limited access to the data kept on specific servers or devices.
Potential threats to privacy and security if necessary precautions are not taken.
The federated averaging technique is used in many different fields. It applies to:
Healthcare: To train models using distributed patient data while maintaining privacy.
Finance sector: It allows financial firms to collaborate without disclosing confidential customer data.
Edge computing relevance: It is also helpful in edge computing applications (IoT devices, smartphones, or edge servers), where devices with a limited connection can contribute to model training.
The picture below has eight different cards, each showing the workflow mechanism of federated averaging. They are not in the correct order. Try fixing the sequence of steps.
Federated averaging is a potential way to address the privacy and data decentralization concerns in machine learning. Federated averaging achieves a careful balance between data safety and model performance by allowing models to be trained locally on specific devices or servers while maintaining the privacy of user data.
Free Resources