Advertisement
Ayro

Karazuba (+number multiplication)

Oct 1st, 2012
859
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Java 4.63 KB | None | 0 0
  1. import java.math.BigInteger;
  2. import java.util.Arrays;
  3. import java.util.Random;
  4.  
  5. public class Karazuba {
  6.  
  7.     final static int THRESHOLD = 20;
  8.  
  9.     // should be of equal length
  10.     static long[] mult(long[] f, long[] g) {
  11.         if (f.length <= THRESHOLD)
  12.             return multStupid(f, g);
  13.         int d = (f.length + 1) >> 1;
  14.         long[] f1 = Arrays.copyOf(f, d);
  15.         long[] f2 = Arrays.copyOfRange(f, d, f.length);
  16.         long[] g1 = Arrays.copyOf(g, d);
  17.         long[] g2 = Arrays.copyOfRange(g, d, g.length);
  18.         long[] f1_g1 = mult(f1, g1);
  19.         long[] f2_g2 = mult(f2, g2);
  20.         long[] f1_g2_f2_g1 = sub_from(sub_from(mult(sum(f1, f2), sum(g1, g2)), f1_g1), f2_g2);
  21.         long[] ret = new long[f.length + g.length - 1];
  22.         for (int i = 0; i < f1_g1.length; i++)
  23.             ret[i] += f1_g1[i];
  24.         for (int i = 0, shift = d << 1; i < f2_g2.length; i++)
  25.             ret[i + shift] += f2_g2[i];
  26.         for (int i = 0, shift = d; i < f1_g2_f2_g1.length; i++)
  27.             ret[i + shift] += f1_g2_f2_g1[i];
  28.         return ret;
  29.     }
  30.  
  31.     static long[] sum(long[] f1, long[] f2) {
  32.         for (int i = 0; i < f2.length; i++)
  33.             f1[i] += f2[i];
  34.         return f1;
  35.     }
  36.  
  37.     static long[] sub_from(long[] f1, long[] f2) {
  38.         if (f2.length > f1.length)
  39.             throw new AssertionError();
  40.         for (int i = 0; i < f2.length; i++) {
  41.             f1[i] -= f2[i];
  42.         }
  43.         return f1;
  44.     }
  45.  
  46.     static long[] multStupid(long[] f, long[] g) {
  47.         long[] ret = new long[f.length + g.length - 1];
  48.         for (int i = 0; i < f.length; i++) {
  49.             for (int j = 0; j < g.length; j++) {
  50.                 ret[i + j] += f[i] * g[j];
  51.             }
  52.         }
  53.         return ret;
  54.     }
  55.  
  56.     // multiplying numbers
  57.     final static int D = 7;
  58.  
  59.     static char[] multNumbers(char[] a, char[] b) {
  60.         if (a.length != b.length)
  61.             throw new AssertionError();
  62.         final int[] pow10 = new int[D];
  63.         pow10[0] = 1;
  64.         for (int i = 1; i < D; i++)
  65.             pow10[i] = pow10[i - 1] * 10;
  66.         long[] f = new long[(a.length + D - 1) / D];
  67.         long[] g = new long[(b.length + D - 1) / D];
  68.         for (int i = 0; i < a.length; i++)
  69.             f[i / D] += (a[a.length - i - 1] - '0') * pow10[i % D];
  70.         for (int i = 0; i < b.length; i++)
  71.             g[i / D] += (b[b.length - i - 1] - '0') * pow10[i % D];
  72.         long[] fg = mult(f, g);
  73.         return normalize(fg);
  74.     }
  75.  
  76.     static char[] normalize(long[] fg) {
  77.         char[] ab = new char[fg.length * D + 20];
  78.         int pos = 0;
  79.         long x = 0;
  80.         for (int i = 0; i < fg.length; i++) {
  81.             x += fg[i];
  82.             for (int j = 0; j < D; j++) {
  83.                 ab[pos++] = (char) ((x % 10) + '0');
  84.                 x /= 10;
  85.             }
  86.         }
  87.         while (x > 0) {
  88.             ab[pos++] = (char) ((x % 10) + '0');
  89.             x /= 10;
  90.         }
  91.         int len;
  92.         for (len = ab.length; len > 1 && (ab[len - 1] == 0 || ab[len - 1] == '0'); len--)
  93.             ;
  94.         ab = Arrays.copyOf(ab, len);
  95.         for (int i = 0; 2 * i < ab.length; i++) {
  96.             char tmp = ab[i];
  97.             ab[i] = ab[ab.length - 1 - i];
  98.             ab[ab.length - 1 - i] = tmp;
  99.         }
  100.         return ab;
  101.     }
  102.  
  103.     // testing
  104.  
  105.     static boolean areEqual(long[] f1, long[] f2) {
  106.         if (f1.length != f2.length)
  107.             return false;
  108.         for (int i = 0; i < f1.length; i++) {
  109.             if (f1[i] != f2[i])
  110.                 return false;
  111.         }
  112.         return true;
  113.     }
  114.  
  115.     static Random rand = new Random();
  116.  
  117.     static void stressTest() {
  118.         for (int iter = 0;; iter++) {
  119.             singleTest(5000 + rand.nextInt(5000));
  120.             System.err.println("Test " + iter + " passed");
  121.         }
  122.     }
  123.  
  124.     static void singleTest(int n) {
  125.         long[] f = new long[n];
  126.         long[] g = new long[n];
  127.         for (int i = 0; i < n; i++) {
  128.             f[i] = rand.nextInt(10);
  129.             g[i] = rand.nextInt(10);
  130.         }
  131.         if (!areEqual(mult(f, g), multStupid(f, g)))
  132.             throw new AssertionError();
  133.     }
  134.  
  135.     static void maxTest() {
  136.         int N = 100000;
  137.         long[] f = new long[N];
  138.         for (int i = 0; i < N; i++)
  139.             f[i] = 1;
  140.         long start_time = System.currentTimeMillis();
  141.         mult(f, f);
  142.         long end_time = System.currentTimeMillis();
  143.         System.err.println("Time consumed: " + (end_time - start_time) + "ms");
  144.     }
  145.  
  146.     static void numbersMaxTest() {
  147.         int N = 500000;
  148.         char[] a = new char[N];
  149.         for (int i = 0; i < N; i++)
  150.             a[i] = '1';
  151.         long start_time = System.currentTimeMillis();
  152.         multNumbers(a, a);
  153.         long end_time = System.currentTimeMillis();
  154.         System.err.println("Time consumed: " + (end_time - start_time) + "ms");
  155.     }
  156.  
  157.     static void numbersSingleTest(int n) {
  158.         char[] a = new char[n];
  159.         char[] b = new char[n];
  160.         for (int i = 0; i < n; i++) {
  161.             a[i] = (char) (rand.nextInt(10) + '0');
  162.             b[i] = (char) (rand.nextInt(10) + '0');
  163.         }
  164.         String w = new String(multNumbers(a, b));
  165.         BigInteger ab = new BigInteger(new String(a)).multiply(new BigInteger(new String(b)));
  166.         if (!w.equals(ab.toString()))
  167.             throw new AssertionError();
  168.     }
  169.  
  170.     static void numbersStressTest() {
  171.         for (int iter = 0; iter < 100; iter++) {
  172.             numbersSingleTest(5000);
  173.             System.err.println("Test " + iter + " passed");
  174.         }
  175.     }
  176. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement