1. Introduction
In the industrial sector, IoT devices continuously collect production data and upload it to the cloud for model training. However, storing data in the cloud incurs significant transmission overhead and raises concerns about potential data privacy violations [
1]. Federated learning (FL) has emerged as a promising approach in distributed machine learning, offering strong privacy protection while enabling collaborative learning [
2]. In FL, the server sends the global model to participating devices, which train the model locally with their data. These IoT devices subsequently upload their updated model parameters to the server, where they are aggregated to refine the global model. This process is repeated iteratively until the global model reaches convergence [
3]. Currently, many experts have conducted substantial research in this area. To improve the convergence speed of the model, Zhong et al. [
4] designed an adaptive graph imputation generator to explore the potential links between subgraphs and combined the multi-function evaluator with the negative sampling mechanism to explore the global information flow to construct a multi-edge collaborative SpreadFGL framework. In addition, to improve the utilization efficiency of resources, Chen et al. [
5] used proximal items to improve classical deep reinforcement learning and avoided the problem of action dispersion by reducing the variance of action–value estimation by decreasing the frequency of network updates, to solve the problem that the existing resource allocation cannot adapt to the personalized dynamic environment. Therefore, they designed a computational offloading and resource allocation method based on personalized deep reinforcement learning to achieve a higher success rate of task execution in different scenarios. To ensure the robustness of the collaborative system, Chen et al. [
6] considered the differences between multi-dimensional user characteristics, activities, and memory to achieve more accurate content recommendations, which is by considering the inefficient collaboration between multi-edge nodes and the risk of being attacked in the collaborative caching mechanism. To mitigate the issue of posterior collapse, they proposed the Discrete-Categorical Variational Auto-Encoder and designed a robust federated deep learning framework. These methods enhance the cooperation efficiency of edge devices from different perspectives. Additionally, it is important to note that in distributed training frameworks, the server typically deploys small-scale model structures on clients, which can lead to some prediction errors [
7]. To enhance performance in distributed applications, increasing the complexity and parameters of deep learning models is often a practical solution [
8]. However, IoT devices are often constrained by limited memory, storage, and computing resources, hindering their ability to train large-scale models locally [
9,
10]. Moreover, transmitting large-scale model parameters between clients and servers can introduce substantial communication delays [
11].
To train high-performance models on resource-constrained devices, federated split learning (FSL) [
10] splits model training between clients and servers. FSL divides the model into client and server components by cutting layers, shifting part of the training to the resource-rich server [
11]. This approach reduces the client’s computational load, enabling the deployment of large models on lightweight devices and improving training efficiency [
12]. For instance, in the SplitFed framework [
13], clients perform forward propagation on their models using local data and upload the labeled intermediate output to the server. The server then performs forward and backward propagation on its model and sends the gradients back to the client, which completes backpropagation and sends its part of the model to the server for aggregation [
10]. However, the client and server need to communicate with each other when the model is updated, resulting in high network overhead. He et al. [
14] introduced a training strategy based on local loss, replacing the traditional global loss function to train the model, so that the client model can be updated without accepting backpropagation from the server model. While this approach reduces network overhead, it heavily depends on the server’s computing resources. Nguyen et al. [
11] divided the large-scale deep learning model into a set of several small sub-models, trained them in parallel on multiple devices, and designed a framework that allowed multiple client clusters to cooperate. This allowed clients within the same cluster to learn from each other, improving the overall model performance. However, during the training process, all devices in the cluster interact with a large number of model parameters, resulting in high communication costs. The direct way to reduce communication costs is by compressing the data exchanged within the FL framework through compression processing. The sparse gradient strategy reduces the amount of data transmitted during each communication without significantly reducing the accuracy [
15,
16]. Thus, sparse gradients can improve training efficiency and resource utilization on edge devices. For example, Thonglek et al. [
15] calculated the absolute difference between the parameters of the local model before and after training, determined the upper quantile of the exchange of updated parameters between the client and the server, and reduced the communication cost. Sun et al. [
16] introduced the Top-k sparsification method into the secure aggregation protocol, which reduced the privacy risk and the communication overhead. Therefore, inspired by the above literature, this paper proposes a sparse gradient strategy to reduce communication data and introduces dequantization to minimize the loss of model accuracy.
Although FSL demonstrates commendable performance under resource constraints, it faces significant challenges due to data heterogeneity [
17]. In real-world scenarios, the diverse data sources from IoT devices often result in substantial distributional discrepancies among clients, leading to significant heterogeneity in the datasets [
18,
19]. This heterogeneity can cause gradient divergence, which adversely affects the performance of the aggregated global model [
8]. To reduce the impact of non-independent and identically distributed (non-IID) data on model performance, many existing solutions, such as local normalization, are applied to data-sensitive layers [
20]. However, this approach, while effective in reducing data discrepancies, hinders knowledge sharing between clients, which limits the overall model’s performance and collaborative potential. To improve collaboration among clients while addressing data heterogeneity, the similarity evaluation mechanism is introduced to encourage clients with similar data distributions to train the model collaboratively [
21,
22]. Although this approach enables personalized models for clients, it may increase the imbalance between clients and overlook the influence of client models on the global model. To tackle this issue, this paper measures the difference between the local model and the global model using Euclidean distance and applies the adaptive aggregation weight to optimize the overall performance and efficiency of the federated learning system. This paper aims to reduce communication overhead, address data heterogeneity, and improve the training efficiency of resource-constrained devices. To achieve this, we propose a novel federated collaborative framework that combines sparse gradient strategies with adaptive aggregation weight. This combination improves the global model’s accuracy and training efficiency without compromising individual clients’ specific performance. First, the sparse gradient strategy is introduced to reduce the communication required for each update between clients. Second, the adaptive aggregation weight strategy is designed to mitigate the impact of data heterogeneity on model performance. Finally, these strategies are integrated into the federated collaboration framework to ensure the robustness of the training process.
The main contributions of this paper are as follows:
(1) To address the issue of high communication costs in existing federated collaboration frameworks, the sparse gradient strategy based on position Mask is designed to reduce data transmission. At the same time, gradient dequantization is introduced to restore the original dense gradient tensor, minimizing the negative impact on model performance while maintaining communication efficiency.
(2) To deal with the challenge of data heterogeneity between clients, an adaptive aggregation weight strategy based on the Euclidean distance is proposed. According to the difference between the local model of the client and the global model, the weight of the client is dynamically adjusted, which reduces the impact of data heterogeneity and enhances the collaboration between the clients.
(3) To enable resource-constrained devices to efficiently participate in distributed training, the new federated collaboration framework based on sparse gradient is designed. This framework not only reduces communication overhead but also improves the efficiency of collaborative training tasks for resource-constrained devices.
The remainder of our work can be arranged as follows:
Section 2 describes the related work.
Section 3 describes the traditional FL framework, federated divide, and collaborative framework.
Section 4 introduces the sparse gradient quantization strategy and the adaptive weighting strategy and presents the novel federated collaborative framework.
Section 5 presents the experimental results. Finally,
Section 6 concludes the article with a summary.
3. Preliminaries
3.1. Traditional FL Framework
Each client in FL uses the local dataset for local training, aggregating the model parameters on the server. We consider
K clients participating in the training, indexed by
K = 1, 2, ⋯,
K. Each client
k updates its local model weight
wk by using the local data
Dk, aiming to enhance the performance of the global model. Thus, the objective of FL can be formulated as follows [
21]:
where (·) is the loss function, w denotes the global model weight parameter, wk denotes the model weight parameter of client k, and f(·) represents the model architecture. The dataset Dk = (xk, yk) represents the client’s local data, where x is the input data, and y is the corresponding label.
After local training, each client sends its finally updated local model
wk to the server. The server then updates the global model parameters by performing a weighted average aggregation [
31], as follows:
When the next communication round begins, the server transmits the updated global model to all clients. This process continues until either the preset number of communication rounds T is reached or the model performance meets the requirements.
3.2. Federated Divide and Collaborative Framework
The traditional FL framework typically assumes that all participating devices can independently train the whole ML model. However, due to the resource limitations on lightweight edge devices, such as computing power, communication, and memory resources of clients, it is often unable to undertake the task of training large-scale network models. Therefore, a novel federated divide and collaborative framework has been proposed to support the training of complex convolutional neural networks on resource-constrained devices, as discussed in reference [
11].
Before the training begins, a subset of S clients c1, c2, …, cS is randomly selected from K devices to form cluster C. The original model W is then divided into S sub-model sets Wen = EW1, W2, …, WS according to the network parameters in each cluster C∈. Next, each sub-model Wi is split using the cutting layer: the lower sub-model, which extracts abstract representation of the input data, and the upper sub-model, which is responsible for prediction. The server then sends each part of the global model to the participating clients, enabling devices within each cluster to update the ensemble model in parallel using only their local data.
During the training process, the first client in each cluster C∈ is designated as the main client, while the others act as proxy clients. The server sends to the main client in each cluster and distributes (i = 1, …, S) to each i-th client. The main client’s sample data are augmented to produce S versions, which are then passed in for forward propagation, resulting in S activated . Then, the main client sends the activation of S cutting layers to S − 1 proxy clients, while retaining one abstract representation. Each client ci∈c1, c2, …, cS uses the activation data to perform forward propagation on its corresponding upper sub-model and then sends its prediction result pi to the server. The server then collects the predictions p1, p2, …, ps from all devices, designs regularization terms by using Jensen–Shannon divergence, and designs loss functions by combining cross-entropy classification, facilitating mutual learning among devices. The loss calculated by the server is sent to each client, where it is backpropagated to each corresponding sub-model’s cutting layer. The gradient at the cutting layer are sent to the main client, which updates after completing backpropagation. Finally, the main client transfers the update to the next designated main client, and this process continues until every client in the cluster has completed training as the main client.
After training, the last main client cS in each cluster C sends the updated ensemble model of the global lower sub-model to the server, while each device ci sends its updated upper sub-model to the server, forming the upper ensemble model . The server then merges and to create an ensemble model WC for each cluster. Finally, the server aggregates the ensemble models from all clusters to obtain a new global ensemble model.
4. Efficient Federated Collaborative Learning Methodology
To address the problem of heterogeneous data in resource-constrained clients, this paper provides a federated collaborative learning with sparse gradients framework (FedCS), as shown in
Figure 1.
The main steps are as follows: Firstly, in each cluster, the model parameters are divided into S sub-model sets, and the sub-models are divided into lower and upper sub-model by using the cutting layer. The server distributes the lower sub-model parameters to the main client and the upper sub-model parameters to all clients. Secondly, the main client uses its local data to perform forward propagation on the lower sub-model up to the cutting layer and sends the activation values of this layer to each client. Each client then propagates the activation values through its upper sub-model. Then, each client sends its predicted values back to the server, which calculates the loss function and returns the gradients to each client. Each client performs backpropagation on its network, and after sparse processing, sends the sparse gradient to the main client. The main client receives the sparse gradients and uses the gradient dequantization strategy to recover the original gradient. Then, it performs backpropagation on the upper sub-model and sends it to the next main client. After all devices in the cluster have been trained as the main client, the network models of all devices are sent to the server. Then, the server integrates the upper sub-model using the adaptive weighting strategy, and combines them with the lower sub-model to form the model for the cluster. Finally, the server performs weighted averaging across clusters to generate an updated global model. This training process continues until the specified stopping criteria are met.
4.1. Sparse Gradient Quantization Strategy
Within the framework of federated collaborative learning, gradient or model parameters must frequently be synchronized between each client (i.e., proxy client) and the main client. Given the large size of gradients or parameters in large network models, training such models on resource-constrained clients can lead to significant communication and computational costs. To address this, the sparse gradients are employed to retain only the most important gradient elements, thereby reducing both computation and communication overhead [
32]. While the sparse gradients help alleviate these costs, it is essential to carefully balance them to minimize overhead without compromising model accuracy. To further mitigate accuracy loss, the gradient dequantization is used to save storage space and computing resources while preserving model performance as much as possible. By integrating sparse gradients with gradient dequantization, an efficient communication strategy for large-scale distributed training is developed, optimizing both performance and resource usage.
In each communication round, the server transmits part of the global ensemble model to the respective devices in each cluster. The main client uses data augmentation to obtain
S-enhanced versions of the sample, and pass them to
. And it obtains
S-cutting layer activations
and sends the activations
to
S − 1 proxy clients for forward propagation. Following the approach of reference [
11], the server collects the predictions
p1,
p2, …,
ps of all devices, and uses Jensen–Shannon divergence to design regularization terms, combining it with cross-entropy classification to design loss functions, which is calculated as follows:
where pi (i = 1, …, S) represents the output vector of client ci, is the cooperative training loss, λcot is the weight factor, and is the cross entropy classification loss. The combined loss is then sent to all devices, where each devices performs backpropagation on its upper sub-model.
When backpropagating to the cutting layer, the high-dimensional gradients need to be transmitted to the main client, a process that requires significant time costs. Given the gradient tensor
G = (
g1,
g2, …,
gn) of the model, the sparse gradient can be expressed as follows:
where gi represents the i-th element in the gradient vector, and represents the sparse threshold. To record which gradient values are retained or zeroed, a Mask needs to be generated, which is defined as follows:
where Mask is a Boolean mask that filters out partial gradients greater than the sparse threshold. According to Equation (6), only those gradient elements whose absolute value is greater than or equal to are retained, and the rest are set to zero. The main client receives the sparse gradient from the cutting layer, and at the same time, it needs to dequantize the sparse gradient and restore the original dense gradient tensor:
where Mask ensures that only the sparse gradient terms are involved in the operation.
4.2. Adaptive Weighting Strategy
Due to the significant variation in data distribution across clients, simply applying a weighted average for aggregating client models on the server can overlook each client’s unique contributions, leading to the suboptimal performance of the global model. To address this issue, the server calculates the Euclidean distance between each client’s local model and the global model after training. This distance serves as an evaluation of each client’s performance in the classification task. Based on these distances, the server dynamically assigns aggregation weights to clients, thereby refining their influence on the global model update. This adaptive weighting approach improves the FL system’s overall performance and efficiency.
After each training round, the final main client
cS in each cluster
C sends the updated ensemble version
of the global sub-model
to the server. Simultaneously, each device
ci in the cluster sends its updated upper sub-model
to the server, which then constructs the ensemble model
. During the ensemble process, the server calculates the Euclidean distance between each client model and the global model, and the weights are allocated based on this distance:
where and represent the k-th parameters of the global model and the client model, respectively, and represents the L2 norm. The aggregation of the proxy client model adopts a weighted average method, and the weights are determined by the previous adaptive weights:
where wi is the weight of client i, is the global model parameter, and is the k-th parameter of the i-th proxy client. The server merges and to create an ensemble model WC for each cluster. It then aggregates the ensemble models from all clusters to obtain a new global ensemble model.
where N(C) is the cardinality of cluster C, representing the number of clients in the cluster.
4.3. FedCS Algorithm
In what follows, we design the novel federated collaborative framework for heterogeneous data on resource-constrained devices (FedCS), detailed in Algorithm 1. Before training begins, for each cluster C∈ containing S devices, the original model W is divided into S sub-models according to the network parameters. Each sub-model Wi is further divided into two parts at the cutting level. The server assigns these sub-models to the participating clients for training, with clients in each cluster updating their models in parallel using local data.
In the
r-round training process, the first device in the cluster is treated as the main client. The main client applies data augmentation to its sample, passing it through the sub-model
to obtain the activation
of
S cutting layers. Next, the main client sends the activation
to
S − 1 proxy clients in the cluster. Each proxy client
ci∈
c1,
c2, …,
cS completes the forward propagation of the upper sub-model in parallel, thus obtaining the predicted value of the sample. The server then collects the predictions
p1,
p2, …,
ps from all devices, calculates the losses using a designated loss function, and propagates gradients back to the cutting layer in each client’s sub-model. Each client applies the sparse gradient strategy according to the
Mask and sends the sparse gradient to the main client. The main client then performs gradient dequantized, updates the sub-models
via backpropagation, and sends the updated
to the next main client in sequence. This process continues until each device in the cluster has taken its turn as the main client.
Algorithm 1: FedCS |
Input: the set of all devices S, the batch size B, the number of global rounds R, the model of cluster C∈ WC, the number of data samples N(C), the cutting layer L, the learning rate η |
|
After the r-th training round, the main client sends the updated ensemble version of the global lower sub-model to the server. Each device ci sends its updated upper sub-model to the server, which then applies an adaptive weight strategy to obtain the upper ensemble model . The server merges and to create an ensemble model for each cluster, and aggregates the ensemble models from all clusters using a weighted average to generate a new global ensemble model.
6. Conclusions
In this paper, a federated collaboration framework with sparse gradients is designed to address the communication overhead, data heterogeneity, and training efficiency of resource-constrained devices in distributed learning environments. First, the model is partitioned across different devices, enabling parallel training on resource-constrained devices for large models. Second, to improve the training efficiency, the sparse gradient strategy is constructed by using position Mask to reduce the data transmission, and the gradient dequantization strategy is introduced to recover the original dense gradient and reduce the negative impact on the model performance. Then, the distance between the client and the global model is assessed with Euclidean distance to measure the performance of each client in the classification task, and then an adaptive weight strategy is designed to assign the client an appropriate aggregate weight. Finally, a new federated collaboration algorithm is designed by combining the sparse gradient quantization method with the adaptive weight strategy. The performance of the proposed algorithm is evaluated using different network models and datasets. Experimental results show that on Cifar10 datasets, our method improves training time by about 35% and accuracy by about 13%. On the Cifar100 dataset, the training time is improved by approximately 8%, while accuracy increases by about 20%.
Although the proposed method efficiently supports distributed training on resource-constrained devices, there are still some directions worth further research:
(1) Because the federated collaborative framework adopts model splitting to assign training tasks to multiple devices, the shared data may expose sensitive user information. Therefore, future work can be combined with existing privacy-preserving technologies, such as differential privacy and multi-party computation, to achieve the protection of transmitted data information. (2) Given the potential for malicious attacks in real-world scenarios, to improve the robustness of the system, methods such as credibility evaluation algorithm and anomaly detection are introduced to identify and isolate untrusted clients and design a mechanism to prevent antagonistic attacks. (3) Considering the communication bottleneck and scale of the federated collaboration framework, future work could focus on optimizing data transmission protocols to reduce communication load, designing more lightweight server-side computation methods, and designing a more simplified weight allocation strategy combined with different types of tasks to reduce the computational burden.