fedbox.optimization.fedprox
1from ..datasets import utils 2from . import fedavg 3from .utils import WeightingScheme, Logger 4 5from copy import deepcopy 6import random 7import torch 8import torch.nn as nn 9import torch.optim as optim 10import torch.utils.data as data 11from tqdm import tqdm 12 13class Agent(fedavg.Agent): 14 ''' 15 An agent (client) uses the FedProx scheme to optimize a shared model on its local subset. 16 ''' 17 18 def __init__(self, subset: utils.FederatedSubset): 19 ''' 20 Initializes the agent with a local `subset` of data samples and labels. 21 22 Parameters 23 ---------- 24 subset: utils.FederatedSubset 25 Subset of data samples and labels 26 ''' 27 28 self.subset = subset 29 30 def step(self, alpha: float, initial: torch.nn.Module, model: torch.nn.Module, x: torch.Tensor, y: torch.Tensor, optimizer: optim.Optimizer, max_gradient_norm: float): 31 ''' 32 Performs an optimization step on `model` using minibatch (`x`, `y`) accounting for the proximal term weighted by `alpha`. 33 34 Parameters 35 ---------- 36 alpha: float 37 Weight of the proximal regularization term in FedProx 38 initial: torch.nn.Module 39 Model received at the beginning of the round from the server 40 model: torch.nn.Module 41 Model that is optimized locally 42 x: torch.Tensor 43 Data samples in the minibatch 44 y: torch.Tensor 45 Data labels in the minibatch 46 optimizer: optim.Optimizer 47 Gradient-based optimizer 48 max_gradient_norm: float 49 Value used to clip the norm of the stochastic gradient 50 ''' 51 52 prediction = model(x) 53 54 loss = nn.functional.cross_entropy(prediction, y) + alpha * self.proxterm(initial, model) 55 56 loss.backward() 57 58 torch.nn.utils.clip_grad_norm_(model.parameters(), max_gradient_norm, error_if_nonfinite = True) 59 60 optimizer.step() 61 optimizer.zero_grad() 62 63 def optimize(self, alpha: float, initial: torch.nn.Module, model: torch.nn.Module, n_steps: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device): 64 ''' 65 Runs `n_steps` stochastic gradient descent steps including the `alpha`-weighted proximal term on the local dataset (one step for each minibatch). 66 67 Parameters 68 ---------- 69 alpha: float 70 Weight of the proximal regularization term in FedProx 71 initial: torch.nn.Module 72 Model received at the beginning of the round from the server 73 model: torch.nn.Module 74 Model that is locally optimized 75 n_steps: int 76 Number of local SGD steps, i.e. number of minibatches 77 step_size: float 78 Step size or learning rate 79 l2_penalty: float 80 Weight of L2 (Tikhonov) regularization term 81 max_gradient_norm: float 82 Value used to clip the norm of the stochastic gradient 83 device: torch.device 84 Accelerator to run the code 85 ''' 86 87 loader = data.DataLoader(self.subset, batch_size = len(self.subset) // n_steps) 88 optimizer = optim.SGD(model.parameters(), lr = step_size, weight_decay = l2_penalty) 89 90 model.train() 91 92 for x, y in loader: 93 self.step(alpha, initial, model, x.to(device), y.to(device), optimizer = optimizer, max_gradient_norm = max_gradient_norm) 94 95 return model 96 97 def multioptimize(self, alpha: float, initial: torch.nn.Module, model: torch.nn.Module, n_epochs: int, batch_size: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device): 98 ''' 99 Runs `n_epochs` stochastic gradient descent epochs including the `alpha`-weighted proximal term on the local dataset. 100 101 Parameters 102 ---------- 103 alpha: float 104 Weight of the proximal regularization term in FedProx 105 initial: torch.nn.Module 106 Model received at the beginning of the round from the server 107 model: torch.nn.Module 108 Model that is locally optimized 109 n_epochs: int 110 Number of local epochs to pass over the entire local dataset 111 step_size: float 112 Step size or learning rate 113 l2_penalty: float 114 Weight of L2 (Tikhonov) regularization term 115 max_gradient_norm: float 116 Value used to clip the norm of the stochastic gradient 117 device: torch.device 118 Accelerator to run the code 119 120 Note 121 ---- 122 Differently from `optimize(...)`, each epoch corresponds to passing over the entire dataset using SGD. 123 ''' 124 125 loader = data.DataLoader(self.subset, batch_size = batch_size) 126 optimizer = optim.SGD(model.parameters(), lr = step_size, weight_decay = l2_penalty) 127 128 model.train() 129 130 for _ in range(n_epochs): 131 for x, y in loader: 132 self.step(alpha, initial, model, x.to(device), y.to(device), optimizer = optimizer, max_gradient_norm = max_gradient_norm) 133 134 return model 135 136 def evaluate(self, model: torch.nn.Module, device: torch.device) -> tuple[float, float]: 137 ''' 138 Evaluate the `model` by computing the average sample loss and accuracy. 139 140 Parameters 141 ---------- 142 model: torch.nn.Module 143 Model that is locally optimized 144 device: torch.device 145 Accelerator to run the code 146 147 Returns 148 ------- 149 tuple[float, float] 150 Tuple of average sample loss and accuracy on the local dataset 151 ''' 152 153 loader = data.DataLoader(self.subset, batch_size = len(self.subset)) 154 x, y = next(iter(loader)) 155 x = x.to(device) 156 y = y.to(device) 157 158 model.eval() 159 160 with torch.no_grad(): 161 prediction = model(x) 162 loss = nn.functional.cross_entropy(prediction, y) 163 accuracy = torch.sum(torch.argmax(prediction, dim = 1) == y) 164 return loss.item(), accuracy.item() 165 166 def proxterm(self, initial: torch.nn.Module, model: torch.nn.Module) -> torch.Tensor: 167 ''' 168 Computes the `alpha`-weighted proximal term as squared difference between `initial` model and locally optimized `model`. 169 170 Parameters 171 ---------- 172 initial: torch.nn.Module 173 Model received at the beginning of the round from the server 174 model: torch.nn.Module 175 Model that is locally optimized 176 177 Returns 178 ------- 179 torch.Tensor 180 Squared difference between `initial` model and locally optimized `model` 181 ''' 182 183 return torch.sum(torch.tensor([ 184 torch.square(model.get_parameter(name) - initial.state_dict()[name].detach()).sum() 185 for name in model.state_dict().keys() 186 ], requires_grad = True)) 187 188 189 190class Coordinator(fedavg.Coordinator): 191 ''' 192 This class represents a centralized server coordinating the training of a shared model across multiple agents (i.e. clients). 193 194 Note 195 ---- 196 The agents locally update their models using the FedProx optimization scheme. 197 ''' 198 199 def __init__( 200 self, 201 alpha: float, 202 model: torch.nn.Module, 203 datasets: dict[str, list[utils.FederatedSubset]], 204 scheme: WeightingScheme = None, 205 logger: Logger = Logger.default() 206 ): 207 ''' 208 Constructs the centralized coordinator, i.e. server, in the federated learning simulation. 209 210 Parameters 211 ---------- 212 alpha: float 213 Weight of the proximal term used by training agents while running the optimization algorithm FedProx 214 model: torch.nn.Module 215 Initial shared model 216 datasets: dict[str, list[utils.FederatedSubset]] 217 Training clients' subsets ('training') and testing clients' subsets ('testing') 218 scheme: WeightingScheme 219 Aggregation scheme to weight local updates from clients 220 logger: Logger 221 Logger instance to save progress during the simulation 222 ''' 223 224 assert alpha > 0 225 226 self.alpha = alpha 227 self.datasets = datasets 228 self.model = model 229 self.agents = { 230 group: [ Agent(subset) for subset in dataset ] for group, dataset in datasets.items() 231 } 232 self.weights = scheme.weights() 233 self.logger = logger 234 235 def run(self, n_iterations: int, n_steps: int = None, n_epochs = None, batch_size: int = 32, step_size: float = 1e-3, step_size_diminishing: bool = False, l2_penalty: float = 1e-4, max_gradient_norm: float = 1.0, device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')): 236 ''' 237 Runs `n_iterations` optimization (with algorithm FedProx) and evaluation rounds on training clients. 238 239 Parameters 240 ---------- 241 n_iterations: int 242 Number of global rounds 243 n_steps: int 244 Number of local SGD steps used for optimization on clients 245 n_epochs: int 246 Number of local epochs used for optimization on clients (mutually excludes n_steps) 247 batch_size: int 248 Number of samples in one SGD minibatch 249 step_size: float 250 Learning rate 251 step_size_diminishing: bool 252 This enables diminishing the step size linearly in time 253 l2_penalty: float 254 Weight of the L2 (Tikhonov) regularization used to penalize local models 255 max_gradient_norm: float 256 Value used to clip the norm of the stochastic gradient during local optimization 257 device: torch.device 258 Accelerator to run the code 259 evaluate: bool 260 Flag that enables evaluation of the update global model on training and testing clients 261 262 Note 263 ---- 264 Runs `n_iterations` times function `iterate(...)`. 265 ''' 266 267 assert n_steps is not None or n_epochs is not None 268 269 self.model = self.model.to(device) 270 self.model.compile() 271 272 for iteration in range(n_iterations): 273 step_size_updated = step_size if not step_size_diminishing else step_size / (iteration + 1) 274 metrics = self.iterate(iteration, n_steps, n_epochs, batch_size, step_size_updated, l2_penalty, max_gradient_norm, device, evaluate = True) 275 276 self.logger.log({ 277 'step': iteration, 278 'loss.training': metrics['training']['loss'], 279 'loss.testing': metrics['testing']['loss'], 280 'accuracy.training': metrics['training']['accuracy'], 281 'accuracy.testing': metrics['testing']['accuracy'], 282 }) 283 284 print(iteration, metrics) 285 286 def iterate(self, iteration: int, n_steps: int, n_epochs: int, batch_size: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device, evaluate: bool = False) -> dict[str, float]: 287 ''' 288 Runs a single optimization round with FedProx algorithm on all training clients. 289 290 Parameters 291 ---------- 292 iteration: int 293 Current global round 294 n_steps: int 295 Number of local SGD steps used for optimization on clients 296 n_epochs: int 297 Number of local epochs used for optimization on clients (mutually excludes n_steps) 298 batch_size: int 299 Number of samples in one SGD minibatch 300 step_size: float 301 Learning rate 302 l2_penalty: float 303 Weight of the L2 (Tikhonov) regularization used to penalize local models 304 max_gradient_norm: float 305 Value used to clip the norm of the stochastic gradient during local optimization 306 device: torch.device 307 Accelerator to run the code 308 evaluate: bool 309 Flag that enables evaluation of the update global model on training and testing clients 310 311 Returns 312 ------- 313 dict[str, float] 314 Dictionary of current round's metrics 315 ''' 316 317 indices = list(range(0, len(self.agents['training']))) 318 k = len(self.agents['training']) 319 320 random.shuffle(indices) 321 322 indices = indices[:k] 323 participants = [ self.agents['training'][i] for i in indices ] 324 weights = [ self.weights['training'][i] for i in indices ] 325 326 initial_model = deepcopy(self.model) 327 updates: list[nn.Module] = [ initial_model for _ in self.agents['training'] ] 328 329 if n_steps is not None: 330 for i, participant in tqdm(zip(indices, participants), total = len(participants), desc = 'Optimization on training agents (iteration {})'.format(iteration)): 331 updates[i] = participant.optimize(self.alpha, initial_model, deepcopy(initial_model), n_steps = n_steps, step_size = step_size, l2_penalty = l2_penalty, max_gradient_norm = max_gradient_norm, device = device) 332 else: 333 for i, participant in tqdm(zip(indices, participants), total = len(participants), desc = 'Optimization on training agents (iteration {})'.format(iteration)): 334 updates[i] = participant.multioptimize(self.alpha, initial_model, deepcopy(initial_model), n_epochs = n_epochs, batch_size = batch_size, step_size = step_size, l2_penalty = l2_penalty, max_gradient_norm = max_gradient_norm, device = device) 335 336 self.average(updates, weights = weights) 337 338 if not evaluate: 339 return {} 340 341 return self.evaluate(iteration, device)
14class Agent(fedavg.Agent): 15 ''' 16 An agent (client) uses the FedProx scheme to optimize a shared model on its local subset. 17 ''' 18 19 def __init__(self, subset: utils.FederatedSubset): 20 ''' 21 Initializes the agent with a local `subset` of data samples and labels. 22 23 Parameters 24 ---------- 25 subset: utils.FederatedSubset 26 Subset of data samples and labels 27 ''' 28 29 self.subset = subset 30 31 def step(self, alpha: float, initial: torch.nn.Module, model: torch.nn.Module, x: torch.Tensor, y: torch.Tensor, optimizer: optim.Optimizer, max_gradient_norm: float): 32 ''' 33 Performs an optimization step on `model` using minibatch (`x`, `y`) accounting for the proximal term weighted by `alpha`. 34 35 Parameters 36 ---------- 37 alpha: float 38 Weight of the proximal regularization term in FedProx 39 initial: torch.nn.Module 40 Model received at the beginning of the round from the server 41 model: torch.nn.Module 42 Model that is optimized locally 43 x: torch.Tensor 44 Data samples in the minibatch 45 y: torch.Tensor 46 Data labels in the minibatch 47 optimizer: optim.Optimizer 48 Gradient-based optimizer 49 max_gradient_norm: float 50 Value used to clip the norm of the stochastic gradient 51 ''' 52 53 prediction = model(x) 54 55 loss = nn.functional.cross_entropy(prediction, y) + alpha * self.proxterm(initial, model) 56 57 loss.backward() 58 59 torch.nn.utils.clip_grad_norm_(model.parameters(), max_gradient_norm, error_if_nonfinite = True) 60 61 optimizer.step() 62 optimizer.zero_grad() 63 64 def optimize(self, alpha: float, initial: torch.nn.Module, model: torch.nn.Module, n_steps: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device): 65 ''' 66 Runs `n_steps` stochastic gradient descent steps including the `alpha`-weighted proximal term on the local dataset (one step for each minibatch). 67 68 Parameters 69 ---------- 70 alpha: float 71 Weight of the proximal regularization term in FedProx 72 initial: torch.nn.Module 73 Model received at the beginning of the round from the server 74 model: torch.nn.Module 75 Model that is locally optimized 76 n_steps: int 77 Number of local SGD steps, i.e. number of minibatches 78 step_size: float 79 Step size or learning rate 80 l2_penalty: float 81 Weight of L2 (Tikhonov) regularization term 82 max_gradient_norm: float 83 Value used to clip the norm of the stochastic gradient 84 device: torch.device 85 Accelerator to run the code 86 ''' 87 88 loader = data.DataLoader(self.subset, batch_size = len(self.subset) // n_steps) 89 optimizer = optim.SGD(model.parameters(), lr = step_size, weight_decay = l2_penalty) 90 91 model.train() 92 93 for x, y in loader: 94 self.step(alpha, initial, model, x.to(device), y.to(device), optimizer = optimizer, max_gradient_norm = max_gradient_norm) 95 96 return model 97 98 def multioptimize(self, alpha: float, initial: torch.nn.Module, model: torch.nn.Module, n_epochs: int, batch_size: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device): 99 ''' 100 Runs `n_epochs` stochastic gradient descent epochs including the `alpha`-weighted proximal term on the local dataset. 101 102 Parameters 103 ---------- 104 alpha: float 105 Weight of the proximal regularization term in FedProx 106 initial: torch.nn.Module 107 Model received at the beginning of the round from the server 108 model: torch.nn.Module 109 Model that is locally optimized 110 n_epochs: int 111 Number of local epochs to pass over the entire local dataset 112 step_size: float 113 Step size or learning rate 114 l2_penalty: float 115 Weight of L2 (Tikhonov) regularization term 116 max_gradient_norm: float 117 Value used to clip the norm of the stochastic gradient 118 device: torch.device 119 Accelerator to run the code 120 121 Note 122 ---- 123 Differently from `optimize(...)`, each epoch corresponds to passing over the entire dataset using SGD. 124 ''' 125 126 loader = data.DataLoader(self.subset, batch_size = batch_size) 127 optimizer = optim.SGD(model.parameters(), lr = step_size, weight_decay = l2_penalty) 128 129 model.train() 130 131 for _ in range(n_epochs): 132 for x, y in loader: 133 self.step(alpha, initial, model, x.to(device), y.to(device), optimizer = optimizer, max_gradient_norm = max_gradient_norm) 134 135 return model 136 137 def evaluate(self, model: torch.nn.Module, device: torch.device) -> tuple[float, float]: 138 ''' 139 Evaluate the `model` by computing the average sample loss and accuracy. 140 141 Parameters 142 ---------- 143 model: torch.nn.Module 144 Model that is locally optimized 145 device: torch.device 146 Accelerator to run the code 147 148 Returns 149 ------- 150 tuple[float, float] 151 Tuple of average sample loss and accuracy on the local dataset 152 ''' 153 154 loader = data.DataLoader(self.subset, batch_size = len(self.subset)) 155 x, y = next(iter(loader)) 156 x = x.to(device) 157 y = y.to(device) 158 159 model.eval() 160 161 with torch.no_grad(): 162 prediction = model(x) 163 loss = nn.functional.cross_entropy(prediction, y) 164 accuracy = torch.sum(torch.argmax(prediction, dim = 1) == y) 165 return loss.item(), accuracy.item() 166 167 def proxterm(self, initial: torch.nn.Module, model: torch.nn.Module) -> torch.Tensor: 168 ''' 169 Computes the `alpha`-weighted proximal term as squared difference between `initial` model and locally optimized `model`. 170 171 Parameters 172 ---------- 173 initial: torch.nn.Module 174 Model received at the beginning of the round from the server 175 model: torch.nn.Module 176 Model that is locally optimized 177 178 Returns 179 ------- 180 torch.Tensor 181 Squared difference between `initial` model and locally optimized `model` 182 ''' 183 184 return torch.sum(torch.tensor([ 185 torch.square(model.get_parameter(name) - initial.state_dict()[name].detach()).sum() 186 for name in model.state_dict().keys() 187 ], requires_grad = True))
An agent (client) uses the FedProx scheme to optimize a shared model on its local subset.
19 def __init__(self, subset: utils.FederatedSubset): 20 ''' 21 Initializes the agent with a local `subset` of data samples and labels. 22 23 Parameters 24 ---------- 25 subset: utils.FederatedSubset 26 Subset of data samples and labels 27 ''' 28 29 self.subset = subset
Initializes the agent with a local subset of data samples and labels.
Parameters
- subset (utils.FederatedSubset): Subset of data samples and labels
31 def step(self, alpha: float, initial: torch.nn.Module, model: torch.nn.Module, x: torch.Tensor, y: torch.Tensor, optimizer: optim.Optimizer, max_gradient_norm: float): 32 ''' 33 Performs an optimization step on `model` using minibatch (`x`, `y`) accounting for the proximal term weighted by `alpha`. 34 35 Parameters 36 ---------- 37 alpha: float 38 Weight of the proximal regularization term in FedProx 39 initial: torch.nn.Module 40 Model received at the beginning of the round from the server 41 model: torch.nn.Module 42 Model that is optimized locally 43 x: torch.Tensor 44 Data samples in the minibatch 45 y: torch.Tensor 46 Data labels in the minibatch 47 optimizer: optim.Optimizer 48 Gradient-based optimizer 49 max_gradient_norm: float 50 Value used to clip the norm of the stochastic gradient 51 ''' 52 53 prediction = model(x) 54 55 loss = nn.functional.cross_entropy(prediction, y) + alpha * self.proxterm(initial, model) 56 57 loss.backward() 58 59 torch.nn.utils.clip_grad_norm_(model.parameters(), max_gradient_norm, error_if_nonfinite = True) 60 61 optimizer.step() 62 optimizer.zero_grad()
Performs an optimization step on model using minibatch (x, y) accounting for the proximal term weighted by alpha.
Parameters
- alpha (float): Weight of the proximal regularization term in FedProx
- initial (torch.nn.Module): Model received at the beginning of the round from the server
- model (torch.nn.Module): Model that is optimized locally
- x (torch.Tensor): Data samples in the minibatch
- y (torch.Tensor): Data labels in the minibatch
- optimizer (optim.Optimizer): Gradient-based optimizer
- max_gradient_norm (float): Value used to clip the norm of the stochastic gradient
64 def optimize(self, alpha: float, initial: torch.nn.Module, model: torch.nn.Module, n_steps: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device): 65 ''' 66 Runs `n_steps` stochastic gradient descent steps including the `alpha`-weighted proximal term on the local dataset (one step for each minibatch). 67 68 Parameters 69 ---------- 70 alpha: float 71 Weight of the proximal regularization term in FedProx 72 initial: torch.nn.Module 73 Model received at the beginning of the round from the server 74 model: torch.nn.Module 75 Model that is locally optimized 76 n_steps: int 77 Number of local SGD steps, i.e. number of minibatches 78 step_size: float 79 Step size or learning rate 80 l2_penalty: float 81 Weight of L2 (Tikhonov) regularization term 82 max_gradient_norm: float 83 Value used to clip the norm of the stochastic gradient 84 device: torch.device 85 Accelerator to run the code 86 ''' 87 88 loader = data.DataLoader(self.subset, batch_size = len(self.subset) // n_steps) 89 optimizer = optim.SGD(model.parameters(), lr = step_size, weight_decay = l2_penalty) 90 91 model.train() 92 93 for x, y in loader: 94 self.step(alpha, initial, model, x.to(device), y.to(device), optimizer = optimizer, max_gradient_norm = max_gradient_norm) 95 96 return model
Runs n_steps stochastic gradient descent steps including the alpha-weighted proximal term on the local dataset (one step for each minibatch).
Parameters
- alpha (float): Weight of the proximal regularization term in FedProx
- initial (torch.nn.Module): Model received at the beginning of the round from the server
- model (torch.nn.Module): Model that is locally optimized
- n_steps (int): Number of local SGD steps, i.e. number of minibatches
- step_size (float): Step size or learning rate
- l2_penalty (float): Weight of L2 (Tikhonov) regularization term
- max_gradient_norm (float): Value used to clip the norm of the stochastic gradient
- device (torch.device): Accelerator to run the code
98 def multioptimize(self, alpha: float, initial: torch.nn.Module, model: torch.nn.Module, n_epochs: int, batch_size: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device): 99 ''' 100 Runs `n_epochs` stochastic gradient descent epochs including the `alpha`-weighted proximal term on the local dataset. 101 102 Parameters 103 ---------- 104 alpha: float 105 Weight of the proximal regularization term in FedProx 106 initial: torch.nn.Module 107 Model received at the beginning of the round from the server 108 model: torch.nn.Module 109 Model that is locally optimized 110 n_epochs: int 111 Number of local epochs to pass over the entire local dataset 112 step_size: float 113 Step size or learning rate 114 l2_penalty: float 115 Weight of L2 (Tikhonov) regularization term 116 max_gradient_norm: float 117 Value used to clip the norm of the stochastic gradient 118 device: torch.device 119 Accelerator to run the code 120 121 Note 122 ---- 123 Differently from `optimize(...)`, each epoch corresponds to passing over the entire dataset using SGD. 124 ''' 125 126 loader = data.DataLoader(self.subset, batch_size = batch_size) 127 optimizer = optim.SGD(model.parameters(), lr = step_size, weight_decay = l2_penalty) 128 129 model.train() 130 131 for _ in range(n_epochs): 132 for x, y in loader: 133 self.step(alpha, initial, model, x.to(device), y.to(device), optimizer = optimizer, max_gradient_norm = max_gradient_norm) 134 135 return model
Runs n_epochs stochastic gradient descent epochs including the alpha-weighted proximal term on the local dataset.
Parameters
- alpha (float): Weight of the proximal regularization term in FedProx
- initial (torch.nn.Module): Model received at the beginning of the round from the server
- model (torch.nn.Module): Model that is locally optimized
- n_epochs (int): Number of local epochs to pass over the entire local dataset
- step_size (float): Step size or learning rate
- l2_penalty (float): Weight of L2 (Tikhonov) regularization term
- max_gradient_norm (float): Value used to clip the norm of the stochastic gradient
- device (torch.device): Accelerator to run the code
Note
Differently from optimize(...), each epoch corresponds to passing over the entire dataset using SGD.
137 def evaluate(self, model: torch.nn.Module, device: torch.device) -> tuple[float, float]: 138 ''' 139 Evaluate the `model` by computing the average sample loss and accuracy. 140 141 Parameters 142 ---------- 143 model: torch.nn.Module 144 Model that is locally optimized 145 device: torch.device 146 Accelerator to run the code 147 148 Returns 149 ------- 150 tuple[float, float] 151 Tuple of average sample loss and accuracy on the local dataset 152 ''' 153 154 loader = data.DataLoader(self.subset, batch_size = len(self.subset)) 155 x, y = next(iter(loader)) 156 x = x.to(device) 157 y = y.to(device) 158 159 model.eval() 160 161 with torch.no_grad(): 162 prediction = model(x) 163 loss = nn.functional.cross_entropy(prediction, y) 164 accuracy = torch.sum(torch.argmax(prediction, dim = 1) == y) 165 return loss.item(), accuracy.item()
Evaluate the model by computing the average sample loss and accuracy.
Parameters
- model (torch.nn.Module): Model that is locally optimized
- device (torch.device): Accelerator to run the code
Returns
- tuple[float, float]: Tuple of average sample loss and accuracy on the local dataset
167 def proxterm(self, initial: torch.nn.Module, model: torch.nn.Module) -> torch.Tensor: 168 ''' 169 Computes the `alpha`-weighted proximal term as squared difference between `initial` model and locally optimized `model`. 170 171 Parameters 172 ---------- 173 initial: torch.nn.Module 174 Model received at the beginning of the round from the server 175 model: torch.nn.Module 176 Model that is locally optimized 177 178 Returns 179 ------- 180 torch.Tensor 181 Squared difference between `initial` model and locally optimized `model` 182 ''' 183 184 return torch.sum(torch.tensor([ 185 torch.square(model.get_parameter(name) - initial.state_dict()[name].detach()).sum() 186 for name in model.state_dict().keys() 187 ], requires_grad = True))
Computes the alpha-weighted proximal term as squared difference between initial model and locally optimized model.
Parameters
- initial (torch.nn.Module): Model received at the beginning of the round from the server
- model (torch.nn.Module): Model that is locally optimized
Returns
- torch.Tensor: Squared difference between
initialmodel and locally optimizedmodel
191class Coordinator(fedavg.Coordinator): 192 ''' 193 This class represents a centralized server coordinating the training of a shared model across multiple agents (i.e. clients). 194 195 Note 196 ---- 197 The agents locally update their models using the FedProx optimization scheme. 198 ''' 199 200 def __init__( 201 self, 202 alpha: float, 203 model: torch.nn.Module, 204 datasets: dict[str, list[utils.FederatedSubset]], 205 scheme: WeightingScheme = None, 206 logger: Logger = Logger.default() 207 ): 208 ''' 209 Constructs the centralized coordinator, i.e. server, in the federated learning simulation. 210 211 Parameters 212 ---------- 213 alpha: float 214 Weight of the proximal term used by training agents while running the optimization algorithm FedProx 215 model: torch.nn.Module 216 Initial shared model 217 datasets: dict[str, list[utils.FederatedSubset]] 218 Training clients' subsets ('training') and testing clients' subsets ('testing') 219 scheme: WeightingScheme 220 Aggregation scheme to weight local updates from clients 221 logger: Logger 222 Logger instance to save progress during the simulation 223 ''' 224 225 assert alpha > 0 226 227 self.alpha = alpha 228 self.datasets = datasets 229 self.model = model 230 self.agents = { 231 group: [ Agent(subset) for subset in dataset ] for group, dataset in datasets.items() 232 } 233 self.weights = scheme.weights() 234 self.logger = logger 235 236 def run(self, n_iterations: int, n_steps: int = None, n_epochs = None, batch_size: int = 32, step_size: float = 1e-3, step_size_diminishing: bool = False, l2_penalty: float = 1e-4, max_gradient_norm: float = 1.0, device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')): 237 ''' 238 Runs `n_iterations` optimization (with algorithm FedProx) and evaluation rounds on training clients. 239 240 Parameters 241 ---------- 242 n_iterations: int 243 Number of global rounds 244 n_steps: int 245 Number of local SGD steps used for optimization on clients 246 n_epochs: int 247 Number of local epochs used for optimization on clients (mutually excludes n_steps) 248 batch_size: int 249 Number of samples in one SGD minibatch 250 step_size: float 251 Learning rate 252 step_size_diminishing: bool 253 This enables diminishing the step size linearly in time 254 l2_penalty: float 255 Weight of the L2 (Tikhonov) regularization used to penalize local models 256 max_gradient_norm: float 257 Value used to clip the norm of the stochastic gradient during local optimization 258 device: torch.device 259 Accelerator to run the code 260 evaluate: bool 261 Flag that enables evaluation of the update global model on training and testing clients 262 263 Note 264 ---- 265 Runs `n_iterations` times function `iterate(...)`. 266 ''' 267 268 assert n_steps is not None or n_epochs is not None 269 270 self.model = self.model.to(device) 271 self.model.compile() 272 273 for iteration in range(n_iterations): 274 step_size_updated = step_size if not step_size_diminishing else step_size / (iteration + 1) 275 metrics = self.iterate(iteration, n_steps, n_epochs, batch_size, step_size_updated, l2_penalty, max_gradient_norm, device, evaluate = True) 276 277 self.logger.log({ 278 'step': iteration, 279 'loss.training': metrics['training']['loss'], 280 'loss.testing': metrics['testing']['loss'], 281 'accuracy.training': metrics['training']['accuracy'], 282 'accuracy.testing': metrics['testing']['accuracy'], 283 }) 284 285 print(iteration, metrics) 286 287 def iterate(self, iteration: int, n_steps: int, n_epochs: int, batch_size: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device, evaluate: bool = False) -> dict[str, float]: 288 ''' 289 Runs a single optimization round with FedProx algorithm on all training clients. 290 291 Parameters 292 ---------- 293 iteration: int 294 Current global round 295 n_steps: int 296 Number of local SGD steps used for optimization on clients 297 n_epochs: int 298 Number of local epochs used for optimization on clients (mutually excludes n_steps) 299 batch_size: int 300 Number of samples in one SGD minibatch 301 step_size: float 302 Learning rate 303 l2_penalty: float 304 Weight of the L2 (Tikhonov) regularization used to penalize local models 305 max_gradient_norm: float 306 Value used to clip the norm of the stochastic gradient during local optimization 307 device: torch.device 308 Accelerator to run the code 309 evaluate: bool 310 Flag that enables evaluation of the update global model on training and testing clients 311 312 Returns 313 ------- 314 dict[str, float] 315 Dictionary of current round's metrics 316 ''' 317 318 indices = list(range(0, len(self.agents['training']))) 319 k = len(self.agents['training']) 320 321 random.shuffle(indices) 322 323 indices = indices[:k] 324 participants = [ self.agents['training'][i] for i in indices ] 325 weights = [ self.weights['training'][i] for i in indices ] 326 327 initial_model = deepcopy(self.model) 328 updates: list[nn.Module] = [ initial_model for _ in self.agents['training'] ] 329 330 if n_steps is not None: 331 for i, participant in tqdm(zip(indices, participants), total = len(participants), desc = 'Optimization on training agents (iteration {})'.format(iteration)): 332 updates[i] = participant.optimize(self.alpha, initial_model, deepcopy(initial_model), n_steps = n_steps, step_size = step_size, l2_penalty = l2_penalty, max_gradient_norm = max_gradient_norm, device = device) 333 else: 334 for i, participant in tqdm(zip(indices, participants), total = len(participants), desc = 'Optimization on training agents (iteration {})'.format(iteration)): 335 updates[i] = participant.multioptimize(self.alpha, initial_model, deepcopy(initial_model), n_epochs = n_epochs, batch_size = batch_size, step_size = step_size, l2_penalty = l2_penalty, max_gradient_norm = max_gradient_norm, device = device) 336 337 self.average(updates, weights = weights) 338 339 if not evaluate: 340 return {} 341 342 return self.evaluate(iteration, device)
This class represents a centralized server coordinating the training of a shared model across multiple agents (i.e. clients).
Note
The agents locally update their models using the FedProx optimization scheme.
200 def __init__( 201 self, 202 alpha: float, 203 model: torch.nn.Module, 204 datasets: dict[str, list[utils.FederatedSubset]], 205 scheme: WeightingScheme = None, 206 logger: Logger = Logger.default() 207 ): 208 ''' 209 Constructs the centralized coordinator, i.e. server, in the federated learning simulation. 210 211 Parameters 212 ---------- 213 alpha: float 214 Weight of the proximal term used by training agents while running the optimization algorithm FedProx 215 model: torch.nn.Module 216 Initial shared model 217 datasets: dict[str, list[utils.FederatedSubset]] 218 Training clients' subsets ('training') and testing clients' subsets ('testing') 219 scheme: WeightingScheme 220 Aggregation scheme to weight local updates from clients 221 logger: Logger 222 Logger instance to save progress during the simulation 223 ''' 224 225 assert alpha > 0 226 227 self.alpha = alpha 228 self.datasets = datasets 229 self.model = model 230 self.agents = { 231 group: [ Agent(subset) for subset in dataset ] for group, dataset in datasets.items() 232 } 233 self.weights = scheme.weights() 234 self.logger = logger
Constructs the centralized coordinator, i.e. server, in the federated learning simulation.
Parameters
- alpha (float): Weight of the proximal term used by training agents while running the optimization algorithm FedProx
- model (torch.nn.Module): Initial shared model
- datasets (dict[str, list[utils.FederatedSubset]]): Training clients' subsets ('training') and testing clients' subsets ('testing')
- scheme (WeightingScheme): Aggregation scheme to weight local updates from clients
- logger (Logger): Logger instance to save progress during the simulation
236 def run(self, n_iterations: int, n_steps: int = None, n_epochs = None, batch_size: int = 32, step_size: float = 1e-3, step_size_diminishing: bool = False, l2_penalty: float = 1e-4, max_gradient_norm: float = 1.0, device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')): 237 ''' 238 Runs `n_iterations` optimization (with algorithm FedProx) and evaluation rounds on training clients. 239 240 Parameters 241 ---------- 242 n_iterations: int 243 Number of global rounds 244 n_steps: int 245 Number of local SGD steps used for optimization on clients 246 n_epochs: int 247 Number of local epochs used for optimization on clients (mutually excludes n_steps) 248 batch_size: int 249 Number of samples in one SGD minibatch 250 step_size: float 251 Learning rate 252 step_size_diminishing: bool 253 This enables diminishing the step size linearly in time 254 l2_penalty: float 255 Weight of the L2 (Tikhonov) regularization used to penalize local models 256 max_gradient_norm: float 257 Value used to clip the norm of the stochastic gradient during local optimization 258 device: torch.device 259 Accelerator to run the code 260 evaluate: bool 261 Flag that enables evaluation of the update global model on training and testing clients 262 263 Note 264 ---- 265 Runs `n_iterations` times function `iterate(...)`. 266 ''' 267 268 assert n_steps is not None or n_epochs is not None 269 270 self.model = self.model.to(device) 271 self.model.compile() 272 273 for iteration in range(n_iterations): 274 step_size_updated = step_size if not step_size_diminishing else step_size / (iteration + 1) 275 metrics = self.iterate(iteration, n_steps, n_epochs, batch_size, step_size_updated, l2_penalty, max_gradient_norm, device, evaluate = True) 276 277 self.logger.log({ 278 'step': iteration, 279 'loss.training': metrics['training']['loss'], 280 'loss.testing': metrics['testing']['loss'], 281 'accuracy.training': metrics['training']['accuracy'], 282 'accuracy.testing': metrics['testing']['accuracy'], 283 }) 284 285 print(iteration, metrics)
Runs n_iterations optimization (with algorithm FedProx) and evaluation rounds on training clients.
Parameters
- n_iterations (int): Number of global rounds
- n_steps (int): Number of local SGD steps used for optimization on clients
- n_epochs (int): Number of local epochs used for optimization on clients (mutually excludes n_steps)
- batch_size (int): Number of samples in one SGD minibatch
- step_size (float): Learning rate
- step_size_diminishing (bool): This enables diminishing the step size linearly in time
- l2_penalty (float): Weight of the L2 (Tikhonov) regularization used to penalize local models
- max_gradient_norm (float): Value used to clip the norm of the stochastic gradient during local optimization
- device (torch.device): Accelerator to run the code
- evaluate (bool): Flag that enables evaluation of the update global model on training and testing clients
Note
Runs n_iterations times function iterate(...).
287 def iterate(self, iteration: int, n_steps: int, n_epochs: int, batch_size: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device, evaluate: bool = False) -> dict[str, float]: 288 ''' 289 Runs a single optimization round with FedProx algorithm on all training clients. 290 291 Parameters 292 ---------- 293 iteration: int 294 Current global round 295 n_steps: int 296 Number of local SGD steps used for optimization on clients 297 n_epochs: int 298 Number of local epochs used for optimization on clients (mutually excludes n_steps) 299 batch_size: int 300 Number of samples in one SGD minibatch 301 step_size: float 302 Learning rate 303 l2_penalty: float 304 Weight of the L2 (Tikhonov) regularization used to penalize local models 305 max_gradient_norm: float 306 Value used to clip the norm of the stochastic gradient during local optimization 307 device: torch.device 308 Accelerator to run the code 309 evaluate: bool 310 Flag that enables evaluation of the update global model on training and testing clients 311 312 Returns 313 ------- 314 dict[str, float] 315 Dictionary of current round's metrics 316 ''' 317 318 indices = list(range(0, len(self.agents['training']))) 319 k = len(self.agents['training']) 320 321 random.shuffle(indices) 322 323 indices = indices[:k] 324 participants = [ self.agents['training'][i] for i in indices ] 325 weights = [ self.weights['training'][i] for i in indices ] 326 327 initial_model = deepcopy(self.model) 328 updates: list[nn.Module] = [ initial_model for _ in self.agents['training'] ] 329 330 if n_steps is not None: 331 for i, participant in tqdm(zip(indices, participants), total = len(participants), desc = 'Optimization on training agents (iteration {})'.format(iteration)): 332 updates[i] = participant.optimize(self.alpha, initial_model, deepcopy(initial_model), n_steps = n_steps, step_size = step_size, l2_penalty = l2_penalty, max_gradient_norm = max_gradient_norm, device = device) 333 else: 334 for i, participant in tqdm(zip(indices, participants), total = len(participants), desc = 'Optimization on training agents (iteration {})'.format(iteration)): 335 updates[i] = participant.multioptimize(self.alpha, initial_model, deepcopy(initial_model), n_epochs = n_epochs, batch_size = batch_size, step_size = step_size, l2_penalty = l2_penalty, max_gradient_norm = max_gradient_norm, device = device) 336 337 self.average(updates, weights = weights) 338 339 if not evaluate: 340 return {} 341 342 return self.evaluate(iteration, device)
Runs a single optimization round with FedProx algorithm on all training clients.
Parameters
- iteration (int): Current global round
- n_steps (int): Number of local SGD steps used for optimization on clients
- n_epochs (int): Number of local epochs used for optimization on clients (mutually excludes n_steps)
- batch_size (int): Number of samples in one SGD minibatch
- step_size (float): Learning rate
- l2_penalty (float): Weight of the L2 (Tikhonov) regularization used to penalize local models
- max_gradient_norm (float): Value used to clip the norm of the stochastic gradient during local optimization
- device (torch.device): Accelerator to run the code
- evaluate (bool): Flag that enables evaluation of the update global model on training and testing clients
Returns
- dict[str, float]: Dictionary of current round's metrics