fedbox.optimization.fedavg

  1from ..datasets import utils
  2from .utils import WeightingScheme, Logger
  3
  4from collections import OrderedDict
  5from copy import deepcopy
  6import random
  7import torch
  8import torch.nn as nn
  9import torch.optim as optim
 10import torch.utils.data as data
 11from tqdm import tqdm
 12
 13
 14class Agent:
 15    '''
 16    An agent (client) uses the FedAvg scheme to optimize a shared model on its local subset.
 17    '''
 18
 19    def __init__(self, subset: utils.FederatedSubset):
 20        '''
 21        Initializes the agent with a local `subset` of data samples and labels.
 22
 23        Parameters
 24        ----------
 25        subset: utils.FederatedSubset
 26            Subset of data samples and labels
 27        '''
 28        
 29        self.subset = subset
 30
 31    def step(self, model: torch.nn.Module, x: torch.Tensor, y: torch.Tensor, optimizer: optim.Optimizer, max_gradient_norm: float):
 32        '''
 33        Performs an optimization step on `model` using minibatch (`x`, `y`).
 34
 35        Parameters
 36        ----------
 37        model: torch.nn.Module
 38            Model that is optimized locally
 39        x: torch.Tensor
 40            Data samples in the minibatch
 41        y: torch.Tensor
 42            Data labels in the minibatch
 43        optimizer: optim.Optimizer
 44            Gradient-based optimizer
 45        max_gradient_norm: float
 46            Value used to clip the norm of the stochastic gradient
 47        '''
 48        
 49        prediction = model(x)
 50        
 51        loss = nn.functional.cross_entropy(prediction, y)
 52        loss.backward()
 53
 54        torch.nn.utils.clip_grad_norm_(model.parameters(), max_gradient_norm, error_if_nonfinite = True)
 55
 56        optimizer.step()
 57        optimizer.zero_grad()
 58
 59    def optimize(self, model: torch.nn.Module, n_steps: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device):
 60        '''
 61        Runs `n_steps` stochastic gradient descent steps on the local dataset (one step for each minibatch).
 62
 63        Parameters
 64        ----------
 65        model: torch.nn.Module
 66            Model that is locally optimized
 67        n_steps: int
 68            Number of local SGD steps, i.e. number of minibatches
 69        step_size: float
 70            Step size or learning rate
 71        l2_penalty: float
 72            Weight of L2 (Tikhonov) regularization term
 73        max_gradient_norm: float
 74            Value used to clip the norm of the stochastic gradient
 75        device: torch.device
 76            Accelerator to run the code
 77        '''
 78        
 79        loader = data.DataLoader(self.subset, batch_size = len(self.subset) // n_steps)
 80        optimizer = optim.SGD(model.parameters(), lr = step_size, weight_decay = l2_penalty)
 81
 82        model.train()
 83
 84        for x, y in loader:
 85            self.step(model, x.to(device), y.to(device), optimizer = optimizer, max_gradient_norm = max_gradient_norm)
 86
 87        return model
 88
 89    def multioptimize(self, model: torch.nn.Module, n_epochs: int, batch_size: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device):
 90        '''
 91        Runs `n_epochs` stochastic gradient descent epochs on the local dataset.
 92
 93        Parameters
 94        ----------
 95        model: torch.nn.Module
 96            Model that is locally optimized
 97        n_epochs: int
 98            Number of local epochs to pass over the entire local dataset
 99        step_size: float
100            Step size or learning rate
101        l2_penalty: float
102            Weight of L2 (Tikhonov) regularization term
103        max_gradient_norm: float
104            Value used to clip the norm of the stochastic gradient
105        device: torch.device
106            Accelerator to run the code
107
108        Note
109        ----
110        Differently from `optimize(...)`, each epoch corresponds to passing over the entire dataset using SGD.
111        '''
112
113        loader = data.DataLoader(self.subset, batch_size = batch_size)
114        optimizer = optim.SGD(model.parameters(), lr = step_size, weight_decay = l2_penalty)
115
116        model.train()
117
118        for _ in range(n_epochs):
119            for x, y in loader:
120                self.step(model, x.to(device), y.to(device), optimizer = optimizer, max_gradient_norm = max_gradient_norm)
121
122        return model
123
124    def evaluate(self, model: torch.nn.Module, device: torch.device) -> tuple[float, float]:
125        '''
126        Evaluate the `model` by computing the average sample loss and accuracy. 
127
128        Parameters
129        ----------
130        model: torch.nn.Module
131            Model that is locally optimized
132        device: torch.device
133            Accelerator to run the code
134
135        Returns
136        -------
137        tuple[float, float]
138            Tuple of average sample loss and accuracy on the local dataset
139        '''
140        
141        loader = data.DataLoader(self.subset, batch_size = len(self.subset))
142        x, y = next(iter(loader))
143        x = x.to(device)
144        y = y.to(device)
145        
146        model.eval()
147
148        with torch.no_grad():
149            prediction = model(x)
150            loss = nn.functional.cross_entropy(prediction, y)
151            accuracy = torch.sum(torch.argmax(prediction, dim = 1) == y)
152            return loss.item(), accuracy.item()
153
154
155class Coordinator:
156    '''
157    This class represents a centralized server coordinating the training of a shared model across multiple agents (i.e. clients).
158    
159    Note
160    ----
161    The agents locally update their models using the FedAvg optimization scheme.
162    '''
163
164    def __init__(
165        self,
166        model: torch.nn.Module,
167        datasets: dict[str, list[utils.FederatedSubset]],
168        scheme: WeightingScheme = None,
169        logger: Logger = Logger.default()
170    ):
171        '''
172        Constructs the centralized coordinator, i.e. server, in the federated learning simulation.
173
174        Parameters
175        ----------
176        model: torch.nn.Module
177            Initial shared model
178        datasets: dict[str, list[utils.FederatedSubset]]
179            Training clients' subsets ('training') and testing clients' subsets ('testing')
180        scheme: WeightingScheme
181            Aggregation scheme to weight local updates from clients
182        logger: Logger
183            Logger instance to save progress during the simulation
184        '''
185        
186        self.datasets = datasets
187        self.model = model
188        self.agents = {
189            group: [ Agent(subset) for subset in dataset ] for group, dataset in datasets.items() 
190        }
191        self.weights = scheme.weights()
192        self.logger = logger
193
194    def run(self, n_iterations: int, n_steps: int = None, n_epochs = None, batch_size: int = 32, step_size: float = 1e-3, step_size_diminishing: bool = False, l2_penalty: float = 1e-4, max_gradient_norm: float = 1.0, device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')):
195        '''
196        Runs `n_iterations` optimization (with algorithm FedAvg) and evaluation rounds on training clients.
197
198        Parameters
199        ----------
200        n_iterations: int
201            Number of global rounds
202        n_steps: int
203            Number of local SGD steps used for optimization on clients
204        n_epochs: int
205            Number of local epochs used for optimization on clients (mutually excludes n_steps)
206        batch_size: int
207            Number of samples in one SGD minibatch
208        step_size: float
209            Learning rate
210        step_size_diminishing: bool
211            This enables diminishing the step size linearly in time
212        l2_penalty: float
213            Weight of the L2 (Tikhonov) regularization used to penalize local models
214        max_gradient_norm: float
215            Value used to clip the norm of the stochastic gradient during local optimization
216        device: torch.device
217            Accelerator to run the code
218        evaluate: bool
219            Flag that enables evaluation of the update global model on training and testing clients
220
221        Note
222        ----
223        Runs `n_iterations` times function `iterate(...)`.
224        '''
225        
226        assert n_steps is not None or n_epochs is not None
227        
228        self.model = self.model.to(device)
229        self.model.compile()
230
231        for iteration in range(n_iterations):
232            step_size_updated = step_size if not step_size_diminishing else step_size / (iteration + 1)
233            metrics = self.iterate(iteration, n_steps, n_epochs, batch_size, step_size_updated, l2_penalty, max_gradient_norm, device, evaluate = True)
234            
235            self.logger.log({
236                'step': iteration,
237                'loss.training': metrics['training']['loss'],
238                'loss.testing': metrics['testing']['loss'],
239                'accuracy.training': metrics['training']['accuracy'],
240                'accuracy.testing': metrics['testing']['accuracy'],
241            })
242
243            print(iteration, metrics)
244
245    def iterate(self, iteration: int, n_steps: int, n_epochs: int, batch_size: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device, evaluate: bool = False) -> dict[str, float]:
246        '''
247        Runs a single optimization round with FedAvg algorithm on all training clients.
248
249        Parameters
250        ----------
251        iteration: int
252            Current global round
253        n_steps: int
254            Number of local SGD steps used for optimization on clients
255        n_epochs: int
256            Number of local epochs used for optimization on clients (mutually excludes n_steps)
257        batch_size: int
258            Number of samples in one SGD minibatch
259        step_size: float
260            Learning rate
261        l2_penalty: float
262            Weight of the L2 (Tikhonov) regularization used to penalize local models
263        max_gradient_norm: float
264            Value used to clip the norm of the stochastic gradient during local optimization
265        device: torch.device
266            Accelerator to run the code
267        evaluate: bool
268            Flag that enables evaluation of the update global model on training and testing clients
269
270        Returns
271        -------
272        dict[str, float]
273            Dictionary of current round's metrics
274        '''
275
276        indices = list(range(0, len(self.agents['training'])))
277        k = len(self.agents['training'])
278        
279        random.shuffle(indices)
280        
281        indices = indices[:k]
282        participants = [ self.agents['training'][i] for i in indices ]
283        weights = [ self.weights['training'][i] for i in indices ]
284
285        initial_model = deepcopy(self.model)
286        updates: list[nn.Module] = [ initial_model for _ in self.agents['training'] ]
287
288        if n_steps is not None:
289            for i, participant in tqdm(zip(indices, participants), total = len(participants), desc = 'Optimization on training agents (iteration {})'.format(iteration)):
290                updates[i] = participant.optimize(deepcopy(initial_model), n_steps = n_steps, step_size = step_size, l2_penalty = l2_penalty, max_gradient_norm = max_gradient_norm, device = device)
291        else:
292            for i, participant in tqdm(zip(indices, participants), total = len(participants), desc = 'Optimization on training agents (iteration {})'.format(iteration)):
293                updates[i] = participant.multioptimize(deepcopy(initial_model), n_epochs = n_epochs, batch_size = batch_size, step_size = step_size, l2_penalty = l2_penalty, max_gradient_norm = max_gradient_norm, device = device)
294
295        self.average(updates, weights = weights)
296
297        if not evaluate:
298            return {}
299        
300        return self.evaluate(iteration, device)
301        
302    def evaluate(self, iteration: int, device: torch.device):
303        '''
304        Computes average sample accuracy and loss of the global model on training and testing clients during current `iteration`.
305
306        Parameters
307        ----------
308        iteration: int
309            Current global round of the simulation
310        device: torch.device
311            Accelerator
312        '''
313        
314        metrics = { 'training': { 'loss': 0, 'accuracy': 0 }, 'testing': { 'loss': 0, 'accuracy': 0 } }
315
316        for agent, weight in tqdm(zip(self.agents['training'], self.weights['training']), total = len(self.agents['training']), desc = 'Evaluation on training agents (iteration {})'.format(iteration)):
317            loss, accuracy = agent.evaluate(self.model, device = device)
318            metrics['training']['loss'] = metrics['training']['loss'] + weight * loss
319            metrics['training']['accuracy'] = metrics['training']['accuracy'] + accuracy
320
321        for agent, weight in tqdm(zip(self.agents['testing'], self.weights['testing']), total = len(self.agents['testing']), desc = 'Evaluation on testing agents (iteration {})'.format(iteration)):
322            loss, accuracy = agent.evaluate(self.model, device = device)
323            metrics['testing']['loss'] = metrics['testing']['loss'] + weight * loss
324            metrics['testing']['accuracy'] = metrics['testing']['accuracy'] + accuracy
325
326        metrics['training']['loss'] /= sum(self.weights['training'])
327        metrics['testing']['loss'] /= sum(self.weights['testing'])
328        
329        metrics['training']['accuracy'] /= sum([ len(agent.subset) for agent in self.agents['training'] ])
330        metrics['testing']['accuracy'] /= sum([ len(agent.subset) for agent in self.agents['testing'] ])
331
332        return metrics
333
334    def average(self, updates: list[torch.nn.Module], weights: list[float]):
335        '''
336        Averages clients' `updates` weighted by aggreation `weights` into the shared model.
337
338        Parameters
339        ----------
340        updates: list[torch.nn.Module]
341            Locally updated clients' models
342        weights: list[float]
343            Aggregation weights (one for each client)
344        '''
345
346        total = sum(weights)
347        self.model.load_state_dict(OrderedDict([
348            (
349                name, 
350                torch.stack([ weight * update.state_dict()[name] for update, weight in zip(updates, weights) ]).sum(dim = 0) / total
351            )
352            for name in self.model.state_dict().keys()
353        ]))
class Agent:
 15class Agent:
 16    '''
 17    An agent (client) uses the FedAvg scheme to optimize a shared model on its local subset.
 18    '''
 19
 20    def __init__(self, subset: utils.FederatedSubset):
 21        '''
 22        Initializes the agent with a local `subset` of data samples and labels.
 23
 24        Parameters
 25        ----------
 26        subset: utils.FederatedSubset
 27            Subset of data samples and labels
 28        '''
 29        
 30        self.subset = subset
 31
 32    def step(self, model: torch.nn.Module, x: torch.Tensor, y: torch.Tensor, optimizer: optim.Optimizer, max_gradient_norm: float):
 33        '''
 34        Performs an optimization step on `model` using minibatch (`x`, `y`).
 35
 36        Parameters
 37        ----------
 38        model: torch.nn.Module
 39            Model that is optimized locally
 40        x: torch.Tensor
 41            Data samples in the minibatch
 42        y: torch.Tensor
 43            Data labels in the minibatch
 44        optimizer: optim.Optimizer
 45            Gradient-based optimizer
 46        max_gradient_norm: float
 47            Value used to clip the norm of the stochastic gradient
 48        '''
 49        
 50        prediction = model(x)
 51        
 52        loss = nn.functional.cross_entropy(prediction, y)
 53        loss.backward()
 54
 55        torch.nn.utils.clip_grad_norm_(model.parameters(), max_gradient_norm, error_if_nonfinite = True)
 56
 57        optimizer.step()
 58        optimizer.zero_grad()
 59
 60    def optimize(self, model: torch.nn.Module, n_steps: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device):
 61        '''
 62        Runs `n_steps` stochastic gradient descent steps on the local dataset (one step for each minibatch).
 63
 64        Parameters
 65        ----------
 66        model: torch.nn.Module
 67            Model that is locally optimized
 68        n_steps: int
 69            Number of local SGD steps, i.e. number of minibatches
 70        step_size: float
 71            Step size or learning rate
 72        l2_penalty: float
 73            Weight of L2 (Tikhonov) regularization term
 74        max_gradient_norm: float
 75            Value used to clip the norm of the stochastic gradient
 76        device: torch.device
 77            Accelerator to run the code
 78        '''
 79        
 80        loader = data.DataLoader(self.subset, batch_size = len(self.subset) // n_steps)
 81        optimizer = optim.SGD(model.parameters(), lr = step_size, weight_decay = l2_penalty)
 82
 83        model.train()
 84
 85        for x, y in loader:
 86            self.step(model, x.to(device), y.to(device), optimizer = optimizer, max_gradient_norm = max_gradient_norm)
 87
 88        return model
 89
 90    def multioptimize(self, model: torch.nn.Module, n_epochs: int, batch_size: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device):
 91        '''
 92        Runs `n_epochs` stochastic gradient descent epochs on the local dataset.
 93
 94        Parameters
 95        ----------
 96        model: torch.nn.Module
 97            Model that is locally optimized
 98        n_epochs: int
 99            Number of local epochs to pass over the entire local dataset
100        step_size: float
101            Step size or learning rate
102        l2_penalty: float
103            Weight of L2 (Tikhonov) regularization term
104        max_gradient_norm: float
105            Value used to clip the norm of the stochastic gradient
106        device: torch.device
107            Accelerator to run the code
108
109        Note
110        ----
111        Differently from `optimize(...)`, each epoch corresponds to passing over the entire dataset using SGD.
112        '''
113
114        loader = data.DataLoader(self.subset, batch_size = batch_size)
115        optimizer = optim.SGD(model.parameters(), lr = step_size, weight_decay = l2_penalty)
116
117        model.train()
118
119        for _ in range(n_epochs):
120            for x, y in loader:
121                self.step(model, x.to(device), y.to(device), optimizer = optimizer, max_gradient_norm = max_gradient_norm)
122
123        return model
124
125    def evaluate(self, model: torch.nn.Module, device: torch.device) -> tuple[float, float]:
126        '''
127        Evaluate the `model` by computing the average sample loss and accuracy. 
128
129        Parameters
130        ----------
131        model: torch.nn.Module
132            Model that is locally optimized
133        device: torch.device
134            Accelerator to run the code
135
136        Returns
137        -------
138        tuple[float, float]
139            Tuple of average sample loss and accuracy on the local dataset
140        '''
141        
142        loader = data.DataLoader(self.subset, batch_size = len(self.subset))
143        x, y = next(iter(loader))
144        x = x.to(device)
145        y = y.to(device)
146        
147        model.eval()
148
149        with torch.no_grad():
150            prediction = model(x)
151            loss = nn.functional.cross_entropy(prediction, y)
152            accuracy = torch.sum(torch.argmax(prediction, dim = 1) == y)
153            return loss.item(), accuracy.item()

An agent (client) uses the FedAvg scheme to optimize a shared model on its local subset.

Agent(subset: fedbox.datasets.utils.FederatedSubset)
20    def __init__(self, subset: utils.FederatedSubset):
21        '''
22        Initializes the agent with a local `subset` of data samples and labels.
23
24        Parameters
25        ----------
26        subset: utils.FederatedSubset
27            Subset of data samples and labels
28        '''
29        
30        self.subset = subset

Initializes the agent with a local subset of data samples and labels.

Parameters
  • subset (utils.FederatedSubset): Subset of data samples and labels
subset
def step( self, model: torch.nn.modules.module.Module, x: torch.Tensor, y: torch.Tensor, optimizer: torch.optim.optimizer.Optimizer, max_gradient_norm: float):
32    def step(self, model: torch.nn.Module, x: torch.Tensor, y: torch.Tensor, optimizer: optim.Optimizer, max_gradient_norm: float):
33        '''
34        Performs an optimization step on `model` using minibatch (`x`, `y`).
35
36        Parameters
37        ----------
38        model: torch.nn.Module
39            Model that is optimized locally
40        x: torch.Tensor
41            Data samples in the minibatch
42        y: torch.Tensor
43            Data labels in the minibatch
44        optimizer: optim.Optimizer
45            Gradient-based optimizer
46        max_gradient_norm: float
47            Value used to clip the norm of the stochastic gradient
48        '''
49        
50        prediction = model(x)
51        
52        loss = nn.functional.cross_entropy(prediction, y)
53        loss.backward()
54
55        torch.nn.utils.clip_grad_norm_(model.parameters(), max_gradient_norm, error_if_nonfinite = True)
56
57        optimizer.step()
58        optimizer.zero_grad()

Performs an optimization step on model using minibatch (x, y).

Parameters
  • model (torch.nn.Module): Model that is optimized locally
  • x (torch.Tensor): Data samples in the minibatch
  • y (torch.Tensor): Data labels in the minibatch
  • optimizer (optim.Optimizer): Gradient-based optimizer
  • max_gradient_norm (float): Value used to clip the norm of the stochastic gradient
def optimize( self, model: torch.nn.modules.module.Module, n_steps: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device):
60    def optimize(self, model: torch.nn.Module, n_steps: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device):
61        '''
62        Runs `n_steps` stochastic gradient descent steps on the local dataset (one step for each minibatch).
63
64        Parameters
65        ----------
66        model: torch.nn.Module
67            Model that is locally optimized
68        n_steps: int
69            Number of local SGD steps, i.e. number of minibatches
70        step_size: float
71            Step size or learning rate
72        l2_penalty: float
73            Weight of L2 (Tikhonov) regularization term
74        max_gradient_norm: float
75            Value used to clip the norm of the stochastic gradient
76        device: torch.device
77            Accelerator to run the code
78        '''
79        
80        loader = data.DataLoader(self.subset, batch_size = len(self.subset) // n_steps)
81        optimizer = optim.SGD(model.parameters(), lr = step_size, weight_decay = l2_penalty)
82
83        model.train()
84
85        for x, y in loader:
86            self.step(model, x.to(device), y.to(device), optimizer = optimizer, max_gradient_norm = max_gradient_norm)
87
88        return model

Runs n_steps stochastic gradient descent steps on the local dataset (one step for each minibatch).

Parameters
  • model (torch.nn.Module): Model that is locally optimized
  • n_steps (int): Number of local SGD steps, i.e. number of minibatches
  • step_size (float): Step size or learning rate
  • l2_penalty (float): Weight of L2 (Tikhonov) regularization term
  • max_gradient_norm (float): Value used to clip the norm of the stochastic gradient
  • device (torch.device): Accelerator to run the code
def multioptimize( self, model: torch.nn.modules.module.Module, n_epochs: int, batch_size: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device):
 90    def multioptimize(self, model: torch.nn.Module, n_epochs: int, batch_size: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device):
 91        '''
 92        Runs `n_epochs` stochastic gradient descent epochs on the local dataset.
 93
 94        Parameters
 95        ----------
 96        model: torch.nn.Module
 97            Model that is locally optimized
 98        n_epochs: int
 99            Number of local epochs to pass over the entire local dataset
100        step_size: float
101            Step size or learning rate
102        l2_penalty: float
103            Weight of L2 (Tikhonov) regularization term
104        max_gradient_norm: float
105            Value used to clip the norm of the stochastic gradient
106        device: torch.device
107            Accelerator to run the code
108
109        Note
110        ----
111        Differently from `optimize(...)`, each epoch corresponds to passing over the entire dataset using SGD.
112        '''
113
114        loader = data.DataLoader(self.subset, batch_size = batch_size)
115        optimizer = optim.SGD(model.parameters(), lr = step_size, weight_decay = l2_penalty)
116
117        model.train()
118
119        for _ in range(n_epochs):
120            for x, y in loader:
121                self.step(model, x.to(device), y.to(device), optimizer = optimizer, max_gradient_norm = max_gradient_norm)
122
123        return model

Runs n_epochs stochastic gradient descent epochs on the local dataset.

Parameters
  • model (torch.nn.Module): Model that is locally optimized
  • n_epochs (int): Number of local epochs to pass over the entire local dataset
  • step_size (float): Step size or learning rate
  • l2_penalty (float): Weight of L2 (Tikhonov) regularization term
  • max_gradient_norm (float): Value used to clip the norm of the stochastic gradient
  • device (torch.device): Accelerator to run the code
Note

Differently from optimize(...), each epoch corresponds to passing over the entire dataset using SGD.

def evaluate( self, model: torch.nn.modules.module.Module, device: torch.device) -> tuple[float, float]:
125    def evaluate(self, model: torch.nn.Module, device: torch.device) -> tuple[float, float]:
126        '''
127        Evaluate the `model` by computing the average sample loss and accuracy. 
128
129        Parameters
130        ----------
131        model: torch.nn.Module
132            Model that is locally optimized
133        device: torch.device
134            Accelerator to run the code
135
136        Returns
137        -------
138        tuple[float, float]
139            Tuple of average sample loss and accuracy on the local dataset
140        '''
141        
142        loader = data.DataLoader(self.subset, batch_size = len(self.subset))
143        x, y = next(iter(loader))
144        x = x.to(device)
145        y = y.to(device)
146        
147        model.eval()
148
149        with torch.no_grad():
150            prediction = model(x)
151            loss = nn.functional.cross_entropy(prediction, y)
152            accuracy = torch.sum(torch.argmax(prediction, dim = 1) == y)
153            return loss.item(), accuracy.item()

Evaluate the model by computing the average sample loss and accuracy.

Parameters
  • model (torch.nn.Module): Model that is locally optimized
  • device (torch.device): Accelerator to run the code
Returns
  • tuple[float, float]: Tuple of average sample loss and accuracy on the local dataset
class Coordinator:
156class Coordinator:
157    '''
158    This class represents a centralized server coordinating the training of a shared model across multiple agents (i.e. clients).
159    
160    Note
161    ----
162    The agents locally update their models using the FedAvg optimization scheme.
163    '''
164
165    def __init__(
166        self,
167        model: torch.nn.Module,
168        datasets: dict[str, list[utils.FederatedSubset]],
169        scheme: WeightingScheme = None,
170        logger: Logger = Logger.default()
171    ):
172        '''
173        Constructs the centralized coordinator, i.e. server, in the federated learning simulation.
174
175        Parameters
176        ----------
177        model: torch.nn.Module
178            Initial shared model
179        datasets: dict[str, list[utils.FederatedSubset]]
180            Training clients' subsets ('training') and testing clients' subsets ('testing')
181        scheme: WeightingScheme
182            Aggregation scheme to weight local updates from clients
183        logger: Logger
184            Logger instance to save progress during the simulation
185        '''
186        
187        self.datasets = datasets
188        self.model = model
189        self.agents = {
190            group: [ Agent(subset) for subset in dataset ] for group, dataset in datasets.items() 
191        }
192        self.weights = scheme.weights()
193        self.logger = logger
194
195    def run(self, n_iterations: int, n_steps: int = None, n_epochs = None, batch_size: int = 32, step_size: float = 1e-3, step_size_diminishing: bool = False, l2_penalty: float = 1e-4, max_gradient_norm: float = 1.0, device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')):
196        '''
197        Runs `n_iterations` optimization (with algorithm FedAvg) and evaluation rounds on training clients.
198
199        Parameters
200        ----------
201        n_iterations: int
202            Number of global rounds
203        n_steps: int
204            Number of local SGD steps used for optimization on clients
205        n_epochs: int
206            Number of local epochs used for optimization on clients (mutually excludes n_steps)
207        batch_size: int
208            Number of samples in one SGD minibatch
209        step_size: float
210            Learning rate
211        step_size_diminishing: bool
212            This enables diminishing the step size linearly in time
213        l2_penalty: float
214            Weight of the L2 (Tikhonov) regularization used to penalize local models
215        max_gradient_norm: float
216            Value used to clip the norm of the stochastic gradient during local optimization
217        device: torch.device
218            Accelerator to run the code
219        evaluate: bool
220            Flag that enables evaluation of the update global model on training and testing clients
221
222        Note
223        ----
224        Runs `n_iterations` times function `iterate(...)`.
225        '''
226        
227        assert n_steps is not None or n_epochs is not None
228        
229        self.model = self.model.to(device)
230        self.model.compile()
231
232        for iteration in range(n_iterations):
233            step_size_updated = step_size if not step_size_diminishing else step_size / (iteration + 1)
234            metrics = self.iterate(iteration, n_steps, n_epochs, batch_size, step_size_updated, l2_penalty, max_gradient_norm, device, evaluate = True)
235            
236            self.logger.log({
237                'step': iteration,
238                'loss.training': metrics['training']['loss'],
239                'loss.testing': metrics['testing']['loss'],
240                'accuracy.training': metrics['training']['accuracy'],
241                'accuracy.testing': metrics['testing']['accuracy'],
242            })
243
244            print(iteration, metrics)
245
246    def iterate(self, iteration: int, n_steps: int, n_epochs: int, batch_size: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device, evaluate: bool = False) -> dict[str, float]:
247        '''
248        Runs a single optimization round with FedAvg algorithm on all training clients.
249
250        Parameters
251        ----------
252        iteration: int
253            Current global round
254        n_steps: int
255            Number of local SGD steps used for optimization on clients
256        n_epochs: int
257            Number of local epochs used for optimization on clients (mutually excludes n_steps)
258        batch_size: int
259            Number of samples in one SGD minibatch
260        step_size: float
261            Learning rate
262        l2_penalty: float
263            Weight of the L2 (Tikhonov) regularization used to penalize local models
264        max_gradient_norm: float
265            Value used to clip the norm of the stochastic gradient during local optimization
266        device: torch.device
267            Accelerator to run the code
268        evaluate: bool
269            Flag that enables evaluation of the update global model on training and testing clients
270
271        Returns
272        -------
273        dict[str, float]
274            Dictionary of current round's metrics
275        '''
276
277        indices = list(range(0, len(self.agents['training'])))
278        k = len(self.agents['training'])
279        
280        random.shuffle(indices)
281        
282        indices = indices[:k]
283        participants = [ self.agents['training'][i] for i in indices ]
284        weights = [ self.weights['training'][i] for i in indices ]
285
286        initial_model = deepcopy(self.model)
287        updates: list[nn.Module] = [ initial_model for _ in self.agents['training'] ]
288
289        if n_steps is not None:
290            for i, participant in tqdm(zip(indices, participants), total = len(participants), desc = 'Optimization on training agents (iteration {})'.format(iteration)):
291                updates[i] = participant.optimize(deepcopy(initial_model), n_steps = n_steps, step_size = step_size, l2_penalty = l2_penalty, max_gradient_norm = max_gradient_norm, device = device)
292        else:
293            for i, participant in tqdm(zip(indices, participants), total = len(participants), desc = 'Optimization on training agents (iteration {})'.format(iteration)):
294                updates[i] = participant.multioptimize(deepcopy(initial_model), n_epochs = n_epochs, batch_size = batch_size, step_size = step_size, l2_penalty = l2_penalty, max_gradient_norm = max_gradient_norm, device = device)
295
296        self.average(updates, weights = weights)
297
298        if not evaluate:
299            return {}
300        
301        return self.evaluate(iteration, device)
302        
303    def evaluate(self, iteration: int, device: torch.device):
304        '''
305        Computes average sample accuracy and loss of the global model on training and testing clients during current `iteration`.
306
307        Parameters
308        ----------
309        iteration: int
310            Current global round of the simulation
311        device: torch.device
312            Accelerator
313        '''
314        
315        metrics = { 'training': { 'loss': 0, 'accuracy': 0 }, 'testing': { 'loss': 0, 'accuracy': 0 } }
316
317        for agent, weight in tqdm(zip(self.agents['training'], self.weights['training']), total = len(self.agents['training']), desc = 'Evaluation on training agents (iteration {})'.format(iteration)):
318            loss, accuracy = agent.evaluate(self.model, device = device)
319            metrics['training']['loss'] = metrics['training']['loss'] + weight * loss
320            metrics['training']['accuracy'] = metrics['training']['accuracy'] + accuracy
321
322        for agent, weight in tqdm(zip(self.agents['testing'], self.weights['testing']), total = len(self.agents['testing']), desc = 'Evaluation on testing agents (iteration {})'.format(iteration)):
323            loss, accuracy = agent.evaluate(self.model, device = device)
324            metrics['testing']['loss'] = metrics['testing']['loss'] + weight * loss
325            metrics['testing']['accuracy'] = metrics['testing']['accuracy'] + accuracy
326
327        metrics['training']['loss'] /= sum(self.weights['training'])
328        metrics['testing']['loss'] /= sum(self.weights['testing'])
329        
330        metrics['training']['accuracy'] /= sum([ len(agent.subset) for agent in self.agents['training'] ])
331        metrics['testing']['accuracy'] /= sum([ len(agent.subset) for agent in self.agents['testing'] ])
332
333        return metrics
334
335    def average(self, updates: list[torch.nn.Module], weights: list[float]):
336        '''
337        Averages clients' `updates` weighted by aggreation `weights` into the shared model.
338
339        Parameters
340        ----------
341        updates: list[torch.nn.Module]
342            Locally updated clients' models
343        weights: list[float]
344            Aggregation weights (one for each client)
345        '''
346
347        total = sum(weights)
348        self.model.load_state_dict(OrderedDict([
349            (
350                name, 
351                torch.stack([ weight * update.state_dict()[name] for update, weight in zip(updates, weights) ]).sum(dim = 0) / total
352            )
353            for name in self.model.state_dict().keys()
354        ]))

This class represents a centralized server coordinating the training of a shared model across multiple agents (i.e. clients).

Note

The agents locally update their models using the FedAvg optimization scheme.

Coordinator( model: torch.nn.modules.module.Module, datasets: dict[str, list[fedbox.datasets.utils.FederatedSubset]], scheme: fedbox.optimization.utils.WeightingScheme = None, logger: fedbox.optimization.utils.Logger = <fedbox.optimization.utils.Logger object>)
165    def __init__(
166        self,
167        model: torch.nn.Module,
168        datasets: dict[str, list[utils.FederatedSubset]],
169        scheme: WeightingScheme = None,
170        logger: Logger = Logger.default()
171    ):
172        '''
173        Constructs the centralized coordinator, i.e. server, in the federated learning simulation.
174
175        Parameters
176        ----------
177        model: torch.nn.Module
178            Initial shared model
179        datasets: dict[str, list[utils.FederatedSubset]]
180            Training clients' subsets ('training') and testing clients' subsets ('testing')
181        scheme: WeightingScheme
182            Aggregation scheme to weight local updates from clients
183        logger: Logger
184            Logger instance to save progress during the simulation
185        '''
186        
187        self.datasets = datasets
188        self.model = model
189        self.agents = {
190            group: [ Agent(subset) for subset in dataset ] for group, dataset in datasets.items() 
191        }
192        self.weights = scheme.weights()
193        self.logger = logger

Constructs the centralized coordinator, i.e. server, in the federated learning simulation.

Parameters
  • model (torch.nn.Module): Initial shared model
  • datasets (dict[str, list[utils.FederatedSubset]]): Training clients' subsets ('training') and testing clients' subsets ('testing')
  • scheme (WeightingScheme): Aggregation scheme to weight local updates from clients
  • logger (Logger): Logger instance to save progress during the simulation
datasets
model
agents
weights
logger
def run( self, n_iterations: int, n_steps: int = None, n_epochs=None, batch_size: int = 32, step_size: float = 0.001, step_size_diminishing: bool = False, l2_penalty: float = 0.0001, max_gradient_norm: float = 1.0, device: torch.device = device(type='cpu')):
195    def run(self, n_iterations: int, n_steps: int = None, n_epochs = None, batch_size: int = 32, step_size: float = 1e-3, step_size_diminishing: bool = False, l2_penalty: float = 1e-4, max_gradient_norm: float = 1.0, device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')):
196        '''
197        Runs `n_iterations` optimization (with algorithm FedAvg) and evaluation rounds on training clients.
198
199        Parameters
200        ----------
201        n_iterations: int
202            Number of global rounds
203        n_steps: int
204            Number of local SGD steps used for optimization on clients
205        n_epochs: int
206            Number of local epochs used for optimization on clients (mutually excludes n_steps)
207        batch_size: int
208            Number of samples in one SGD minibatch
209        step_size: float
210            Learning rate
211        step_size_diminishing: bool
212            This enables diminishing the step size linearly in time
213        l2_penalty: float
214            Weight of the L2 (Tikhonov) regularization used to penalize local models
215        max_gradient_norm: float
216            Value used to clip the norm of the stochastic gradient during local optimization
217        device: torch.device
218            Accelerator to run the code
219        evaluate: bool
220            Flag that enables evaluation of the update global model on training and testing clients
221
222        Note
223        ----
224        Runs `n_iterations` times function `iterate(...)`.
225        '''
226        
227        assert n_steps is not None or n_epochs is not None
228        
229        self.model = self.model.to(device)
230        self.model.compile()
231
232        for iteration in range(n_iterations):
233            step_size_updated = step_size if not step_size_diminishing else step_size / (iteration + 1)
234            metrics = self.iterate(iteration, n_steps, n_epochs, batch_size, step_size_updated, l2_penalty, max_gradient_norm, device, evaluate = True)
235            
236            self.logger.log({
237                'step': iteration,
238                'loss.training': metrics['training']['loss'],
239                'loss.testing': metrics['testing']['loss'],
240                'accuracy.training': metrics['training']['accuracy'],
241                'accuracy.testing': metrics['testing']['accuracy'],
242            })
243
244            print(iteration, metrics)

Runs n_iterations optimization (with algorithm FedAvg) and evaluation rounds on training clients.

Parameters
  • n_iterations (int): Number of global rounds
  • n_steps (int): Number of local SGD steps used for optimization on clients
  • n_epochs (int): Number of local epochs used for optimization on clients (mutually excludes n_steps)
  • batch_size (int): Number of samples in one SGD minibatch
  • step_size (float): Learning rate
  • step_size_diminishing (bool): This enables diminishing the step size linearly in time
  • l2_penalty (float): Weight of the L2 (Tikhonov) regularization used to penalize local models
  • max_gradient_norm (float): Value used to clip the norm of the stochastic gradient during local optimization
  • device (torch.device): Accelerator to run the code
  • evaluate (bool): Flag that enables evaluation of the update global model on training and testing clients
Note

Runs n_iterations times function iterate(...).

def iterate( self, iteration: int, n_steps: int, n_epochs: int, batch_size: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device, evaluate: bool = False) -> dict[str, float]:
246    def iterate(self, iteration: int, n_steps: int, n_epochs: int, batch_size: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device, evaluate: bool = False) -> dict[str, float]:
247        '''
248        Runs a single optimization round with FedAvg algorithm on all training clients.
249
250        Parameters
251        ----------
252        iteration: int
253            Current global round
254        n_steps: int
255            Number of local SGD steps used for optimization on clients
256        n_epochs: int
257            Number of local epochs used for optimization on clients (mutually excludes n_steps)
258        batch_size: int
259            Number of samples in one SGD minibatch
260        step_size: float
261            Learning rate
262        l2_penalty: float
263            Weight of the L2 (Tikhonov) regularization used to penalize local models
264        max_gradient_norm: float
265            Value used to clip the norm of the stochastic gradient during local optimization
266        device: torch.device
267            Accelerator to run the code
268        evaluate: bool
269            Flag that enables evaluation of the update global model on training and testing clients
270
271        Returns
272        -------
273        dict[str, float]
274            Dictionary of current round's metrics
275        '''
276
277        indices = list(range(0, len(self.agents['training'])))
278        k = len(self.agents['training'])
279        
280        random.shuffle(indices)
281        
282        indices = indices[:k]
283        participants = [ self.agents['training'][i] for i in indices ]
284        weights = [ self.weights['training'][i] for i in indices ]
285
286        initial_model = deepcopy(self.model)
287        updates: list[nn.Module] = [ initial_model for _ in self.agents['training'] ]
288
289        if n_steps is not None:
290            for i, participant in tqdm(zip(indices, participants), total = len(participants), desc = 'Optimization on training agents (iteration {})'.format(iteration)):
291                updates[i] = participant.optimize(deepcopy(initial_model), n_steps = n_steps, step_size = step_size, l2_penalty = l2_penalty, max_gradient_norm = max_gradient_norm, device = device)
292        else:
293            for i, participant in tqdm(zip(indices, participants), total = len(participants), desc = 'Optimization on training agents (iteration {})'.format(iteration)):
294                updates[i] = participant.multioptimize(deepcopy(initial_model), n_epochs = n_epochs, batch_size = batch_size, step_size = step_size, l2_penalty = l2_penalty, max_gradient_norm = max_gradient_norm, device = device)
295
296        self.average(updates, weights = weights)
297
298        if not evaluate:
299            return {}
300        
301        return self.evaluate(iteration, device)

Runs a single optimization round with FedAvg algorithm on all training clients.

Parameters
  • iteration (int): Current global round
  • n_steps (int): Number of local SGD steps used for optimization on clients
  • n_epochs (int): Number of local epochs used for optimization on clients (mutually excludes n_steps)
  • batch_size (int): Number of samples in one SGD minibatch
  • step_size (float): Learning rate
  • l2_penalty (float): Weight of the L2 (Tikhonov) regularization used to penalize local models
  • max_gradient_norm (float): Value used to clip the norm of the stochastic gradient during local optimization
  • device (torch.device): Accelerator to run the code
  • evaluate (bool): Flag that enables evaluation of the update global model on training and testing clients
Returns
  • dict[str, float]: Dictionary of current round's metrics
def evaluate(self, iteration: int, device: torch.device):
303    def evaluate(self, iteration: int, device: torch.device):
304        '''
305        Computes average sample accuracy and loss of the global model on training and testing clients during current `iteration`.
306
307        Parameters
308        ----------
309        iteration: int
310            Current global round of the simulation
311        device: torch.device
312            Accelerator
313        '''
314        
315        metrics = { 'training': { 'loss': 0, 'accuracy': 0 }, 'testing': { 'loss': 0, 'accuracy': 0 } }
316
317        for agent, weight in tqdm(zip(self.agents['training'], self.weights['training']), total = len(self.agents['training']), desc = 'Evaluation on training agents (iteration {})'.format(iteration)):
318            loss, accuracy = agent.evaluate(self.model, device = device)
319            metrics['training']['loss'] = metrics['training']['loss'] + weight * loss
320            metrics['training']['accuracy'] = metrics['training']['accuracy'] + accuracy
321
322        for agent, weight in tqdm(zip(self.agents['testing'], self.weights['testing']), total = len(self.agents['testing']), desc = 'Evaluation on testing agents (iteration {})'.format(iteration)):
323            loss, accuracy = agent.evaluate(self.model, device = device)
324            metrics['testing']['loss'] = metrics['testing']['loss'] + weight * loss
325            metrics['testing']['accuracy'] = metrics['testing']['accuracy'] + accuracy
326
327        metrics['training']['loss'] /= sum(self.weights['training'])
328        metrics['testing']['loss'] /= sum(self.weights['testing'])
329        
330        metrics['training']['accuracy'] /= sum([ len(agent.subset) for agent in self.agents['training'] ])
331        metrics['testing']['accuracy'] /= sum([ len(agent.subset) for agent in self.agents['testing'] ])
332
333        return metrics

Computes average sample accuracy and loss of the global model on training and testing clients during current iteration.

Parameters
  • iteration (int): Current global round of the simulation
  • device (torch.device): Accelerator
def average( self, updates: list[torch.nn.modules.module.Module], weights: list[float]):
335    def average(self, updates: list[torch.nn.Module], weights: list[float]):
336        '''
337        Averages clients' `updates` weighted by aggreation `weights` into the shared model.
338
339        Parameters
340        ----------
341        updates: list[torch.nn.Module]
342            Locally updated clients' models
343        weights: list[float]
344            Aggregation weights (one for each client)
345        '''
346
347        total = sum(weights)
348        self.model.load_state_dict(OrderedDict([
349            (
350                name, 
351                torch.stack([ weight * update.state_dict()[name] for update, weight in zip(updates, weights) ]).sum(dim = 0) / total
352            )
353            for name in self.model.state_dict().keys()
354        ]))

Averages clients' updates weighted by aggreation weights into the shared model.

Parameters
  • updates (list[torch.nn.Module]): Locally updated clients' models
  • weights (list[float]): Aggregation weights (one for each client)