5

I am trying to implement my own str.join method in Python, e.g: ''.join(['aa','bbb','cccc']) returns 'aabbbcccc'. I know that string concatenation using the join method would result in linear (in the number of characters of the result) complexity, and I want to know how to do it, as using the '+' operator in a for loop would result in quadratic complexity e.g.:

res=''
for word in ['aa','bbb','cccc']:
  res = res +  word

As strings are immutable, this copies a new string at each iteration resulting in quadratic running time. However, I want to know how to do it in linear time or find how ''.join works exactly.

I could not find anywhere a linear time algorithm nor the implementation of str.join(iterable). Any help is much appreciated.

Ngo Cuong
  • 63
  • 5
  • 4
    Maybe you want to check the built-in str.join code. Potentially here's how - https://stackoverflow.com/questions/8608587/finding-the-source-code-for-built-in-python-functions – FRizal Aug 03 '21 at 06:27
  • You can try ```+=``` –  Aug 03 '21 at 06:27
  • 3
    I imagine the c code implementation pre-calculates the result length and does a single buffer assignment for the result. Why not look at the source code rather than speculate? – DisappointedByUnaccountableMod Aug 03 '21 at 06:28
  • See https://stackoverflow.com/a/37782238/550094 and https://stackoverflow.com/questions/32462194/python-understanding-iterators-and-join-better – Thierry Lathuille Aug 03 '21 at 06:46

1 Answers1

5

Joining str as actual str is a red herring and not what Python itself does: Python operates on mutable bytes, not the str, which also removes the need to know string internals. In specific, str.join converts its arguments to bytes, then pre-allocates and mutates its result.

This directly corresponds to:

  1. a wrapper to encode/decode str arguments to/from bytes
  2. summing the len of elements and separators
  3. allocating a mutable bytesarray to construct the result
  4. copying each element/separator directly into the result
# helper to convert to/from joinable bytes
def str_join(sep: "str", elements: "list[str]") -> "str":
    joined_bytes = bytes_join(
        sep.encode(),
        [elem.encode() for elem in elements],
    )
    return joined_bytes.decode()

# actual joining at bytes level
def bytes_join(sep: "bytes", elements: "list[bytes]") -> "bytes":
    # create a mutable buffer that is long enough to hold the result
    total_length = sum(len(elem) for elem in elements)
    total_length += (len(elements) - 1) * len(sep)
    result = bytearray(total_length)
    # copy all characters from the inputs to the result
    insert_idx = 0
    for elem in elements:
        result[insert_idx:insert_idx+len(elem)] = elem
        insert_idx += len(elem)
        if insert_idx < total_length:
            result[insert_idx:insert_idx+len(sep)] = sep
            insert_idx += len(sep)
    return bytes(result)

print(str_join(" ", ["Hello", "World!"]))

Notably, while the element iteration and element copying basically are two nested loops, they iterate over separate things. The algorithm still touches each character/byte only thrice/once.

MisterMiyagi
  • 44,374
  • 10
  • 104
  • 119
  • Thank you for your time and for your answer! I think I understand it now. However, do you know why when I ran your str_join function and the string concatenation with loop, the one with the loop is much faster? On my computer I tried to time it with Python's time module and the loop version takes 0.49sec, whereas str_join takes 2.37 sec. – Ngo Cuong Aug 03 '21 at 09:44
  • 1
    The ``+``/``+=`` operation for ``str`` is implemented in pure C and thus very performant. Even though *in theory* its complexity is worse, *in practice* it has a much better constant factor. – MisterMiyagi Aug 03 '21 at 10:05