Guest User

Untitled

a guest
May 24th, 2018
72
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.80 KB | None | 0 0
  1. import java.security.SecureRandom;
  2. import java.time.LocalTime;
  3. import java.util.Random;
  4. import java.util.function.BinaryOperator;
  5. import java.util.logging.Logger;
  6.  
  7. import static java.lang.StrictMath.abs;
  8. import static java.lang.StrictMath.pow;
  9. import static java.time.LocalTime.now;
  10.  
  11. /**
  12. * Note: it's up to you to square f and provide both partial derivatives for f^2 (for x and y).
  13. */
  14. @SuppressWarnings("unused")
  15. public final class SteepestDecent {
  16.  
  17. private static final double MAX_ERROR = 0.0001;
  18. private static final double LEARNING_RATE = 0.01;
  19. private static final Logger log = Logger.getAnonymousLogger();
  20. private static final Random RANDOM = new SecureRandom();
  21. private static final int MAX_RAND_VAR = 5;
  22. private final BinaryOperator<Double> partDerivativeX, partDerivativeY;
  23. private double x, y;
  24.  
  25. @SuppressWarnings({"ConstantConditions", "MethodParameterNamingConvention", "ConstantAssertCondition"})
  26. public SteepestDecent(final BinaryOperator<Double> partialDerivativeX, final BinaryOperator<Double> partialDerivativeY) {
  27. this(partialDerivativeX, partialDerivativeY, RANDOM.nextInt(MAX_RAND_VAR), RANDOM.nextInt(MAX_RAND_VAR));
  28. }
  29.  
  30. @SuppressWarnings({"ConstantConditions", "MethodParameterNamingConvention", "ConstantAssertCondition", "WeakerAccess"})
  31. public SteepestDecent(final BinaryOperator<Double> partialDerivativeX, final BinaryOperator<Double> partialDerivativeY, final double x, final double y) {
  32. assert LEARNING_RATE > 0;
  33. assert MAX_ERROR >= 0;
  34. partDerivativeX = partialDerivativeX;
  35. partDerivativeY = partialDerivativeY;
  36. this.x = x;
  37. this.y = y;
  38. }
  39.  
  40. @SuppressWarnings("MethodCallInLoopCondition")
  41. public void train() {
  42. final LocalTime start = now();
  43.  
  44. double errX, errY;
  45.  
  46. do {
  47. final double deltaX = partDerivativeX.apply(x, y);
  48. final double deltaY = partDerivativeY.apply(x, y);
  49. log.info(String.format("delta x = %f delta y = %f", deltaX, deltaY));
  50. x -= (deltaX * LEARNING_RATE);
  51. y -= (deltaY * LEARNING_RATE);
  52. log.info(String.format("x = %f y = %f", x, y));
  53. errX = abs(x - deltaX);
  54. errY = abs(y - deltaY);
  55. log.info(String.format("error x = %f error y = %f", errX, errY));
  56. } while ((errX >= MAX_ERROR) || (errY >= MAX_ERROR));
  57.  
  58. log.info(String.format("Took %d seconds.", now().minusSeconds(start.getSecond())
  59. .getSecond()));
  60. log.info(String.format("Final x = %f y = %f", x, y));
  61. }
  62.  
  63. public static void main(final String... args) {
  64. final SteepestDecent algo = new SteepestDecent((x, y) -> ((-4.0) * x * y) + (4.0 * pow(x, 3.0)), (x, y) -> (2 * y) - (2 * pow(x, 2.0)), 3, -5);
  65. algo.train();
  66. }
  67. }
Add Comment
Please, Sign In to add comment