7

I'm learning data structures and algorithms, and here is a question that I'm stuck with.

I have to improve the performance of the recursive call by storing the value into memory.

But the problem is that the non-improved version seems faster than this.

Can someone help me out?

Syracuse numbers are a sequence of positive integers defined by the following rules:

syra(1) ≡ 1

syra(n) ≡ n + syra(n/2), if n mod 2 == 0

syra(n) ≡ n + syra((n*3)+1), otherwise

import java.util.HashMap;
import java.util.Map;

public class SyraLengthsEfficient {

    int counter = 0;
    public int syraLength(long n) {
        if (n < 1) {
            throw new IllegalArgumentException();
        }

        if (n < 500 && map.containsKey(n)) {
            counter += map.get(n);
            return map.get(n);
        } else if (n == 1) {
            counter++;
            return 1;
        } else if (n % 2 == 0) {
            counter++;
            return syraLength(n / 2);
        } else {
            counter++;
            return syraLength(n * 3 + 1);
        }
    }

    Map<Integer, Integer> map = new HashMap<Integer, Integer>();

    public int lengths(int n) {
        if (n < 1) {
            throw new IllegalArgumentException();
        }    
        for (int i = 1; i <= n; i++) {
            syraLength(i);
            if (i < 500 && !map.containsKey(i)) {
                map.put(i, counter);
            }
        }    
        return counter;
    }

    public static void main(String[] args) {
        System.out.println(new SyraLengthsEfficient().lengths(5000000));
    }
}

Here is the normal version that i wrote:

 public class SyraLengths{

        int total=1;
        public int syraLength(long n) {
            if (n < 1)
                throw new IllegalArgumentException();
            if (n == 1) {
                int temp=total;
                total=1;
                return temp;
            }
            else if (n % 2 == 0) {
                total++;
                return syraLength(n / 2);
            }
            else {
                total++;
                return syraLength(n * 3 + 1);
            }
        }

        public int lengths(int n){
            if(n<1){
                throw new IllegalArgumentException();
            }
            int total=0;
            for(int i=1;i<=n;i++){
                total+=syraLength(i);
            }

            return total;
        }

        public static void main(String[] args){
            System.out.println(new SyraLengths().lengths(5000000));
        }
       }

EDIT

It is slower than non-enhanced version.

import java.util.HashMap;
import java.util.Map;

public class SyraLengthsEfficient {

    private Map<Long, Long> map = new HashMap<Long, Long>();

    public long syraLength(long n, long count) {

        if (n < 1)
            throw new IllegalArgumentException();

        if (!map.containsKey(n)) {
            if (n == 1) {
                count++;
                map.put(n, count);
            } else if (n % 2 == 0) {
                count++;
                map.put(n, count + syraLength(n / 2, 0));
            } else {
                count++;
                map.put(n, count + syraLength(3 * n + 1, 0));
            }
        }

        return map.get(n);

    }

    public int lengths(int n) {
        if (n < 1) {
            throw new IllegalArgumentException();
        }
        int total = 0;
        for (int i = 1; i <= n; i++) {
            // long temp = syraLength(i, 0);
            // System.out.println(i + " : " + temp);
            total += syraLength(i, 0);

        }
        return total;
    }

    public static void main(String[] args) {
        System.out.println(new SyraLengthsEfficient().lengths(50000000));
    }
}

FINAL SOLUTION (mark as correct by school auto mark system)

public class SyraLengthsEfficient {

private int[] values = new int[10 * 1024 * 1024];

public int syraLength(long n, int count) {

    if (n <= values.length && values[(int) (n - 1)] != 0) {
        return count + values[(int) (n - 1)];
    } else if (n == 1) {
        count++;
        values[(int) (n - 1)] = 1;
        return count;
    } else if (n % 2 == 0) {
        count++;
        if (n <= values.length) {
            values[(int) (n - 1)] = count + syraLength(n / 2, 0);
            return values[(int) (n - 1)];
        } else {
            return count + syraLength(n / 2, 0);
        }
    } else {
        count++;
        if (n <= values.length) {
            values[(int) (n - 1)] = count + syraLength(n * 3 + 1, 0);
            return values[(int) (n - 1)];
        } else {
            return count + syraLength(n * 3 + 1, 0);
        }
    }

}

public int lengths(int n) {
    if (n < 1) {
        throw new IllegalArgumentException();
    }
    int total = 0;
    for (int i = 1; i <= n; i++) {
        total += syraLength(i, 0);
    }
    return total;
}

public static void main(String[] args) {
    SyraLengthsEfficient s = new SyraLengthsEfficient();
    System.out.println(s.lengths(50000000));
}

}

Bill the Lizard
  • 398,270
  • 210
  • 566
  • 880
Timeless
  • 7,338
  • 9
  • 60
  • 94
  • 1
    For questions on improving currently working code, you might consider http://codereview.stackexchange.com – James Montagne Jun 04 '12 at 15:55
  • The first thought before I load the code into my IDE - you might consider using [Trove](http://trove.starlight-systems.com/) `TIntIntHashMap` (HashMap of int primitives) as it would be substantially faster than its JDK counterpart which uses the `Integer` wrapper. – Petr Janeček Jun 04 '12 at 16:05
  • Just to clear my understanding.. the first snippet is your answer and the second snippet was part of the question? And you are saying the second snippet is faster? – Hari Menon Jun 04 '12 at 16:07
  • @Raze2dust the second one is the pratical x in my notes, and the first one is the pratical x+1 in my notes. the notes ask us to improve the performance of the recursion by given hint says to store the result into memory to reduce recursive call. and the problem is the non-improved version seems faster than enhanced version. i thought maybe somewhere is wrong.but i can't find it. – Timeless Jun 04 '12 at 16:10
  • 1
    @Slanec thanks but it is a DSA course, i think the purpose is not to use library to improve the collection performance but to understand recursion. i wondering if there is an fault in my code. – Timeless Jun 04 '12 at 16:12
  • is your task to identify syra(n) or to sum the numbers syra(1) + syra(2) + ... syra(n)? – jeff Jun 04 '12 at 16:28
  • @jeff not to sum the numbers , but to count the numbers – Timeless Jun 04 '12 at 16:44
  • This post should have a homework tag. – Judge Mental Jun 04 '12 at 19:01
  • 1
    @null what's the purpose of the `counter` and `total` variables? you're adding all the syraLengths in them! If what you intended was to find the number of times each function gets called, then a simple increment to the counter when entering the method would be enough, and *do not* add to them the result of calling `syraLength()` – Óscar López Jun 04 '12 at 22:22
  • @JudgeMental thanks for mention homework tag. – Timeless Jun 05 '12 at 03:45

3 Answers3

2

Forget about the answers that say that your code is inefficient because of the use of a Map, that's not the reason why it's going slow - it's the fact that you're limiting the cache of calculated numbers to n < 500. Once you remove that restriction, things start to work pretty fast; here's a proof of concept for you to fill-in the details:

private Map<Long, Long> map = new HashMap<Long, Long>();

public long syraLength(long n) {

    if (!map.containsKey(n)) {
        if (n == 1)
            map.put(n, 1L);
        else if (n % 2 == 0)
            map.put(n, n + syraLength(n/2));
        else
            map.put(n, n + syraLength(3*n+1));
    }

    return map.get(n);

}

If you want to read more about what's happening in the program and why is so fast, take a look at this wikipedia article about Memoization.

Also, I think you're misusing the counter variable, you increment it (++) when a value is calculated the first time, but you accumulate over it (+=) when a value is found in the map. That doesn't seem right to me, and I doubt that it gives the expected result.

Óscar López
  • 232,561
  • 37
  • 312
  • 386
  • Thanks for point out the misuse of recursion, it is hard to understand it.I have improved my version of code, the purpose is not to sum all the values but to count all the values. but it is still slow, i have removed the restriction of size of hashmap. – Timeless Jun 05 '12 at 04:05
  • Thanks oscar, I finally got the answer. I did not use hashmap, but an array to store the value. Your answer inspires me. – Timeless Jun 06 '12 at 03:29
  • @null I'm glad you could solve it! If this answer was of any use to you, please don't forget to accept it – Óscar López Jun 06 '12 at 03:32
-1

don't use map. store temporary result in a field (it's called accumulator) and do the iteration in a loop until n = 1. after each loop your accumulator will grow by n. and in each loop your n will be growing 3 times + 1 or will be decreasing 2 times. hope that helps you solve your homework

piotrek
  • 13,982
  • 13
  • 79
  • 165
-1

Of course it doesn't do as well, you're adding a lot of overhead in the map.put and map.get calls (hashing, bucket creation, etc...). Plus you are autoboxing, which adds a messload of object creation. My guess is that the overhead of the map is far outweighing the benefit.

Try using two arrays instead. one to hold values, and to hold flags that tell you if the value is set or not.

int [] syr = new int[Integer.MAX_VALUE];
boolean [] syrcomputed = new boolean[Integer.MAX_VALUE];

and use those instead of the map:

if (syrcomputed[n]) {
   return syr[n];
}
else {
    syrcomputed[n] = true;
    syr[n] = ....;
}

Also, i would think you might run into some overflow here with larger numbers (as syr approaches MAX_INT/3 you would definately see this if it's not divisible by 2).

As such, you should probably use long types for all your calculations as well.

PS: if your purpose is truly to understand recursion, you shouldn't store the values as a an instance variable, but should be passing it down as an accumulator:

public int syr(int n) {
  return syr(n, new int[Integer.MAX_VALUE], new boolean[Integer.MAX_VALUE]);
}

private int syr(int n, int[] syr, boolean[] syrcomputed) {
   if (syrcomputed[n]) {
     return syr[n];
   }
   else {
     s = [ block for recursive computation ]
     syrcomputed[n] = true;
     syr = s;
   }
}

In some functional languages (scheme, erlang ,etc...) this actually gets unrolled as a tail call (which avoids stack creation). Even though the hotspot jvm doesn't do this (at least to my knowledge), it's still an important concept.

Matt
  • 11,523
  • 2
  • 23
  • 33