3

Let's say I have an enum

class Color(Enum):
    RED = "RED"
    GREEN = "GREEN"
    BLUE = "BLUE"

I wanted to create a ColorDict class that works as a native python dictionary but only takes the Color enum or its corresponding string value as key.

d = ColorDict() # I want to implement a ColorDict class such that ...

d[Color.RED] = 123
d["RED"] = 456  # I want this to override the previous value
d[Color.RED]    # ==> 456
d["foo"] = 789  # I want this to produce an KeyError exception

What's the "pythonic way" of implementing this ColorDict class? Shall I use inheritance (overriding python's native dict) or composition (keep a dict as a member)?

Ying Xiong
  • 4,578
  • 8
  • 33
  • 69
  • Inheritance or composition is really up to you. If use inheritance, you will have to override all the methods that accept inputs, so `__setiitem__`, `.update`, it may be easy enough. – juanpa.arrivillaga Nov 28 '21 at 18:51
  • It's personal choice, but I generally prefer composition in these cases. It makes the interface a lot easier to understand by being explicit about what operations you want to expose, limiting the amount of work you need to do, especially if you don't care about implementing the entire dict interface. – flakes Nov 28 '21 at 18:54
  • @Mark The code snippets contains the behavior I hope to achieve, not what currently I observe. I updated the comment to be clearer. Sorry for the confusion. – Ying Xiong Nov 28 '21 at 18:54
  • 1
    An alternative is to inherit from [`collections.abc.MutableMapping`](https://docs.python.org/3.9/library/collections.abc.html#collections-abstract-base-classes) which would involve composition, but you would only have to implement a minimal amount of methods – juanpa.arrivillaga Nov 28 '21 at 18:54
  • Thanks @YingXiong, I realized I misread that right before you posted. – Mark Nov 28 '21 at 18:56
  • `KeyError` is a lookup error (something not found), but the assignment is not a lookup. I would consider the `ValueError` for a wrong key value, but to be honest, I'm not sure which one is more appropriate. – VPfB Nov 28 '21 at 19:35

2 Answers2

3

A simple solution would be to slightly modify your Color object and then subclass dict to add a test for the key. I would do something like this:

class Color(Enum):
    RED = "RED"
    GREEN = "GREEN"
    BLUE = "BLUE"

    @classmethod
    def is_color(cls, color):
        if isinstance(color, cls):
            color=color.value
        if not color in cls.__members__:
            return False
        else:
            return True


class ColorDict(dict):
    
    def __setitem__(self, k, v):
        if Color.is_color(k):
            super().__setitem__(Color(k), v)
        else:
            raise KeyError(f"Color {k} is not valid")

    def __getitem__(self, k):
        if isinstance(k, str):
            k = Color(k.upper())
        return super().__getitem__(k)

d = ColorDict()

d[Color.RED] = 123
d["RED"] = 456
d[Color.RED]
d["foo"] = 789

In the Color class, I have added a test function to return True or False if a color is/isn't in the allowed list. The upper() function puts the string in upper case so it can be compared to the pre-defined values.

Then I have subclassed the dict object to override the __setitem__ special method to include a test of the value passed, and an override of __getitem__ to convert any key passed as str into the correct Enum. Depending on the specifics of how you want to use the ColorDict class, you may need to override more functions. There's a good explanation of that here: How to properly subclass dict and override __getitem__ & __setitem__

defladamouse
  • 567
  • 2
  • 13
  • 1
    Minor comment - it's probably a bad idea to name your filtering function (`test_color`) using the `test_` prefix, as a number of test frameworks might inadvertently pick that up as a test case. `is_color` would be an idiomatic filter name. – Nathaniel Ford Nov 28 '21 at 19:29
  • That is a very good point! I'll modify it. – defladamouse Nov 28 '21 at 19:29
  • 1
    Same as @VPfB's comment. We probably need to do something like `super().__setitem__(Color(k), v)`. – Ying Xiong Nov 28 '21 at 19:42
  • Thank you both, I've added the suggested `Color(k)`, which appears to solve the problem. But now `__getitem__` needs overriding as well to match. – defladamouse Nov 28 '21 at 19:45
2

One way is to use the abstract base class collections.abc.MutableMapping, this way, you only need to override the abstract methods and then you can be sure that access always goes through your logic -- you can do this with dict too, but for example, overriding dict.__setitem__ will not affect dict.update, dict.setdefault etc... So you have to override those by hand too. Usually, it is easier to just use the abstract base class:

from collections.abc import MutableMapping
from enum import Enum

class Color(Enum):
    RED = "RED"
    GREEN = "GREEN"
    BLUE = "BLUE"

class ColorDict(MutableMapping):

    def __init__(self): # could handle more ways of initializing  but for simplicity...
        self._data = {}

    def __getitem__(self, item):
        return self._data[color]

    def __setitem__(self, item, value):
        color = self._handle_item(item)
        self._data[color] = value

    def __delitem__(self, item):
        del self._data[color]

    def __iter__(self):
        return iter(self._data)

    def __len__(self):
        return len(self._data)

    def _handle_item(self, item):
        try:
            color = Color(item)
        except ValueError:
            raise KeyError(item) from None
        return color

Note, you can also add:

    def __repr__(self):
        return repr(self._data)

For easier debugging.

An example in the repl:

In [3]: d = ColorDict() # I want to implement a ColorDict class such that ...
   ...:
   ...: d[Color.RED] = 123
   ...: d["RED"] = 456  # I want this to override the previous value
   ...: d[Color.RED]    # ==> 456
Out[3]: 456

In [4]: d["foo"] = 789  # I want this to produce an KeyError exception
   ...:
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-4-9cf80d6dd8b4> in <module>
----> 1 d["foo"] = 789  # I want this to produce an KeyError exception

<ipython-input-2-a0780e16594b> in __setitem__(self, item, value)
     17
     18     def __setitem__(self, item, value):
---> 19         color = self._handle_item(item)
     20         self._data[color] = value
     21

<ipython-input-2-a0780e16594b> in _handle_item(self, item)
     34             color = Color(item)
     35         except ValueError:
---> 36             raise KeyError(item) from None
     37         return color
     38     def __repr__(self): return repr(self._data)

KeyError: 'foo'

In [5]: d
Out[5]: {<Color.RED: 'RED'>: 456}
juanpa.arrivillaga
  • 88,713
  • 10
  • 131
  • 172