fedbox.optimization.fedprox

  1from ..datasets import utils
  2from . import fedavg
  3from .utils import WeightingScheme, Logger
  4
  5from copy import deepcopy
  6import random
  7import torch
  8import torch.nn as nn
  9import torch.optim as optim
 10import torch.utils.data as data
 11from tqdm import tqdm
 12
 13class Agent(fedavg.Agent):
 14    '''
 15    An agent (client) uses the FedProx scheme to optimize a shared model on its local subset.
 16    '''
 17
 18    def __init__(self, subset: utils.FederatedSubset):
 19        '''
 20        Initializes the agent with a local `subset` of data samples and labels.
 21
 22        Parameters
 23        ----------
 24        subset: utils.FederatedSubset
 25            Subset of data samples and labels
 26        '''
 27        
 28        self.subset = subset
 29
 30    def step(self, alpha: float, initial: torch.nn.Module, model: torch.nn.Module, x: torch.Tensor, y: torch.Tensor, optimizer: optim.Optimizer, max_gradient_norm: float):
 31        '''
 32        Performs an optimization step on `model` using minibatch (`x`, `y`) accounting for the proximal term weighted by `alpha`.
 33
 34        Parameters
 35        ----------
 36        alpha: float
 37            Weight of the proximal regularization term in FedProx
 38        initial: torch.nn.Module
 39            Model received at the beginning of the round from the server
 40        model: torch.nn.Module
 41            Model that is optimized locally
 42        x: torch.Tensor
 43            Data samples in the minibatch
 44        y: torch.Tensor
 45            Data labels in the minibatch
 46        optimizer: optim.Optimizer
 47            Gradient-based optimizer
 48        max_gradient_norm: float
 49            Value used to clip the norm of the stochastic gradient
 50        '''
 51
 52        prediction = model(x)
 53        
 54        loss = nn.functional.cross_entropy(prediction, y) + alpha * self.proxterm(initial, model)
 55
 56        loss.backward()
 57
 58        torch.nn.utils.clip_grad_norm_(model.parameters(), max_gradient_norm, error_if_nonfinite = True)
 59
 60        optimizer.step()
 61        optimizer.zero_grad()
 62
 63    def optimize(self, alpha: float, initial: torch.nn.Module, model: torch.nn.Module, n_steps: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device):
 64        '''
 65        Runs `n_steps` stochastic gradient descent steps including the `alpha`-weighted proximal term on the local dataset (one step for each minibatch).
 66
 67        Parameters
 68        ----------
 69        alpha: float
 70            Weight of the proximal regularization term in FedProx
 71        initial: torch.nn.Module
 72            Model received at the beginning of the round from the server
 73        model: torch.nn.Module
 74            Model that is locally optimized
 75        n_steps: int
 76            Number of local SGD steps, i.e. number of minibatches
 77        step_size: float
 78            Step size or learning rate
 79        l2_penalty: float
 80            Weight of L2 (Tikhonov) regularization term
 81        max_gradient_norm: float
 82            Value used to clip the norm of the stochastic gradient
 83        device: torch.device
 84            Accelerator to run the code
 85        '''
 86        
 87        loader = data.DataLoader(self.subset, batch_size = len(self.subset) // n_steps)
 88        optimizer = optim.SGD(model.parameters(), lr = step_size, weight_decay = l2_penalty)
 89
 90        model.train()
 91
 92        for x, y in loader:
 93            self.step(alpha, initial, model, x.to(device), y.to(device), optimizer = optimizer, max_gradient_norm = max_gradient_norm)
 94
 95        return model
 96
 97    def multioptimize(self, alpha: float, initial: torch.nn.Module, model: torch.nn.Module, n_epochs: int, batch_size: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device):
 98        '''
 99        Runs `n_epochs` stochastic gradient descent epochs including the `alpha`-weighted proximal term on the local dataset.
100
101        Parameters
102        ----------
103        alpha: float
104            Weight of the proximal regularization term in FedProx
105        initial: torch.nn.Module
106            Model received at the beginning of the round from the server
107        model: torch.nn.Module
108            Model that is locally optimized
109        n_epochs: int
110            Number of local epochs to pass over the entire local dataset
111        step_size: float
112            Step size or learning rate
113        l2_penalty: float
114            Weight of L2 (Tikhonov) regularization term
115        max_gradient_norm: float
116            Value used to clip the norm of the stochastic gradient
117        device: torch.device
118            Accelerator to run the code
119
120        Note
121        ----
122        Differently from `optimize(...)`, each epoch corresponds to passing over the entire dataset using SGD.
123        '''
124
125        loader = data.DataLoader(self.subset, batch_size = batch_size)
126        optimizer = optim.SGD(model.parameters(), lr = step_size, weight_decay = l2_penalty)
127
128        model.train()
129
130        for _ in range(n_epochs):
131            for x, y in loader:
132                self.step(alpha, initial, model, x.to(device), y.to(device), optimizer = optimizer, max_gradient_norm = max_gradient_norm)
133
134        return model
135
136    def evaluate(self, model: torch.nn.Module, device: torch.device) -> tuple[float, float]:
137        '''
138        Evaluate the `model` by computing the average sample loss and accuracy.
139
140        Parameters
141        ----------
142        model: torch.nn.Module
143            Model that is locally optimized
144        device: torch.device
145            Accelerator to run the code
146
147        Returns
148        -------
149        tuple[float, float]
150            Tuple of average sample loss and accuracy on the local dataset
151        '''
152
153        loader = data.DataLoader(self.subset, batch_size = len(self.subset))
154        x, y = next(iter(loader))
155        x = x.to(device)
156        y = y.to(device)
157        
158        model.eval()
159        
160        with torch.no_grad():
161            prediction = model(x)
162            loss = nn.functional.cross_entropy(prediction, y)
163            accuracy = torch.sum(torch.argmax(prediction, dim = 1) == y)
164            return loss.item(), accuracy.item()
165        
166    def proxterm(self, initial: torch.nn.Module, model: torch.nn.Module) -> torch.Tensor:
167        '''
168        Computes the `alpha`-weighted proximal term as squared difference between `initial` model and locally optimized `model`.
169
170        Parameters
171        ----------
172        initial: torch.nn.Module
173            Model received at the beginning of the round from the server
174        model: torch.nn.Module
175            Model that is locally optimized
176
177        Returns
178        -------
179        torch.Tensor
180            Squared difference between `initial` model and locally optimized `model`
181        '''
182
183        return torch.sum(torch.tensor([
184            torch.square(model.get_parameter(name) - initial.state_dict()[name].detach()).sum()
185            for name in model.state_dict().keys()
186        ], requires_grad = True))
187
188
189
190class Coordinator(fedavg.Coordinator):
191    '''
192    This class represents a centralized server coordinating the training of a shared model across multiple agents (i.e. clients).
193    
194    Note
195    ----
196    The agents locally update their models using the FedProx optimization scheme.
197    '''
198
199    def __init__(
200        self,
201        alpha: float,
202        model: torch.nn.Module,
203        datasets: dict[str, list[utils.FederatedSubset]],
204        scheme: WeightingScheme = None,
205        logger: Logger = Logger.default()
206    ):
207        '''
208        Constructs the centralized coordinator, i.e. server, in the federated learning simulation.
209
210        Parameters
211        ----------
212        alpha: float
213            Weight of the proximal term used by training agents while running the optimization algorithm FedProx
214        model: torch.nn.Module
215            Initial shared model
216        datasets: dict[str, list[utils.FederatedSubset]]
217            Training clients' subsets ('training') and testing clients' subsets ('testing')
218        scheme: WeightingScheme
219            Aggregation scheme to weight local updates from clients
220        logger: Logger
221            Logger instance to save progress during the simulation
222        '''
223
224        assert alpha > 0
225
226        self.alpha = alpha
227        self.datasets = datasets
228        self.model = model
229        self.agents = {
230            group: [ Agent(subset) for subset in dataset ] for group, dataset in datasets.items() 
231        }
232        self.weights = scheme.weights()
233        self.logger = logger
234
235    def run(self, n_iterations: int, n_steps: int = None, n_epochs = None, batch_size: int = 32, step_size: float = 1e-3, step_size_diminishing: bool = False, l2_penalty: float = 1e-4, max_gradient_norm: float = 1.0, device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')):
236        '''
237        Runs `n_iterations` optimization (with algorithm FedProx) and evaluation rounds on training clients.
238
239        Parameters
240        ----------
241        n_iterations: int
242            Number of global rounds
243        n_steps: int
244            Number of local SGD steps used for optimization on clients
245        n_epochs: int
246            Number of local epochs used for optimization on clients (mutually excludes n_steps)
247        batch_size: int
248            Number of samples in one SGD minibatch
249        step_size: float
250            Learning rate
251        step_size_diminishing: bool
252            This enables diminishing the step size linearly in time
253        l2_penalty: float
254            Weight of the L2 (Tikhonov) regularization used to penalize local models
255        max_gradient_norm: float
256            Value used to clip the norm of the stochastic gradient during local optimization
257        device: torch.device
258            Accelerator to run the code
259        evaluate: bool
260            Flag that enables evaluation of the update global model on training and testing clients
261
262        Note
263        ----
264        Runs `n_iterations` times function `iterate(...)`.
265        '''
266
267        assert n_steps is not None or n_epochs is not None
268        
269        self.model = self.model.to(device)
270        self.model.compile()
271
272        for iteration in range(n_iterations):
273            step_size_updated = step_size if not step_size_diminishing else step_size / (iteration + 1)
274            metrics = self.iterate(iteration, n_steps, n_epochs, batch_size, step_size_updated, l2_penalty, max_gradient_norm, device, evaluate = True)
275            
276            self.logger.log({
277                'step': iteration,
278                'loss.training': metrics['training']['loss'],
279                'loss.testing': metrics['testing']['loss'],
280                'accuracy.training': metrics['training']['accuracy'],
281                'accuracy.testing': metrics['testing']['accuracy'],
282            })
283
284            print(iteration, metrics)
285
286    def iterate(self, iteration: int, n_steps: int, n_epochs: int, batch_size: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device, evaluate: bool = False) -> dict[str, float]:
287        '''
288        Runs a single optimization round with FedProx algorithm on all training clients.
289
290        Parameters
291        ----------
292        iteration: int
293            Current global round
294        n_steps: int
295            Number of local SGD steps used for optimization on clients
296        n_epochs: int
297            Number of local epochs used for optimization on clients (mutually excludes n_steps)
298        batch_size: int
299            Number of samples in one SGD minibatch
300        step_size: float
301            Learning rate
302        l2_penalty: float
303            Weight of the L2 (Tikhonov) regularization used to penalize local models
304        max_gradient_norm: float
305            Value used to clip the norm of the stochastic gradient during local optimization
306        device: torch.device
307            Accelerator to run the code
308        evaluate: bool
309            Flag that enables evaluation of the update global model on training and testing clients
310
311        Returns
312        -------
313        dict[str, float]
314            Dictionary of current round's metrics
315        '''
316
317        indices = list(range(0, len(self.agents['training'])))
318        k = len(self.agents['training'])
319        
320        random.shuffle(indices)
321        
322        indices = indices[:k]
323        participants = [ self.agents['training'][i] for i in indices ]
324        weights = [ self.weights['training'][i] for i in indices ]
325
326        initial_model = deepcopy(self.model)
327        updates: list[nn.Module] = [ initial_model for _ in self.agents['training'] ]
328
329        if n_steps is not None:
330            for i, participant in tqdm(zip(indices, participants), total = len(participants), desc = 'Optimization on training agents (iteration {})'.format(iteration)):
331                updates[i] = participant.optimize(self.alpha, initial_model, deepcopy(initial_model), n_steps = n_steps, step_size = step_size, l2_penalty = l2_penalty, max_gradient_norm = max_gradient_norm, device = device)
332        else:
333            for i, participant in tqdm(zip(indices, participants), total = len(participants), desc = 'Optimization on training agents (iteration {})'.format(iteration)):
334                updates[i] = participant.multioptimize(self.alpha, initial_model, deepcopy(initial_model), n_epochs = n_epochs, batch_size = batch_size, step_size = step_size, l2_penalty = l2_penalty, max_gradient_norm = max_gradient_norm, device = device)
335
336        self.average(updates, weights = weights)
337
338        if not evaluate:
339            return {}
340        
341        return self.evaluate(iteration, device)
class Agent(fedbox.optimization.fedavg.Agent):
 14class Agent(fedavg.Agent):
 15    '''
 16    An agent (client) uses the FedProx scheme to optimize a shared model on its local subset.
 17    '''
 18
 19    def __init__(self, subset: utils.FederatedSubset):
 20        '''
 21        Initializes the agent with a local `subset` of data samples and labels.
 22
 23        Parameters
 24        ----------
 25        subset: utils.FederatedSubset
 26            Subset of data samples and labels
 27        '''
 28        
 29        self.subset = subset
 30
 31    def step(self, alpha: float, initial: torch.nn.Module, model: torch.nn.Module, x: torch.Tensor, y: torch.Tensor, optimizer: optim.Optimizer, max_gradient_norm: float):
 32        '''
 33        Performs an optimization step on `model` using minibatch (`x`, `y`) accounting for the proximal term weighted by `alpha`.
 34
 35        Parameters
 36        ----------
 37        alpha: float
 38            Weight of the proximal regularization term in FedProx
 39        initial: torch.nn.Module
 40            Model received at the beginning of the round from the server
 41        model: torch.nn.Module
 42            Model that is optimized locally
 43        x: torch.Tensor
 44            Data samples in the minibatch
 45        y: torch.Tensor
 46            Data labels in the minibatch
 47        optimizer: optim.Optimizer
 48            Gradient-based optimizer
 49        max_gradient_norm: float
 50            Value used to clip the norm of the stochastic gradient
 51        '''
 52
 53        prediction = model(x)
 54        
 55        loss = nn.functional.cross_entropy(prediction, y) + alpha * self.proxterm(initial, model)
 56
 57        loss.backward()
 58
 59        torch.nn.utils.clip_grad_norm_(model.parameters(), max_gradient_norm, error_if_nonfinite = True)
 60
 61        optimizer.step()
 62        optimizer.zero_grad()
 63
 64    def optimize(self, alpha: float, initial: torch.nn.Module, model: torch.nn.Module, n_steps: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device):
 65        '''
 66        Runs `n_steps` stochastic gradient descent steps including the `alpha`-weighted proximal term on the local dataset (one step for each minibatch).
 67
 68        Parameters
 69        ----------
 70        alpha: float
 71            Weight of the proximal regularization term in FedProx
 72        initial: torch.nn.Module
 73            Model received at the beginning of the round from the server
 74        model: torch.nn.Module
 75            Model that is locally optimized
 76        n_steps: int
 77            Number of local SGD steps, i.e. number of minibatches
 78        step_size: float
 79            Step size or learning rate
 80        l2_penalty: float
 81            Weight of L2 (Tikhonov) regularization term
 82        max_gradient_norm: float
 83            Value used to clip the norm of the stochastic gradient
 84        device: torch.device
 85            Accelerator to run the code
 86        '''
 87        
 88        loader = data.DataLoader(self.subset, batch_size = len(self.subset) // n_steps)
 89        optimizer = optim.SGD(model.parameters(), lr = step_size, weight_decay = l2_penalty)
 90
 91        model.train()
 92
 93        for x, y in loader:
 94            self.step(alpha, initial, model, x.to(device), y.to(device), optimizer = optimizer, max_gradient_norm = max_gradient_norm)
 95
 96        return model
 97
 98    def multioptimize(self, alpha: float, initial: torch.nn.Module, model: torch.nn.Module, n_epochs: int, batch_size: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device):
 99        '''
100        Runs `n_epochs` stochastic gradient descent epochs including the `alpha`-weighted proximal term on the local dataset.
101
102        Parameters
103        ----------
104        alpha: float
105            Weight of the proximal regularization term in FedProx
106        initial: torch.nn.Module
107            Model received at the beginning of the round from the server
108        model: torch.nn.Module
109            Model that is locally optimized
110        n_epochs: int
111            Number of local epochs to pass over the entire local dataset
112        step_size: float
113            Step size or learning rate
114        l2_penalty: float
115            Weight of L2 (Tikhonov) regularization term
116        max_gradient_norm: float
117            Value used to clip the norm of the stochastic gradient
118        device: torch.device
119            Accelerator to run the code
120
121        Note
122        ----
123        Differently from `optimize(...)`, each epoch corresponds to passing over the entire dataset using SGD.
124        '''
125
126        loader = data.DataLoader(self.subset, batch_size = batch_size)
127        optimizer = optim.SGD(model.parameters(), lr = step_size, weight_decay = l2_penalty)
128
129        model.train()
130
131        for _ in range(n_epochs):
132            for x, y in loader:
133                self.step(alpha, initial, model, x.to(device), y.to(device), optimizer = optimizer, max_gradient_norm = max_gradient_norm)
134
135        return model
136
137    def evaluate(self, model: torch.nn.Module, device: torch.device) -> tuple[float, float]:
138        '''
139        Evaluate the `model` by computing the average sample loss and accuracy.
140
141        Parameters
142        ----------
143        model: torch.nn.Module
144            Model that is locally optimized
145        device: torch.device
146            Accelerator to run the code
147
148        Returns
149        -------
150        tuple[float, float]
151            Tuple of average sample loss and accuracy on the local dataset
152        '''
153
154        loader = data.DataLoader(self.subset, batch_size = len(self.subset))
155        x, y = next(iter(loader))
156        x = x.to(device)
157        y = y.to(device)
158        
159        model.eval()
160        
161        with torch.no_grad():
162            prediction = model(x)
163            loss = nn.functional.cross_entropy(prediction, y)
164            accuracy = torch.sum(torch.argmax(prediction, dim = 1) == y)
165            return loss.item(), accuracy.item()
166        
167    def proxterm(self, initial: torch.nn.Module, model: torch.nn.Module) -> torch.Tensor:
168        '''
169        Computes the `alpha`-weighted proximal term as squared difference between `initial` model and locally optimized `model`.
170
171        Parameters
172        ----------
173        initial: torch.nn.Module
174            Model received at the beginning of the round from the server
175        model: torch.nn.Module
176            Model that is locally optimized
177
178        Returns
179        -------
180        torch.Tensor
181            Squared difference between `initial` model and locally optimized `model`
182        '''
183
184        return torch.sum(torch.tensor([
185            torch.square(model.get_parameter(name) - initial.state_dict()[name].detach()).sum()
186            for name in model.state_dict().keys()
187        ], requires_grad = True))

An agent (client) uses the FedProx scheme to optimize a shared model on its local subset.

Agent(subset: fedbox.datasets.utils.FederatedSubset)
19    def __init__(self, subset: utils.FederatedSubset):
20        '''
21        Initializes the agent with a local `subset` of data samples and labels.
22
23        Parameters
24        ----------
25        subset: utils.FederatedSubset
26            Subset of data samples and labels
27        '''
28        
29        self.subset = subset

Initializes the agent with a local subset of data samples and labels.

Parameters
  • subset (utils.FederatedSubset): Subset of data samples and labels
subset
def step( self, alpha: float, initial: torch.nn.modules.module.Module, model: torch.nn.modules.module.Module, x: torch.Tensor, y: torch.Tensor, optimizer: torch.optim.optimizer.Optimizer, max_gradient_norm: float):
31    def step(self, alpha: float, initial: torch.nn.Module, model: torch.nn.Module, x: torch.Tensor, y: torch.Tensor, optimizer: optim.Optimizer, max_gradient_norm: float):
32        '''
33        Performs an optimization step on `model` using minibatch (`x`, `y`) accounting for the proximal term weighted by `alpha`.
34
35        Parameters
36        ----------
37        alpha: float
38            Weight of the proximal regularization term in FedProx
39        initial: torch.nn.Module
40            Model received at the beginning of the round from the server
41        model: torch.nn.Module
42            Model that is optimized locally
43        x: torch.Tensor
44            Data samples in the minibatch
45        y: torch.Tensor
46            Data labels in the minibatch
47        optimizer: optim.Optimizer
48            Gradient-based optimizer
49        max_gradient_norm: float
50            Value used to clip the norm of the stochastic gradient
51        '''
52
53        prediction = model(x)
54        
55        loss = nn.functional.cross_entropy(prediction, y) + alpha * self.proxterm(initial, model)
56
57        loss.backward()
58
59        torch.nn.utils.clip_grad_norm_(model.parameters(), max_gradient_norm, error_if_nonfinite = True)
60
61        optimizer.step()
62        optimizer.zero_grad()

Performs an optimization step on model using minibatch (x, y) accounting for the proximal term weighted by alpha.

Parameters
  • alpha (float): Weight of the proximal regularization term in FedProx
  • initial (torch.nn.Module): Model received at the beginning of the round from the server
  • model (torch.nn.Module): Model that is optimized locally
  • x (torch.Tensor): Data samples in the minibatch
  • y (torch.Tensor): Data labels in the minibatch
  • optimizer (optim.Optimizer): Gradient-based optimizer
  • max_gradient_norm (float): Value used to clip the norm of the stochastic gradient
def optimize( self, alpha: float, initial: torch.nn.modules.module.Module, model: torch.nn.modules.module.Module, n_steps: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device):
64    def optimize(self, alpha: float, initial: torch.nn.Module, model: torch.nn.Module, n_steps: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device):
65        '''
66        Runs `n_steps` stochastic gradient descent steps including the `alpha`-weighted proximal term on the local dataset (one step for each minibatch).
67
68        Parameters
69        ----------
70        alpha: float
71            Weight of the proximal regularization term in FedProx
72        initial: torch.nn.Module
73            Model received at the beginning of the round from the server
74        model: torch.nn.Module
75            Model that is locally optimized
76        n_steps: int
77            Number of local SGD steps, i.e. number of minibatches
78        step_size: float
79            Step size or learning rate
80        l2_penalty: float
81            Weight of L2 (Tikhonov) regularization term
82        max_gradient_norm: float
83            Value used to clip the norm of the stochastic gradient
84        device: torch.device
85            Accelerator to run the code
86        '''
87        
88        loader = data.DataLoader(self.subset, batch_size = len(self.subset) // n_steps)
89        optimizer = optim.SGD(model.parameters(), lr = step_size, weight_decay = l2_penalty)
90
91        model.train()
92
93        for x, y in loader:
94            self.step(alpha, initial, model, x.to(device), y.to(device), optimizer = optimizer, max_gradient_norm = max_gradient_norm)
95
96        return model

Runs n_steps stochastic gradient descent steps including the alpha-weighted proximal term on the local dataset (one step for each minibatch).

Parameters
  • alpha (float): Weight of the proximal regularization term in FedProx
  • initial (torch.nn.Module): Model received at the beginning of the round from the server
  • model (torch.nn.Module): Model that is locally optimized
  • n_steps (int): Number of local SGD steps, i.e. number of minibatches
  • step_size (float): Step size or learning rate
  • l2_penalty (float): Weight of L2 (Tikhonov) regularization term
  • max_gradient_norm (float): Value used to clip the norm of the stochastic gradient
  • device (torch.device): Accelerator to run the code
def multioptimize( self, alpha: float, initial: torch.nn.modules.module.Module, model: torch.nn.modules.module.Module, n_epochs: int, batch_size: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device):
 98    def multioptimize(self, alpha: float, initial: torch.nn.Module, model: torch.nn.Module, n_epochs: int, batch_size: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device):
 99        '''
100        Runs `n_epochs` stochastic gradient descent epochs including the `alpha`-weighted proximal term on the local dataset.
101
102        Parameters
103        ----------
104        alpha: float
105            Weight of the proximal regularization term in FedProx
106        initial: torch.nn.Module
107            Model received at the beginning of the round from the server
108        model: torch.nn.Module
109            Model that is locally optimized
110        n_epochs: int
111            Number of local epochs to pass over the entire local dataset
112        step_size: float
113            Step size or learning rate
114        l2_penalty: float
115            Weight of L2 (Tikhonov) regularization term
116        max_gradient_norm: float
117            Value used to clip the norm of the stochastic gradient
118        device: torch.device
119            Accelerator to run the code
120
121        Note
122        ----
123        Differently from `optimize(...)`, each epoch corresponds to passing over the entire dataset using SGD.
124        '''
125
126        loader = data.DataLoader(self.subset, batch_size = batch_size)
127        optimizer = optim.SGD(model.parameters(), lr = step_size, weight_decay = l2_penalty)
128
129        model.train()
130
131        for _ in range(n_epochs):
132            for x, y in loader:
133                self.step(alpha, initial, model, x.to(device), y.to(device), optimizer = optimizer, max_gradient_norm = max_gradient_norm)
134
135        return model

Runs n_epochs stochastic gradient descent epochs including the alpha-weighted proximal term on the local dataset.

Parameters
  • alpha (float): Weight of the proximal regularization term in FedProx
  • initial (torch.nn.Module): Model received at the beginning of the round from the server
  • model (torch.nn.Module): Model that is locally optimized
  • n_epochs (int): Number of local epochs to pass over the entire local dataset
  • step_size (float): Step size or learning rate
  • l2_penalty (float): Weight of L2 (Tikhonov) regularization term
  • max_gradient_norm (float): Value used to clip the norm of the stochastic gradient
  • device (torch.device): Accelerator to run the code
Note

Differently from optimize(...), each epoch corresponds to passing over the entire dataset using SGD.

def evaluate( self, model: torch.nn.modules.module.Module, device: torch.device) -> tuple[float, float]:
137    def evaluate(self, model: torch.nn.Module, device: torch.device) -> tuple[float, float]:
138        '''
139        Evaluate the `model` by computing the average sample loss and accuracy.
140
141        Parameters
142        ----------
143        model: torch.nn.Module
144            Model that is locally optimized
145        device: torch.device
146            Accelerator to run the code
147
148        Returns
149        -------
150        tuple[float, float]
151            Tuple of average sample loss and accuracy on the local dataset
152        '''
153
154        loader = data.DataLoader(self.subset, batch_size = len(self.subset))
155        x, y = next(iter(loader))
156        x = x.to(device)
157        y = y.to(device)
158        
159        model.eval()
160        
161        with torch.no_grad():
162            prediction = model(x)
163            loss = nn.functional.cross_entropy(prediction, y)
164            accuracy = torch.sum(torch.argmax(prediction, dim = 1) == y)
165            return loss.item(), accuracy.item()

Evaluate the model by computing the average sample loss and accuracy.

Parameters
  • model (torch.nn.Module): Model that is locally optimized
  • device (torch.device): Accelerator to run the code
Returns
  • tuple[float, float]: Tuple of average sample loss and accuracy on the local dataset
def proxterm( self, initial: torch.nn.modules.module.Module, model: torch.nn.modules.module.Module) -> torch.Tensor:
167    def proxterm(self, initial: torch.nn.Module, model: torch.nn.Module) -> torch.Tensor:
168        '''
169        Computes the `alpha`-weighted proximal term as squared difference between `initial` model and locally optimized `model`.
170
171        Parameters
172        ----------
173        initial: torch.nn.Module
174            Model received at the beginning of the round from the server
175        model: torch.nn.Module
176            Model that is locally optimized
177
178        Returns
179        -------
180        torch.Tensor
181            Squared difference between `initial` model and locally optimized `model`
182        '''
183
184        return torch.sum(torch.tensor([
185            torch.square(model.get_parameter(name) - initial.state_dict()[name].detach()).sum()
186            for name in model.state_dict().keys()
187        ], requires_grad = True))

Computes the alpha-weighted proximal term as squared difference between initial model and locally optimized model.

Parameters
  • initial (torch.nn.Module): Model received at the beginning of the round from the server
  • model (torch.nn.Module): Model that is locally optimized
Returns
  • torch.Tensor: Squared difference between initial model and locally optimized model
class Coordinator(fedbox.optimization.fedavg.Coordinator):
191class Coordinator(fedavg.Coordinator):
192    '''
193    This class represents a centralized server coordinating the training of a shared model across multiple agents (i.e. clients).
194    
195    Note
196    ----
197    The agents locally update their models using the FedProx optimization scheme.
198    '''
199
200    def __init__(
201        self,
202        alpha: float,
203        model: torch.nn.Module,
204        datasets: dict[str, list[utils.FederatedSubset]],
205        scheme: WeightingScheme = None,
206        logger: Logger = Logger.default()
207    ):
208        '''
209        Constructs the centralized coordinator, i.e. server, in the federated learning simulation.
210
211        Parameters
212        ----------
213        alpha: float
214            Weight of the proximal term used by training agents while running the optimization algorithm FedProx
215        model: torch.nn.Module
216            Initial shared model
217        datasets: dict[str, list[utils.FederatedSubset]]
218            Training clients' subsets ('training') and testing clients' subsets ('testing')
219        scheme: WeightingScheme
220            Aggregation scheme to weight local updates from clients
221        logger: Logger
222            Logger instance to save progress during the simulation
223        '''
224
225        assert alpha > 0
226
227        self.alpha = alpha
228        self.datasets = datasets
229        self.model = model
230        self.agents = {
231            group: [ Agent(subset) for subset in dataset ] for group, dataset in datasets.items() 
232        }
233        self.weights = scheme.weights()
234        self.logger = logger
235
236    def run(self, n_iterations: int, n_steps: int = None, n_epochs = None, batch_size: int = 32, step_size: float = 1e-3, step_size_diminishing: bool = False, l2_penalty: float = 1e-4, max_gradient_norm: float = 1.0, device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')):
237        '''
238        Runs `n_iterations` optimization (with algorithm FedProx) and evaluation rounds on training clients.
239
240        Parameters
241        ----------
242        n_iterations: int
243            Number of global rounds
244        n_steps: int
245            Number of local SGD steps used for optimization on clients
246        n_epochs: int
247            Number of local epochs used for optimization on clients (mutually excludes n_steps)
248        batch_size: int
249            Number of samples in one SGD minibatch
250        step_size: float
251            Learning rate
252        step_size_diminishing: bool
253            This enables diminishing the step size linearly in time
254        l2_penalty: float
255            Weight of the L2 (Tikhonov) regularization used to penalize local models
256        max_gradient_norm: float
257            Value used to clip the norm of the stochastic gradient during local optimization
258        device: torch.device
259            Accelerator to run the code
260        evaluate: bool
261            Flag that enables evaluation of the update global model on training and testing clients
262
263        Note
264        ----
265        Runs `n_iterations` times function `iterate(...)`.
266        '''
267
268        assert n_steps is not None or n_epochs is not None
269        
270        self.model = self.model.to(device)
271        self.model.compile()
272
273        for iteration in range(n_iterations):
274            step_size_updated = step_size if not step_size_diminishing else step_size / (iteration + 1)
275            metrics = self.iterate(iteration, n_steps, n_epochs, batch_size, step_size_updated, l2_penalty, max_gradient_norm, device, evaluate = True)
276            
277            self.logger.log({
278                'step': iteration,
279                'loss.training': metrics['training']['loss'],
280                'loss.testing': metrics['testing']['loss'],
281                'accuracy.training': metrics['training']['accuracy'],
282                'accuracy.testing': metrics['testing']['accuracy'],
283            })
284
285            print(iteration, metrics)
286
287    def iterate(self, iteration: int, n_steps: int, n_epochs: int, batch_size: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device, evaluate: bool = False) -> dict[str, float]:
288        '''
289        Runs a single optimization round with FedProx algorithm on all training clients.
290
291        Parameters
292        ----------
293        iteration: int
294            Current global round
295        n_steps: int
296            Number of local SGD steps used for optimization on clients
297        n_epochs: int
298            Number of local epochs used for optimization on clients (mutually excludes n_steps)
299        batch_size: int
300            Number of samples in one SGD minibatch
301        step_size: float
302            Learning rate
303        l2_penalty: float
304            Weight of the L2 (Tikhonov) regularization used to penalize local models
305        max_gradient_norm: float
306            Value used to clip the norm of the stochastic gradient during local optimization
307        device: torch.device
308            Accelerator to run the code
309        evaluate: bool
310            Flag that enables evaluation of the update global model on training and testing clients
311
312        Returns
313        -------
314        dict[str, float]
315            Dictionary of current round's metrics
316        '''
317
318        indices = list(range(0, len(self.agents['training'])))
319        k = len(self.agents['training'])
320        
321        random.shuffle(indices)
322        
323        indices = indices[:k]
324        participants = [ self.agents['training'][i] for i in indices ]
325        weights = [ self.weights['training'][i] for i in indices ]
326
327        initial_model = deepcopy(self.model)
328        updates: list[nn.Module] = [ initial_model for _ in self.agents['training'] ]
329
330        if n_steps is not None:
331            for i, participant in tqdm(zip(indices, participants), total = len(participants), desc = 'Optimization on training agents (iteration {})'.format(iteration)):
332                updates[i] = participant.optimize(self.alpha, initial_model, deepcopy(initial_model), n_steps = n_steps, step_size = step_size, l2_penalty = l2_penalty, max_gradient_norm = max_gradient_norm, device = device)
333        else:
334            for i, participant in tqdm(zip(indices, participants), total = len(participants), desc = 'Optimization on training agents (iteration {})'.format(iteration)):
335                updates[i] = participant.multioptimize(self.alpha, initial_model, deepcopy(initial_model), n_epochs = n_epochs, batch_size = batch_size, step_size = step_size, l2_penalty = l2_penalty, max_gradient_norm = max_gradient_norm, device = device)
336
337        self.average(updates, weights = weights)
338
339        if not evaluate:
340            return {}
341        
342        return self.evaluate(iteration, device)

This class represents a centralized server coordinating the training of a shared model across multiple agents (i.e. clients).

Note

The agents locally update their models using the FedProx optimization scheme.

Coordinator( alpha: float, model: torch.nn.modules.module.Module, datasets: dict[str, list[fedbox.datasets.utils.FederatedSubset]], scheme: fedbox.optimization.utils.WeightingScheme = None, logger: fedbox.optimization.utils.Logger = <fedbox.optimization.utils.Logger object>)
200    def __init__(
201        self,
202        alpha: float,
203        model: torch.nn.Module,
204        datasets: dict[str, list[utils.FederatedSubset]],
205        scheme: WeightingScheme = None,
206        logger: Logger = Logger.default()
207    ):
208        '''
209        Constructs the centralized coordinator, i.e. server, in the federated learning simulation.
210
211        Parameters
212        ----------
213        alpha: float
214            Weight of the proximal term used by training agents while running the optimization algorithm FedProx
215        model: torch.nn.Module
216            Initial shared model
217        datasets: dict[str, list[utils.FederatedSubset]]
218            Training clients' subsets ('training') and testing clients' subsets ('testing')
219        scheme: WeightingScheme
220            Aggregation scheme to weight local updates from clients
221        logger: Logger
222            Logger instance to save progress during the simulation
223        '''
224
225        assert alpha > 0
226
227        self.alpha = alpha
228        self.datasets = datasets
229        self.model = model
230        self.agents = {
231            group: [ Agent(subset) for subset in dataset ] for group, dataset in datasets.items() 
232        }
233        self.weights = scheme.weights()
234        self.logger = logger

Constructs the centralized coordinator, i.e. server, in the federated learning simulation.

Parameters
  • alpha (float): Weight of the proximal term used by training agents while running the optimization algorithm FedProx
  • model (torch.nn.Module): Initial shared model
  • datasets (dict[str, list[utils.FederatedSubset]]): Training clients' subsets ('training') and testing clients' subsets ('testing')
  • scheme (WeightingScheme): Aggregation scheme to weight local updates from clients
  • logger (Logger): Logger instance to save progress during the simulation
alpha
datasets
model
agents
weights
logger
def run( self, n_iterations: int, n_steps: int = None, n_epochs=None, batch_size: int = 32, step_size: float = 0.001, step_size_diminishing: bool = False, l2_penalty: float = 0.0001, max_gradient_norm: float = 1.0, device: torch.device = device(type='cpu')):
236    def run(self, n_iterations: int, n_steps: int = None, n_epochs = None, batch_size: int = 32, step_size: float = 1e-3, step_size_diminishing: bool = False, l2_penalty: float = 1e-4, max_gradient_norm: float = 1.0, device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')):
237        '''
238        Runs `n_iterations` optimization (with algorithm FedProx) and evaluation rounds on training clients.
239
240        Parameters
241        ----------
242        n_iterations: int
243            Number of global rounds
244        n_steps: int
245            Number of local SGD steps used for optimization on clients
246        n_epochs: int
247            Number of local epochs used for optimization on clients (mutually excludes n_steps)
248        batch_size: int
249            Number of samples in one SGD minibatch
250        step_size: float
251            Learning rate
252        step_size_diminishing: bool
253            This enables diminishing the step size linearly in time
254        l2_penalty: float
255            Weight of the L2 (Tikhonov) regularization used to penalize local models
256        max_gradient_norm: float
257            Value used to clip the norm of the stochastic gradient during local optimization
258        device: torch.device
259            Accelerator to run the code
260        evaluate: bool
261            Flag that enables evaluation of the update global model on training and testing clients
262
263        Note
264        ----
265        Runs `n_iterations` times function `iterate(...)`.
266        '''
267
268        assert n_steps is not None or n_epochs is not None
269        
270        self.model = self.model.to(device)
271        self.model.compile()
272
273        for iteration in range(n_iterations):
274            step_size_updated = step_size if not step_size_diminishing else step_size / (iteration + 1)
275            metrics = self.iterate(iteration, n_steps, n_epochs, batch_size, step_size_updated, l2_penalty, max_gradient_norm, device, evaluate = True)
276            
277            self.logger.log({
278                'step': iteration,
279                'loss.training': metrics['training']['loss'],
280                'loss.testing': metrics['testing']['loss'],
281                'accuracy.training': metrics['training']['accuracy'],
282                'accuracy.testing': metrics['testing']['accuracy'],
283            })
284
285            print(iteration, metrics)

Runs n_iterations optimization (with algorithm FedProx) and evaluation rounds on training clients.

Parameters
  • n_iterations (int): Number of global rounds
  • n_steps (int): Number of local SGD steps used for optimization on clients
  • n_epochs (int): Number of local epochs used for optimization on clients (mutually excludes n_steps)
  • batch_size (int): Number of samples in one SGD minibatch
  • step_size (float): Learning rate
  • step_size_diminishing (bool): This enables diminishing the step size linearly in time
  • l2_penalty (float): Weight of the L2 (Tikhonov) regularization used to penalize local models
  • max_gradient_norm (float): Value used to clip the norm of the stochastic gradient during local optimization
  • device (torch.device): Accelerator to run the code
  • evaluate (bool): Flag that enables evaluation of the update global model on training and testing clients
Note

Runs n_iterations times function iterate(...).

def iterate( self, iteration: int, n_steps: int, n_epochs: int, batch_size: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device, evaluate: bool = False) -> dict[str, float]:
287    def iterate(self, iteration: int, n_steps: int, n_epochs: int, batch_size: int, step_size: float, l2_penalty: float, max_gradient_norm: float, device: torch.device, evaluate: bool = False) -> dict[str, float]:
288        '''
289        Runs a single optimization round with FedProx algorithm on all training clients.
290
291        Parameters
292        ----------
293        iteration: int
294            Current global round
295        n_steps: int
296            Number of local SGD steps used for optimization on clients
297        n_epochs: int
298            Number of local epochs used for optimization on clients (mutually excludes n_steps)
299        batch_size: int
300            Number of samples in one SGD minibatch
301        step_size: float
302            Learning rate
303        l2_penalty: float
304            Weight of the L2 (Tikhonov) regularization used to penalize local models
305        max_gradient_norm: float
306            Value used to clip the norm of the stochastic gradient during local optimization
307        device: torch.device
308            Accelerator to run the code
309        evaluate: bool
310            Flag that enables evaluation of the update global model on training and testing clients
311
312        Returns
313        -------
314        dict[str, float]
315            Dictionary of current round's metrics
316        '''
317
318        indices = list(range(0, len(self.agents['training'])))
319        k = len(self.agents['training'])
320        
321        random.shuffle(indices)
322        
323        indices = indices[:k]
324        participants = [ self.agents['training'][i] for i in indices ]
325        weights = [ self.weights['training'][i] for i in indices ]
326
327        initial_model = deepcopy(self.model)
328        updates: list[nn.Module] = [ initial_model for _ in self.agents['training'] ]
329
330        if n_steps is not None:
331            for i, participant in tqdm(zip(indices, participants), total = len(participants), desc = 'Optimization on training agents (iteration {})'.format(iteration)):
332                updates[i] = participant.optimize(self.alpha, initial_model, deepcopy(initial_model), n_steps = n_steps, step_size = step_size, l2_penalty = l2_penalty, max_gradient_norm = max_gradient_norm, device = device)
333        else:
334            for i, participant in tqdm(zip(indices, participants), total = len(participants), desc = 'Optimization on training agents (iteration {})'.format(iteration)):
335                updates[i] = participant.multioptimize(self.alpha, initial_model, deepcopy(initial_model), n_epochs = n_epochs, batch_size = batch_size, step_size = step_size, l2_penalty = l2_penalty, max_gradient_norm = max_gradient_norm, device = device)
336
337        self.average(updates, weights = weights)
338
339        if not evaluate:
340            return {}
341        
342        return self.evaluate(iteration, device)

Runs a single optimization round with FedProx algorithm on all training clients.

Parameters
  • iteration (int): Current global round
  • n_steps (int): Number of local SGD steps used for optimization on clients
  • n_epochs (int): Number of local epochs used for optimization on clients (mutually excludes n_steps)
  • batch_size (int): Number of samples in one SGD minibatch
  • step_size (float): Learning rate
  • l2_penalty (float): Weight of the L2 (Tikhonov) regularization used to penalize local models
  • max_gradient_norm (float): Value used to clip the norm of the stochastic gradient during local optimization
  • device (torch.device): Accelerator to run the code
  • evaluate (bool): Flag that enables evaluation of the update global model on training and testing clients
Returns
  • dict[str, float]: Dictionary of current round's metrics