Advertisement
aj-lmu

SCAFFOLD optimizer

Feb 28th, 2023 (edited)
138
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.88 KB | Source Code | 0 0
  1. import tensorflow as tf
  2. from tensorflow import keras as keras
  3. import numpy as np
  4.  
  5.  
  6. class Scaffold(keras.optimizers.SGD):
  7.     def __init__(
  8.         self,
  9.         learning_rate=0.1,
  10.         momentum=0.9,
  11.         **kwargs,
  12.     ):
  13.         super(Scaffold, self).__init__(
  14.             name="scaffold", learning_rate=learning_rate, momentum=momentum, **kwargs
  15.         )
  16.  
  17.     def _get_gradients(self, tape, loss, var_list, grad_loss=None):
  18.         grads = tape.gradient(loss, var_list, grad_loss)
  19.  
  20.         # c_diff = - ci + c
  21.         grads = [
  22.             grads_layer + c_diff_layer
  23.             for grads_layer, c_diff_layer in zip(grads, self.c_diff)
  24.         ]
  25.         return list(zip(grads, var_list))
  26.  
  27.     def set_controls(self, weights, st=None):
  28.         server_controls = st.server_controls if st else None
  29.         local_controls = st.local_controls if st else None
  30.  
  31.         # c:  server controls
  32.         # ci: client controls (local)
  33.         # c_diff: (-ci + c) = c - ci
  34.         self.c = (
  35.             tf.nest.map_structure(
  36.                 lambda array: tf.Variable(array, dtype=tf.float32), server_controls
  37.             )
  38.             if server_controls
  39.             else [
  40.                 tf.Variable(tf.zeros(shape=layer.shape, dtype=tf.float32))
  41.                 for layer in weights
  42.             ]
  43.         )
  44.         self.ci = (
  45.             tf.nest.map_structure(
  46.                 lambda array: tf.Variable(array, dtype=tf.float32), local_controls
  47.             )
  48.             if local_controls
  49.             else [
  50.                 tf.Variable(tf.zeros(shape=layer.shape, dtype=tf.float32))
  51.                 for layer in weights
  52.             ]
  53.         )
  54.  
  55.         # c_diff = -ci + c = c - ci
  56.         self.c_diff = [
  57.             tf.Variable(tf.subtract(c_layer, ci_layer))
  58.             for c_layer, ci_layer in zip(self.c, self.ci)
  59.         ]
  60.  
  61.     def get_new_client_controls(self, global_weights, local_weights, option=1):
  62.         # model difference (global - local)
  63.         model_diff = [
  64.             np.subtract(global_layer, local_layer)
  65.             for global_layer, local_layer in zip(global_weights, local_weights)
  66.         ]
  67.  
  68.         if option == 1:
  69.             return model_diff
  70.         else:
  71.             local_lr = float(self.lr)
  72.             local_steps = int(self.iterations.value())
  73.  
  74.             scale = 1 / (local_steps * local_lr)
  75.             ci_new = [
  76.                 # local_control - server_control + scale * delta
  77.                 np.add(
  78.                     np.subtract(local_control, server_control),
  79.                     np.multiply(scale, delta),
  80.                 )
  81.                 for local_control, server_control, delta in zip(
  82.                     self.ci, self.c, model_diff
  83.                 )
  84.             ]
  85.             return ci_new
  86.  
  87.     def get_config(self):
  88.         config = super().get_config()
  89.         # config.update()
  90.         return config
  91.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement