1

I'm trying to solve the three-sum problem on LeetCode and I believe I've come up with some O(n^2) submissions, but I keep on getting a "Time Limit Exceeded" error.

For example, this solution using itertools.combinations:

from itertools import combinations

class Solution:
    def threeSum(self, nums):
        results = [triplet for triplet in combinations(nums, 3) if sum(triplet) == 0]
        return [list(result) for result in set([tuple(sorted(res)) for res in results])]

Results in the following error:

enter image description here

Similarly, this solution,

from itertools import combinations

class Solution:
    def threeSum(self, nums):
        """
        :type nums: List[int]
        :rtype: List[List[int]]
        """
        _map = self.get_map(nums)

        results = set()
        for i, j in combinations(range(len(nums)), 2):
            target = -nums[i] - nums[j]
            if target in _map and _map[target] and _map[target] - set([i, j]):
                results.add(tuple(sorted([target, nums[i], nums[j]])))
        return [list(result) for result in results]

    @staticmethod
    def get_map(nums):
        _map = {}
        for index, num in enumerate(nums):
            if num in _map:
                _map[num].add(index)
            else:
                _map[num] = set([index])
        return _map 

yields a "Time Limit Exceeded" for an input consisting of a long array of zeros:

enter image description here

This question has been asked before (Optimizing solution to Three Sum), but I'm looking for suggestions pertaining to these solutions specifically. Any idea what is making the solutions 'too slow' for LeetCode?

Update

It occurred to me that determined _map[target] - set([i, j]) - that is, whether the current set of indices are not also indices of the target value - could be expensive, so I should first look up whether the corresponding number pair has been seen or not. So I tried this:

from itertools import combinations

class Solution:
    def threeSum(self, nums):
        """
        :type nums: List[int]
        :rtype: List[List[int]]
        """
        _map = self.get_map(nums)

        results = set()
        seen = set()
        for i, j in combinations(range(len(nums)), 2):
            target = -nums[i] - nums[j]
            pair = tuple(sorted([nums[i], nums[j]]))
            if target in _map and pair not in seen and _map[target] - set([i, j]):
                seen.add(pair)
                results.add(tuple(sorted([target, nums[i], nums[j]])))
        return [list(result) for result in results]

    @staticmethod
    def get_map(nums):
        _map = {}
        for index, num in enumerate(nums):
            if num in _map:
                _map[num].add(index)
            else:
                _map[num] = set([index])
        return _map

However, this fails on another test case with large input numbers:

enter image description here

Cœur
  • 37,241
  • 25
  • 195
  • 267
Kurt Peek
  • 52,165
  • 91
  • 301
  • 526
  • 5
    `combinations(nums, 3)` has O(n^3) elements. Consider going through `combinations(nums, 2)` and looking up the result in a hashed data structure. – Alex Hall Jun 17 '18 at 22:27
  • Good point: my first solution is O(n^3). My second solution, however, does use `combinations(nums, 2)` and believe it is O(n^2), but it fails on a different test case, namely, one with repeated zeros. In this case, it seems a further 'optimization' is needed, but what? – Kurt Peek Jun 18 '18 at 05:57
  • 1
    `_map[target] - set([i, j])` creates a new set, which potentially can be of linear size. I'm not sure if it's possible to construct a bad test case for that, though. My suspicion is that your code has too much Pythonic overhead, like `tuple(sorted(...))` — commenting that out gives 2x speed up for me locally. – yeputons Jun 18 '18 at 15:19

1 Answers1

2

This has worked for me, uses a few optimizations for a lot of repeated elements. We store the count of the appearances of each element and then only iterate over each different element. The rest is similar to what you have already done

from collections import Counter
import itertools

class Solution:
    def threeSum(self, nums):
        counts = Counter(nums)
        numSet = list(set(nums))
        result = set()

        for idx, num1 in enumerate(numSet):
            for idx2, num2 in enumerate(itertools.islice(numSet, idx, None), start=idx):
                num3 = (num1 + num2) * -1
                if num3 in counts:
                    d = Counter([num1, num2, num3])
                    if not any(d[n] > counts[n] for n in d):
                        result.add(tuple(sorted([num1, num2, num3])))

        return list(result)   
juvian
  • 15,875
  • 2
  • 37
  • 38
  • Nice! For an input consisting of 3,000 zeros (as run by LeetCode), this reduced the runtime to 0.00s from ~0.8s. I've refactored your solution a bit. – Kurt Peek Oct 27 '18 at 18:25
  • @KurtPeek slice copying over second loop is too slow and causes time limit, changed that. The other refactor is nice – juvian Oct 27 '18 at 19:24