2

with the following code, I count the restricted integer partitions(each number can only occure once in each partition) with k numbers in each partition, each number is equal or greater than 1 and not greater than m. This code generate a lot of cache values so that it goes out memory quickly.

Example:

sum := 15, k := 4, m:= 10 expected result is 6

Has following restricted integer partitions:

1,2,3,9,1,2,4,8,1,2,5,7,1,3,4,7,1,3,5,7,2,3,4,6

public class Key{
  private final int sum;
  private final short k1;
  private final short start;
  private final short end;

  public Key(int sum, short k1, short start, short end){
    this.sum = sum;
    this.k1 = k1;
    this.start = start;
    this.end = end;
  }
  // + hashcode and equals
}

public BigInteger calcRestrictedIntegerPartitions(int sum,short k,short m){
  return calcRestrictedIntegerPartitionsHelper(sum,(short)0,k,(short)1,m,new HashMap<>());
}

private BigInteger calcRestrictedIntegerPartitionsHelper(int sum, short k1, short k, short start, short end, Map<Key,BigInteger> cache){
  if(sum < 0){
    return BigInteger.ZERO;
  }
  if(k1 == k){
    if(sum ==0){
      return BigInteger.ONE;
    }
    return BigInteger.ZERO;
  }
  if(end*(k-k1) < sum){
    return BigInteger.ZERO;
  }

  final Key key = new Key(sum,(short)(k-k1),start,end);

  BigInteger fetched = cache.get(key);

  if(fetched == null){
    BigInteger tmp = BigInteger.ZERO;

    for(short i=start; i <= end;i++){
      tmp = tmp.add(calcRestrictedIntegerPartitionsHelper(sum-i,(short)(k1+1),k,(short)(i+1),end,cache));
    }

    cache.put(key, tmp);
    return tmp;
  }

  return fetched;
}

Is there formula to avoid/reduce caching? Or how Can I count restricted integer partions with k and m?

TTho Einthausend
  • 609
  • 4
  • 13

3 Answers3

3

Your key contains 4 parts, so hash space might reach value of product of max values for these parts. It is possible to diminish key to 3 parts using backward loops and zero value as natural limit.

Python example uses in-built functionality lru_cache with hashtable size = N*K*M

@functools.lru_cache(250000)
def diff_partition(N, K, M):
    '''Counts integer partitions of N with K distint parts <= M'''
    if K == 0:
        if N == 0:
            return 1
        return 0
    res = 0
    for i in range(min(N, M), -1, -1):
        res += diff_partition(N - i, K - 1, i - 1)
    return res

def diffparts(Sum, K, M):   #diminish problem size allowing zero part
    return diff_partition(Sum - K, K, M-1)

print(diffparts(500, 25, 200))

>>>147151784574
MBo
  • 77,366
  • 5
  • 53
  • 86
  • @TreffnonX It's interesting. Is it possible to use some parameter dependence? – MBo Oct 12 '20 at 10:49
  • I was actually wrong. My answer is flawed. It happens to work for the example, but ultimately, three key-values are needed. (Sorry) – TreffnonX Oct 12 '20 at 11:26
3

Your problem can be transposed, so you only need 3 keys in your cache and a lot less runtime to boot. Less distinct keys means better caching (A smarter person than me may still find a cheaper solution).

Let's view the partitions as sets. The elements of each set shall be ordered (ascending). You have already done this implicitly, when you stated the expected results for sum := 15, k := 4, m:= 10 as [1, 2, 3, 9]; [1, 2, 4, 8] ....

The restrictions you defined for the partitions are:

  • exactly k elements per set
  • max m as element
  • distinct values
  • non-zero positive integers

The restriction of distinction is actually a bit bothersome, so we will lift it. For that, we need to transform the problem a bit. Because the elements of your set are ascending (and distinct), we know, that the minimum value of each element is an ascending sequence (if we ignore that the sum must be sum), so the minia are: [1, 2, 3, ...]. If m were for example less than k, then the number of possible partitions would always be zero. Likewise, if the sum of [1, 2, 3, ... k] is more than sum, then you also have zero results. We exclude these edge cases at the beginning, to make sure the transformation is legal.

Let us look at a geometric representation of a 'legal partition' and how we want to transform it. We have k columns, m rows and sum squares are filled blue (either light or dark blue).

enter image description here

The red and dark blue squares are irrelevant, as we already know, the dark blue squares must always be filled, and the red ones must always be empty. Therefore we can exclude them from our calculation and assume their respective states as we go along. The resulting box is represented on the right side. Every column was 'shifted down' by it's position, and the red and dark blue areas are cut off. We now have a smaller overall box and a column can now be empty (and we may have the same number of blue boxes among neighboring columns).

Algorithmically the transformation now works like this: For every element in a legal partition, we subtract it's position (starting at 1). So for [1, 2, 4, 8] we get [0, 0, 1, 4]. Furthermore, we have to adapt our bounds (sum and m) accordingly:

// from the sum, we subtract the sum of [1, 2, 3, ... k], which is (k * (k + 1) / 2)
sum_2 = sum - (k * (k + 1) / 2)

// from m we subtract the maximum position (which is k)
m_2 = m - k

Now we have transposed our partitioning problem into another partitioning problem, one that does not have the restriction of distinct elements! Also, this partition can contain element 0, which our original could not. (We keep the internal ascending order).

Now we need to refine the recursion a bit. If we know the elements are ascending, not necessariely distinct and always less-equal to m_2, then we have bound the possible elements to a range. Example:

[0, 1, 3, n1, n2]
=> 3 <= n1 <= m_2
=> 3 <= n2 <= m_2

enter image description here

Because we know that n1 and n2 in the example are 3 or greater, when calling the recursion, we can also instead reduce them both by 3 and reduce sum_2 by 2 * 3 (one is the number of 'open' elements, one is the value of the last 'fixed' element). This way, what we pass in the recursion does not have an upper and a lower bound, but only an upper bound, which is what we had before (m).

Because of this, we can toss 1 value of your cache key: start. Instead we now only have 3: sum, m and k, when solving this reduced problem.

The following implementation works to this effect:

@Test
public void test() {
    calcNumRIPdistinctElementsSpecificKmaxM(600, (short) 25, (short) 200);
}

public BigInteger calcNumRIPdistinctElementsSpecificKmaxM(int sum, short k, short m) {
    // If the biggest allowed number in a partition is less than the number of parts, then
    // they cannot all be distinct, therefore we have zero results.
    if (m < k) {
        return BigInteger.ZERO;
    }
    
    // If the sum of minimum element-values for k is less than the expected sum, then
    // we also have no results.
    final int v = ((k * ((int) k + 1)) / 2);
    if (sum < v) {
        return BigInteger.ZERO;
    }
    
    // We normalize the problem by lifting the distinction restriction.
    final Cache cache = new Cache();
    final int sumNorm = sum - v;
    final short mNorm = (short) (m - k);
    
    BigInteger result = calcNumRIPspecificKmaxM(sumNorm, k, mNorm, cache);

    System.out.println("Calculation (n=" + sum + ", k=" + k + ", m=" + m + ")");
    System.out.println("p = " + result);
    System.out.println("entries = " + cache.getNumEntries());
    System.out.println("c-rate = " + cache.getCacheRate());
    
    return result;
}

public BigInteger calcNumRIPspecificKmaxM(int sum, short k, short m, Cache cache) {
    
    // We can improve cache use by standing the k*m-rectangle upright (k being the 'bottom').
    if (k > m) {
        final short c = k;
        k = m;
        m = c;
    }
    
    // If the result is trivial, we just calculate it. This is true for k < 3
    if (k < 3) {
        if (k == 0) {
            return sum == 0 ? BigInteger.ONE : BigInteger.ZERO;
            
        } else if (k == 1) {
            return sum <= m ? BigInteger.ONE : BigInteger.ZERO;
            
        } else {
            final int upper = Math.min(sum, m);
            final int lower = sum - upper;
            
            if (upper < lower) {
                return BigInteger.ZERO;
            }
            
            final int difference = upper - lower;
            final int numSubParts = difference / 2 + 1;
            return BigInteger.valueOf(numSubParts);
        }
    }
    
    // If k * m / 2 < sum, we can 'invert' the sub problem to reduce the number of keys further.
    sum = Math.min(sum, k * m - sum);
    
    // If the sum is less than m and maybe even k, we can reduce the box. This improves the cache size even further.
    if (sum < m) {
        m = (short) sum;
        
        if (sum < k) {
            k = (short) sum;

            if (k < 3) {
                return calcNumRIPspecificKmaxM(sum, k, m, cache);
            }
        }
    }
    
    // If the result is non-trivial, we check the cache or delegate.
    final Triple<Short, Short, Integer> key = Triple.of(k, m, sum);
    final BigInteger cachedResult = cache.lookUp(key);
    if (cachedResult != null) {
        return cachedResult;
    }
    
    BigInteger current = BigInteger.ZERO;
    
    // i = m is reached in case the result is an ascending stair e.g. [1, 2, 3, 4]
    for (int i = 0; i <= m; ++i) {
        final int currentSum = sum - (i * k);
        if (currentSum < 0) {
            break;
        }
        
        short currentK = (short) (k - 1);
        short currentM = (short) (m - i);
        
        current = current.add(calcNumRIPspecificKmaxM(currentSum, currentK, currentM, cache));
    }
    
    // We cache this new result and return it.
    cache.enter(key, current);
    return current;
}

public static class Cache {
    private final HashMap<Triple<Short, Short, Integer>, BigInteger> map = new HashMap<>(1024);
    private long numLookUps = 0;
    private long numReuse = 0;
    
    public BigInteger lookUp(Triple<Short, Short, Integer> key) {
        ++numLookUps;
        
        BigInteger value = map.get(key);
        if (value != null) {
            ++numReuse;
        }
        
        return value;
    }
    
    public void enter(Triple<Short, Short, Integer> key, BigInteger value) {
        map.put(key, value);
    }
    
    public double getCacheRate() {
        return (double) numReuse / map.size();
    }
    
    public int getNumEntries() {
        return map.size();
    }
    
    public long numLookUps() {
        return numLookUps;
    }
    
    public long getNumReuse() {
        return numReuse;
    }
}

Note: I used apache-common's Triple-class as key here, to spare the implementation of an explicit key-class, but this is not an optimization in runtime, it just saves code.

Edit: Beside a fix to a problem found by @MBo (thank you), I added a few shortcuts to reach the same result. The algorithm now performs even better, and the cache (reuse) rate is better. Maybe this will satisfy your requirements?

The optimizations explained (they are only applicable after the above mentioned transposition of the problem):

  • If k > m, we can 'flip' the rectangle upright, and still get the same result for the number of legal partitions. This will map some 'lying' configurations into 'upright' configurations and reduce the overall amount of different keys.

enter image description here

  • If the number of squares in the rectangle is larger than the number of 'empty spaces', we can consider the 'empty spaces' as squares instead, which will map another bunch of keys together.

enter image description here

  • If sum < k and/or sum < m, we can reduce k and/or m to sum, and still get the same number of partitions. (this is the most impacting optimization, as it often skips multiple redundant interim steps and frequently reaches m = k = sum)

enter image description here

TreffnonX
  • 2,924
  • 15
  • 23
  • 1
    Very nice. Works 3-5 times faster than my version due to smarter problem size reducing. – MBo Oct 12 '20 at 11:51
  • Thanks, though your version is *a lot* more elegant! And it can be dont without first transposing the entire problem... – TreffnonX Oct 12 '20 at 11:52
  • Oops... hmm.. guess I need to make that one out. – TreffnonX Oct 12 '20 at 12:02
  • 1
    Found it. The loop needs to go up to `m` after all. In 'staircase' scenarios, the solution is actually in `i = m` in the loop of the inner function. Should yield the right result now. Thank you for double-checking! – TreffnonX Oct 12 '20 at 12:07
  • What additional optimizations could be made by a smarter person? – TTho Einthausend Oct 12 '20 at 13:15
  • @TThoEinthausend If I knew that, I'd make the optimization myself :D - I suspect, that a mathematician might know a specific function to reduce the recursion, because the general partition problem can be expressed as a specific function. It is only the restrictions that make them hard to handle. If someone found a way to lift the 'm' and 'k'-restrictions as well, the function might become trivial altogether. – TreffnonX Oct 12 '20 at 13:25
  • I found this, in the update of the question there is some formula maybe this can be used to create an even better solution, but i have no idea how? https://math.stackexchange.com/questions/646705/counting-integer-partitions-of-n-into-exactly-k-distinct-parts-size-at-most-m PS: your solution works fine but is still to slow/ memory hungry for medium values. – TTho Einthausend Oct 12 '20 at 18:03
  • The Q/A you have found actually confirm the way I suggest, as they suggest the same concepts (mathematically though, rather than programmatically). They too reduce the distinction restriction and do the same recursion even (they call it 'recurrence relation'). That said, there may be minor optimizations, like my rectangle-inversion, but they will always be incremental, but won't benefit the general complexity. The only way I can imagine to beat this complexity is with an entirely different approach. What exactly is the goal you are chasing? How efficient should the algorithm become? – TreffnonX Oct 13 '20 at 06:05
  • THeoretically, you could limit the cache size by limiting the cache size, but then the algorithm would recalculate some results, which would impact runtime. But that way, your algorithm would definetly terminate with a result, the only question would be: when :) – TreffnonX Oct 13 '20 at 06:11
  • I improved the performance even further by some geometric insights about the problem. I tried to explain them in the code. Alltogether the algorithm runs about 3x faster for the example test values, and the memory consumption is reduced accordingly, while yielding the same result. – TreffnonX Oct 13 '20 at 07:26
1

An alternative would be to use a constraint solver and configure it to show all solutions. Here a solution with MiniZinc:

include "globals.mzn";

int: sum = 15;
int: k = 4;
int: m = 10;

array[1..k] of var 1..m: numbers;

constraint sum(numbers) = sum;

constraint alldifferent(numbers);

constraint increasing(numbers);

solve satisfy;
Magnus Åhlander
  • 1,408
  • 1
  • 7
  • 15