Advertisement
Guest User

Untitled

a guest
Oct 26th, 2020
52
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 8.58 KB | None | 0 0
  1. import random
  2. import numpy as np
  3. import tensorflow as tf
  4. from termcolor import colored
  5. from tensorflow.python.ops import math_ops, state_ops, control_flow_ops
  6. from . import TF_KERAS
  7.  
  8. WARN = colored('WARNING:', 'red')
  9.  
  10.  
  11. def _apply_weight_decays(self, var, var_t):
  12.     l1, l2 = self.weight_decays[var.name]
  13.     if l1 == 0 and l2 == 0:
  14.         if self.init_verbose and not self._init_notified:
  15.             print("Both penalties are 0 for %s, will skip" % var.name)
  16.         return var_t
  17.  
  18.     norm = math_ops.cast(math_ops.sqrt(1 / self.total_iterations_wd),
  19.                          'float32')
  20.     l1_normalized = l1 * norm
  21.     l2_normalized = l2 * norm
  22.  
  23.     if l1 != 0 and l2 != 0:
  24.         decay = l1_normalized * math_ops.sign(var) + l2_normalized * var
  25.     elif l1 != 0:
  26.         decay = l1_normalized * math_ops.sign(var)
  27.     else:
  28.         decay = l2_normalized * var
  29.     var_t = var_t - self.eta_t * decay
  30.  
  31.     if self.init_verbose and not self._init_notified:
  32.         norm_print = (1 / self.total_iterations_wd) ** (1 / 2)
  33.         l1n_print, l2n_print = l1 * norm_print, l2 * norm_print
  34.         decays_str = "{}(L1), {}(L2)".format(l1n_print, l2n_print)
  35.         print('{} weight decay set for {}'.format(decays_str, var.name))
  36.     return var_t
  37.  
  38.  
  39. def _compute_eta_t(self):
  40.     PI = 3.141592653589793
  41.     t_frac = math_ops.cast(self.t_cur / (self.total_iterations - 1), 'float32')
  42.     eta_t = self.eta_min + 0.5 * (self.eta_max - self.eta_min) * \
  43.         (1 + math_ops.cos(PI * t_frac))
  44.     return eta_t
  45.  
  46.  
  47. def _apply_lr_multiplier(self, lr_t, var):
  48.     multiplier_name = [mult_name for mult_name in self.lr_multipliers
  49.                        if mult_name in var.name]
  50.     if multiplier_name != []:
  51.         lr_mult = self.lr_multipliers[multiplier_name[0]]
  52.     else:
  53.         lr_mult = 1
  54.     lr_t = lr_t * lr_mult
  55.  
  56.     if self.init_verbose and not self._init_notified:
  57.         lr_print = self._init_lr * lr_mult
  58.         if lr_mult != 1:
  59.             print('{} init learning rate set for {} -- {}'.format(
  60.                '%.e' % round(lr_print, 5), var.name, lr_t))
  61.         else:
  62.             print('No change in learning rate {} -- {}'.format(
  63.                 var.name, lr_print))
  64.     return lr_t
  65.  
  66.  
  67. def _update_t_cur_eta_t(self):  # keras
  68.     self.updates.append(_update_t_cur(self))
  69.     # Cosine annealing
  70.     if self.use_cosine_annealing:
  71.         # ensure eta_t is updated AFTER t_cur
  72.         with tf.control_dependencies([self.updates[-1]]):
  73.             self.updates.append(state_ops.assign(self.eta_t,
  74.                                                  _compute_eta_t(self)))
  75.  
  76.  
  77. def _update_t_cur_eta_t_v2(self, lr_t=None, var=None):  # tf.keras
  78.     t_cur_update, eta_t_update = None, None  # in case not assigned
  79.  
  80.     # update `t_cur` if iterating last `(grad, var)`
  81.     iteration_done = (self._updates_processed == (self._updates_per_iter - 1))
  82.     if iteration_done:
  83.         t_cur_update = _update_t_cur(self)
  84.         self._updates_processed = 0  # reset
  85.     else:
  86.         self._updates_processed += 1
  87.  
  88.     # Cosine annealing
  89.     if self.use_cosine_annealing and iteration_done:
  90.         # ensure eta_t is updated AFTER t_cur
  91.         with tf.control_dependencies([t_cur_update]):
  92.             eta_t_update = state_ops.assign(self.eta_t, _compute_eta_t(self),
  93.                                             use_locking=self._use_locking)
  94.         self.lr_t = lr_t * self.eta_t  # for external tracking
  95.  
  96.     return iteration_done, t_cur_update, eta_t_update
  97.  
  98.  
  99. def _update_t_cur(self):
  100.     kw = {'use_locking': self._use_locking} if TF_KERAS else {}
  101.     if self.autorestart:
  102.         return control_flow_ops.cond(
  103.             math_ops.equal(self.t_cur, self.total_iterations - 1),
  104.             lambda: state_ops.assign(self.t_cur, 0, **kw),
  105.             lambda: state_ops.assign_add(self.t_cur, 1, **kw),
  106.         )
  107.     return state_ops.assign_add(self.t_cur, 1, **kw)
  108.  
  109.  
  110. def _set_autorestart(self, autorestart, use_cosine_annealing):
  111.     if autorestart is None:
  112.         self.autorestart = bool(use_cosine_annealing)
  113.     elif autorestart and not use_cosine_annealing:
  114.         raise ValueError("`autorestart` can only be used with "
  115.                          "`use_cosine_annealing`")
  116.     else:
  117.         self.autorestart = autorestart
  118.  
  119.  
  120. def _check_args(self, total_iterations, use_cosine_annealing, weight_decays):
  121.     if use_cosine_annealing and total_iterations > 1:
  122.         print('Using cosine annealing learning rates')
  123.     elif (use_cosine_annealing or weight_decays) and total_iterations <= 1:
  124.         print(WARN, "'total_iterations'==%s, must be >1" % total_iterations
  125.               + " to use cosine annealing and/or weight decays; "
  126.               "proceeding without either")
  127.         self.use_cosine_annealing = False
  128.         self.autorestart = False
  129.         self.weight_decays = {}
  130.  
  131.  
  132. def _init_weight_decays(model, zero_penalties, weight_decays):
  133.     if not zero_penalties:
  134.         print(WARN, "loss-based weight penalties should be set to zero. "
  135.               "(set `zero_penalties=True`)")
  136.     if weight_decays is not None and model is not None:
  137.         print(WARN, "`weight_decays` is set automatically when "
  138.               "passing in `model`; will override supplied")
  139.     if model is not None:
  140.         weight_decays = get_weight_decays(model, zero_penalties)
  141.     return weight_decays
  142.  
  143.  
  144. def get_weight_decays(model, zero_penalties=False):
  145.     wd_dict = {}
  146.     for layer in model.layers:
  147.         layer_penalties = _get_layer_penalties(layer, zero_penalties)
  148.         if layer_penalties:
  149.             for p in layer_penalties:
  150.                 weight_name, weight_penalty = p
  151.                 if not all(wp == 0 for wp in weight_penalty):
  152.                     wd_dict.update({weight_name: weight_penalty})
  153.     return wd_dict
  154.  
  155.  
  156. def _get_layer_penalties(layer, zero_penalties=False):
  157.     if hasattr(layer, 'cell') or \
  158.       (hasattr(layer, 'layer') and hasattr(layer.layer, 'cell')):
  159.         return _rnn_penalties(layer, zero_penalties)
  160.     elif hasattr(layer, 'layer') and not hasattr(layer.layer, 'cell'):
  161.         layer = layer.layer
  162.  
  163.     penalties= []
  164.     for weight_name in ['kernel', 'bias']:
  165.         _lambda = getattr(layer, weight_name + '_regularizer', None)
  166.         if _lambda is not None:
  167.             l1l2 = _get_and_maybe_zero_penalties(_lambda, zero_penalties)
  168.             penalties.append([getattr(layer, weight_name).name, l1l2])
  169.     return penalties
  170.  
  171.  
  172. def _rnn_penalties(layer, zero_penalties=False):
  173.     penalties = []
  174.     if hasattr(layer, 'backward_layer'):
  175.         for layer in [layer.forward_layer, layer.backward_layer]:
  176.             penalties += _cell_penalties(layer.cell, zero_penalties)
  177.         return penalties
  178.     else:
  179.         return _cell_penalties(layer.cell, zero_penalties)
  180.  
  181.  
  182. def _cell_penalties(rnn_cell, zero_penalties=False):
  183.     cell = rnn_cell
  184.     penalties = []  # kernel-recurrent-bias
  185.  
  186.     for weight_idx, weight_type in enumerate(['kernel', 'recurrent', 'bias']):
  187.         _lambda = getattr(cell, weight_type + '_regularizer', None)
  188.         if _lambda is not None:
  189.             weight_name = cell.weights[weight_idx].name
  190.             l1l2 = _get_and_maybe_zero_penalties(_lambda, zero_penalties)
  191.             penalties.append([weight_name, l1l2])
  192.     return penalties
  193.  
  194.  
  195. def _get_and_maybe_zero_penalties(_lambda, zero_penalties):
  196.     if zero_penalties:
  197.         if hasattr(_lambda, 'l1'):
  198.             _lambda.l1 = np.array(0., dtype=_lambda.l1.dtype)
  199.         if hasattr(_lambda, 'l2'):
  200.             _lambda.l2 = np.array(0., dtype=_lambda.l2.dtype)
  201.     return (float(getattr(_lambda, 'l1', 0.)),
  202.             float(getattr(_lambda, 'l2', 0.)))
  203.  
  204.  
  205. def fill_dict_in_order(_dict, values_list):
  206.     for idx, key in enumerate(_dict.keys()):
  207.         _dict[key] = values_list[idx]
  208.     return _dict
  209.  
  210.  
  211. def reset_seeds(reset_graph_with_backend=None, verbose=1):
  212.     if reset_graph_with_backend is not None:
  213.         K = reset_graph_with_backend
  214.         K.clear_session()
  215.         tf.compat.v1.reset_default_graph()
  216.         if verbose:
  217.             print("KERAS AND TENSORFLOW GRAPHS RESET")
  218.  
  219.     np.random.seed(1)
  220.     random.seed(2)
  221.     if tf.__version__[0] == '2':
  222.         tf.random.set_seed(3)
  223.     else:
  224.         tf.set_random_seed(3)
  225.     if verbose:
  226.         print("RANDOM SEEDS RESET")
  227.  
  228.  
  229. def K_eval(x, backend):
  230.     K = backend
  231.     try:
  232.         return K.get_value(K.to_dense(x))
  233.     except Exception:
  234.         try:
  235.             eval_fn = K.function([], [x])
  236.             return eval_fn([])[0]
  237.         except Exception:
  238.             try:
  239.                 return K.eager(K.eval)(x)
  240.             except Exception:
  241.                 return K.eval(x)
  242.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement