2

I have an enum to represent distinct subsets of a dataset, and combinations of those subsets:

from enum import Flag, auto

class DataSubset(Flag):
    TRAIN = auto()
    TEST = auto()
    VALIDATION = auto()
    EXCLUDED = auto()

    TRAIN_TEST = TRAIN | TEST
    ALL_INCLUDED = TRAIN_TEST | VALIDATION
    ALL = ALL_INCLUDED | EXCLUDED

Is there a way to iterate through just the distinct flags, but not the named combinations? i.e.:

[DataSubset.TRAIN, DataSubset.TEST, DataSubset.VALIDATION, DataSubset.EXCLUDED]

The goal is to be able to do something like this:

def get_subsets(subset):
    return [sub for sub in DataSubset.distinct_flags if sub in subset]

and then:

>>> get_subsets(DataSubset.TRAIN)
[DataSubset.TRAIN]
>>> get_subsets(DataSubset.TRAIN_TEST)
[DataSubset.TRAIN, DataSubset.TEST]
>>> get_subsets(DataSubset.ALL)
[DataSubset.TRAIN, DataSubset.TEST, DataSubset.VALIDATION, DataSubset.EXCLUDED]
martineau
  • 119,623
  • 25
  • 170
  • 301
bwk
  • 622
  • 6
  • 18

2 Answers2

2

Kind of a silly solution, but you can use the Bit Twiddling Hacks test for integers being a power of 2 to find only single bit flags. If you have flags that are aliases of existing flags, not combinations of them, this will include them, but it will filter out any flag that doesn't set precisely one bit:

def distinct_flags(enm):
    return [x for x in enm if (x.value & (x.value - 1)) == 0]

which when used gets the following results (slightly prettier since I ran it in IPython):

>>> distinct_flags(DataSubset)
[<DataSubset.TRAIN: 1>,
 <DataSubset.TEST: 2>,
 <DataSubset.VALIDATION: 4>,
 <DataSubset.EXCLUDED: 8>]

You'd just build your get_subsets function around that functionality or merge both bits of functionality (filtering to single flags and to ones included in the subset provided) into the if condition in your existing code.

ShadowRanger
  • 143,180
  • 12
  • 188
  • 271
0

I was able to achieve this by creating a new metaclass, incorporating this answer on getting powers of 2 and this answer on changing iterating behavior

from enum import EnumMeta, Flag, auto

class DistinctFlag(EnumMeta):
    def __iter__(cls):
        for x in super().__iter__():
            if (x.value & (x.value-1))==0 and x.value != 0:
                yield x

                
class DataSubset(Flag, metaclass=DistinctFlag):
    """Enum to describe distinct subsets of a modeling dataset"""
    BURN_IN = auto()
    TRAIN = auto()
    TEST = auto()
    HOLDOUT = auto()
    EXCLUDED = auto()
    
    TRAIN_TEST = TRAIN | TEST
    OBS = BURN_IN | TRAIN_TEST
    ALL_INCLUDED = OBS | HOLDOUT
    ALL = ALL_INCLUDED | EXCLUDED
    

So then:

>>> print(list(DataSubset)
[<DataSubset.BURN_IN: 1>,
 <DataSubset.TRAIN: 2>,
 <DataSubset.TEST: 4>,
 <DataSubset.HOLDOUT: 8>,
 <DataSubset.EXCLUDED: 16>]
bwk
  • 622
  • 6
  • 18
  • Note: You don't need the `and x.value != 0` part of your test for the case where all flags are created with `auto()` or a combination of said flags, as `auto()` for flags starts from `1`; no single flag would have a value of `0` (this isn't just implementation detail, it's a logical requirement; otherwise, every flag would be implicitly unioned with the flag with value `0`). – ShadowRanger Oct 28 '20 at 19:02