Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import java.security.SecureRandom;
- import java.time.LocalTime;
- import java.util.Random;
- import java.util.function.BinaryOperator;
- import java.util.logging.Logger;
- import static java.lang.StrictMath.abs;
- import static java.lang.StrictMath.pow;
- import static java.time.LocalTime.now;
- /**
- * Note: it's up to you to square f and provide both partial derivatives for f^2 (for x and y).
- */
- @SuppressWarnings("unused")
- public final class SteepestDecent {
- private static final double MAX_ERROR = 0.0001;
- private static final double LEARNING_RATE = 0.01;
- private static final Logger log = Logger.getAnonymousLogger();
- private static final Random RANDOM = new SecureRandom();
- private static final int MAX_RAND_VAR = 5;
- private final BinaryOperator<Double> partDerivativeX, partDerivativeY;
- private double x, y;
- @SuppressWarnings({"ConstantConditions", "MethodParameterNamingConvention", "ConstantAssertCondition"})
- public SteepestDecent(final BinaryOperator<Double> partialDerivativeX, final BinaryOperator<Double> partialDerivativeY) {
- this(partialDerivativeX, partialDerivativeY, RANDOM.nextInt(MAX_RAND_VAR), RANDOM.nextInt(MAX_RAND_VAR));
- }
- @SuppressWarnings({"ConstantConditions", "MethodParameterNamingConvention", "ConstantAssertCondition", "WeakerAccess"})
- public SteepestDecent(final BinaryOperator<Double> partialDerivativeX, final BinaryOperator<Double> partialDerivativeY, final double x, final double y) {
- assert LEARNING_RATE > 0;
- assert MAX_ERROR >= 0;
- partDerivativeX = partialDerivativeX;
- partDerivativeY = partialDerivativeY;
- this.x = x;
- this.y = y;
- }
- @SuppressWarnings("MethodCallInLoopCondition")
- public void train() {
- final LocalTime start = now();
- double errX, errY;
- do {
- final double deltaX = partDerivativeX.apply(x, y);
- final double deltaY = partDerivativeY.apply(x, y);
- log.info(String.format("delta x = %f delta y = %f", deltaX, deltaY));
- x -= (deltaX * LEARNING_RATE);
- y -= (deltaY * LEARNING_RATE);
- log.info(String.format("x = %f y = %f", x, y));
- errX = abs(x - deltaX);
- errY = abs(y - deltaY);
- log.info(String.format("error x = %f error y = %f", errX, errY));
- } while ((errX >= MAX_ERROR) || (errY >= MAX_ERROR));
- log.info(String.format("Took %d seconds.", now().minusSeconds(start.getSecond())
- .getSecond()));
- log.info(String.format("Final x = %f y = %f", x, y));
- }
- public static void main(final String... args) {
- final SteepestDecent algo = new SteepestDecent((x, y) -> ((-4.0) * x * y) + (4.0 * pow(x, 3.0)), (x, y) -> (2 * y) - (2 * pow(x, 2.0)), 3, -5);
- algo.train();
- }
- }
Add Comment
Please, Sign In to add comment