fedbox.optimization.fedavg
1from ..datasets import utils 2from .utils import WeightingScheme, Logger 3 4from collections import OrderedDict 5from copy import deepcopy 6import random 7import torch 8import torch.nn as nn 9import torch.optim as optim 10import torch.utils.data as data 11from tqdm import tqdm 12 13 14class Agent: 15 ''' 16 An agent (client) uses the FedAvg scheme to optimize a shared model on its local subset. 17 ''' 18 19 def __init__(self, subset: utils.FederatedSubset): 20 ''' 21 Initializes the agent with a local `subset` of data samples and labels. 22 23 Parameters 24 ---------- 25 subset: utils.FederatedSubset 26 Subset of data samples and labels 27 ''' 28 29 self.subset = subset 30 31 def step(self, model: torch.nn.Module, x: torch.Tensor, y: torch.Tensor, optimizer: optim.Optimizer, max_gradient_norm: float): 32 ''' 33 Performs an optimization step on `model` using minibatch (`x`, `y`). 34 35 Parameters 36 ---------- 37 model: torch.nn.Module 38 Model that is optimized locally 39 x: torch.Tensor 40 Data samples in the minibatch 41 y: torch.Tensor 42 Data labels in the minibatch 43 optimizer: optim.Optimizer 44 Gradient-based optimizer 45 max_gradient_norm: float 46 Value used to clip the norm of the stochastic gradient 47 ''' 48 49 prediction = model(x) 50 51 loss = nn.functional.cross_entropy(prediction, y) 52 loss.backward() 53 54 torch.nn.utils.clip_grad_norm_(model.parameters(), max_gradient_norm, error_if_nonfinite = True) 55 56 optimizer.step() 57 optimizer.zero_grad() 58 59 def optimize(self, model: torch.nn.Module, n_steps: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device): 60 ''' 61 Runs `n_steps` stochastic gradient descent steps on the local dataset (one step for each minibatch). 62 63 Parameters 64 ---------- 65 model: torch.nn.Module 66 Model that is locally optimized 67 n_steps: int 68 Number of local SGD steps, i.e. number of minibatches 69 step_size: float 70 Step size or learning rate 71 l2_penalty: float 72 Weight of L2 (Tikhonov) regularization term 73 max_gradient_norm: float 74 Value used to clip the norm of the stochastic gradient 75 device: torch.device 76 Accelerator to run the code 77 ''' 78 79 loader = data.DataLoader(self.subset, batch_size = len(self.subset) // n_steps) 80 optimizer = optim.SGD(model.parameters(), lr = step_size, weight_decay = l2_penalty) 81 82 model.train() 83 84 for x, y in loader: 85 self.step(model, x.to(device), y.to(device), optimizer = optimizer, max_gradient_norm = max_gradient_norm) 86 87 return model 88 89 def multioptimize(self, model: torch.nn.Module, n_epochs: int, batch_size: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device): 90 ''' 91 Runs `n_epochs` stochastic gradient descent epochs on the local dataset. 92 93 Parameters 94 ---------- 95 model: torch.nn.Module 96 Model that is locally optimized 97 n_epochs: int 98 Number of local epochs to pass over the entire local dataset 99 step_size: float 100 Step size or learning rate 101 l2_penalty: float 102 Weight of L2 (Tikhonov) regularization term 103 max_gradient_norm: float 104 Value used to clip the norm of the stochastic gradient 105 device: torch.device 106 Accelerator to run the code 107 108 Note 109 ---- 110 Differently from `optimize(...)`, each epoch corresponds to passing over the entire dataset using SGD. 111 ''' 112 113 loader = data.DataLoader(self.subset, batch_size = batch_size) 114 optimizer = optim.SGD(model.parameters(), lr = step_size, weight_decay = l2_penalty) 115 116 model.train() 117 118 for _ in range(n_epochs): 119 for x, y in loader: 120 self.step(model, x.to(device), y.to(device), optimizer = optimizer, max_gradient_norm = max_gradient_norm) 121 122 return model 123 124 def evaluate(self, model: torch.nn.Module, device: torch.device) -> tuple[float, float]: 125 ''' 126 Evaluate the `model` by computing the average sample loss and accuracy. 127 128 Parameters 129 ---------- 130 model: torch.nn.Module 131 Model that is locally optimized 132 device: torch.device 133 Accelerator to run the code 134 135 Returns 136 ------- 137 tuple[float, float] 138 Tuple of average sample loss and accuracy on the local dataset 139 ''' 140 141 loader = data.DataLoader(self.subset, batch_size = len(self.subset)) 142 x, y = next(iter(loader)) 143 x = x.to(device) 144 y = y.to(device) 145 146 model.eval() 147 148 with torch.no_grad(): 149 prediction = model(x) 150 loss = nn.functional.cross_entropy(prediction, y) 151 accuracy = torch.sum(torch.argmax(prediction, dim = 1) == y) 152 return loss.item(), accuracy.item() 153 154 155class Coordinator: 156 ''' 157 This class represents a centralized server coordinating the training of a shared model across multiple agents (i.e. clients). 158 159 Note 160 ---- 161 The agents locally update their models using the FedAvg optimization scheme. 162 ''' 163 164 def __init__( 165 self, 166 model: torch.nn.Module, 167 datasets: dict[str, list[utils.FederatedSubset]], 168 scheme: WeightingScheme = None, 169 logger: Logger = Logger.default() 170 ): 171 ''' 172 Constructs the centralized coordinator, i.e. server, in the federated learning simulation. 173 174 Parameters 175 ---------- 176 model: torch.nn.Module 177 Initial shared model 178 datasets: dict[str, list[utils.FederatedSubset]] 179 Training clients' subsets ('training') and testing clients' subsets ('testing') 180 scheme: WeightingScheme 181 Aggregation scheme to weight local updates from clients 182 logger: Logger 183 Logger instance to save progress during the simulation 184 ''' 185 186 self.datasets = datasets 187 self.model = model 188 self.agents = { 189 group: [ Agent(subset) for subset in dataset ] for group, dataset in datasets.items() 190 } 191 self.weights = scheme.weights() 192 self.logger = logger 193 194 def run(self, n_iterations: int, n_steps: int = None, n_epochs = None, batch_size: int = 32, step_size: float = 1e-3, step_size_diminishing: bool = False, l2_penalty: float = 1e-4, max_gradient_norm: float = 1.0, device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')): 195 ''' 196 Runs `n_iterations` optimization (with algorithm FedAvg) and evaluation rounds on training clients. 197 198 Parameters 199 ---------- 200 n_iterations: int 201 Number of global rounds 202 n_steps: int 203 Number of local SGD steps used for optimization on clients 204 n_epochs: int 205 Number of local epochs used for optimization on clients (mutually excludes n_steps) 206 batch_size: int 207 Number of samples in one SGD minibatch 208 step_size: float 209 Learning rate 210 step_size_diminishing: bool 211 This enables diminishing the step size linearly in time 212 l2_penalty: float 213 Weight of the L2 (Tikhonov) regularization used to penalize local models 214 max_gradient_norm: float 215 Value used to clip the norm of the stochastic gradient during local optimization 216 device: torch.device 217 Accelerator to run the code 218 evaluate: bool 219 Flag that enables evaluation of the update global model on training and testing clients 220 221 Note 222 ---- 223 Runs `n_iterations` times function `iterate(...)`. 224 ''' 225 226 assert n_steps is not None or n_epochs is not None 227 228 self.model = self.model.to(device) 229 self.model.compile() 230 231 for iteration in range(n_iterations): 232 step_size_updated = step_size if not step_size_diminishing else step_size / (iteration + 1) 233 metrics = self.iterate(iteration, n_steps, n_epochs, batch_size, step_size_updated, l2_penalty, max_gradient_norm, device, evaluate = True) 234 235 self.logger.log({ 236 'step': iteration, 237 'loss.training': metrics['training']['loss'], 238 'loss.testing': metrics['testing']['loss'], 239 'accuracy.training': metrics['training']['accuracy'], 240 'accuracy.testing': metrics['testing']['accuracy'], 241 }) 242 243 print(iteration, metrics) 244 245 def iterate(self, iteration: int, n_steps: int, n_epochs: int, batch_size: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device, evaluate: bool = False) -> dict[str, float]: 246 ''' 247 Runs a single optimization round with FedAvg algorithm on all training clients. 248 249 Parameters 250 ---------- 251 iteration: int 252 Current global round 253 n_steps: int 254 Number of local SGD steps used for optimization on clients 255 n_epochs: int 256 Number of local epochs used for optimization on clients (mutually excludes n_steps) 257 batch_size: int 258 Number of samples in one SGD minibatch 259 step_size: float 260 Learning rate 261 l2_penalty: float 262 Weight of the L2 (Tikhonov) regularization used to penalize local models 263 max_gradient_norm: float 264 Value used to clip the norm of the stochastic gradient during local optimization 265 device: torch.device 266 Accelerator to run the code 267 evaluate: bool 268 Flag that enables evaluation of the update global model on training and testing clients 269 270 Returns 271 ------- 272 dict[str, float] 273 Dictionary of current round's metrics 274 ''' 275 276 indices = list(range(0, len(self.agents['training']))) 277 k = len(self.agents['training']) 278 279 random.shuffle(indices) 280 281 indices = indices[:k] 282 participants = [ self.agents['training'][i] for i in indices ] 283 weights = [ self.weights['training'][i] for i in indices ] 284 285 initial_model = deepcopy(self.model) 286 updates: list[nn.Module] = [ initial_model for _ in self.agents['training'] ] 287 288 if n_steps is not None: 289 for i, participant in tqdm(zip(indices, participants), total = len(participants), desc = 'Optimization on training agents (iteration {})'.format(iteration)): 290 updates[i] = participant.optimize(deepcopy(initial_model), n_steps = n_steps, step_size = step_size, l2_penalty = l2_penalty, max_gradient_norm = max_gradient_norm, device = device) 291 else: 292 for i, participant in tqdm(zip(indices, participants), total = len(participants), desc = 'Optimization on training agents (iteration {})'.format(iteration)): 293 updates[i] = participant.multioptimize(deepcopy(initial_model), n_epochs = n_epochs, batch_size = batch_size, step_size = step_size, l2_penalty = l2_penalty, max_gradient_norm = max_gradient_norm, device = device) 294 295 self.average(updates, weights = weights) 296 297 if not evaluate: 298 return {} 299 300 return self.evaluate(iteration, device) 301 302 def evaluate(self, iteration: int, device: torch.device): 303 ''' 304 Computes average sample accuracy and loss of the global model on training and testing clients during current `iteration`. 305 306 Parameters 307 ---------- 308 iteration: int 309 Current global round of the simulation 310 device: torch.device 311 Accelerator 312 ''' 313 314 metrics = { 'training': { 'loss': 0, 'accuracy': 0 }, 'testing': { 'loss': 0, 'accuracy': 0 } } 315 316 for agent, weight in tqdm(zip(self.agents['training'], self.weights['training']), total = len(self.agents['training']), desc = 'Evaluation on training agents (iteration {})'.format(iteration)): 317 loss, accuracy = agent.evaluate(self.model, device = device) 318 metrics['training']['loss'] = metrics['training']['loss'] + weight * loss 319 metrics['training']['accuracy'] = metrics['training']['accuracy'] + accuracy 320 321 for agent, weight in tqdm(zip(self.agents['testing'], self.weights['testing']), total = len(self.agents['testing']), desc = 'Evaluation on testing agents (iteration {})'.format(iteration)): 322 loss, accuracy = agent.evaluate(self.model, device = device) 323 metrics['testing']['loss'] = metrics['testing']['loss'] + weight * loss 324 metrics['testing']['accuracy'] = metrics['testing']['accuracy'] + accuracy 325 326 metrics['training']['loss'] /= sum(self.weights['training']) 327 metrics['testing']['loss'] /= sum(self.weights['testing']) 328 329 metrics['training']['accuracy'] /= sum([ len(agent.subset) for agent in self.agents['training'] ]) 330 metrics['testing']['accuracy'] /= sum([ len(agent.subset) for agent in self.agents['testing'] ]) 331 332 return metrics 333 334 def average(self, updates: list[torch.nn.Module], weights: list[float]): 335 ''' 336 Averages clients' `updates` weighted by aggreation `weights` into the shared model. 337 338 Parameters 339 ---------- 340 updates: list[torch.nn.Module] 341 Locally updated clients' models 342 weights: list[float] 343 Aggregation weights (one for each client) 344 ''' 345 346 total = sum(weights) 347 self.model.load_state_dict(OrderedDict([ 348 ( 349 name, 350 torch.stack([ weight * update.state_dict()[name] for update, weight in zip(updates, weights) ]).sum(dim = 0) / total 351 ) 352 for name in self.model.state_dict().keys() 353 ]))
15class Agent: 16 ''' 17 An agent (client) uses the FedAvg scheme to optimize a shared model on its local subset. 18 ''' 19 20 def __init__(self, subset: utils.FederatedSubset): 21 ''' 22 Initializes the agent with a local `subset` of data samples and labels. 23 24 Parameters 25 ---------- 26 subset: utils.FederatedSubset 27 Subset of data samples and labels 28 ''' 29 30 self.subset = subset 31 32 def step(self, model: torch.nn.Module, x: torch.Tensor, y: torch.Tensor, optimizer: optim.Optimizer, max_gradient_norm: float): 33 ''' 34 Performs an optimization step on `model` using minibatch (`x`, `y`). 35 36 Parameters 37 ---------- 38 model: torch.nn.Module 39 Model that is optimized locally 40 x: torch.Tensor 41 Data samples in the minibatch 42 y: torch.Tensor 43 Data labels in the minibatch 44 optimizer: optim.Optimizer 45 Gradient-based optimizer 46 max_gradient_norm: float 47 Value used to clip the norm of the stochastic gradient 48 ''' 49 50 prediction = model(x) 51 52 loss = nn.functional.cross_entropy(prediction, y) 53 loss.backward() 54 55 torch.nn.utils.clip_grad_norm_(model.parameters(), max_gradient_norm, error_if_nonfinite = True) 56 57 optimizer.step() 58 optimizer.zero_grad() 59 60 def optimize(self, model: torch.nn.Module, n_steps: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device): 61 ''' 62 Runs `n_steps` stochastic gradient descent steps on the local dataset (one step for each minibatch). 63 64 Parameters 65 ---------- 66 model: torch.nn.Module 67 Model that is locally optimized 68 n_steps: int 69 Number of local SGD steps, i.e. number of minibatches 70 step_size: float 71 Step size or learning rate 72 l2_penalty: float 73 Weight of L2 (Tikhonov) regularization term 74 max_gradient_norm: float 75 Value used to clip the norm of the stochastic gradient 76 device: torch.device 77 Accelerator to run the code 78 ''' 79 80 loader = data.DataLoader(self.subset, batch_size = len(self.subset) // n_steps) 81 optimizer = optim.SGD(model.parameters(), lr = step_size, weight_decay = l2_penalty) 82 83 model.train() 84 85 for x, y in loader: 86 self.step(model, x.to(device), y.to(device), optimizer = optimizer, max_gradient_norm = max_gradient_norm) 87 88 return model 89 90 def multioptimize(self, model: torch.nn.Module, n_epochs: int, batch_size: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device): 91 ''' 92 Runs `n_epochs` stochastic gradient descent epochs on the local dataset. 93 94 Parameters 95 ---------- 96 model: torch.nn.Module 97 Model that is locally optimized 98 n_epochs: int 99 Number of local epochs to pass over the entire local dataset 100 step_size: float 101 Step size or learning rate 102 l2_penalty: float 103 Weight of L2 (Tikhonov) regularization term 104 max_gradient_norm: float 105 Value used to clip the norm of the stochastic gradient 106 device: torch.device 107 Accelerator to run the code 108 109 Note 110 ---- 111 Differently from `optimize(...)`, each epoch corresponds to passing over the entire dataset using SGD. 112 ''' 113 114 loader = data.DataLoader(self.subset, batch_size = batch_size) 115 optimizer = optim.SGD(model.parameters(), lr = step_size, weight_decay = l2_penalty) 116 117 model.train() 118 119 for _ in range(n_epochs): 120 for x, y in loader: 121 self.step(model, x.to(device), y.to(device), optimizer = optimizer, max_gradient_norm = max_gradient_norm) 122 123 return model 124 125 def evaluate(self, model: torch.nn.Module, device: torch.device) -> tuple[float, float]: 126 ''' 127 Evaluate the `model` by computing the average sample loss and accuracy. 128 129 Parameters 130 ---------- 131 model: torch.nn.Module 132 Model that is locally optimized 133 device: torch.device 134 Accelerator to run the code 135 136 Returns 137 ------- 138 tuple[float, float] 139 Tuple of average sample loss and accuracy on the local dataset 140 ''' 141 142 loader = data.DataLoader(self.subset, batch_size = len(self.subset)) 143 x, y = next(iter(loader)) 144 x = x.to(device) 145 y = y.to(device) 146 147 model.eval() 148 149 with torch.no_grad(): 150 prediction = model(x) 151 loss = nn.functional.cross_entropy(prediction, y) 152 accuracy = torch.sum(torch.argmax(prediction, dim = 1) == y) 153 return loss.item(), accuracy.item()
An agent (client) uses the FedAvg scheme to optimize a shared model on its local subset.
20 def __init__(self, subset: utils.FederatedSubset): 21 ''' 22 Initializes the agent with a local `subset` of data samples and labels. 23 24 Parameters 25 ---------- 26 subset: utils.FederatedSubset 27 Subset of data samples and labels 28 ''' 29 30 self.subset = subset
Initializes the agent with a local subset of data samples and labels.
Parameters
- subset (utils.FederatedSubset): Subset of data samples and labels
32 def step(self, model: torch.nn.Module, x: torch.Tensor, y: torch.Tensor, optimizer: optim.Optimizer, max_gradient_norm: float): 33 ''' 34 Performs an optimization step on `model` using minibatch (`x`, `y`). 35 36 Parameters 37 ---------- 38 model: torch.nn.Module 39 Model that is optimized locally 40 x: torch.Tensor 41 Data samples in the minibatch 42 y: torch.Tensor 43 Data labels in the minibatch 44 optimizer: optim.Optimizer 45 Gradient-based optimizer 46 max_gradient_norm: float 47 Value used to clip the norm of the stochastic gradient 48 ''' 49 50 prediction = model(x) 51 52 loss = nn.functional.cross_entropy(prediction, y) 53 loss.backward() 54 55 torch.nn.utils.clip_grad_norm_(model.parameters(), max_gradient_norm, error_if_nonfinite = True) 56 57 optimizer.step() 58 optimizer.zero_grad()
Performs an optimization step on model using minibatch (x, y).
Parameters
- model (torch.nn.Module): Model that is optimized locally
- x (torch.Tensor): Data samples in the minibatch
- y (torch.Tensor): Data labels in the minibatch
- optimizer (optim.Optimizer): Gradient-based optimizer
- max_gradient_norm (float): Value used to clip the norm of the stochastic gradient
60 def optimize(self, model: torch.nn.Module, n_steps: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device): 61 ''' 62 Runs `n_steps` stochastic gradient descent steps on the local dataset (one step for each minibatch). 63 64 Parameters 65 ---------- 66 model: torch.nn.Module 67 Model that is locally optimized 68 n_steps: int 69 Number of local SGD steps, i.e. number of minibatches 70 step_size: float 71 Step size or learning rate 72 l2_penalty: float 73 Weight of L2 (Tikhonov) regularization term 74 max_gradient_norm: float 75 Value used to clip the norm of the stochastic gradient 76 device: torch.device 77 Accelerator to run the code 78 ''' 79 80 loader = data.DataLoader(self.subset, batch_size = len(self.subset) // n_steps) 81 optimizer = optim.SGD(model.parameters(), lr = step_size, weight_decay = l2_penalty) 82 83 model.train() 84 85 for x, y in loader: 86 self.step(model, x.to(device), y.to(device), optimizer = optimizer, max_gradient_norm = max_gradient_norm) 87 88 return model
Runs n_steps stochastic gradient descent steps on the local dataset (one step for each minibatch).
Parameters
- model (torch.nn.Module): Model that is locally optimized
- n_steps (int): Number of local SGD steps, i.e. number of minibatches
- step_size (float): Step size or learning rate
- l2_penalty (float): Weight of L2 (Tikhonov) regularization term
- max_gradient_norm (float): Value used to clip the norm of the stochastic gradient
- device (torch.device): Accelerator to run the code
90 def multioptimize(self, model: torch.nn.Module, n_epochs: int, batch_size: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device): 91 ''' 92 Runs `n_epochs` stochastic gradient descent epochs on the local dataset. 93 94 Parameters 95 ---------- 96 model: torch.nn.Module 97 Model that is locally optimized 98 n_epochs: int 99 Number of local epochs to pass over the entire local dataset 100 step_size: float 101 Step size or learning rate 102 l2_penalty: float 103 Weight of L2 (Tikhonov) regularization term 104 max_gradient_norm: float 105 Value used to clip the norm of the stochastic gradient 106 device: torch.device 107 Accelerator to run the code 108 109 Note 110 ---- 111 Differently from `optimize(...)`, each epoch corresponds to passing over the entire dataset using SGD. 112 ''' 113 114 loader = data.DataLoader(self.subset, batch_size = batch_size) 115 optimizer = optim.SGD(model.parameters(), lr = step_size, weight_decay = l2_penalty) 116 117 model.train() 118 119 for _ in range(n_epochs): 120 for x, y in loader: 121 self.step(model, x.to(device), y.to(device), optimizer = optimizer, max_gradient_norm = max_gradient_norm) 122 123 return model
Runs n_epochs stochastic gradient descent epochs on the local dataset.
Parameters
- model (torch.nn.Module): Model that is locally optimized
- n_epochs (int): Number of local epochs to pass over the entire local dataset
- step_size (float): Step size or learning rate
- l2_penalty (float): Weight of L2 (Tikhonov) regularization term
- max_gradient_norm (float): Value used to clip the norm of the stochastic gradient
- device (torch.device): Accelerator to run the code
Note
Differently from optimize(...), each epoch corresponds to passing over the entire dataset using SGD.
125 def evaluate(self, model: torch.nn.Module, device: torch.device) -> tuple[float, float]: 126 ''' 127 Evaluate the `model` by computing the average sample loss and accuracy. 128 129 Parameters 130 ---------- 131 model: torch.nn.Module 132 Model that is locally optimized 133 device: torch.device 134 Accelerator to run the code 135 136 Returns 137 ------- 138 tuple[float, float] 139 Tuple of average sample loss and accuracy on the local dataset 140 ''' 141 142 loader = data.DataLoader(self.subset, batch_size = len(self.subset)) 143 x, y = next(iter(loader)) 144 x = x.to(device) 145 y = y.to(device) 146 147 model.eval() 148 149 with torch.no_grad(): 150 prediction = model(x) 151 loss = nn.functional.cross_entropy(prediction, y) 152 accuracy = torch.sum(torch.argmax(prediction, dim = 1) == y) 153 return loss.item(), accuracy.item()
Evaluate the model by computing the average sample loss and accuracy.
Parameters
- model (torch.nn.Module): Model that is locally optimized
- device (torch.device): Accelerator to run the code
Returns
- tuple[float, float]: Tuple of average sample loss and accuracy on the local dataset
156class Coordinator: 157 ''' 158 This class represents a centralized server coordinating the training of a shared model across multiple agents (i.e. clients). 159 160 Note 161 ---- 162 The agents locally update their models using the FedAvg optimization scheme. 163 ''' 164 165 def __init__( 166 self, 167 model: torch.nn.Module, 168 datasets: dict[str, list[utils.FederatedSubset]], 169 scheme: WeightingScheme = None, 170 logger: Logger = Logger.default() 171 ): 172 ''' 173 Constructs the centralized coordinator, i.e. server, in the federated learning simulation. 174 175 Parameters 176 ---------- 177 model: torch.nn.Module 178 Initial shared model 179 datasets: dict[str, list[utils.FederatedSubset]] 180 Training clients' subsets ('training') and testing clients' subsets ('testing') 181 scheme: WeightingScheme 182 Aggregation scheme to weight local updates from clients 183 logger: Logger 184 Logger instance to save progress during the simulation 185 ''' 186 187 self.datasets = datasets 188 self.model = model 189 self.agents = { 190 group: [ Agent(subset) for subset in dataset ] for group, dataset in datasets.items() 191 } 192 self.weights = scheme.weights() 193 self.logger = logger 194 195 def run(self, n_iterations: int, n_steps: int = None, n_epochs = None, batch_size: int = 32, step_size: float = 1e-3, step_size_diminishing: bool = False, l2_penalty: float = 1e-4, max_gradient_norm: float = 1.0, device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')): 196 ''' 197 Runs `n_iterations` optimization (with algorithm FedAvg) and evaluation rounds on training clients. 198 199 Parameters 200 ---------- 201 n_iterations: int 202 Number of global rounds 203 n_steps: int 204 Number of local SGD steps used for optimization on clients 205 n_epochs: int 206 Number of local epochs used for optimization on clients (mutually excludes n_steps) 207 batch_size: int 208 Number of samples in one SGD minibatch 209 step_size: float 210 Learning rate 211 step_size_diminishing: bool 212 This enables diminishing the step size linearly in time 213 l2_penalty: float 214 Weight of the L2 (Tikhonov) regularization used to penalize local models 215 max_gradient_norm: float 216 Value used to clip the norm of the stochastic gradient during local optimization 217 device: torch.device 218 Accelerator to run the code 219 evaluate: bool 220 Flag that enables evaluation of the update global model on training and testing clients 221 222 Note 223 ---- 224 Runs `n_iterations` times function `iterate(...)`. 225 ''' 226 227 assert n_steps is not None or n_epochs is not None 228 229 self.model = self.model.to(device) 230 self.model.compile() 231 232 for iteration in range(n_iterations): 233 step_size_updated = step_size if not step_size_diminishing else step_size / (iteration + 1) 234 metrics = self.iterate(iteration, n_steps, n_epochs, batch_size, step_size_updated, l2_penalty, max_gradient_norm, device, evaluate = True) 235 236 self.logger.log({ 237 'step': iteration, 238 'loss.training': metrics['training']['loss'], 239 'loss.testing': metrics['testing']['loss'], 240 'accuracy.training': metrics['training']['accuracy'], 241 'accuracy.testing': metrics['testing']['accuracy'], 242 }) 243 244 print(iteration, metrics) 245 246 def iterate(self, iteration: int, n_steps: int, n_epochs: int, batch_size: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device, evaluate: bool = False) -> dict[str, float]: 247 ''' 248 Runs a single optimization round with FedAvg algorithm on all training clients. 249 250 Parameters 251 ---------- 252 iteration: int 253 Current global round 254 n_steps: int 255 Number of local SGD steps used for optimization on clients 256 n_epochs: int 257 Number of local epochs used for optimization on clients (mutually excludes n_steps) 258 batch_size: int 259 Number of samples in one SGD minibatch 260 step_size: float 261 Learning rate 262 l2_penalty: float 263 Weight of the L2 (Tikhonov) regularization used to penalize local models 264 max_gradient_norm: float 265 Value used to clip the norm of the stochastic gradient during local optimization 266 device: torch.device 267 Accelerator to run the code 268 evaluate: bool 269 Flag that enables evaluation of the update global model on training and testing clients 270 271 Returns 272 ------- 273 dict[str, float] 274 Dictionary of current round's metrics 275 ''' 276 277 indices = list(range(0, len(self.agents['training']))) 278 k = len(self.agents['training']) 279 280 random.shuffle(indices) 281 282 indices = indices[:k] 283 participants = [ self.agents['training'][i] for i in indices ] 284 weights = [ self.weights['training'][i] for i in indices ] 285 286 initial_model = deepcopy(self.model) 287 updates: list[nn.Module] = [ initial_model for _ in self.agents['training'] ] 288 289 if n_steps is not None: 290 for i, participant in tqdm(zip(indices, participants), total = len(participants), desc = 'Optimization on training agents (iteration {})'.format(iteration)): 291 updates[i] = participant.optimize(deepcopy(initial_model), n_steps = n_steps, step_size = step_size, l2_penalty = l2_penalty, max_gradient_norm = max_gradient_norm, device = device) 292 else: 293 for i, participant in tqdm(zip(indices, participants), total = len(participants), desc = 'Optimization on training agents (iteration {})'.format(iteration)): 294 updates[i] = participant.multioptimize(deepcopy(initial_model), n_epochs = n_epochs, batch_size = batch_size, step_size = step_size, l2_penalty = l2_penalty, max_gradient_norm = max_gradient_norm, device = device) 295 296 self.average(updates, weights = weights) 297 298 if not evaluate: 299 return {} 300 301 return self.evaluate(iteration, device) 302 303 def evaluate(self, iteration: int, device: torch.device): 304 ''' 305 Computes average sample accuracy and loss of the global model on training and testing clients during current `iteration`. 306 307 Parameters 308 ---------- 309 iteration: int 310 Current global round of the simulation 311 device: torch.device 312 Accelerator 313 ''' 314 315 metrics = { 'training': { 'loss': 0, 'accuracy': 0 }, 'testing': { 'loss': 0, 'accuracy': 0 } } 316 317 for agent, weight in tqdm(zip(self.agents['training'], self.weights['training']), total = len(self.agents['training']), desc = 'Evaluation on training agents (iteration {})'.format(iteration)): 318 loss, accuracy = agent.evaluate(self.model, device = device) 319 metrics['training']['loss'] = metrics['training']['loss'] + weight * loss 320 metrics['training']['accuracy'] = metrics['training']['accuracy'] + accuracy 321 322 for agent, weight in tqdm(zip(self.agents['testing'], self.weights['testing']), total = len(self.agents['testing']), desc = 'Evaluation on testing agents (iteration {})'.format(iteration)): 323 loss, accuracy = agent.evaluate(self.model, device = device) 324 metrics['testing']['loss'] = metrics['testing']['loss'] + weight * loss 325 metrics['testing']['accuracy'] = metrics['testing']['accuracy'] + accuracy 326 327 metrics['training']['loss'] /= sum(self.weights['training']) 328 metrics['testing']['loss'] /= sum(self.weights['testing']) 329 330 metrics['training']['accuracy'] /= sum([ len(agent.subset) for agent in self.agents['training'] ]) 331 metrics['testing']['accuracy'] /= sum([ len(agent.subset) for agent in self.agents['testing'] ]) 332 333 return metrics 334 335 def average(self, updates: list[torch.nn.Module], weights: list[float]): 336 ''' 337 Averages clients' `updates` weighted by aggreation `weights` into the shared model. 338 339 Parameters 340 ---------- 341 updates: list[torch.nn.Module] 342 Locally updated clients' models 343 weights: list[float] 344 Aggregation weights (one for each client) 345 ''' 346 347 total = sum(weights) 348 self.model.load_state_dict(OrderedDict([ 349 ( 350 name, 351 torch.stack([ weight * update.state_dict()[name] for update, weight in zip(updates, weights) ]).sum(dim = 0) / total 352 ) 353 for name in self.model.state_dict().keys() 354 ]))
This class represents a centralized server coordinating the training of a shared model across multiple agents (i.e. clients).
Note
The agents locally update their models using the FedAvg optimization scheme.
165 def __init__( 166 self, 167 model: torch.nn.Module, 168 datasets: dict[str, list[utils.FederatedSubset]], 169 scheme: WeightingScheme = None, 170 logger: Logger = Logger.default() 171 ): 172 ''' 173 Constructs the centralized coordinator, i.e. server, in the federated learning simulation. 174 175 Parameters 176 ---------- 177 model: torch.nn.Module 178 Initial shared model 179 datasets: dict[str, list[utils.FederatedSubset]] 180 Training clients' subsets ('training') and testing clients' subsets ('testing') 181 scheme: WeightingScheme 182 Aggregation scheme to weight local updates from clients 183 logger: Logger 184 Logger instance to save progress during the simulation 185 ''' 186 187 self.datasets = datasets 188 self.model = model 189 self.agents = { 190 group: [ Agent(subset) for subset in dataset ] for group, dataset in datasets.items() 191 } 192 self.weights = scheme.weights() 193 self.logger = logger
Constructs the centralized coordinator, i.e. server, in the federated learning simulation.
Parameters
- model (torch.nn.Module): Initial shared model
- datasets (dict[str, list[utils.FederatedSubset]]): Training clients' subsets ('training') and testing clients' subsets ('testing')
- scheme (WeightingScheme): Aggregation scheme to weight local updates from clients
- logger (Logger): Logger instance to save progress during the simulation
195 def run(self, n_iterations: int, n_steps: int = None, n_epochs = None, batch_size: int = 32, step_size: float = 1e-3, step_size_diminishing: bool = False, l2_penalty: float = 1e-4, max_gradient_norm: float = 1.0, device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')): 196 ''' 197 Runs `n_iterations` optimization (with algorithm FedAvg) and evaluation rounds on training clients. 198 199 Parameters 200 ---------- 201 n_iterations: int 202 Number of global rounds 203 n_steps: int 204 Number of local SGD steps used for optimization on clients 205 n_epochs: int 206 Number of local epochs used for optimization on clients (mutually excludes n_steps) 207 batch_size: int 208 Number of samples in one SGD minibatch 209 step_size: float 210 Learning rate 211 step_size_diminishing: bool 212 This enables diminishing the step size linearly in time 213 l2_penalty: float 214 Weight of the L2 (Tikhonov) regularization used to penalize local models 215 max_gradient_norm: float 216 Value used to clip the norm of the stochastic gradient during local optimization 217 device: torch.device 218 Accelerator to run the code 219 evaluate: bool 220 Flag that enables evaluation of the update global model on training and testing clients 221 222 Note 223 ---- 224 Runs `n_iterations` times function `iterate(...)`. 225 ''' 226 227 assert n_steps is not None or n_epochs is not None 228 229 self.model = self.model.to(device) 230 self.model.compile() 231 232 for iteration in range(n_iterations): 233 step_size_updated = step_size if not step_size_diminishing else step_size / (iteration + 1) 234 metrics = self.iterate(iteration, n_steps, n_epochs, batch_size, step_size_updated, l2_penalty, max_gradient_norm, device, evaluate = True) 235 236 self.logger.log({ 237 'step': iteration, 238 'loss.training': metrics['training']['loss'], 239 'loss.testing': metrics['testing']['loss'], 240 'accuracy.training': metrics['training']['accuracy'], 241 'accuracy.testing': metrics['testing']['accuracy'], 242 }) 243 244 print(iteration, metrics)
Runs n_iterations optimization (with algorithm FedAvg) and evaluation rounds on training clients.
Parameters
- n_iterations (int): Number of global rounds
- n_steps (int): Number of local SGD steps used for optimization on clients
- n_epochs (int): Number of local epochs used for optimization on clients (mutually excludes n_steps)
- batch_size (int): Number of samples in one SGD minibatch
- step_size (float): Learning rate
- step_size_diminishing (bool): This enables diminishing the step size linearly in time
- l2_penalty (float): Weight of the L2 (Tikhonov) regularization used to penalize local models
- max_gradient_norm (float): Value used to clip the norm of the stochastic gradient during local optimization
- device (torch.device): Accelerator to run the code
- evaluate (bool): Flag that enables evaluation of the update global model on training and testing clients
Note
Runs n_iterations times function iterate(...).
246 def iterate(self, iteration: int, n_steps: int, n_epochs: int, batch_size: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device, evaluate: bool = False) -> dict[str, float]: 247 ''' 248 Runs a single optimization round with FedAvg algorithm on all training clients. 249 250 Parameters 251 ---------- 252 iteration: int 253 Current global round 254 n_steps: int 255 Number of local SGD steps used for optimization on clients 256 n_epochs: int 257 Number of local epochs used for optimization on clients (mutually excludes n_steps) 258 batch_size: int 259 Number of samples in one SGD minibatch 260 step_size: float 261 Learning rate 262 l2_penalty: float 263 Weight of the L2 (Tikhonov) regularization used to penalize local models 264 max_gradient_norm: float 265 Value used to clip the norm of the stochastic gradient during local optimization 266 device: torch.device 267 Accelerator to run the code 268 evaluate: bool 269 Flag that enables evaluation of the update global model on training and testing clients 270 271 Returns 272 ------- 273 dict[str, float] 274 Dictionary of current round's metrics 275 ''' 276 277 indices = list(range(0, len(self.agents['training']))) 278 k = len(self.agents['training']) 279 280 random.shuffle(indices) 281 282 indices = indices[:k] 283 participants = [ self.agents['training'][i] for i in indices ] 284 weights = [ self.weights['training'][i] for i in indices ] 285 286 initial_model = deepcopy(self.model) 287 updates: list[nn.Module] = [ initial_model for _ in self.agents['training'] ] 288 289 if n_steps is not None: 290 for i, participant in tqdm(zip(indices, participants), total = len(participants), desc = 'Optimization on training agents (iteration {})'.format(iteration)): 291 updates[i] = participant.optimize(deepcopy(initial_model), n_steps = n_steps, step_size = step_size, l2_penalty = l2_penalty, max_gradient_norm = max_gradient_norm, device = device) 292 else: 293 for i, participant in tqdm(zip(indices, participants), total = len(participants), desc = 'Optimization on training agents (iteration {})'.format(iteration)): 294 updates[i] = participant.multioptimize(deepcopy(initial_model), n_epochs = n_epochs, batch_size = batch_size, step_size = step_size, l2_penalty = l2_penalty, max_gradient_norm = max_gradient_norm, device = device) 295 296 self.average(updates, weights = weights) 297 298 if not evaluate: 299 return {} 300 301 return self.evaluate(iteration, device)
Runs a single optimization round with FedAvg algorithm on all training clients.
Parameters
- iteration (int): Current global round
- n_steps (int): Number of local SGD steps used for optimization on clients
- n_epochs (int): Number of local epochs used for optimization on clients (mutually excludes n_steps)
- batch_size (int): Number of samples in one SGD minibatch
- step_size (float): Learning rate
- l2_penalty (float): Weight of the L2 (Tikhonov) regularization used to penalize local models
- max_gradient_norm (float): Value used to clip the norm of the stochastic gradient during local optimization
- device (torch.device): Accelerator to run the code
- evaluate (bool): Flag that enables evaluation of the update global model on training and testing clients
Returns
- dict[str, float]: Dictionary of current round's metrics
303 def evaluate(self, iteration: int, device: torch.device): 304 ''' 305 Computes average sample accuracy and loss of the global model on training and testing clients during current `iteration`. 306 307 Parameters 308 ---------- 309 iteration: int 310 Current global round of the simulation 311 device: torch.device 312 Accelerator 313 ''' 314 315 metrics = { 'training': { 'loss': 0, 'accuracy': 0 }, 'testing': { 'loss': 0, 'accuracy': 0 } } 316 317 for agent, weight in tqdm(zip(self.agents['training'], self.weights['training']), total = len(self.agents['training']), desc = 'Evaluation on training agents (iteration {})'.format(iteration)): 318 loss, accuracy = agent.evaluate(self.model, device = device) 319 metrics['training']['loss'] = metrics['training']['loss'] + weight * loss 320 metrics['training']['accuracy'] = metrics['training']['accuracy'] + accuracy 321 322 for agent, weight in tqdm(zip(self.agents['testing'], self.weights['testing']), total = len(self.agents['testing']), desc = 'Evaluation on testing agents (iteration {})'.format(iteration)): 323 loss, accuracy = agent.evaluate(self.model, device = device) 324 metrics['testing']['loss'] = metrics['testing']['loss'] + weight * loss 325 metrics['testing']['accuracy'] = metrics['testing']['accuracy'] + accuracy 326 327 metrics['training']['loss'] /= sum(self.weights['training']) 328 metrics['testing']['loss'] /= sum(self.weights['testing']) 329 330 metrics['training']['accuracy'] /= sum([ len(agent.subset) for agent in self.agents['training'] ]) 331 metrics['testing']['accuracy'] /= sum([ len(agent.subset) for agent in self.agents['testing'] ]) 332 333 return metrics
Computes average sample accuracy and loss of the global model on training and testing clients during current iteration.
Parameters
- iteration (int): Current global round of the simulation
- device (torch.device): Accelerator
335 def average(self, updates: list[torch.nn.Module], weights: list[float]): 336 ''' 337 Averages clients' `updates` weighted by aggreation `weights` into the shared model. 338 339 Parameters 340 ---------- 341 updates: list[torch.nn.Module] 342 Locally updated clients' models 343 weights: list[float] 344 Aggregation weights (one for each client) 345 ''' 346 347 total = sum(weights) 348 self.model.load_state_dict(OrderedDict([ 349 ( 350 name, 351 torch.stack([ weight * update.state_dict()[name] for update, weight in zip(updates, weights) ]).sum(dim = 0) / total 352 ) 353 for name in self.model.state_dict().keys() 354 ]))
Averages clients' updates weighted by aggreation weights into the shared model.
Parameters
- updates (list[torch.nn.Module]): Locally updated clients' models
- weights (list[float]): Aggregation weights (one for each client)